From 34595d138f3948c307eac1d7bd56e5422b438a08 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 6 Sep 2022 11:37:57 +0000 Subject: [PATCH] Remove the binder nest from DataConDef It's mostly redundant with the other fields, so there's little reason to keep it around. --- src/lib/Builder.hs | 5 +++-- src/lib/CheckType.hs | 15 +++++--------- src/lib/Core.hs | 4 ++-- src/lib/Inference.hs | 46 +++++++++++++++++++++++++----------------- src/lib/Interpreter.hs | 2 +- src/lib/PPrint.hs | 4 ++-- src/lib/QueryType.hs | 23 +++++++++++++-------- src/lib/Serialize.hs | 2 +- src/lib/Types/Core.hs | 9 ++++----- 9 files changed, 60 insertions(+), 50 deletions(-) diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index 09cb0bb2f..5a793816a 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -986,7 +986,8 @@ fromNewtypeWrapper ty = do TypeCon _ defName params <- return ty def <- lookupDataDef defName [con] <- instantiateDataDef def params - DataConDef _ (EmptyAbs (Nest (_:>wrappedTy) Empty)) _ _ <- return con + -- Single field constructors are represented by their field + DataConDef _ wrappedTy [_] <- return con return wrappedTy tangentBaseMonoidFor :: Builder m => Type n -> m n (BaseMonoid n) @@ -1048,7 +1049,7 @@ emitInstanceDef instanceDef@(InstanceDef className _ _ _) = do emitDataConName :: (Mut n, TopBuilder m) => DataDefName n -> Int -> Atom n -> m n (Name DataConNameC n) emitDataConName dataDefName conIdx conAtom = do DataDef _ _ dataCons <- lookupDataDef dataDefName - let (DataConDef name _ _ _) = dataCons !! conIdx + let (DataConDef name _ _) = dataCons !! conIdx emitBinding (getNameHint name) $ DataConBinding dataDefName conIdx conAtom zipNest :: (forall ii ii'. a -> b ii ii' -> b' ii ii') diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index 51aecd81f..524e9989f 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -285,7 +285,7 @@ instance HasType Atom where -- Newtypes TypeCon _ defName params | i == 0 -> do def <- lookupDataDef defName - [DataConDef _ _ repTy _] <- checkedInstantiateDataDef def params + [DataConDef _ repTy _] <- checkedInstantiateDataDef def params return repTy TC (Fin _) | i == 0 -> return IdxRepTy StaticRecordTy types | i == 0 -> return $ ProdTy $ toList types @@ -690,7 +690,8 @@ typeCheckPrimOp op = case op of case t' of TypeCon _ dataDefName (DataDefParams [] []) -> do DataDef _ _ dataConDefs <- lookupDataDef dataDefName - forM_ dataConDefs \(DataConDef _ (Abs binders _) _ _) -> checkEmptyNest binders + forM_ dataConDefs \(DataConDef _ _ idxs) -> + unless (null idxs) $ throw TypeErr "Not empty" VariantTy _ -> return () -- TODO: check empty payload SumTy cases -> forM_ cases \cty -> checkAlphaEq cty UnitTy _ -> error $ "Not a sum type: " ++ pprint t' @@ -828,12 +829,6 @@ checkRWSAction rws f = do resultTy' <- liftHoistExcept $ hoist (PairB regionBinder refBinder) resultTy return (resultTy', referentTy') --- Having this as a separate helper function helps with "'b0' is untouchable" errors --- from GADT+monad type inference. -checkEmptyNest :: Fallible m => Nest b n l -> m () -checkEmptyNest Empty = return () -checkEmptyNest _ = throw TypeErr "Not empty" - checkCase :: Typer m => HasType body => Atom i -> [AltP body i] -> Type i -> EffectRow i -> m i o (Type o) checkCase scrut alts resultTy effs = do @@ -850,7 +845,7 @@ checkCaseAltsBinderTys ty = case ty of TypeCon _ defName params -> do def <- lookupDataDef defName cons <- checkedInstantiateDataDef def params - return [repTy | DataConDef _ _ repTy _ <- cons] + return [repTy | DataConDef _ repTy _ <- cons] SumTy types -> return types VariantTy (NoExt types) -> return $ toList types VariantTy _ -> fail "Can't pattern-match partially-known variants" @@ -1261,7 +1256,7 @@ checkDataLike ty = case ty of params' <- substM params def <- lookupDataDef =<< substM defName dataCons <- instantiateDataDef def params' - dropSubst $ forM_ dataCons \(DataConDef _ _ repTy _) -> checkDataLike repTy + dropSubst $ forM_ dataCons \(DataConDef _ repTy _) -> checkDataLike repTy TC con -> case con of BaseType _ -> return () ProdType as -> mapM_ recur as diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 998c5f508..72a8331b9 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -602,9 +602,9 @@ fromNonDepTabType ty = do return (ixTy, resultTy') nonDepDataConTys :: DataConDef n -> Maybe [Type n] -nonDepDataConTys (DataConDef _ (Abs binders UnitE) repTy _) = +nonDepDataConTys (DataConDef _ repTy idxs) = case repTy of - ProdTy tys | nestLength binders == length tys -> Just tys + ProdTy tys | length idxs == length tys -> Just tys _ -> Nothing infixr 1 ?--> diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index f9d6c6a00..64d8ce84c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -68,11 +68,11 @@ data UDeclInferenceResult e n = inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, SubstE Name e) => UDecl n l -> e l -> m n (UDeclInferenceResult e n) inferTopUDecl (UDataDefDecl def tc dcs) result = do - def' <- liftInfererM $ solveLocal $ inferDataDef def + PairE def' (Abs ddbs' (ListE cbs')) <- liftInfererM $ solveLocal $ inferDataDef def defName <- emitDataDef def' tc' <- emitTyConName defName =<< tyConDefAsAtom defName - dcs' <- forM (iota (nestLength dcs)) \i -> - emitDataConName defName i =<< dataConDefAsAtom defName i + dcs' <- forM (enumerate cbs') \(i, cbs'') -> + emitDataConName defName i =<< dataConDefAsAtom (Abs ddbs' cbs'') defName i let subst = tc @> tc' <.> dcs @@> dcs' UDeclResultDone <$> applySubst subst result inferTopUDecl (UInterface paramBs superclasses methodTys className methodNames) result = do @@ -1376,13 +1376,16 @@ tyConDefAsAtom defName = liftBuilder do buildTyConLam defName PlainArrow \sourceName params -> return $ TypeCon sourceName (sink defName) params -dataConDefAsAtom :: EnvReader m => DataDefName n -> Int -> m n (Atom n) -dataConDefAsAtom defName conIx = liftBuilder do +dataConDefAsAtom :: EnvReader m + => Abs DataDefBinders (EmptyAbs (Nest Binder)) n + -> DataDefName n -> Int -> m n (Atom n) +dataConDefAsAtom bsAbs defName conIx = liftBuilder do buildTyConLam defName ImplicitArrow \_ params -> do defName' <- sinkM defName - def@(DataDef sourceName _ _ ) <- lookupDataDef defName' + def@(DataDef sourceName _ _) <- lookupDataDef defName' conDefs <- instantiateDataDef def params - DataConDef _ (EmptyAbs conArgBinders) conRep _ <- return $ conDefs !! conIx + let DataConDef _ conRep _ = conDefs !! conIx + Abs conArgBinders UnitE <- applyDataConAbs (sink bsAbs) params buildPureNaryLam (EmptyAbs $ binderNestAsPiNest PlainArrow conArgBinders) \conArgs -> do conProd <- buildDataCon (sink conRep) $ Var <$> conArgs return $ Con $ Newtype (sink $ TypeCon sourceName defName' params) $ @@ -1390,7 +1393,7 @@ dataConDefAsAtom defName conIx = liftBuilder do [] -> error "unreachable" [_] -> conProd _ -> SumVal conTys conIx conProd - where conTys = sinkList $ conDefs <&> \(DataConDef _ _ rty _) -> rty + where conTys = sinkList $ conDefs <&> \(DataConDef _ rty _) -> rty buildDataCon :: EnvReader m => Type n -> [Atom n] -> m n (Atom n) buildDataCon repTy topArgs = wrap repTy topArgs @@ -1417,20 +1420,25 @@ binderNestAsPiNest arr = \case Empty -> Empty Nest (b:>ty) rest -> Nest (PiBinder b ty arr) $ binderNestAsPiNest arr rest -inferDataDef :: EmitsInf o => UDataDef i -> InfererM i o (DataDef o) +inferDataDef :: EmitsInf o => UDataDef i + -> InfererM i o (PairE DataDef (Abs DataDefBinders (ListE (EmptyAbs (Nest Binder)))) o) inferDataDef (UDataDef (tyConName, paramBs) clsBs dataCons) = do - Abs paramBs' (Abs clsBs' (ListE dataCons')) <- + Abs paramBs' (Abs clsBs' (ListE dataConsWithBs')) <- withNestedUBinders paramBs id \_ -> do withNestedUBinders clsBs (LamBound . LamBinding ClassArrow) \_ -> ListE <$> mapM inferDataCon dataCons - return $ DataDef tyConName (DataDefBinders paramBs' clsBs') dataCons' + let (dataCons', bs') = unzip $ fromPairE <$> dataConsWithBs' + let ddbs = (DataDefBinders paramBs' clsBs') + return $ PairE (DataDef tyConName ddbs dataCons') + (Abs ddbs $ ListE bs') inferDataCon :: EmitsInf o - => (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) + => (SourceName, UDataDefTrail i) + -> InfererM i o (PairE DataConDef (EmptyAbs (Nest Binder)) o) inferDataCon (sourceName, UDataDefTrail argBs) = do argBs' <- checkUBinders (EmptyAbs argBs) let (repTy, projIdxs) = dataConRepTy argBs' - return (DataConDef sourceName argBs' repTy projIdxs) + return $ PairE (DataConDef sourceName repTy projIdxs) argBs' dataConRepTy :: EmptyAbs (Nest Binder) n -> (Type n, [[Int]]) dataConRepTy (Abs topBs UnitE) = case topBs of @@ -1710,9 +1718,9 @@ checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat UPatCon ~(InternalName _ conName) ps -> do (dataDefName, con) <- substM conName >>= getDataCon DataDef sourceName paramBs cons <- lookupDataDef dataDefName - DataConDef _ (EmptyAbs argBs) repTy idxs <- return $ cons !! con - when (nestLength argBs /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (nestLength argBs) + DataConDef _ repTy idxs <- return $ cons !! con + when (length idxs /= nestLength ps) $ throw TypeErr $ + "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) (params, repTy') <- inferParams (Abs paramBs repTy) constrainEq scrutineeTy $ TypeCon sourceName dataDefName params @@ -1785,9 +1793,9 @@ bindLamPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of (dataDefName, _) <- getDataCon =<< substM conName DataDef sourceName paramBs cons <- lookupDataDef dataDefName case cons of - [DataConDef _ (EmptyAbs argBs) _ _] -> do - when (nestLength argBs /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (nestLength argBs) + [DataConDef _ _ idxs] -> do + when (length idxs /= nestLength ps) $ throw TypeErr $ + "Unexpected number of pattern binders. Expected " ++ show (length idxs) ++ " got " ++ show (nestLength ps) (params, UnitE) <- inferParams (Abs paramBs UnitE) constrainVarTy v $ TypeCon sourceName dataDefName params diff --git a/src/lib/Interpreter.hs b/src/lib/Interpreter.hs index 43e5ece1b..989c27689 100644 --- a/src/lib/Interpreter.hs +++ b/src/lib/Interpreter.hs @@ -249,7 +249,7 @@ matchUPat (WithSrcB _ pat) x = do (UPatCon _ ps, Con (Newtype (TypeCon _ dataDefName _) _)) -> do DataDef _ _ cons <- lookupDataDef dataDefName case cons of - [DataConDef _ _ _ idxs] -> matchUPats ps [getProjection ix x' | ix <- idxs] + [DataConDef _ _ idxs] -> matchUPats ps [getProjection ix x' | ix <- idxs] _ -> error "Expected a single ADt constructor" (UPatPair (PairB p1 p2), PairVal x1 x2) -> do matchUPat p1 x1 >>= (`followedByFrag` matchUPat p2 x2) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 5392d405b..960dd5b50 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -488,8 +488,8 @@ instance Pretty (DataDef n) where "data" <+> p name <+> p bs <> prettyLines cons instance Pretty (DataConDef n) where - pretty (DataConDef name bs _ _) = - p name <+> ":" <+> p bs + pretty (DataConDef name repTy _) = + p name <+> ":" <+> p repTy instance Pretty (ClassDef n) where pretty (ClassDef classSourceName methodNames params superclasses methodTys) = diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 3c744a57c..2ab9ddf29 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -11,7 +11,7 @@ module QueryType ( caseAltsBinderTys, depPairLeftTy, extendEffect, getAppType, getTabAppType, getBaseMonoidType, getReferentTy, getMethodIndex, - instantiateDataDef, dataDefRep, + instantiateDataDef, applyDataConAbs, dataDefRep, instantiateNaryPi, instantiateDepPairTy, instantiatePi, instantiateTabPi, litType, lamExprTy, numNaryPiArgs, naryLamExprType, specializedFunType, @@ -80,7 +80,7 @@ caseAltsBinderTys ty = case ty of TypeCon _ defName params -> do def <- lookupDataDef defName cons <- instantiateDataDef def params - return [repTy | DataConDef _ _ repTy _ <- cons] + return [repTy | DataConDef _ repTy _ <- cons] SumTy types -> return types VariantTy (NoExt types) -> return $ toList types VariantTy _ -> fail "Can't pattern-match partially-known variants" @@ -125,17 +125,24 @@ getMethodIndex className methodSourceName = do {-# INLINE getMethodIndex #-} instantiateDataDef :: ScopeReader m => DataDef n -> DataDefParams n -> m n [DataConDef n] -instantiateDataDef (DataDef _ (DataDefBinders bs1 bs2) cons) (DataDefParams xs1 xs2) = do - fromListE <$> applySubst (bs1 @@> (SubstVal <$> xs1) <.> bs2 @@> (SubstVal <$> xs2)) (ListE cons) +instantiateDataDef (DataDef _ bs cons) params = do + fromListE <$> applyDataConAbs (Abs bs $ ListE cons) params {-# INLINE instantiateDataDef #-} +applyDataConAbs :: (SubstE AtomSubstVal e, SinkableE e, ScopeReader m) + => Abs DataDefBinders e n -> DataDefParams n -> m n (e n) +applyDataConAbs (Abs (DataDefBinders bs1 bs2) e) (DataDefParams xs1 xs2) = do + let paramsSubst = bs1 @@> (SubstVal <$> xs1) <.> bs2 @@> (SubstVal <$> xs2) + applySubst paramsSubst e +{-# INLINE applyDataConAbs #-} + -- Returns a representation type (type of an TypeCon-typed Newtype payload) -- given a list of instantiated DataConDefs. dataDefRep :: [DataConDef n] -> Type n dataDefRep = \case [] -> error "unreachable" -- There's no representation for a void type - [DataConDef _ _ ty _] -> ty - tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty + [DataConDef _ ty _] -> ty + tys -> SumTy $ tys <&> \(DataConDef _ ty _) -> ty instantiateNaryPi :: EnvReader m => NaryPiType n -> [Atom n] -> m n (NaryPiType n) instantiateNaryPi (NaryPiType bs eff resultTy) args = do @@ -211,7 +218,7 @@ projectionIndices ty = case ty of TypeCon _ defName _ -> do DataDef _ _ cons <- lookupDataDef defName case cons of - [DataConDef _ _ _ idxs] -> return idxs + [DataConDef _ _ idxs] -> return idxs _ -> unsupported StaticRecordTy types -> return $ iota (length types) <&> (:[0]) ProdTy tys -> return $ iota (length tys) <&> (:[]) @@ -356,7 +363,7 @@ instance HasType Atom where StaticRecordTy types | i == 0 -> return $ ProdTy $ toList types TypeCon _ defName params | i == 0 -> do def <- lookupDataDef defName - [DataConDef _ _ repTy _] <- instantiateDataDef def params + [DataConDef _ repTy _] <- instantiateDataDef def params return repTy RecordTy _ -> throw CompilerErr "Can't project partially-known records" Var _ -> throw CompilerErr $ "Tried to project value of unreduced type " <> pprint ty diff --git a/src/lib/Serialize.hs b/src/lib/Serialize.hs index 7bf601d88..14103558b 100644 --- a/src/lib/Serialize.hs +++ b/src/lib/Serialize.hs @@ -141,7 +141,7 @@ prettyVal val = case val of prettyData :: (MonadIO1 m, EnvReader m, Fallible1 m) => DataDefName n -> Int -> Atom n -> m n (Doc ann) prettyData dataDefName t rep = do DataDef _ _ dataCons <- lookupDataDef dataDefName - DataConDef conName _ _ idxs <- return $ dataCons !! t + DataConDef conName _ idxs <- return $ dataCons !! t prettyArgs <- forM idxs \ix -> prettyVal $ getProjection (init ix) rep return $ case prettyArgs of [] -> pretty conName diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index f7fa585fc..3358eba43 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -127,11 +127,10 @@ data DataDef n where -- binder name is in UExpr and Env DataDef :: SourceName -> DataDefBinders n l -> [DataConDef l] -> DataDef n --- TODO: The binder nest should be unnecessary. Try to get rid of it. data DataConDef n = -- Name for pretty printing, constructor elements, representation type, -- list of projection indices that recovers elements from the representation. - DataConDef SourceName (EmptyAbs (Nest Binder) n) (Type n) [[Int]] + DataConDef SourceName (Type n) [[Int]] deriving (Show, Generic) data DataDefBinders n l where @@ -877,10 +876,10 @@ instance AlphaEqE DataDef instance AlphaHashableE DataDef instance GenericE DataConDef where - type RepE DataConDef = (LiftE (SourceName, [[Int]])) `PairE` (Abs (Nest Binder) UnitE) `PairE` Type - fromE (DataConDef name ab repTy idxs) = (LiftE (name, idxs)) `PairE` ab `PairE` repTy + type RepE DataConDef = (LiftE (SourceName, [[Int]])) `PairE` Type + fromE (DataConDef name repTy idxs) = (LiftE (name, idxs)) `PairE` repTy {-# INLINE fromE #-} - toE ((LiftE (name, idxs)) `PairE` ab `PairE` repTy) = DataConDef name ab repTy idxs + toE ((LiftE (name, idxs)) `PairE` repTy) = DataConDef name repTy idxs {-# INLINE toE #-} instance SinkableE DataConDef instance HoistableE DataConDef