Skip to content

Commit

Permalink
Remove the binder nest from DataConDef
Browse files Browse the repository at this point in the history
It's mostly redundant with the other fields, so there's little reason to
keep it around.
  • Loading branch information
apaszke committed Sep 7, 2022
1 parent 5627a37 commit 34595d1
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 50 deletions.
5 changes: 3 additions & 2 deletions src/lib/Builder.hs
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,8 @@ fromNewtypeWrapper ty = do
TypeCon _ defName params <- return ty
def <- lookupDataDef defName
[con] <- instantiateDataDef def params
DataConDef _ (EmptyAbs (Nest (_:>wrappedTy) Empty)) _ _ <- return con
-- Single field constructors are represented by their field
DataConDef _ wrappedTy [_] <- return con
return wrappedTy

tangentBaseMonoidFor :: Builder m => Type n -> m n (BaseMonoid n)
Expand Down Expand Up @@ -1048,7 +1049,7 @@ emitInstanceDef instanceDef@(InstanceDef className _ _ _) = do
emitDataConName :: (Mut n, TopBuilder m) => DataDefName n -> Int -> Atom n -> m n (Name DataConNameC n)
emitDataConName dataDefName conIdx conAtom = do
DataDef _ _ dataCons <- lookupDataDef dataDefName
let (DataConDef name _ _ _) = dataCons !! conIdx
let (DataConDef name _ _) = dataCons !! conIdx
emitBinding (getNameHint name) $ DataConBinding dataDefName conIdx conAtom

zipNest :: (forall ii ii'. a -> b ii ii' -> b' ii ii')
Expand Down
15 changes: 5 additions & 10 deletions src/lib/CheckType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ instance HasType Atom where
-- Newtypes
TypeCon _ defName params | i == 0 -> do
def <- lookupDataDef defName
[DataConDef _ _ repTy _] <- checkedInstantiateDataDef def params
[DataConDef _ repTy _] <- checkedInstantiateDataDef def params
return repTy
TC (Fin _) | i == 0 -> return IdxRepTy
StaticRecordTy types | i == 0 -> return $ ProdTy $ toList types
Expand Down Expand Up @@ -690,7 +690,8 @@ typeCheckPrimOp op = case op of
case t' of
TypeCon _ dataDefName (DataDefParams [] []) -> do
DataDef _ _ dataConDefs <- lookupDataDef dataDefName
forM_ dataConDefs \(DataConDef _ (Abs binders _) _ _) -> checkEmptyNest binders
forM_ dataConDefs \(DataConDef _ _ idxs) ->
unless (null idxs) $ throw TypeErr "Not empty"
VariantTy _ -> return () -- TODO: check empty payload
SumTy cases -> forM_ cases \cty -> checkAlphaEq cty UnitTy
_ -> error $ "Not a sum type: " ++ pprint t'
Expand Down Expand Up @@ -828,12 +829,6 @@ checkRWSAction rws f = do
resultTy' <- liftHoistExcept $ hoist (PairB regionBinder refBinder) resultTy
return (resultTy', referentTy')

-- Having this as a separate helper function helps with "'b0' is untouchable" errors
-- from GADT+monad type inference.
checkEmptyNest :: Fallible m => Nest b n l -> m ()
checkEmptyNest Empty = return ()
checkEmptyNest _ = throw TypeErr "Not empty"

checkCase :: Typer m => HasType body
=> Atom i -> [AltP body i] -> Type i -> EffectRow i -> m i o (Type o)
checkCase scrut alts resultTy effs = do
Expand All @@ -850,7 +845,7 @@ checkCaseAltsBinderTys ty = case ty of
TypeCon _ defName params -> do
def <- lookupDataDef defName
cons <- checkedInstantiateDataDef def params
return [repTy | DataConDef _ _ repTy _ <- cons]
return [repTy | DataConDef _ repTy _ <- cons]
SumTy types -> return types
VariantTy (NoExt types) -> return $ toList types
VariantTy _ -> fail "Can't pattern-match partially-known variants"
Expand Down Expand Up @@ -1261,7 +1256,7 @@ checkDataLike ty = case ty of
params' <- substM params
def <- lookupDataDef =<< substM defName
dataCons <- instantiateDataDef def params'
dropSubst $ forM_ dataCons \(DataConDef _ _ repTy _) -> checkDataLike repTy
dropSubst $ forM_ dataCons \(DataConDef _ repTy _) -> checkDataLike repTy
TC con -> case con of
BaseType _ -> return ()
ProdType as -> mapM_ recur as
Expand Down
4 changes: 2 additions & 2 deletions src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,9 @@ fromNonDepTabType ty = do
return (ixTy, resultTy')

nonDepDataConTys :: DataConDef n -> Maybe [Type n]
nonDepDataConTys (DataConDef _ (Abs binders UnitE) repTy _) =
nonDepDataConTys (DataConDef _ repTy idxs) =
case repTy of
ProdTy tys | nestLength binders == length tys -> Just tys
ProdTy tys | length idxs == length tys -> Just tys
_ -> Nothing

infixr 1 ?-->
Expand Down
46 changes: 27 additions & 19 deletions src/lib/Inference.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,11 @@ data UDeclInferenceResult e n =
inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, SubstE Name e)
=> UDecl n l -> e l -> m n (UDeclInferenceResult e n)
inferTopUDecl (UDataDefDecl def tc dcs) result = do
def' <- liftInfererM $ solveLocal $ inferDataDef def
PairE def' (Abs ddbs' (ListE cbs')) <- liftInfererM $ solveLocal $ inferDataDef def
defName <- emitDataDef def'
tc' <- emitTyConName defName =<< tyConDefAsAtom defName
dcs' <- forM (iota (nestLength dcs)) \i ->
emitDataConName defName i =<< dataConDefAsAtom defName i
dcs' <- forM (enumerate cbs') \(i, cbs'') ->
emitDataConName defName i =<< dataConDefAsAtom (Abs ddbs' cbs'') defName i
let subst = tc @> tc' <.> dcs @@> dcs'
UDeclResultDone <$> applySubst subst result
inferTopUDecl (UInterface paramBs superclasses methodTys className methodNames) result = do
Expand Down Expand Up @@ -1376,21 +1376,24 @@ tyConDefAsAtom defName = liftBuilder do
buildTyConLam defName PlainArrow \sourceName params ->
return $ TypeCon sourceName (sink defName) params

dataConDefAsAtom :: EnvReader m => DataDefName n -> Int -> m n (Atom n)
dataConDefAsAtom defName conIx = liftBuilder do
dataConDefAsAtom :: EnvReader m
=> Abs DataDefBinders (EmptyAbs (Nest Binder)) n
-> DataDefName n -> Int -> m n (Atom n)
dataConDefAsAtom bsAbs defName conIx = liftBuilder do
buildTyConLam defName ImplicitArrow \_ params -> do
defName' <- sinkM defName
def@(DataDef sourceName _ _ ) <- lookupDataDef defName'
def@(DataDef sourceName _ _) <- lookupDataDef defName'
conDefs <- instantiateDataDef def params
DataConDef _ (EmptyAbs conArgBinders) conRep _ <- return $ conDefs !! conIx
let DataConDef _ conRep _ = conDefs !! conIx
Abs conArgBinders UnitE <- applyDataConAbs (sink bsAbs) params
buildPureNaryLam (EmptyAbs $ binderNestAsPiNest PlainArrow conArgBinders) \conArgs -> do
conProd <- buildDataCon (sink conRep) $ Var <$> conArgs
return $ Con $ Newtype (sink $ TypeCon sourceName defName' params) $
case conDefs of
[] -> error "unreachable"
[_] -> conProd
_ -> SumVal conTys conIx conProd
where conTys = sinkList $ conDefs <&> \(DataConDef _ _ rty _) -> rty
where conTys = sinkList $ conDefs <&> \(DataConDef _ rty _) -> rty

buildDataCon :: EnvReader m => Type n -> [Atom n] -> m n (Atom n)
buildDataCon repTy topArgs = wrap repTy topArgs
Expand All @@ -1417,20 +1420,25 @@ binderNestAsPiNest arr = \case
Empty -> Empty
Nest (b:>ty) rest -> Nest (PiBinder b ty arr) $ binderNestAsPiNest arr rest

inferDataDef :: EmitsInf o => UDataDef i -> InfererM i o (DataDef o)
inferDataDef :: EmitsInf o => UDataDef i
-> InfererM i o (PairE DataDef (Abs DataDefBinders (ListE (EmptyAbs (Nest Binder)))) o)
inferDataDef (UDataDef (tyConName, paramBs) clsBs dataCons) = do
Abs paramBs' (Abs clsBs' (ListE dataCons')) <-
Abs paramBs' (Abs clsBs' (ListE dataConsWithBs')) <-
withNestedUBinders paramBs id \_ -> do
withNestedUBinders clsBs (LamBound . LamBinding ClassArrow) \_ ->
ListE <$> mapM inferDataCon dataCons
return $ DataDef tyConName (DataDefBinders paramBs' clsBs') dataCons'
let (dataCons', bs') = unzip $ fromPairE <$> dataConsWithBs'
let ddbs = (DataDefBinders paramBs' clsBs')
return $ PairE (DataDef tyConName ddbs dataCons')
(Abs ddbs $ ListE bs')

inferDataCon :: EmitsInf o
=> (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o)
=> (SourceName, UDataDefTrail i)
-> InfererM i o (PairE DataConDef (EmptyAbs (Nest Binder)) o)
inferDataCon (sourceName, UDataDefTrail argBs) = do
argBs' <- checkUBinders (EmptyAbs argBs)
let (repTy, projIdxs) = dataConRepTy argBs'
return (DataConDef sourceName argBs' repTy projIdxs)
return $ PairE (DataConDef sourceName repTy projIdxs) argBs'

dataConRepTy :: EmptyAbs (Nest Binder) n -> (Type n, [[Int]])
dataConRepTy (Abs topBs UnitE) = case topBs of
Expand Down Expand Up @@ -1710,9 +1718,9 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat
UPatCon ~(InternalName _ conName) ps -> do
(dataDefName, con) <- substM conName >>= getDataCon
DataDef sourceName paramBs cons <- lookupDataDef dataDefName
DataConDef _ (EmptyAbs argBs) repTy idxs <- return $ cons !! con
when (nestLength argBs /= nestLength ps) $ throw TypeErr $
"Unexpected number of pattern binders. Expected " ++ show (nestLength argBs)
DataConDef _ repTy idxs <- return $ cons !! con
when (length idxs /= nestLength ps) $ throw TypeErr $
"Unexpected number of pattern binders. Expected " ++ show (length idxs)
++ " got " ++ show (nestLength ps)
(params, repTy') <- inferParams (Abs paramBs repTy)
constrainEq scrutineeTy $ TypeCon sourceName dataDefName params
Expand Down Expand Up @@ -1785,9 +1793,9 @@ bindLamPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of
(dataDefName, _) <- getDataCon =<< substM conName
DataDef sourceName paramBs cons <- lookupDataDef dataDefName
case cons of
[DataConDef _ (EmptyAbs argBs) _ _] -> do
when (nestLength argBs /= nestLength ps) $ throw TypeErr $
"Unexpected number of pattern binders. Expected " ++ show (nestLength argBs)
[DataConDef _ _ idxs] -> do
when (length idxs /= nestLength ps) $ throw TypeErr $
"Unexpected number of pattern binders. Expected " ++ show (length idxs)
++ " got " ++ show (nestLength ps)
(params, UnitE) <- inferParams (Abs paramBs UnitE)
constrainVarTy v $ TypeCon sourceName dataDefName params
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Interpreter.hs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ matchUPat (WithSrcB _ pat) x = do
(UPatCon _ ps, Con (Newtype (TypeCon _ dataDefName _) _)) -> do
DataDef _ _ cons <- lookupDataDef dataDefName
case cons of
[DataConDef _ _ _ idxs] -> matchUPats ps [getProjection ix x' | ix <- idxs]
[DataConDef _ _ idxs] -> matchUPats ps [getProjection ix x' | ix <- idxs]
_ -> error "Expected a single ADt constructor"
(UPatPair (PairB p1 p2), PairVal x1 x2) -> do
matchUPat p1 x1 >>= (`followedByFrag` matchUPat p2 x2)
Expand Down
4 changes: 2 additions & 2 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ instance Pretty (DataDef n) where
"data" <+> p name <+> p bs <> prettyLines cons

instance Pretty (DataConDef n) where
pretty (DataConDef name bs _ _) =
p name <+> ":" <+> p bs
pretty (DataConDef name repTy _) =
p name <+> ":" <+> p repTy

instance Pretty (ClassDef n) where
pretty (ClassDef classSourceName methodNames params superclasses methodTys) =
Expand Down
23 changes: 15 additions & 8 deletions src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ module QueryType (
caseAltsBinderTys, depPairLeftTy, extendEffect,
getAppType, getTabAppType, getBaseMonoidType, getReferentTy,
getMethodIndex,
instantiateDataDef, dataDefRep,
instantiateDataDef, applyDataConAbs, dataDefRep,
instantiateNaryPi, instantiateDepPairTy, instantiatePi, instantiateTabPi,
litType, lamExprTy,
numNaryPiArgs, naryLamExprType, specializedFunType,
Expand Down Expand Up @@ -80,7 +80,7 @@ caseAltsBinderTys ty = case ty of
TypeCon _ defName params -> do
def <- lookupDataDef defName
cons <- instantiateDataDef def params
return [repTy | DataConDef _ _ repTy _ <- cons]
return [repTy | DataConDef _ repTy _ <- cons]
SumTy types -> return types
VariantTy (NoExt types) -> return $ toList types
VariantTy _ -> fail "Can't pattern-match partially-known variants"
Expand Down Expand Up @@ -125,17 +125,24 @@ getMethodIndex className methodSourceName = do
{-# INLINE getMethodIndex #-}

instantiateDataDef :: ScopeReader m => DataDef n -> DataDefParams n -> m n [DataConDef n]
instantiateDataDef (DataDef _ (DataDefBinders bs1 bs2) cons) (DataDefParams xs1 xs2) = do
fromListE <$> applySubst (bs1 @@> (SubstVal <$> xs1) <.> bs2 @@> (SubstVal <$> xs2)) (ListE cons)
instantiateDataDef (DataDef _ bs cons) params = do
fromListE <$> applyDataConAbs (Abs bs $ ListE cons) params
{-# INLINE instantiateDataDef #-}

applyDataConAbs :: (SubstE AtomSubstVal e, SinkableE e, ScopeReader m)
=> Abs DataDefBinders e n -> DataDefParams n -> m n (e n)
applyDataConAbs (Abs (DataDefBinders bs1 bs2) e) (DataDefParams xs1 xs2) = do
let paramsSubst = bs1 @@> (SubstVal <$> xs1) <.> bs2 @@> (SubstVal <$> xs2)
applySubst paramsSubst e
{-# INLINE applyDataConAbs #-}

-- Returns a representation type (type of an TypeCon-typed Newtype payload)
-- given a list of instantiated DataConDefs.
dataDefRep :: [DataConDef n] -> Type n
dataDefRep = \case
[] -> error "unreachable" -- There's no representation for a void type
[DataConDef _ _ ty _] -> ty
tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty
[DataConDef _ ty _] -> ty
tys -> SumTy $ tys <&> \(DataConDef _ ty _) -> ty

instantiateNaryPi :: EnvReader m => NaryPiType n -> [Atom n] -> m n (NaryPiType n)
instantiateNaryPi (NaryPiType bs eff resultTy) args = do
Expand Down Expand Up @@ -211,7 +218,7 @@ projectionIndices ty = case ty of
TypeCon _ defName _ -> do
DataDef _ _ cons <- lookupDataDef defName
case cons of
[DataConDef _ _ _ idxs] -> return idxs
[DataConDef _ _ idxs] -> return idxs
_ -> unsupported
StaticRecordTy types -> return $ iota (length types) <&> (:[0])
ProdTy tys -> return $ iota (length tys) <&> (:[])
Expand Down Expand Up @@ -356,7 +363,7 @@ instance HasType Atom where
StaticRecordTy types | i == 0 -> return $ ProdTy $ toList types
TypeCon _ defName params | i == 0 -> do
def <- lookupDataDef defName
[DataConDef _ _ repTy _] <- instantiateDataDef def params
[DataConDef _ repTy _] <- instantiateDataDef def params
return repTy
RecordTy _ -> throw CompilerErr "Can't project partially-known records"
Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty
Expand Down
2 changes: 1 addition & 1 deletion src/lib/Serialize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ prettyVal val = case val of
prettyData :: (MonadIO1 m, EnvReader m, Fallible1 m) => DataDefName n -> Int -> Atom n -> m n (Doc ann)
prettyData dataDefName t rep = do
DataDef _ _ dataCons <- lookupDataDef dataDefName
DataConDef conName _ _ idxs <- return $ dataCons !! t
DataConDef conName _ idxs <- return $ dataCons !! t
prettyArgs <- forM idxs \ix -> prettyVal $ getProjection (init ix) rep
return $ case prettyArgs of
[] -> pretty conName
Expand Down
9 changes: 4 additions & 5 deletions src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,10 @@ data DataDef n where
-- binder name is in UExpr and Env
DataDef :: SourceName -> DataDefBinders n l -> [DataConDef l] -> DataDef n

-- TODO: The binder nest should be unnecessary. Try to get rid of it.
data DataConDef n =
-- Name for pretty printing, constructor elements, representation type,
-- list of projection indices that recovers elements from the representation.
DataConDef SourceName (EmptyAbs (Nest Binder) n) (Type n) [[Int]]
DataConDef SourceName (Type n) [[Int]]
deriving (Show, Generic)

data DataDefBinders n l where
Expand Down Expand Up @@ -877,10 +876,10 @@ instance AlphaEqE DataDef
instance AlphaHashableE DataDef

instance GenericE DataConDef where
type RepE DataConDef = (LiftE (SourceName, [[Int]])) `PairE` (Abs (Nest Binder) UnitE) `PairE` Type
fromE (DataConDef name ab repTy idxs) = (LiftE (name, idxs)) `PairE` ab `PairE` repTy
type RepE DataConDef = (LiftE (SourceName, [[Int]])) `PairE` Type
fromE (DataConDef name repTy idxs) = (LiftE (name, idxs)) `PairE` repTy
{-# INLINE fromE #-}
toE ((LiftE (name, idxs)) `PairE` ab `PairE` repTy) = DataConDef name ab repTy idxs
toE ((LiftE (name, idxs)) `PairE` repTy) = DataConDef name repTy idxs
{-# INLINE toE #-}
instance SinkableE DataConDef
instance HoistableE DataConDef
Expand Down

0 comments on commit 34595d1

Please sign in to comment.