Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dougalm committed Sep 22, 2023
1 parent 537bb9b commit 766a8bd
Show file tree
Hide file tree
Showing 8 changed files with 322 additions and 309 deletions.
334 changes: 165 additions & 169 deletions lib/prelude.dx

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,15 @@ expr = propagateSrcE expr' where
_ -> throw SyntaxErr $ "Prefix (" ++ name ++ ") not legal as a bare expression"
where
range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS
range rangeName lim = explicitApp rangeName [ns UHole, lim]
range rangeName lim = explicitApp rangeName [lim]
expr' (CPostfix name g) =
case name of
".." -> range "RangeFrom" <$> expr g
"<.." -> range "RangeFromExc" <$> expr g
_ -> throw SyntaxErr $ "Postfix (" ++ name ++ ") not legal as a bare expression"
where
range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS
range rangeName lim = explicitApp rangeName [ns UHole, lim]
range rangeName lim = explicitApp rangeName [lim]
expr' (CLambda params body) = do
params' <- explicitBindersOptAnn $ map stripParens params
body' <- block body
Expand Down
3 changes: 2 additions & 1 deletion src/lib/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ instance (Monad m, CtxReader m) => CtxReader (EnvReaderT m o) where
{-# INLINE getErrCtx #-}

instance (Monad m, Catchable m) => Catchable (EnvReaderT m o) where
catchErr (EnvReaderT (ReaderT m)) f = undefined
catchErr (EnvReaderT (ReaderT m)) f = EnvReaderT $ ReaderT \env ->
m env `catchErr` \err -> runReaderT (runEnvReaderT' $ f err) env
{-# INLINE catchErr #-}

-- === Instances for Name monads ===
Expand Down
138 changes: 76 additions & 62 deletions src/lib/Inference.hs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/lib/PPrint.hs
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ instance Pretty IxMethod where
instance Pretty (SolverBinding n) where
pretty (InfVarBound ty _) = "Inference variable of type:" <+> p ty
pretty (SkolemBound ty ) = "Skolem variable of type:" <+> p ty
pretty (DictBound ty ) = "Dictionary variable of type:" <+> p ty

instance Pretty (Binding c n) where
pretty b = case b of
Expand Down
2 changes: 1 addition & 1 deletion src/lib/QueryType.hs
Original file line number Diff line number Diff line change
Expand Up @@ -425,4 +425,4 @@ checkExtends allowed (EffectRow effs effTail) = do
forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $
throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++
"\nAllowed: " ++ pprint allowed

{-# INLINE checkExtends #-}
5 changes: 4 additions & 1 deletion src/lib/Types/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2250,17 +2250,20 @@ instance AlphaEqE LinearizationSpec
instance AlphaHashableE LinearizationSpec

instance GenericE SolverBinding where
type RepE SolverBinding = EitherE2
type RepE SolverBinding = EitherE3
(PairE CType (LiftE InfVarCtx))
CType
CType
fromE = \case
InfVarBound ty ctx -> Case0 (PairE ty (LiftE ctx))
SkolemBound ty -> Case1 ty
DictBound ty -> Case2 ty
{-# INLINE fromE #-}

toE = \case
Case0 (PairE ty (LiftE ct)) -> InfVarBound ty ct
Case1 ty -> SkolemBound ty
Case2 ty -> DictBound ty
_ -> error "impossible"
{-# INLINE toE #-}

Expand Down
144 changes: 71 additions & 73 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM
SubstReaderT Name
(ReaderT1 CommuteMap
(ReaderT1 (LiftE Word32)
(StateT1 (LiftE Err) (BuilderT SimpIR FallibleM)))) i o a }
(StateT1 (LiftE [Err]) (BuilderT SimpIR FallibleM)))) i o a }
deriving ( Functor, Applicative, Monad, MonadFail, MonadReader (CommuteMap o)
, MonadState (LiftE Err o), Fallible, ScopeReader, EnvReader
, MonadState (LiftE [Err] o), Fallible, ScopeReader, EnvReader
, EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable
, SubstReader Name)

vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Err)
vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, [Err])
vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do
case popNest bsDestB of
Just (PairB bs b) ->
Expand All @@ -102,19 +102,18 @@ vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do
{-# SCC vectorizeLoops #-}

liftTopVectorizeM :: (EnvReader m)
=> Word32 -> TopVectorizeM i i a -> m i (a, Err)
liftTopVectorizeM vectorByteWidth action = undefined
-- liftTopVectorizeM vectorByteWidth action = do
-- fallible <- liftBuilderT $
-- flip runStateT1 mempty $ runReaderT1 (LiftE vectorByteWidth) $
-- runReaderT1 mempty $ runSubstReaderT idSubst $
-- runTopVectorizeM action
-- case runFallibleM fallible of
-- -- The failure case should not occur: vectorization errors should have been
-- -- caught inside `vectorizeLoopsDecls` (and should have been added to the
-- -- `Err` state of the `StateT` instance that is run with `runStateT` above).
-- Failure errs -> error $ pprint errs
-- Success (a, (LiftE errs)) -> return $ (a, errs)
=> Word32 -> TopVectorizeM i i a -> m i (a, [Err])
liftTopVectorizeM vectorByteWidth action = do
fallible <- liftBuilderT $
flip runStateT1 mempty $ runReaderT1 (LiftE vectorByteWidth) $
runReaderT1 mempty $ runSubstReaderT idSubst $
runTopVectorizeM action
case runFallibleM fallible of
-- The failure case should not occur: vectorization errors should have been
-- caught inside `vectorizeLoopsDecls` (and should have been added to the
-- `Err` state of the `StateT` instance that is run with `runStateT` above).
Failure errs -> error $ pprint errs
Success (a, (LiftE errs)) -> return $ (a, errs)

addVectErrCtx :: Fallible m => String -> String -> m a -> m a
addVectErrCtx name payload m =
Expand Down Expand Up @@ -167,63 +166,62 @@ vectorizeLoopsLamExpr (LamExpr bs body) = case bs of
return $ LamExpr (Nest b' bs') body'

vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
vectorizeLoopsExpr expr = undefined
-- vectorizeLoopsExpr expr = do
-- vectorByteWidth <- askVectorByteWidth
-- narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
-- let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
-- case expr of
-- PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
-- sz <- simplifyIxSize =<< renameM ixty
-- case sz of
-- Just n | n `mod` loopWidth == 0 -> (do
-- safe <- vectorSafeEffect effs
-- if safe
-- then (do
-- Distinct <- getDistinct
-- let vn = n `div` loopWidth
-- body' <- vectorizeSeq loopWidth ixty body
-- dest' <- renameM dest
-- seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
-- return $ PrimOp $ DAMOp seqOp)
-- else renameM expr)
-- `catchErr` \errs -> do
-- let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
-- ctx = mempty { messageCtx = [msg] }
-- errs' = prependCtxToErrs ctx errs
-- modify (<> LiftE errs')
-- recurSeq expr
-- _ -> recurSeq expr
-- PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
-- item' <- renameM item
-- itemTy <- return $ getType item'
-- lam <- buildEffLam noHint itemTy \hb refb ->
-- extendRenamer (hb' @> atomVarName hb) do
-- extendRenamer (refb' @> atomVarName refb) do
-- vectorizeLoopsBlock body
-- PrimOp . Hof <$> mkTypedHof (RunReader item' lam)
-- PrimOp (Hof (TypedHof (EffTy _ ty)
-- (RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do
-- dest' <- renameM dest
-- monoid' <- renameM monoid
-- commutativity <- monoidCommutativity monoid'
-- PairTy _ accTy <- renameM ty
-- lam <- buildEffLam noHint accTy \hb refb ->
-- extendRenamer (hb' @> atomVarName hb) do
-- extendRenamer (refb' @> atomVarName refb) do
-- extendCommuteMap (atomVarName hb) commutativity do
-- vectorizeLoopsBlock body
-- PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam)
-- _ -> renameM expr
-- where
-- recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
-- recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do
-- effs' <- renameM effs
-- ixty' <- renameM ixty
-- dest' <- renameM dest
-- body' <- vectorizeLoopsLamExpr body
-- return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
-- recurSeq _ = error "Impossible"
vectorizeLoopsExpr expr = do
vectorByteWidth <- askVectorByteWidth
narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
case expr of
PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
sz <- simplifyIxSize =<< renameM ixty
case sz of
Just n | n `mod` loopWidth == 0 -> (do
safe <- vectorSafeEffect effs
if safe
then (do
Distinct <- getDistinct
let vn = n `div` loopWidth
body' <- vectorizeSeq loopWidth ixty body
dest' <- renameM dest
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
return $ PrimOp $ DAMOp seqOp)
else renameM expr)
`catchErr` \err -> do
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
ctx = mempty { messageCtx = [msg] }
err' = prependCtxToErr ctx err
modify (\(LiftE errs) -> LiftE (err':errs))
recurSeq expr
_ -> recurSeq expr
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
item' <- renameM item
itemTy <- return $ getType item'
lam <- buildEffLam noHint itemTy \hb refb ->
extendRenamer (hb' @> atomVarName hb) do
extendRenamer (refb' @> atomVarName refb) do
vectorizeLoopsBlock body
PrimOp . Hof <$> mkTypedHof (RunReader item' lam)
PrimOp (Hof (TypedHof (EffTy _ ty)
(RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do
dest' <- renameM dest
monoid' <- renameM monoid
commutativity <- monoidCommutativity monoid'
PairTy _ accTy <- renameM ty
lam <- buildEffLam noHint accTy \hb refb ->
extendRenamer (hb' @> atomVarName hb) do
extendRenamer (refb' @> atomVarName refb) do
extendCommuteMap (atomVarName hb) commutativity do
vectorizeLoopsBlock body
PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam)
_ -> renameM expr
where
recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o)
recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do
effs' <- renameM effs
ixty' <- renameM ixty
dest' <- renameM dest
body' <- vectorizeLoopsLamExpr body
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
recurSeq _ = error "Impossible"

simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m)
=> IxType SimpIR n -> m n (Maybe Word32)
Expand Down

0 comments on commit 766a8bd

Please sign in to comment.