diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 418b185d8..09cb0bb2f 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -23,7 +23,7 @@ module Builder ( Emits, EmitsEvidence (..), buildPi, buildNonDepPi, buildLam, buildTabLam, buildLamGeneral, buildAbs, buildNaryAbs, buildNaryLam, buildNullaryLam, buildNaryLamExpr, asNaryLam, - buildAlt, buildUnaryAlt, buildUnaryAtomAlt, + buildAlt, buildUnaryAtomAlt, emitDataDef, emitClassDef, emitInstanceDef, emitDataConName, emitTyConName, buildCase, emitMaybeCase, buildSplitCase, emitBlock, emitDecls, BuilderEmissions, emitAtomToName, @@ -796,24 +796,14 @@ buildNaryLamExpr (NaryPiType (NonEmptyNest b bs) eff resultTy) cont = buildAlt :: ScopableBuilder m - => EmptyAbs (Nest Binder) n - -> (forall l. (Distinct l, Emits l, DExt n l) => [AtomName l] -> m l (Atom l)) + => Type n + -> (forall l. (Distinct l, Emits l, DExt n l) => AtomName l -> m l (Atom l)) -> m n (Alt n) -buildAlt bs body = do - buildNaryAbs bs \xs -> do +buildAlt ty body = do + buildAbs noHint ty \x -> buildBlock do Distinct <- getDistinct - xs' <- mapM sinkM xs - body xs' - -buildUnaryAlt - :: ScopableBuilder m - => Type n - -> (forall l. (Emits l, DExt n l) => AtomName l -> m l (Atom l)) - -> m n (Alt n) -buildUnaryAlt ty body = do - bs <- singletonBinderNest noHint ty - buildAlt bs \[v] -> body v + body $ sink x buildUnaryAtomAlt :: ScopableBuilder m @@ -821,8 +811,7 @@ buildUnaryAtomAlt -> (forall l. (Distinct l, DExt n l) => AtomName l -> m l (Atom l)) -> m n (AltP Atom n) buildUnaryAtomAlt ty body = do - bs <- singletonBinderNest noHint ty - buildNaryAbs bs \[v] -> do + buildAbs noHint ty \v -> do Distinct <- getDistinct body v @@ -830,22 +819,22 @@ buildUnaryAtomAlt ty body = do -- out the result type from one of the alts rather than providing it explicitly buildCase :: (Emits n, ScopableBuilder m) => Atom n -> Type n - -> (forall l. (Emits l, DExt n l) => Int -> [Atom l] -> m l (Atom l)) + -> (forall l. (Emits l, DExt n l) => Int -> Atom l -> m l (Atom l)) -> m n (Atom n) buildCase scrut resultTy indexedAltBody = do case trySelectBranch scrut of Just (i, arg) -> do Distinct <- getDistinct - indexedAltBody i [sink arg] + indexedAltBody i $ sink arg Nothing -> do scrutTy <- getType scrut altBinderTys <- caseAltsBinderTys scrutTy (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do - blk <- buildBlock $ indexedAltBody i [Var $ sink x] + blk <- buildBlock $ indexedAltBody i $ Var $ sink x eff <- getEffects blk return $ blk `PairE` eff - return (Abs (Nest b' Empty) body, ignoreHoistFailure $ hoist b' eff') + return (Abs b' body, ignoreHoistFailure $ hoist b' eff') liftM Var $ emit $ Case scrut alts resultTy $ mconcat effs buildSplitCase :: (Emits n, ScopableBuilder m) @@ -855,7 +844,7 @@ buildSplitCase :: (Emits n, ScopableBuilder m) -> m n (Atom n) buildSplitCase tys scrut resultTy match fallback = do split <- emitOp $ VariantSplit tys scrut - buildCase split resultTy \i [v] -> + buildCase split resultTy \i v -> case i of 0 -> match v 1 -> fallback v @@ -1307,7 +1296,7 @@ emitIf :: (Emits n, ScopableBuilder m) -> m n (Atom n) emitIf predicate resultTy trueCase falseCase = do predicate' <- emitOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate - buildCase predicate' resultTy \i [_] -> + buildCase predicate' resultTy \i _ -> case i of 0 -> falseCase 1 -> trueCase @@ -1319,7 +1308,7 @@ emitMaybeCase :: (Emits n, ScopableBuilder m) -> (forall l. (Emits l, DExt n l) => Atom l -> m l (Atom l)) -> m n (Atom n) emitMaybeCase scrut resultTy nothingCase justCase = do - buildCase scrut resultTy \i [v] -> + buildCase scrut resultTy \i v -> case i of 0 -> nothingCase 1 -> justCase v diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index f59e07e03..51aecd81f 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -858,8 +858,7 @@ checkCaseAltsBinderTys ty = case ty of checkAlt :: (HasType body, Typer m) => Type o -> Type o -> AltP body i -> m i o () -checkAlt resultTyReq bTyReq (Abs bs body) = do - Nest b Empty <- return bs +checkAlt resultTyReq bTyReq (Abs b body) = do bTy <- substM $ binderType b checkAlphaEq bTyReq bTy substBinders b \_ -> do diff --git a/src/lib/GenericTraversal.hs b/src/lib/GenericTraversal.hs index 8f9d38d6a..699ed6b15 100644 --- a/src/lib/GenericTraversal.hs +++ b/src/lib/GenericTraversal.hs @@ -185,14 +185,9 @@ traverseAlt :: GenericTraverser s => Alt i -> GenericTraverserM s i o (Alt o) -traverseAlt (Abs Empty body) = Abs Empty <$> tge body -traverseAlt (Abs (Nest (b:>ty) bs) body) = do +traverseAlt (Abs (b:>ty) body) = do ty' <- tge ty - Abs b' (Abs bs' body') <- - buildAbs (getNameHint b) ty' \v -> do - extendRenamer (b@>v) $ - traverseAlt $ Abs bs body - return $ Abs (Nest b' bs') body' + buildAbs (getNameHint b) ty' \v -> extendRenamer (b@>v) $ tge body -- See Note [Confuse GHC] from Simplify.hs confuseGHC :: EnvReader m => m n (DistinctEvidence n) diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index 08e6fca5c..52e38b727 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -368,8 +368,8 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of e' <- substM e case trySelectBranch e' of Just (con, arg) -> do - Abs bs body <- return $ alts !! con - extendSubst (bs @@> [SubstVal arg]) $ translateBlock maybeDest body + Abs b body <- return $ alts !! con + extendSubst (b @> SubstVal arg) $ translateBlock maybeDest body Nothing -> case e' of Con (Newtype (VariantTy _) (Con (SumAsProd _ tag xss))) -> go tag xss Con (Newtype (TypeCon _ _ _) (Con (SumAsProd _ tag xss))) -> go tag xss @@ -380,8 +380,8 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of tag' <- fromScalarAtom tag dest <- allocDest maybeDest =<< substM ty emitSwitch tag' (zip xss alts) $ - \(xs, Abs bs body) -> - void $ extendSubst (bs @@> [SubstVal $ sink xs]) $ + \(xs, Abs b body) -> + void $ extendSubst (b @> SubstVal (sink xs)) $ translateBlock (Just $ sink dest) body destToAtom dest where diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index fcef3ee91..f9d6c6a00 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -1199,14 +1199,14 @@ instance SinkableE IndexedAlt where sinkingProofE = todoSinkableProof buildNthOrderedAlt :: (Emits n, Builder m) - => [IndexedAlt n] -> Type n -> Type n -> Int -> [Atom n] + => [IndexedAlt n] -> Type n -> Type n -> Int -> Atom n -> m n (Atom n) -buildNthOrderedAlt alts scrutTy resultTy i vs = do +buildNthOrderedAlt alts scrutTy resultTy i v = do case lookup (nthCaseAltIdx scrutTy i) [(idx, alt) | IndexedAlt idx alt <- alts] of Nothing -> do resultTy' <- sinkM resultTy emitOp $ ThrowError resultTy' - Just alt -> applyNaryAbs alt (SubstVal <$> vs) >>= emitBlock + Just alt -> applyAbs alt (SubstVal v) >>= emitBlock -- converts from the ordinal index used in the core IR to the more complicated -- `CaseAltIndex` used in the surface IR. @@ -1226,11 +1226,11 @@ buildMonomorphicCase => [IndexedAlt n] -> Atom n -> Type n -> m n (Atom n) buildMonomorphicCase alts scrut resultTy = do scrutTy <- getType scrut - buildCase scrut resultTy \i vs -> do + buildCase scrut resultTy \i v -> do ListE alts' <- sinkM $ ListE alts scrutTy' <- sinkM scrutTy resultTy' <- sinkM resultTy - buildNthOrderedAlt alts' scrutTy' resultTy' i vs + buildNthOrderedAlt alts' scrutTy' resultTy' i v buildSortedCase :: (Fallible1 m, Builder m, Emits n) => Atom n -> [IndexedAlt n] -> Type n @@ -1245,7 +1245,7 @@ buildSortedCase scrut alts resultTy = do -- Single constructor ADTs are not sum types, so elide the case. [_] -> do let [IndexedAlt _ alt] = alts - emitBlock =<< applyNaryAbs alt [SubstVal $ unwrapNewtype scrut] + emitBlock =<< applyAbs alt (SubstVal $ unwrapNewtype scrut) _ -> liftEmitBuilder $ buildMonomorphicCase alts scrut resultTy VariantTy (Ext types tailName) -> do case filter isVariantTailAlt alts of @@ -1273,7 +1273,7 @@ buildSortedCase scrut alts resultTy = do resultTy' <- sinkM resultTy liftEmitBuilder $ buildMonomorphicCase alts' v resultTy') (\v -> do tailAlt' <- sinkM tailAlt - applyNaryAbs tailAlt' [SubstVal v] >>= emitBlock ) + applyAbs tailAlt' (SubstVal v) >>= emitBlock ) _ -> throw TypeErr "Can't specify more than one variant tail pattern." _ -> fail $ "Unexpected case expression type: " <> pprint scrutTy @@ -1716,7 +1716,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat ++ " got " ++ show (nestLength ps) (params, repTy') <- inferParams (Abs paramBs repTy) constrainEq scrutineeTy $ TypeCon sourceName dataDefName params - buildUnaryAltInf repTy' \arg -> do + buildAltInf repTy' \arg -> do args <- forM idxs $ init |> NE.nonEmpty |> \case Nothing -> return arg Just idxs' -> emit $ Atom $ ProjectElt idxs' arg @@ -1729,7 +1729,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat let patTypes = prevTys <> labeledSingleton label ty let extPatTypes = Ext patTypes $ Just rest constrainEq scrutineeTy $ VariantTy extPatTypes - buildUnaryAltInf ty \x -> + buildAltInf ty \x -> bindLamPat p x cont UPatVariantLift labels p -> do prevTys <- mapM (const $ freshType TyKind) labels @@ -1737,7 +1737,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat let extPatTypes = Ext prevTys $ Just rest constrainEq scrutineeTy $ VariantTy extPatTypes let ty = VariantTy $ Ext NoLabeledItems $ Just rest - buildUnaryAltInf ty \x -> + buildAltInf ty \x -> bindLamPat p x cont _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" @@ -2802,31 +2802,16 @@ buildTabPiInf hint ty body = do withAllowedEffects Pure $ body v return $ TabPiType (b:>ty) resultTy -buildUnaryAltInf +buildAltInf :: EmitsInf n => Type n -> (forall l. (EmitsBoth l, Ext n l) => AtomName l -> InfererM i l (Atom l)) -> InfererM i n (Alt n) -buildUnaryAltInf ty body = do - bs <- liftBuilder $ singletonBinderNest noHint ty - buildAltInf bs \[v] -> body v - -buildAltInf - :: EmitsInf n - => EmptyAbs (Nest Binder) n - -> (forall l. (EmitsBoth l, Ext n l) => [AtomName l] -> InfererM i l (Atom l)) - -> InfererM i n (Alt n) -buildAltInf (Abs Empty UnitE) body = - Abs Empty <$> buildBlockInf (body []) -buildAltInf (Abs (Nest (b:>ty) bs) UnitE) body = do - Abs b' (Abs bs' body') <- - buildAbsInf (getNameHint b) ty \v -> do - ab <- sinkM $ Abs b (EmptyAbs bs) - bs' <- applyAbs ab v - buildAltInf bs' \vs -> do - v' <- sinkM v - body $ v' : vs - return $ Abs (Nest b' bs') body' +buildAltInf ty body = do + buildAbsInf noHint ty \v -> + buildBlockInf do + Distinct <- getDistinct + body $ sink v -- === EmitsInf predicate === diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 6725fa911..43e5ece1b 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -156,8 +156,8 @@ evalExpr expr = case expr of case trySelectBranch e' of Nothing -> error "branch should be chosen at this point" Just (con, arg) -> do - Abs bs body <- return $ alts !! con - extendSubst (bs @@> [SubstVal arg]) $ evalBlock body + Abs b body <- return $ alts !! con + extendSubst (b @> SubstVal arg) $ evalBlock body Hof hof -> case hof of RunIO (Lam (LamExpr b body)) -> extendSubst (b @> SubstVal UnitTy) $ diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index 7de4ee659..654cf5644 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -367,10 +367,10 @@ linearizeExpr expr = case expr of resultTangentType <- tangentType resultTy' resultTyWithTangent <- PairTy <$> substM resultTy <*> tangentFunType resultTangentType - (ans, linLam) <- fromPair =<< buildCase e' resultTyWithTangent \i xs -> do - xs' <- mapM emitAtomToName xs - Abs bs body <- return $ alts !! i - extendSubst (bs @@> xs') $ withTangentFunAsLambda $ linearizeBlock body + (ans, linLam) <- fromPair =<< buildCase e' resultTyWithTangent \i x -> do + x' <- emitAtomToName x + Abs b body <- return $ alts !! i + extendSubst (b @> x') $ withTangentFunAsLambda $ linearizeBlock body return $ WithTangent ans do applyLinToTangents $ sink linLam diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 5bc46bc25..5392d405b 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -177,8 +177,7 @@ prettyPrecCase name e alts effs = atPrec LowestPrec $ effectLine row = hardline <> "case annotated with effects" <+> p row prettyAlt :: PrettyE e => AltP e n -> Doc ann -prettyAlt (Abs bs body) = hsep (map prettyBinderNoAnn bs') <+> "->" <> nest 2 (p body) - where bs' = fromNest bs +prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (p body) prettyBinderNoAnn :: Binder n l -> Doc ann prettyBinderNoAnn (b:>_) = p b diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index b1e6e572e..271cdd176 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -198,15 +198,15 @@ simplifyExpr expr = confuseGHC >>= \_ -> case expr of resultTy' <- substM resultTy case trySelectBranch e' of Just (i, arg) -> do - Abs bs body <- return $ alts !! i - extendSubst (bs @@> [SubstVal arg]) $ simplifyBlock body + Abs b body <- return $ alts !! i + extendSubst (b @> SubstVal arg) $ simplifyBlock body Nothing -> do isData resultTy' >>= \case True -> do - alts' <- forM alts \(Abs bs body) -> do - bs' <- substM $ EmptyAbs bs - buildNaryAbs bs' \xs -> - extendSubst (bs @@> map Rename xs) $ + alts' <- forM alts \(Abs b body) -> do + bTy' <- substM $ binderType b + buildAbs (getNameHint b) bTy' \x -> + extendSubst (b @> Rename x) $ buildBlock $ simplifyBlock body liftM Var $ emit $ Case e' alts' resultTy' eff' False -> defuncCase e' alts resultTy' @@ -244,9 +244,9 @@ defuncCase scrut alts resultTy = do return $ ignoreHoistFailure $ hoist bs' ty injectAltResult :: EnvReader m => [Type n] -> Int -> Alt n -> m n (Alt n) - injectAltResult sumTys con (Abs bs body) = liftBuilder do - buildAlt (EmptyAbs bs) \vs -> do - originalResult <- emitBlock =<< applySubst (bs@@>vs) body + injectAltResult sumTys con (Abs b body) = liftBuilder do + buildAlt (binderType b) \v -> do + originalResult <- emitBlock =<< applySubst (b@>v) body (dataResult, nonDataResult) <- fromPair originalResult return $ PairVal dataResult $ Con $ SumCon (sinkList sumTys) con nonDataResult @@ -297,9 +297,9 @@ simplifyApp f xs = -- TODO: Don't rebuild the alts here! Factor out Case simplification -- with lazy substitution and call it from here! resultTy <- getAppType ty $ toList xs - alts' <- forM alts \(Abs bs a) -> do - buildAlt (EmptyAbs bs) \vs -> do - a' <- applySubst (bs@@>vs) a + alts' <- forM alts \(Abs b a) -> do + buildAlt (binderType b) \v -> do + a' <- applySubst (b@>v) a naryApp a' (map sink $ toList xs) caseExpr <- caseComputingEffs e alts' resultTy dropSubst $ simplifyExpr caseExpr @@ -457,9 +457,9 @@ simplifyTabApp f xs = -- TODO: Don't rebuild the alts here! Factor out Case simplification -- with lazy substitution and call it from here! resultTy <- getTabAppType ty $ toList xs - alts' <- forM alts \(Abs bs a) -> do - buildAlt (EmptyAbs bs) \vs -> do - a' <- applySubst (bs@@>vs) a + alts' <- forM alts \(Abs b a) -> do + buildAlt (binderType b) \v -> do + a' <- applySubst (b@>v) a naryTabApp a' (map sink $ toList xs) caseExpr <- caseComputingEffs e alts' resultTy dropSubst $ simplifyExpr $ caseExpr @@ -510,14 +510,14 @@ simplifyAtom atom = confuseGHC >>= \_ -> case atom of e' <- simplifyAtom e case trySelectBranch e' of Just (i, arg) -> do - Abs bs body <- return $ alts !! i - extendSubst (bs @@> [SubstVal arg]) $ simplifyAtom body + Abs b body <- return $ alts !! i + extendSubst (b @> SubstVal arg) $ simplifyAtom body Nothing -> do rTy' <- substM rTy - alts' <- forM alts \(Abs bs body) -> do - bs' <- substM $ EmptyAbs bs - buildNaryAbs bs' \xs -> - extendSubst (bs @@> map Rename xs) $ + alts' <- forM alts \(Abs b body) -> do + bTy' <- substM $ binderType b + buildAbs (getNameHint b) bTy' \xs -> + extendSubst (b @> Rename xs) $ simplifyAtom body return $ ACase e' alts' rTy' BoxedRef _ -> error "Should only occur in Imp lowering" @@ -698,7 +698,7 @@ simplifyOp op = case op of let fullLabels = toList $ reflectLabels fullTys let labels = toList $ reflectLabels rightTys -- Emit a case statement (ordered by the arg type) that lifts the type. - buildCase right (VariantTy fullRow) \caseIdx [v] -> do + buildCase right (VariantTy fullRow) \caseIdx v -> do -- TODO: This is slow! Optimize this! We keep searching through lists all the time! let (label, i) = labels !! caseIdx let idx = fromJust $ elemIndex (label, i + length (lookupLabel leftTys label)) fullLabels @@ -715,7 +715,7 @@ simplifyOp op = case op of let fullLabels = toList $ reflectLabels fullTys let leftLabels = toList $ reflectLabels leftTys let rightLabels = toList $ reflectLabels rightTys - buildCase full resTy \caseIdx [v] -> do + buildCase full resTy \caseIdx v -> do let (label, i) = fullLabels !! caseIdx let labelIx labs li = fromJust $ elemIndex li labs let resTys' = sinkList resTys @@ -858,9 +858,9 @@ exceptToMaybeExpr expr = case expr of Case e alts resultTy _ -> do e' <- substM e resultTy' <- substM $ MaybeTy resultTy - buildCase e' resultTy' \i vs -> do - Abs bs body <- return $ alts !! i - extendSubst (bs @@> map SubstVal vs) $ exceptToMaybeBlock body + buildCase e' resultTy' \i v -> do + Abs b body <- return $ alts !! i + extendSubst (b @> SubstVal v) $ exceptToMaybeBlock body Atom x -> do x' <- substM x ty <- getType x' diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 55344f3a9..f8b08bc77 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -185,10 +185,10 @@ transposeExpr expr ct = case expr of True -> notImplemented False -> do e' <- substNonlin e - void $ buildCase e' UnitTy \i vs -> do - vs' <- mapM emitAtomToName vs - Abs bs body <- return $ alts !! i - extendSubst (bs @@> map RenameNonlin vs') do + void $ buildCase e' UnitTy \i v -> do + v' <- emitAtomToName v + Abs b body <- return $ alts !! i + extendSubst (b @> RenameNonlin v') do transposeBlock body (sink ct) return UnitVal diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 03964c93e..f7fa585fc 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -107,9 +107,8 @@ type BaseMonoid n = BaseMonoidP (Atom n) type AtomBinderP = BinderP AtomNameC type Binder = AtomBinderP Type - -- TODO: Alts don't need binder nests anymore --- we have Atom projections! -type AltP (e::E) = Abs (Nest Binder) e :: E -type Alt = AltP Block :: E +type AltP (e::E) = Abs Binder e :: E +type Alt = AltP Block :: E -- The additional invariant enforced by this newtype is that the list should -- never contain empty StaticFields members, nor StaticFields in two consecutive