Skip to content

Commit

Permalink
Only keep a single Binder in Alts
Browse files Browse the repository at this point in the history
We no longer allow n-ary bindings in case expressions, so we don't need
to use a nest.
  • Loading branch information
apaszke committed Sep 7, 2022
1 parent 3320ab9 commit 5627a37
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 110 deletions.
39 changes: 14 additions & 25 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ module Builder (
Emits, EmitsEvidence (..), buildPi, buildNonDepPi,
buildLam, buildTabLam, buildLamGeneral,
buildAbs, buildNaryAbs, buildNaryLam, buildNullaryLam, buildNaryLamExpr, asNaryLam,
buildAlt, buildUnaryAlt, buildUnaryAtomAlt,
buildAlt, buildUnaryAtomAlt,
emitDataDef, emitClassDef, emitInstanceDef, emitDataConName, emitTyConName,
buildCase, emitMaybeCase, buildSplitCase,
emitBlock, emitDecls, BuilderEmissions, emitAtomToName,
Expand Down Expand Up @@ -796,56 +796,45 @@ buildNaryLamExpr (NaryPiType (NonEmptyNest b bs) eff resultTy) cont =

buildAlt
:: ScopableBuilder m
=> EmptyAbs (Nest Binder) n
-> (forall l. (Distinct l, Emits l, DExt n l) => [AtomName l] -> m l (Atom l))
=> Type n
-> (forall l. (Distinct l, Emits l, DExt n l) => AtomName l -> m l (Atom l))
-> m n (Alt n)
buildAlt bs body = do
buildNaryAbs bs \xs -> do
buildAlt ty body = do
buildAbs noHint ty \x ->
buildBlock do
Distinct <- getDistinct
xs' <- mapM sinkM xs
body xs'

buildUnaryAlt
:: ScopableBuilder m
=> Type n
-> (forall l. (Emits l, DExt n l) => AtomName l -> m l (Atom l))
-> m n (Alt n)
buildUnaryAlt ty body = do
bs <- singletonBinderNest noHint ty
buildAlt bs \[v] -> body v
body $ sink x

buildUnaryAtomAlt
:: ScopableBuilder m
=> Type n
-> (forall l. (Distinct l, DExt n l) => AtomName l -> m l (Atom l))
-> m n (AltP Atom n)
buildUnaryAtomAlt ty body = do
bs <- singletonBinderNest noHint ty
buildNaryAbs bs \[v] -> do
buildAbs noHint ty \v -> do
Distinct <- getDistinct
body v

-- TODO: consider a version with nonempty list of alternatives where we figure
-- out the result type from one of the alts rather than providing it explicitly
buildCase :: (Emits n, ScopableBuilder m)
=> Atom n -> Type n
-> (forall l. (Emits l, DExt n l) => Int -> [Atom l] -> m l (Atom l))
-> (forall l. (Emits l, DExt n l) => Int -> Atom l -> m l (Atom l))
-> m n (Atom n)
buildCase scrut resultTy indexedAltBody = do
case trySelectBranch scrut of
Just (i, arg) -> do
Distinct <- getDistinct
indexedAltBody i [sink arg]
indexedAltBody i $ sink arg
Nothing -> do
scrutTy <- getType scrut
altBinderTys <- caseAltsBinderTys scrutTy
(alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do
(Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do
blk <- buildBlock $ indexedAltBody i [Var $ sink x]
blk <- buildBlock $ indexedAltBody i $ Var $ sink x
eff <- getEffects blk
return $ blk `PairE` eff
return (Abs (Nest b' Empty) body, ignoreHoistFailure $ hoist b' eff')
return (Abs b' body, ignoreHoistFailure $ hoist b' eff')
liftM Var $ emit $ Case scrut alts resultTy $ mconcat effs

buildSplitCase :: (Emits n, ScopableBuilder m)
Expand All @@ -855,7 +844,7 @@ buildSplitCase :: (Emits n, ScopableBuilder m)
-> m n (Atom n)
buildSplitCase tys scrut resultTy match fallback = do
split <- emitOp $ VariantSplit tys scrut
buildCase split resultTy \i [v] ->
buildCase split resultTy \i v ->
case i of
0 -> match v
1 -> fallback v
Expand Down Expand Up @@ -1307,7 +1296,7 @@ emitIf :: (Emits n, ScopableBuilder m)
-> m n (Atom n)
emitIf predicate resultTy trueCase falseCase = do
predicate' <- emitOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate
buildCase predicate' resultTy \i [_] ->
buildCase predicate' resultTy \i _ ->
case i of
0 -> falseCase
1 -> trueCase
Expand All @@ -1319,7 +1308,7 @@ emitMaybeCase :: (Emits n, ScopableBuilder m)
-> (forall l. (Emits l, DExt n l) => Atom l -> m l (Atom l))
-> m n (Atom n)
emitMaybeCase scrut resultTy nothingCase justCase = do
buildCase scrut resultTy \i [v] ->
buildCase scrut resultTy \i v ->
case i of
0 -> nothingCase
1 -> justCase v
Expand Down
3 changes: 1 addition & 2 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -858,8 +858,7 @@ checkCaseAltsBinderTys ty = case ty of

checkAlt :: (HasType body, Typer m)
=> Type o -> Type o -> AltP body i -> m i o ()
checkAlt resultTyReq bTyReq (Abs bs body) = do
Nest b Empty <- return bs
checkAlt resultTyReq bTyReq (Abs b body) = do
bTy <- substM $ binderType b
checkAlphaEq bTyReq bTy
substBinders b \_ -> do
Expand Down
9 changes: 2 additions & 7 deletions src/lib/GenericTraversal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,9 @@ traverseAlt
:: GenericTraverser s
=> Alt i
-> GenericTraverserM s i o (Alt o)
traverseAlt (Abs Empty body) = Abs Empty <$> tge body
traverseAlt (Abs (Nest (b:>ty) bs) body) = do
traverseAlt (Abs (b:>ty) body) = do
ty' <- tge ty
Abs b' (Abs bs' body') <-
buildAbs (getNameHint b) ty' \v -> do
extendRenamer (b@>v) $
traverseAlt $ Abs bs body
return $ Abs (Nest b' bs') body'
buildAbs (getNameHint b) ty' \v -> extendRenamer (b@>v) $ tge body

-- See Note [Confuse GHC] from Simplify.hs
confuseGHC :: EnvReader m => m n (DistinctEvidence n)
Expand Down
8 changes: 4 additions & 4 deletions src/lib/Imp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,8 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
e' <- substM e
case trySelectBranch e' of
Just (con, arg) -> do
Abs bs body <- return $ alts !! con
extendSubst (bs @@> [SubstVal arg]) $ translateBlock maybeDest body
Abs b body <- return $ alts !! con
extendSubst (b @> SubstVal arg) $ translateBlock maybeDest body
Nothing -> case e' of
Con (Newtype (VariantTy _) (Con (SumAsProd _ tag xss))) -> go tag xss
Con (Newtype (TypeCon _ _ _) (Con (SumAsProd _ tag xss))) -> go tag xss
Expand All @@ -380,8 +380,8 @@ translateExpr maybeDest expr = confuseGHC >>= \_ -> case expr of
tag' <- fromScalarAtom tag
dest <- allocDest maybeDest =<< substM ty
emitSwitch tag' (zip xss alts) $
\(xs, Abs bs body) ->
void $ extendSubst (bs @@> [SubstVal $ sink xs]) $
\(xs, Abs b body) ->
void $ extendSubst (b @> SubstVal (sink xs)) $
translateBlock (Just $ sink dest) body
destToAtom dest
where
Expand Down
47 changes: 16 additions & 31 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1199,14 +1199,14 @@ instance SinkableE IndexedAlt where
sinkingProofE = todoSinkableProof

buildNthOrderedAlt :: (Emits n, Builder m)
=> [IndexedAlt n] -> Type n -> Type n -> Int -> [Atom n]
=> [IndexedAlt n] -> Type n -> Type n -> Int -> Atom n
-> m n (Atom n)
buildNthOrderedAlt alts scrutTy resultTy i vs = do
buildNthOrderedAlt alts scrutTy resultTy i v = do
case lookup (nthCaseAltIdx scrutTy i) [(idx, alt) | IndexedAlt idx alt <- alts] of
Nothing -> do
resultTy' <- sinkM resultTy
emitOp $ ThrowError resultTy'
Just alt -> applyNaryAbs alt (SubstVal <$> vs) >>= emitBlock
Just alt -> applyAbs alt (SubstVal v) >>= emitBlock

-- converts from the ordinal index used in the core IR to the more complicated
-- `CaseAltIndex` used in the surface IR.
Expand All @@ -1226,11 +1226,11 @@ buildMonomorphicCase
=> [IndexedAlt n] -> Atom n -> Type n -> m n (Atom n)
buildMonomorphicCase alts scrut resultTy = do
scrutTy <- getType scrut
buildCase scrut resultTy \i vs -> do
buildCase scrut resultTy \i v -> do
ListE alts' <- sinkM $ ListE alts
scrutTy' <- sinkM scrutTy
resultTy' <- sinkM resultTy
buildNthOrderedAlt alts' scrutTy' resultTy' i vs
buildNthOrderedAlt alts' scrutTy' resultTy' i v

buildSortedCase :: (Fallible1 m, Builder m, Emits n)
=> Atom n -> [IndexedAlt n] -> Type n
Expand All @@ -1245,7 +1245,7 @@ buildSortedCase scrut alts resultTy = do
-- Single constructor ADTs are not sum types, so elide the case.
[_] -> do
let [IndexedAlt _ alt] = alts
emitBlock =<< applyNaryAbs alt [SubstVal $ unwrapNewtype scrut]
emitBlock =<< applyAbs alt (SubstVal $ unwrapNewtype scrut)
_ -> liftEmitBuilder $ buildMonomorphicCase alts scrut resultTy
VariantTy (Ext types tailName) -> do
case filter isVariantTailAlt alts of
Expand Down Expand Up @@ -1273,7 +1273,7 @@ buildSortedCase scrut alts resultTy = do
resultTy' <- sinkM resultTy
liftEmitBuilder $ buildMonomorphicCase alts' v resultTy')
(\v -> do tailAlt' <- sinkM tailAlt
applyNaryAbs tailAlt' [SubstVal v] >>= emitBlock )
applyAbs tailAlt' (SubstVal v) >>= emitBlock )
_ -> throw TypeErr "Can't specify more than one variant tail pattern."
_ -> fail $ "Unexpected case expression type: " <> pprint scrutTy

Expand Down Expand Up @@ -1716,7 +1716,7 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat
++ " got " ++ show (nestLength ps)
(params, repTy') <- inferParams (Abs paramBs repTy)
constrainEq scrutineeTy $ TypeCon sourceName dataDefName params
buildUnaryAltInf repTy' \arg -> do
buildAltInf repTy' \arg -> do
args <- forM idxs $ init |> NE.nonEmpty |> \case
Nothing -> return arg
Just idxs' -> emit $ Atom $ ProjectElt idxs' arg
Expand All @@ -1729,15 +1729,15 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat
let patTypes = prevTys <> labeledSingleton label ty
let extPatTypes = Ext patTypes $ Just rest
constrainEq scrutineeTy $ VariantTy extPatTypes
buildUnaryAltInf ty \x ->
buildAltInf ty \x ->
bindLamPat p x cont
UPatVariantLift labels p -> do
prevTys <- mapM (const $ freshType TyKind) labels
rest <- freshInferenceName LabeledRowKind
let extPatTypes = Ext prevTys $ Just rest
constrainEq scrutineeTy $ VariantTy extPatTypes
let ty = VariantTy $ Ext NoLabeledItems $ Just rest
buildUnaryAltInf ty \x ->
buildAltInf ty \x ->
bindLamPat p x cont
_ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern"

Expand Down Expand Up @@ -2802,31 +2802,16 @@ buildTabPiInf hint ty body = do
withAllowedEffects Pure $ body v
return $ TabPiType (b:>ty) resultTy

buildUnaryAltInf
buildAltInf
:: EmitsInf n
=> Type n
-> (forall l. (EmitsBoth l, Ext n l) => AtomName l -> InfererM i l (Atom l))
-> InfererM i n (Alt n)
buildUnaryAltInf ty body = do
bs <- liftBuilder $ singletonBinderNest noHint ty
buildAltInf bs \[v] -> body v

buildAltInf
:: EmitsInf n
=> EmptyAbs (Nest Binder) n
-> (forall l. (EmitsBoth l, Ext n l) => [AtomName l] -> InfererM i l (Atom l))
-> InfererM i n (Alt n)
buildAltInf (Abs Empty UnitE) body =
Abs Empty <$> buildBlockInf (body [])
buildAltInf (Abs (Nest (b:>ty) bs) UnitE) body = do
Abs b' (Abs bs' body') <-
buildAbsInf (getNameHint b) ty \v -> do
ab <- sinkM $ Abs b (EmptyAbs bs)
bs' <- applyAbs ab v
buildAltInf bs' \vs -> do
v' <- sinkM v
body $ v' : vs
return $ Abs (Nest b' bs') body'
buildAltInf ty body = do
buildAbsInf noHint ty \v ->
buildBlockInf do
Distinct <- getDistinct
body $ sink v

-- === EmitsInf predicate ===

Expand Down
4 changes: 2 additions & 2 deletions src/lib/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ evalExpr expr = case expr of
case trySelectBranch e' of
Nothing -> error "branch should be chosen at this point"
Just (con, arg) -> do
Abs bs body <- return $ alts !! con
extendSubst (bs @@> [SubstVal arg]) $ evalBlock body
Abs b body <- return $ alts !! con
extendSubst (b @> SubstVal arg) $ evalBlock body
Hof hof -> case hof of
RunIO (Lam (LamExpr b body)) ->
extendSubst (b @> SubstVal UnitTy) $
Expand Down
8 changes: 4 additions & 4 deletions src/lib/Linearize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ linearizeExpr expr = case expr of
resultTangentType <- tangentType resultTy'
resultTyWithTangent <- PairTy <$> substM resultTy
<*> tangentFunType resultTangentType
(ans, linLam) <- fromPair =<< buildCase e' resultTyWithTangent \i xs -> do
xs' <- mapM emitAtomToName xs
Abs bs body <- return $ alts !! i
extendSubst (bs @@> xs') $ withTangentFunAsLambda $ linearizeBlock body
(ans, linLam) <- fromPair =<< buildCase e' resultTyWithTangent \i x -> do
x' <- emitAtomToName x
Abs b body <- return $ alts !! i
extendSubst (b @> x') $ withTangentFunAsLambda $ linearizeBlock body
return $ WithTangent ans do
applyLinToTangents $ sink linLam

Expand Down
3 changes: 1 addition & 2 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ prettyPrecCase name e alts effs = atPrec LowestPrec $
effectLine row = hardline <> "case annotated with effects" <+> p row

prettyAlt :: PrettyE e => AltP e n -> Doc ann
prettyAlt (Abs bs body) = hsep (map prettyBinderNoAnn bs') <+> "->" <> nest 2 (p body)
where bs' = fromNest bs
prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (p body)

prettyBinderNoAnn :: Binder n l -> Doc ann
prettyBinderNoAnn (b:>_) = p b
Expand Down
Loading

0 comments on commit 5627a37

Please sign in to comment.