Skip to content

Commit

Permalink
Merge pull request google-research#1272 from google-research/uexpr-block
Browse files Browse the repository at this point in the history
Add blocks to UExpr IR
  • Loading branch information
dougalm authored Apr 8, 2023
2 parents 1ae0689 + 5ab8e8c commit 4d88297
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 183 deletions.
22 changes: 11 additions & 11 deletions lib/prelude.dx
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,8 @@ instance VSpace(n=>a) given (a|VSpace, n|Ix)

instance VSpace((a, b)) given (a|VSpace, b|VSpace)
def (.*)(s, pair) =
(a, b) = pair
(s .* a, s .* b)
(x, y) = pair
(s .* x, s .* y)

instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace)
def (.*)(s, xs) = for i. s .* xs[i]
Expand Down Expand Up @@ -1552,19 +1552,19 @@ instance Show(())
def show(_) = "()"

instance Show((a, b)) given (a|Show, b|Show)
def show(x) =
(a, b) = x
"(" <> show a <> ", " <> show b <> ")"
def show(tup) =
(x, y) = tup
"(" <> show x <> ", " <> show y <> ")"

instance Show((a, b, c)) given (a|Show, b|Show, c|Show)
def show(x) =
(a, b, c) = x
"(" <> show a <> ", " <> show b <> ", " <> show c <> ")"
def show(tup) =
(x, y, z) = tup
"(" <> show x <> ", " <> show y <> ", " <> show z <> ")"

instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show)
def show(x) =
(a, b, c, d) = x
"(" <> show a <> ", " <> show b <> ", " <> show c <> ", " <> show d <> ")"
def show(tup) =
(x, y, z, w) = tup
"(" <> show x <> ", " <> show y <> ", " <> show z <> ", " <> show w <> ")"

'### Parse interface
For types that can be parsed from a `String`.
Expand Down
74 changes: 43 additions & 31 deletions src/lib/AbstractSyntax.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ import Util
parseExpr :: Fallible m => Group -> m (UExpr VoidS)
parseExpr e = liftSyntaxM $ expr e

parseDecl :: Fallible m => CTopDecl -> m (UDecl VoidS VoidS)
parseDecl :: Fallible m => CTopDecl -> m (UTopDecl VoidS VoidS)
parseDecl d = liftSyntaxM $ topDecl d

parseBlock :: Fallible m => CSBlock -> m (UExpr VoidS)
parseBlock :: Fallible m => CSBlock -> m (UBlock VoidS)
parseBlock b = liftSyntaxM $ block b

liftSyntaxM :: Fallible m => SyntaxM a -> m a
Expand All @@ -93,7 +93,7 @@ checkSourceBlockParses = \case
when (ann /= PlainLet) $ fail "Cannot annotate expressions"
void $ expr e
TopDecl d -> void $ topDecl d
Command _ b -> void $ block b
Command _ b -> void $ expr b
DeclareForeign _ _ ty -> void $ expr ty
DeclareCustomLinearization _ _ body -> void $ expr body
Misc _ -> return ()
Expand All @@ -103,9 +103,9 @@ checkSourceBlockParses = \case

type SyntaxM = FallibleM

topDecl :: CTopDecl -> SyntaxM (UDecl VoidS VoidS)
topDecl :: CTopDecl -> SyntaxM (UTopDecl VoidS VoidS)
topDecl = dropSrc topDecl' where
topDecl' (CSDecl ann d) = decl ann (WithSrc Nothing d)
topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc Nothing d)
topDecl' (CData name tyConParams givens constructors) = do
tyConParams' <- aExplicitParams tyConParams
givens' <- toNest <$> fromMaybeM givens [] aGivens
Expand All @@ -131,27 +131,24 @@ topDecl = dropSrc topDecl' where
ty' <- expr ty
return (fromString methodName, ty')
return $ UInterface params' methodTys (fromString name) (toNest methodNames)
topDecl' (CInstanceDecl def) = aInstanceDef def
topDecl' (CEffectDecl _ _) = error "not implemented"
topDecl' (CHandlerDecl _ _ _ _ _ _) = error "not implemented"

uExprAsDecl :: UExpr VoidS -> UDecl VoidS VoidS
uExprAsDecl e = ULet PlainLet (nsB UPatIgnore) Nothing e

decl :: LetAnn -> CSDecl -> SyntaxM (UDecl VoidS VoidS)
decl ann = dropSrc \case
CLet binder body -> do
decl ann = propagateSrcB \case
CLet binder rhs -> do
(p, ty) <- patOptAnn binder
ULet ann p ty <$> block body
ULet ann p ty <$> asExpr <$> block rhs
CBind _ _ -> throw SyntaxErr "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope."
CDefDecl def -> do
(name, lam) <- aDef def
return $ ULet ann (fromString name) Nothing (ns $ ULam lam)
CExpr g -> uExprAsDecl <$> expr g
CInstanceDecl def -> aInstanceDef def
CExpr g -> UExprDecl <$> expr g
CPass -> return UPass

aInstanceDef :: CInstanceDef -> SyntaxM (UDecl VoidS VoidS)
aInstanceDef (CInstanceDef clName args givens (CSBlock methods) instNameAndParams) = do
aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS)
aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do
let clName' = fromString clName
args' <- mapM expr args
givens' <- toNest <$> fromMaybeM givens [] aGivens
Expand Down Expand Up @@ -379,19 +376,32 @@ aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of
return $ UMethodDef (fromString name) rhs'
_ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`."

block :: CSBlock -> SyntaxM (UExpr VoidS)
block (CSBlock []) = throw SyntaxErr "Block must end in expression"
block (ExprBlock g) = expr g
block (CSBlock ((WithSrc pos (CBind b rhs)):ds)) = do
asExpr :: UBlock VoidS -> UExpr VoidS
asExpr (WithSrcE src b) = case b of
UBlock Empty e -> e
_ -> WithSrcE src $ UDo $ WithSrcE src b

block :: CSBlock -> SyntaxM (UBlock VoidS)
block (ExprBlock g) = WithSrcE Nothing . UBlock Empty <$> expr g
block (IndentedBlock decls) = do
(decls', result) <- blockDecls decls
return $ WithSrcE Nothing $ UBlock decls' result

blockDecls :: [CSDecl] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS)
blockDecls [] = error "shouldn't have empty list of decls"
blockDecls [WithSrc src d] = addSrcContext src case d of
CExpr g -> (Empty,) <$> expr g
_ -> throw SyntaxErr "Block must end in expression"
blockDecls (WithSrc pos (CBind b rhs):ds) = do
WithExpl _ b' <- generalBinder DataParam Explicit b
rhs' <- block rhs
body <- block $ CSBlock ds
rhs' <- asExpr <$> block rhs
body <- block $ IndentedBlock ds
let lam = ULam $ ULamExpr (UnaryNest (WithExpl Explicit b')) ExplicitApp Nothing Nothing body
return $ WithSrcE pos $ extendAppRight rhs' (ns lam)
block (CSBlock (d@(WithSrc pos _):ds)) = do
return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam))
blockDecls (d:ds) = do
d' <- decl PlainLet d
e' <- block $ CSBlock ds
return $ WithSrcE pos $ UDecl $ UDeclExpr d' e'
(ds', e) <- blockDecls ds
return (Nest d' ds', e)

-- === Concrete to abstract syntax of expressions ===

Expand Down Expand Up @@ -424,7 +434,7 @@ expr = propagateSrcE expr' where
resultTy <- expr rhs
return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy
_ -> throw SyntaxErr "Argument types should be in parentheses"
expr' (CDo b) = dropSrcE <$> block b
expr' (CDo b) = UDo <$> block b
-- Binders (e.g., in pi types) should not hit this case
expr' (CBin (WithSrc opSrc op) lhs rhs) =
case op of
Expand Down Expand Up @@ -511,15 +521,15 @@ expr = propagateSrcE expr' where
-- TODO: Can we fetch the source position from the error context, to feed into `buildFor`?
e <- buildFor (0, 0) dir <$> mapM binderOptTy indices <*> block body
if trailingUnit
then return $ UDecl $ UDeclExpr (ULet PlainLet (nsB UPatIgnore) Nothing e) $ ns $ unitExpr
then return $ UDo $ ns $ UBlock (UnaryNest (nsB $ UExprDecl e)) (ns unitExpr)
else return $ dropSrcE e
expr' (CCase scrut alts) = UCase <$> (expr scrut) <*> mapM alternative alts
where alternative (match, body) = UAlt <$> casePat match <*> block body
expr' (CIf p c a) = do
p' <- expr p
c' <- block c
a' <- case a of
Nothing -> return $ ns $ unitExpr
Nothing -> return $ ns $ UBlock Empty $ ns unitExpr
(Just alternative) -> block alternative
return $ UCase p'
[ UAlt (nsB $ UPatCon "True" Empty) c'
Expand All @@ -537,10 +547,12 @@ labelExpr PlainLabel str = ULabel str
-- === Builders ===

-- TODO Does this generalize? Swap list for Nest?
buildFor :: SrcPos -> Direction -> [UOptAnnBinder VoidS VoidS] -> UExpr VoidS -> UExpr VoidS
buildFor :: SrcPos -> Direction -> [UOptAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS
buildFor pos dir binders body = case binders of
[] -> body
b:bs -> WithSrcE (Just pos) $ UFor dir $ UForExpr b $ buildFor pos dir bs body
[] -> error "should have nonempty list of binder"
[b] -> WithSrcE (Just pos) $ UFor dir $ UForExpr b body
b:bs -> WithSrcE (Just pos) $ UFor dir $ UForExpr b $
ns $ UBlock Empty $ buildFor pos dir bs body

-- === Helpers ===

Expand Down
Loading

0 comments on commit 4d88297

Please sign in to comment.