Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hoisting of loop-invariant subexpressions #3195

Merged
merged 19 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Juvix/Compiler/Core/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ data TransformationId
| LetFolding
| LambdaFolding
| LetHoisting
| LoopHoisting
| Inlining
| MandatoryInlining
| FoldTypeSynonyms
Expand All @@ -47,6 +48,7 @@ data TransformationId
| OptPhaseExec
| OptPhaseVampIR
| OptPhaseMain
| OptPhasePreLifting
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand Down Expand Up @@ -77,7 +79,7 @@ toVampIRTransformations =

toStrippedTransformations :: TransformationId -> [TransformationId]
toStrippedTransformations checkId =
combineInfoTablesTransformations ++ [checkId, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames]
combineInfoTablesTransformations ++ [checkId, OptPhasePreLifting, LambdaLetRecLifting, TopEtaExpand, OptPhaseExec, MoveApps, RemoveTypeArgs, DisambiguateNames]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand Down Expand Up @@ -109,6 +111,7 @@ instance TransformationId' TransformationId where
LetFolding -> strLetFolding
LambdaFolding -> strLambdaFolding
LetHoisting -> strLetHoisting
LoopHoisting -> strLoopHoisting
Inlining -> strInlining
MandatoryInlining -> strMandatoryInlining
FoldTypeSynonyms -> strFoldTypeSynonyms
Expand All @@ -124,6 +127,7 @@ instance TransformationId' TransformationId where
OptPhaseExec -> strOptPhaseExec
OptPhaseVampIR -> strOptPhaseVampIR
OptPhaseMain -> strOptPhaseMain
OptPhasePreLifting -> strOptPhasePreLifting

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import Juvix.Prelude
strLetHoisting :: Text
strLetHoisting = "let-hoisting"

strLoopHoisting :: Text
strLoopHoisting = "loop-hoisting"

strStoredPipeline :: Text
strStoredPipeline = "pipeline-stored"

Expand Down Expand Up @@ -142,3 +145,6 @@ strOptPhaseVampIR = "opt-phase-vampir"

strOptPhaseMain :: Text
strOptPhaseMain = "opt-phase-main"

strOptPhasePreLifting :: Text
strOptPhasePreLifting = "opt-phase-pre-lifting"
6 changes: 6 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ mkLambda i bi b = NLam (Lambda i bi b)
mkLambda' :: Type -> Node -> Node
mkLambda' ty = mkLambda Info.empty (mkBinder' ty)

mkLambda'' :: Binder -> Node -> Node
mkLambda'' = mkLambda Info.empty

mkLambdas :: [Info] -> [Binder] -> Node -> Node
mkLambdas is bs n = foldl' (flip (uncurry mkLambda)) n (reverse (zipExact is bs))

mkLambdas' :: [Type] -> Node -> Node
mkLambdas' tys n = foldl' (flip mkLambda') n (reverse tys)

mkLambdas'' :: [Binder] -> Node -> Node
mkLambdas'' bs n = foldl' (flip mkLambda'') n (reverse bs)

mkLetItem :: Text -> Type -> Node -> LetItem
mkLetItem name ty = LetItem (mkBinder name ty)

Expand Down
64 changes: 64 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,42 @@ isImmediate md = \case
isImmediate' :: Node -> Bool
isImmediate' = isImmediate emptyModule

isImmediateOrLambda :: Module -> Node -> Bool
isImmediateOrLambda md node = isImmediate md node || isLambda node

-- | True if the argument is fully evaluated first-order data
isDataValue :: Node -> Bool
isDataValue = \case
NCst {} -> True
NCtr Constr {..} -> all isDataValue _constrArgs
_ -> False

isFullyApplied :: Module -> BinderList Binder -> Node -> Bool
isFullyApplied md bl node = case h of
NIdt Ident {..}
| Just ii <- lookupIdentifierInfo' md _identSymbol ->
length args == ii ^. identifierArgsNum
NVar Var {..} ->
case BL.lookupMay _varIndex bl of
Just Binder {..} ->
length args == length (typeArgs _binderType)
Nothing ->
False
_ ->
False
where
(h, args) = unfoldApps' node

isFailNode :: Node -> Bool
isFailNode = \case
NBlt (BuiltinApp {..}) | _builtinAppOp == OpFail -> True
_ -> False

isLambda :: Node -> Bool
isLambda = \case
NLam {} -> True
_ -> False

isTrueConstr :: Node -> Bool
isTrueConstr = \case
NCtr Constr {..} | _constrTag == BuiltinTag TagTrue -> True
Expand Down Expand Up @@ -576,3 +600,43 @@ checkInfoTable tab =
all isClosed (tab ^. identContext)
&& all (isClosed . (^. identifierType)) (tab ^. infoIdentifiers)
&& all (isClosed . (^. constructorType)) (tab ^. infoConstructors)

-- | Checks if the `n`th argument (zero-based) is passed without modification to
-- direct recursive calls.
isArgRecursiveInvariant :: Module -> Symbol -> Int -> Bool
isArgRecursiveInvariant tab sym argNum = run $ execState True $ dmapNRM go body
where
nodeSym = lookupIdentifierNode tab sym
(lams, body) = unfoldLambdas nodeSym
n = length lams

go :: (Member (State Bool) r) => Level -> Node -> Sem r Recur
go lvl node = case node of
NApp {} ->
let (h, args) = unfoldApps' node
in case h of
NIdt Ident {..}
| _identSymbol == sym ->
let b =
argNum < length args
&& case args !! argNum of
NVar Var {..} | _varIndex == lvl + n - argNum - 1 -> True
_ -> False
in do
modify' (&& b)
mapM_ (dmapNRM' (lvl, go)) args
return $ End node
_ -> return $ Recur node
NIdt Ident {..}
| _identSymbol == sym -> do
put False
return $ End node
_ -> return $ Recur node

isDirectlyRecursive :: Module -> Symbol -> Bool
isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNode md sym)
where
go :: Node -> Bool
go = \case
NIdt Ident {..} -> _identSymbol == sym
_ -> False
11 changes: 7 additions & 4 deletions src/Juvix/Compiler/Core/Info/FreeVarsInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ computeFreeVarsInfo' lambdaMultiplier = umap go
fvi =
FreeVarsInfo
. fmap (* lambdaMultiplier)
$ getFreeVars 1 _lambdaBody
$ getFreeVars' 1 _lambdaBody
_ ->
modifyInfo (Info.insert fvi) node
where
Expand All @@ -47,20 +47,23 @@ computeFreeVarsInfo' lambdaMultiplier = umap go
foldr
( \NodeChild {..} acc ->
Map.unionWith (+) acc $
getFreeVars _childBindersNum _childNode
getFreeVars' _childBindersNum _childNode
)
mempty
(children node)

getFreeVars :: Int -> Node -> Map Index Int
getFreeVars bindersNum node =
getFreeVars' :: Int -> Node -> Map Index Int
getFreeVars' bindersNum node =
Map.mapKeysMonotonic (\idx -> idx - bindersNum)
. Map.filterWithKey (\idx _ -> idx >= bindersNum)
$ getFreeVarsInfo node ^. infoFreeVars

getFreeVarsInfo :: Node -> FreeVarsInfo
getFreeVarsInfo = fromJust . Info.lookup kFreeVarsInfo . getInfo

getFreeVars :: Node -> [Index]
getFreeVars = Map.keys . Map.filter (> 0) . (^. infoFreeVars) . getFreeVarsInfo

freeVarOccurrences :: Index -> Node -> Int
freeVarOccurrences idx n = fromMaybe 0 (Map.lookup idx (getFreeVarsInfo n ^. infoFreeVars))

Expand Down
21 changes: 21 additions & 0 deletions src/Juvix/Compiler/Core/Info/VolatilityInfo.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module Juvix.Compiler.Core.Info.VolatilityInfo where

import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info qualified as Info

newtype VolatilityInfo = VolatilityInfo
{ _infoIsVolatile :: Bool
}

instance IsInfo VolatilityInfo

kVolatilityInfo :: Key VolatilityInfo
kVolatilityInfo = Proxy

makeLenses ''VolatilityInfo

isVolatile :: Info -> Bool
isVolatile i =
case Info.lookup kVolatilityInfo i of
Just VolatilityInfo {..} -> _infoIsVolatile
Nothing -> False
4 changes: 4 additions & 0 deletions src/Juvix/Compiler/Core/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ import Juvix.Compiler.Core.Transformation.Optimize.FilterUnreachable (filterUnre
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting
import Juvix.Compiler.Core.Transformation.Optimize.MandatoryInlining
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Eval qualified as Phase.Eval
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Exec qualified as Phase.Exec
import Juvix.Compiler.Core.Transformation.Optimize.Phase.Main qualified as Phase.Main
import Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting qualified as Phase.PreLifting
import Juvix.Compiler.Core.Transformation.Optimize.Phase.VampIR qualified as Phase.VampIR
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyComparisons (simplifyComparisons)
import Juvix.Compiler.Core.Transformation.Optimize.SimplifyIfs
Expand Down Expand Up @@ -92,6 +94,7 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
LetFolding -> return . letFolding
LambdaFolding -> return . lambdaFolding
LetHoisting -> return . letHoisting
LoopHoisting -> return . loopHoisting
Inlining -> inlining
MandatoryInlining -> return . mandatoryInlining
FoldTypeSynonyms -> return . foldTypeSynonyms
Expand All @@ -107,3 +110,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
OptPhaseExec -> Phase.Exec.optimize
OptPhaseVampIR -> Phase.VampIR.optimize
OptPhaseMain -> Phase.Main.optimize
OptPhasePreLifting -> Phase.PreLifting.optimize
8 changes: 8 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/ComputeTypeInfo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.TypeInfo qualified as Info
import Juvix.Compiler.Core.Transformation.Base

computeNodeType' :: Module -> BinderList Binder -> Node -> Type
computeNodeType' md bl node = rePis argtys' ty'
where
ty = computeNodeType md (mkLambdas'' (reverse (toList bl)) node)
(argtys, ty') = unfoldPi ty
argtys' = drop (length bl) argtys

-- | Computes the type of a closed well-typed node.
computeNodeType :: Module -> Node -> Type
computeNodeType md = Info.getNodeType . computeNodeTypeInfo md

Expand Down
10 changes: 6 additions & 4 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LetFolding.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info.DebugOpsInfo as Info
import Juvix.Compiler.Core.Info.FreeVarsInfo as Info
import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info
import Juvix.Compiler.Core.Transformation.Base

convertNode :: (Module -> BinderList Binder -> Node -> Bool) -> Module -> Node -> Node
Expand All @@ -25,10 +26,11 @@ convertNode isFoldable md = rmapL go
go :: ([BinderChange] -> Node -> Node) -> BinderList Binder -> Node -> Node
go recur bl = \case
NLet Let {..}
| ( isImmediate md (_letItem ^. letItemValue)
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable md bl (_letItem ^. letItemValue)
)
| not (Info.isVolatile _letInfo)
&& ( isImmediate md (_letItem ^. letItemValue)
|| Info.freeVarOccurrences 0 _letBody <= 1
|| isFoldable md bl (_letItem ^. letItemValue)
)
&& not (Info.hasDebugOps _letBody) ->
go (recur . (mkBCRemove b val' :)) (BL.cons b bl) _letBody
where
Expand Down
101 changes: 101 additions & 0 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info
import Juvix.Compiler.Core.Info.VolatilityInfo qualified as Info
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.ComputeTypeInfo (computeNodeType')

loopHoisting :: Module -> Module
loopHoisting md = mapT (const (umapL go)) md
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
NApp {} -> case h of
NIdt Ident {..}
| Just ii <- lookupIdentifierInfo' md _identSymbol,
length args == ii ^. identifierArgsNum ->
goApp bl _identSymbol h 0 args
_ -> node
where
(h, args) = unfoldApps node
_ ->
node

goApp :: BinderList Binder -> Symbol -> Node -> Int -> [(Info, Node)] -> Node
goApp bl sym h argNum args = case args of
[] -> h
(info, arg) : args' -> case arg of
NLam {}
| isHoistable sym argNum ->
goLamApp bl sym info h arg (argNum + 1) args'
_ -> goApp bl sym (mkApp info h arg) (argNum + 1) args'

isHoistable :: Symbol -> Int -> Bool
isHoistable sym argNum =
isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym

goLamApp :: BinderList Binder -> Symbol -> Info -> Node -> Node -> Int -> [(Info, Node)] -> Node
goLamApp bl sym info h arg argNum args'
| null subterms = goApp bl sym (mkApp info h arg) argNum args'
| otherwise =
setLetsVolatile n $
mkLets'
(map (\node -> (computeNodeType' md bl node, node)) subterms')
( adjustLetBoundVars
. shift n
$ (mkApps (mkApp info h (reLambdasRev lams body')) args')
)
where
(lams, body) = unfoldLambdasRev arg
bl' = BL.prepend (map (^. lambdaLhsBinder) lams) bl
(subterms, body') = extractMaximalInvariantSubterms (length bl) bl' body
n = length subterms
subterms' = zipWith shift [0 ..] subterms

extractMaximalInvariantSubterms :: Int -> BinderList Binder -> Node -> ([Node], Node)
extractMaximalInvariantSubterms initialBindersNum bl0 body =
first (map (removeInfo Info.kFreeVarsInfo))
. second (removeInfo Info.kFreeVarsInfo)
. run
. runState []
$ dmapLRM' (bl0, extract) (Info.computeFreeVarsInfo body)
where
extract :: (Member (State [Node]) r) => BinderList Binder -> Node -> Sem r Recur
extract bl node
| not (isImmediate md node || isLambda node)
&& isFullyApplied md bl node
&& null boundVars = do
k <- length <$> get @[Node]
modify' ((shift (-n) node) :)
-- This variable is later adjusted to the correct index in `adjustLetBoundVars`
return $ End (mkVar' (-k - 1))
| otherwise =
return $ Recur node
where
boundVars = filter (< n) $ Info.getFreeVars node
n = length bl - initialBindersNum

adjustLetBoundVars :: Node -> Node
adjustLetBoundVars = umapN adjust
where
adjust :: Level -> Node -> Node
adjust n node = case node of
NVar Var {..}
| _varIndex < 0 -> mkVar' (n - _varIndex - 1)
_ -> node

setLetsVolatile :: Int -> Node -> Node
setLetsVolatile n
| n == 0 = id
| otherwise = \case
NLet Let {..} ->
NLet
Let
{ _letInfo = Info.insert (Info.VolatilityInfo True) _letInfo,
_letBody = setLetsVolatile (n - 1) _letBody,
_letItem
}
node -> node
Loading
Loading