Skip to content

Commit

Permalink
Automatically abstract interfaces required by annotated lambda binders
Browse files Browse the repository at this point in the history
With this change, we automatically elaborate a type such as
```
(n : Type) ?-> (a : Type) ?-> Table n a -> a
```
into
```
(n : Type) ?-> (a : Type) ?-> Ix n ?=> Table n a -> a
```
(assuming `def Table (n : Type) (_ : Ix n) ?=> (a : Type) : Type = n=>a`)

In the near future this should let us start enforcing the index set
constraints in a backwards-compatible and non-verbose way, since any
type appearing to the left of `=>` will get constrained automatically.

This will be convenient for associated types too, because e.g. mentioning
`TangentSpace a` in a type will automatically add a constraint `Manifold a`
to the type signature.
  • Loading branch information
apaszke committed Oct 27, 2021
1 parent 4402e08 commit bc6c82e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ buildPi b fArr fTy = do
case typeReduceBlock scope block of
Right piTy -> return piTy
Left _ -> throw CompilerErr $
"Unexpected irreducible decls in pi type: " ++ pprint decls
"Unexpected irreducible decls in pi type: " ++ pprint block

buildAbsAux :: (MonadBuilder m, HasVars a) => Binder -> (Atom -> m (a, b)) -> m (Abs Binder (Nest Decl, a), b)
buildAbsAux b f = do
Expand Down
125 changes: 85 additions & 40 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -191,19 +191,35 @@ checkOrInferRho (WithSrc pos expr) reqTy = do
UPi (UPatAnn (WithSrc pos' pat) ann) arr ty -> do
-- TODO: make sure there's no effect if it's an implicit or table arrow
-- TODO: check leaks
ann' <- checkAnn ann
piTy <- addSrcContext' pos' case pat of
UPatBinder UIgnore -> buildPi (Ignore ann') (const $ mapM checkUEffRow arr)
(const $ checkUType ty)
_ -> withNameHint ("pat" :: Name) $ buildPi b
(\(Var v) -> withBindPat (WithSrc pos' pat) v $ mapM checkUEffRow arr)
(\(Var v) -> withBindPat (WithSrc pos' pat) v $ checkUType ty)
where b = case pat of
-- Note: The binder name becomes part of the type, so we
-- need to keep the same name used in the pattern.
UPatBinder (UBind v) -> Bind (v:>ann')
_ -> Ignore ann'
piTy <- checkAnnAutoInterface ann \wanteds getAnnType -> do
let maybeIntroWanteds = case arr of
TabArrow -> id
_ -> introWanteds wanteds
maybeIntroWanteds $ do
ann' <- getAnnType -- Only attempt to reduce the type once we have wanteds in scope
addSrcContext' pos' $ case pat of
UPatBinder UIgnore -> buildPi (Ignore ann') (const $ mapM checkUEffRow arr)
(const $ checkTailTy ty)
_ -> withNameHint ("pat" :: Name) $ buildPi b
(\(Var v) -> withBindPat (WithSrc pos' pat) v $ mapM checkUEffRow arr)
(\(Var v) -> withBindPat (WithSrc pos' pat) v $ checkTailTy ty)
where b = case pat of
-- Note: The binder name becomes part of the type, so we
-- need to keep the same name used in the pattern.
UPatBinder (UBind v) -> Bind (v:>ann')
_ -> Ignore ann'
matchRequirement piTy
where
introWanteds :: [Type] -> UInferM Type -> UInferM Type
introWanteds [] m = m
introWanteds (h:t) m = buildPi (Ignore h) (const $ return ClassArrow) (const $ introWanteds t m)

checkTailTy :: UType -> UInferM Type
checkTailTy tty = case tty of
-- Continued arguments in a curried function will do interface
-- quantification themselves.
WithSrc _ (UPi _ tarr _) | tarr /= TabArrow -> checkUType tty
_ -> checkUTypeAutoInterface tty introWanteds
UDecl decl body -> do
env <- inferUDecl decl
extInferSubst env $ checkOrInferRho body reqTy
Expand Down Expand Up @@ -455,32 +471,47 @@ checkULam (UPatAnn p ann) body piTy = do
checkSigma body Suggest $ snd $ applyAbs piTy x

checkInstance :: Nest UPatAnnArrow -> Name -> [UType] -> [UMethodDef] -> UInferM Atom
checkInstance (Nest (UPatAnnArrow (UPatAnn p ann) arrow) rest) className params methods = do
case arrow of
ImplicitArrow -> return ()
ClassArrow -> return ()
_ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow
argTy <- checkAnn ann
buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \(Var v) ->
checkLeaks [v] $ withBindPat p v $ checkInstance rest className params methods
checkInstance Empty className params methods = do
substEnv <- getInferSubst
className' <- case envLookup substEnv className of
Nothing -> return className
Just (Rename className') -> return className'
Just (SubstVal _) -> throw TypeErr $ "Not a valid class: " ++ pprint className
params' <- mapM checkUType params
ClassDef def methodNames <- getClassDef className'
[ClassDictCon superclassTys methodTys] <- return $ applyDataDefParams (snd def) params'
let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys
methodsChecked <- mapM (checkMethodDef className' methodTys) methods
let (idxs, methods') = unzip $ sortOn fst $ methodsChecked
forM_ (repeated idxs) \i ->
throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i)
forM_ ([0..(length methodTys - 1)] `listDiff` idxs) \i ->
throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i)
return $ DataCon def params' 0 [PairVal (ProdVal superclassHoles)
(ProdVal methods')]
checkInstance arrows className params methods = case arrows of
Nest (UPatAnnArrow (UPatAnn p ann) arrow) rest -> do
case arrow of
ImplicitArrow -> return ()
ClassArrow -> return ()
_ -> throw TypeErr $ "Not a valid arrow for an instance: " ++ pprint arrow
checkAnnAutoInterface ann \wanteds getAnn -> do
introWanteds wanteds do
argTy <- getAnn
buildLam (Bind $ patNameHint p :> argTy) (fromUArrow arrow) \(Var v) ->
checkLeaks [v] $ withBindPat p v $ checkInstance rest className params methods
Empty -> do
substEnv <- getInferSubst
className' <- case envLookup substEnv className of
Nothing -> return className
Just (Rename className') -> return className'
Just (SubstVal _) -> throw TypeErr $ "Not a valid class: " ++ pprint className
(paramWanteds, paramGetters) <- checkParams params
introWanteds paramWanteds $ do
params' <- sequence paramGetters
ClassDef def methodNames <- getClassDef className'
[ClassDictCon superclassTys methodTys] <- return $ applyDataDefParams (snd def) params'
let superclassHoles = fmap (Con . ClassDictHole Nothing) superclassTys
methodsChecked <- mapM (checkMethodDef className' methodTys) methods
let (idxs, methods') = unzip $ sortOn fst $ methodsChecked
forM_ (repeated idxs) \i ->
throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i)
forM_ ([0..(length methodTys - 1)] `listDiff` idxs) \i ->
throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i)
return $ DataCon def params' 0 [PairVal (ProdVal superclassHoles)
(ProdVal methods')]
where
checkParams :: [UExpr] -> UInferM ([Type], [UInferM Type])
checkParams [] = return ([], [])
checkParams (e:t) = checkUTypeAutoInterface e \wanteds e' -> do
(tailWanteds, tailGets) <- checkParams t
return (wanteds ++ tailWanteds, e' : tailGets)

introWanteds :: [Type] -> UInferM Type -> UInferM Type
introWanteds [] m = m
introWanteds (h:t) m = buildLam (Ignore h) ClassArrow (const $ introWanteds t m)

checkMethodDef :: ClassDefName -> [Type] -> UMethodDef -> UInferM (Int, Atom)
checkMethodDef className methodTys (UMethodDef ~(UInternalVar v) rhs) = do
Expand Down Expand Up @@ -659,12 +690,25 @@ checkAnn ann = case ann of
Just ty -> checkUType ty
Nothing -> freshType TyKind

checkAnnAutoInterface :: Maybe UType -> ([Type] -> UInferM Type -> UInferM a) -> UInferM a
checkAnnAutoInterface ann cont = case ann of
Just ty -> checkUTypeAutoInterface ty cont
Nothing -> cont [] (freshType TyKind)

checkUType :: UType -> UInferM Type
checkUType ty@(WithSrc ctx _) =
(buildScoped $ withEffects Pure $ checkRho ty TyKind) >>=
typeReduceBlockWithWanteds >>=
typeReductionAsAtom ctx "Failed to reduce type annotation"

checkUTypeAutoInterface :: UType -> ([Type] -> UInferM Type -> UInferM a) -> UInferM a
checkUTypeAutoInterface ty@(WithSrc ctx _) cont =
(buildScoped $ withEffects Pure $ checkRho ty TyKind) >>=
typeReduceBlockWithWanteds >>= \case
(Left block, wanteds) -> cont (snd <$> wanteds) $
typeReductionAsAtom ctx "Failed to reduce type annotation" (Left block, wanteds)
(Right ans, wanteds) -> cont (snd <$> wanteds) (return ans)

-- Delayed unification task. When performing its solve, it tries to perform
-- dictionary synthesis within the inner scope and normalize eafResult such
-- that it contains no free vars from the eafInner scope (while it can refer
Expand Down Expand Up @@ -1275,9 +1319,10 @@ typeReduceBlockWithWanteds block@(Block decls _) = do
-- We're lifting expressions from the inside of the block, so we need to make sure
-- that we're not leaking any internal vars. Note that thanks to laziness the error
-- will not get raised unless we actually consume the wanteds downstream.
let safeWanteds = case null $ foldMap (freeVars . snd) wanteds `envDiff` scope of
let blockDefs = foldMap boundVars decls
let safeWanteds = case null $ foldMap (freeVars . snd) wanteds `envIntersect` blockDefs of
True -> wanteds
False -> error "Not implemented yet!"
False -> error $ "Not implemented yet!"
return $ (,safeWanteds) $ case typeReduceBlock scope block of
Left (_, ans) -> Left $ Block decls $ Atom ans
Right ans -> Right ans
Expand Down
29 changes: 29 additions & 0 deletions tests/typeclass-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,32 @@ def f8 [Ord' (n=>Int)] (x : n=>Int) : Int = eq x
-------------------- Multi-parameter interfaces --------------------

-- TODO!

-------------------- Automatic quantification --------------------

interface X a
x_ a : Int

def MyPairOfXs (a : Type) (_ : X a) ?=> : Type = (a & a)

instance X Int
x_ = 1

-- No automatic quantification needed
def q0 (x : MyPairOfXs Int) : Int = fst x
def q1 [X a] (x : MyPairOfXs a) : a = fst x

-- Should work with implicit quantification
def q2 (x : MyPairOfXs a) : a = fst x

-- Should work with implicit quantification
def q3 (x : MyPairOfXs a) : Int = x_ a

-- We should also implicitly quantify over constraints of the return type
def f4 (x : a) : MyPairOfXs a = (x, x)

-- Check automatic quantification for interfaces
interface AutoQuant a
dummy : Int
instance AutoQuant (MyPairOfXs a)
dummy = 1

0 comments on commit bc6c82e

Please sign in to comment.