diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 52bd802a1..198a5a225 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -11,6 +11,7 @@ module Builder ( emit, emitHinted, emitOp, emitUnOp, buildPureLam, BuilderT (..), Builder (..), ScopableBuilder (..), + buildScopedAssumeNoDecls, Builder2, BuilderM, ScopableBuilder2, liftBuilderT, buildBlock, withType, absToBlock, app, add, mul, sub, neg, div', iadd, imul, isub, idiv, ilt, ieq, irem, @@ -29,7 +30,7 @@ module Builder ( emitDataDef, emitClassDef, emitInstanceDef, emitDataConName, emitTyConName, emitEffectDef, emitHandlerDef, emitEffectOpDef, buildCase, emitMaybeCase, buildSplitCase, - emitBlock, emitDecls, BuilderEmissions, emitAtomToName, + emitBlock, emitDecls, BuilderEmissions, emitExprToAtom, emitAtomToName, TopBuilder (..), TopBuilderT (..), liftTopBuilderTWith, runTopBuilderT, TopBuilder2, emitBindingDefault, emitSourceMap, emitSynthCandidates, addInstanceSynthCandidate, @@ -113,6 +114,7 @@ emitOp op = Var <$> emit (Op op) emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) emitUnOp op x = emitOp $ UnOp op x +{-# INLINE emitUnOp #-} emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) emitBlock (Block _ decls result) = emitDecls decls result @@ -129,9 +131,24 @@ emitDecls' (Nest (Let b (DeclBinding ann _ expr)) rest) e = do v <- emitDecl (getNameHint b) ann expr' extendSubst (b @> v) $ emitDecls' rest e +emitExprToAtom :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) +emitExprToAtom (Atom atom) = return atom +emitExprToAtom expr = Var <$> emit expr +{-# INLINE emitExprToAtom #-} + emitAtomToName :: (Builder r m, Emits n) => NameHint -> Atom r n -> m n (AtomName r n) emitAtomToName _ (Var v) = return v emitAtomToName hint x = emitHinted hint (Atom x) +{-# INLINE emitAtomToName #-} + +buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m) + => (forall l. (Emits l, DExt n l) => m l (e l)) + -> m n (e n) +buildScopedAssumeNoDecls cont = do + buildScoped cont >>= \case + (Abs Empty e) -> return e + _ -> error "Expected no decl emissions" +{-# INLINE buildScopedAssumeNoDecls #-} -- === "Hoisting" top-level builder class === diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 84e6e982c..55b878716 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -486,11 +486,20 @@ substBinders -> (forall o'. DExt o o' => b o o' -> m i' o' a) -> m i o a substBinders b cont = do - ab <- substM $ Abs b $ idSubstFrag b - refreshAbs ab \b' subst -> - extendSubst subst $ cont b' + substBindersFrag b \subst b' -> extendSubst subst $ cont b' {-# INLINE substBinders #-} +substBindersFrag + :: ( SinkableV v, SubstV v v, EnvExtender2 m, FromName v + , SubstReader v m, SubstB v b, SubstV Name v, SubstB Name b, BindsEnv b) + => b i i' + -> (forall o'. DExt o o' => SubstFrag v i i' o' -> b o o' -> m i o' a) + -> m i o a +substBindersFrag b cont = do + ab <- substM $ Abs b $ idSubstFrag b + refreshAbs ab \b' subst -> cont subst b' +{-# INLINE substBindersFrag #-} + withFreshBinder :: (Color c, EnvExtender m, ToBinding binding c) => NameHint -> binding n @@ -664,6 +673,17 @@ a --@ b = Pi <$> nonDepPiType LinArrow a Pure b (==>) :: ScopeReader m => IxType r n -> Type r n -> m n (Type r n) a ==> b = TabPi <$> nonDepTabPiType a b +-- These `fromNary` functions traverse a chain of unary structures (LamExpr, +-- TabLamExpr, PiType, respectively) up to the given maxDepth, and return the +-- discovered binders packed as the nary structure (NaryLamExpr or NaryPiType), +-- including a count of how many binders there were. +-- - If there are no binders, return Nothing. +-- - If there are more than maxDepth binders, only return maxDepth of them, and +-- leave the others in the unary structure. +-- - The `exact` versions only succeed if at least maxDepth binders were +-- present, in which case exactly maxDepth binders are packed into the nary +-- structure. Excess binders, if any, are still left in the unary structures. + -- first argument is the number of args expected fromNaryLamExact :: Int -> Atom r n -> Maybe (NaryLamExpr r n) fromNaryLamExact exactDepth _ | exactDepth <= 0 = error "expected positive number of args" @@ -706,6 +726,19 @@ fromNaryTabLamExact exactDepth lam = do guard $ realDepth == exactDepth return naryLam +fromNaryForExpr :: Int -> Expr r n -> Maybe (Int, NaryLamExpr r n) +fromNaryForExpr maxDepth | maxDepth <= 0 = error "expected positive number of args" +fromNaryForExpr maxDepth = \case + (Hof (For _ _ (Lam (LamExpr (LamBinder b ty _ Pure) body)))) -> + extend <|> (Just $ (1, NaryLamExpr (NonEmptyNest (b:>ty) Empty) Pure body)) + where + extend = do + expr <- exprBlock body + guard $ maxDepth > 1 + (d, NaryLamExpr (NonEmptyNest b2 bs2) effs2 body2) <- fromNaryForExpr (maxDepth - 1) expr + return $ (d + 1, NaryLamExpr (NonEmptyNest (b:>ty) (Nest b2 bs2)) effs2 body2) + _ -> Nothing + -- first argument is the number of args expected fromNaryPiType :: Int -> Type r n -> Maybe (NaryPiType r n) fromNaryPiType n _ | n <= 0 = error "expected positive number of args" diff --git a/src/lib/Name.hs b/src/lib/Name.hs index eaf1843cc..230119c24 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -786,13 +786,18 @@ instance HoistableB b => HoistableB (NonEmptyNest b) instance SinkableB b => SinkableB (NonEmptyNest b) instance (BindsNames b, SinkableV v, SubstB v b) => SubstB v (NonEmptyNest b) -applySubstPure :: (SubstE v e, SinkableE e, SinkableV v, FromName v, Ext h o, Distinct o) +applySubstFragPure :: (SubstE v e, SinkableE e, SinkableV v, FromName v, Ext h o, Distinct o) => Scope o -> SubstFrag v h i o -> e i -> e o -applySubstPure scope substFrag x = do +applySubstFragPure scope substFrag x = do let fullSubst = sink idSubst <>> substFrag - case tryApplyIdentitySubst fullSubst x of + applySubstPure scope fullSubst x + +applySubstPure :: (SubstE v e, SinkableE e, SinkableV v, FromName v, Distinct o) + => Scope o -> Subst v i o -> e i -> e o +applySubstPure scope subst x = do + case tryApplyIdentitySubst subst x of Just x' -> x' - Nothing -> fmapNames scope (fullSubst !) x + Nothing -> fmapNames scope (subst !) x applySubst :: (ScopeReader m, SubstE v e, SinkableE e, SinkableV v, FromName v) => Ext h o @@ -800,7 +805,7 @@ applySubst :: (ScopeReader m, SubstE v e, SinkableE e, SinkableV v, FromName v) applySubst substFrag x = do Distinct <- getDistinct scope <- unsafeGetScope - return $ applySubstPure scope substFrag x + return $ applySubstFragPure scope substFrag x {-# INLINE applySubst #-} applyAbs :: ( SinkableV v, SinkableE e @@ -3219,6 +3224,12 @@ instance HoistableV v => HoistableE (SubstFrag v i i') where instance SubstV substVal v => SubstE substVal (SubstFrag v i i') where substE env frag = fmapSubstFrag (\_ val -> substE env val) frag +instance SubstV substVal v => SubstE substVal (Subst v i) where + substE env = \case + Subst f frag -> Subst (\n -> substE env (f n)) $ substE env frag + UnsafeMakeIdentitySubst + -> Subst (\n -> substE env (fromName $ unsafeCoerceE n)) emptyInFrag + -- === unsafe coercions === -- Sometimes we need to break the glass. But at least these are slightly safer diff --git a/src/lib/Syntax.hs b/src/lib/Syntax.hs index bcc0c9634..871f0a467 100644 --- a/src/lib/Syntax.hs +++ b/src/lib/Syntax.hs @@ -102,7 +102,7 @@ module Syntax ( IsCUDARequired (..), NaryLamExpr (..), NaryPiType (..), fromNaryLam, fromNaryTabLam, fromNaryTabLamExact, - fromNaryLamExact, fromNaryPiType, + fromNaryLamExact, fromNaryForExpr, fromNaryPiType, NonEmpty (..), nonEmpty, naryLamExprAsAtom, naryPiTypeAsType, WithCNameInterface (..), FunObjCode, FunObjCodeName, IFunBinder (..), diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index d1270ce60..95c64ec19 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -1027,7 +1027,14 @@ pattern AtomicBlock :: Atom r n -> Block r n pattern AtomicBlock atom <- Block _ Empty atom where AtomicBlock atom = Block NoBlockAnn Empty atom +exprBlock :: Block r n -> Maybe (Expr r n) +exprBlock (Block _ (Nest (Let b (DeclBinding _ _ expr)) Empty) (Var n)) + | n == binderName b = Just expr +exprBlock _ = Nothing +{-# INLINE exprBlock #-} + pattern BinaryLamExpr :: LamBinder r n l1 -> LamBinder r l1 l2 -> Block r l2 -> LamExpr r n + pattern BinaryLamExpr b1 b2 body = LamExpr b1 (AtomicBlock (Lam (LamExpr b2 body))) pattern MaybeTy :: Type r n -> Type r n