Skip to content

Commit

Permalink
improve optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Nov 29, 2024
1 parent 80dd864 commit ed93261
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 21 deletions.
15 changes: 11 additions & 4 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,19 @@ isDataValue = \case
NCtr Constr {..} -> all isDataValue _constrArgs
_ -> False

isFullyApplied :: Module -> Node -> Bool
isFullyApplied md node = case h of
isFullyApplied :: Module -> BinderList Binder -> Node -> Bool
isFullyApplied md bl node = case h of
NIdt Ident {..}
| Just ii <- lookupIdentifierInfo' md _identSymbol ->
length args >= ii ^. identifierArgsNum
_ -> False
length args == ii ^. identifierArgsNum
NVar Var {..} ->
case BL.lookupMay _varIndex bl of
Just Binder {..} ->
length args == length (typeArgs _binderType)
Nothing ->
False
_ ->
False
where
(h, args) = unfoldApps' node

Expand Down
34 changes: 21 additions & 13 deletions src/Juvix/Compiler/Core/Transformation/Optimize/LoopHoisting.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting (loopHoisting) where

import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Info qualified as Info
import Juvix.Compiler.Core.Info.FreeVarsInfo qualified as Info
Expand All @@ -13,11 +14,12 @@ loopHoisting md = mapT (const (umapL go)) md
go :: BinderList Binder -> Node -> Node
go bl node = case node of
NApp {} -> case h of
-- TODO: variables
NIdt Ident {..} -> goApp bl _identSymbol h 0 args
NIdt Ident {..}
| Just ii <- lookupIdentifierInfo' md _identSymbol,
length args == ii ^. identifierArgsNum ->
goApp bl _identSymbol h 0 args
_ -> node
where
-- TODO: consider only fully applied
(h, args) = unfoldApps node
_ ->
node
Expand All @@ -27,10 +29,14 @@ loopHoisting md = mapT (const (umapL go)) md
[] -> h
(info, arg) : args' -> case arg of
NLam {}
| isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym ->
| isHoistable sym argNum ->
goLamApp bl sym info h arg (argNum + 1) args'
_ -> goApp bl sym (mkApp info h arg) (argNum + 1) args'

isHoistable :: Symbol -> Int -> Bool
isHoistable sym argNum =
isArgRecursiveInvariant md sym argNum && isDirectlyRecursive md sym

goLamApp :: BinderList Binder -> Symbol -> Info -> Node -> Node -> Int -> [(Info, Node)] -> Node
goLamApp bl sym info h arg argNum args'
| null subterms = goApp bl sym (mkApp info h arg) argNum args'
Expand All @@ -44,31 +50,33 @@ loopHoisting md = mapT (const (umapL go)) md
)
where
(lams, body) = unfoldLambdasRev arg
(subterms, body') = extractMaximalInvariantSubterms (length lams) body
bl' = BL.prepend (map (^. lambdaLhsBinder) lams) bl
(subterms, body') = extractMaximalInvariantSubterms (length bl) bl' body
n = length subterms
subterms' = zipWith shift [0 ..] subterms

extractMaximalInvariantSubterms :: Int -> Node -> ([Node], Node)
extractMaximalInvariantSubterms bindersNum body =
extractMaximalInvariantSubterms :: Int -> BinderList Binder -> Node -> ([Node], Node)
extractMaximalInvariantSubterms initialBindersNum bl0 body =
first (map (removeInfo Info.kFreeVarsInfo))
. second (removeInfo Info.kFreeVarsInfo)
. run
. runState []
$ dmapNRM extract (Info.computeFreeVarsInfo body)
$ dmapLRM' (bl0, extract) (Info.computeFreeVarsInfo body)
where
extract :: (Member (State [Node]) r) => Level -> Node -> Sem r Recur
extract n node
extract :: (Member (State [Node]) r) => BinderList Binder -> Node -> Sem r Recur
extract bl node
| not (isImmediate md node || isLambda node)
&& isFullyApplied md node -- TODO: variables
&& isFullyApplied md bl node
&& null boundVars = do
k <- length <$> get @[Node]
modify' ((shift (-(n + bindersNum)) node) :)
modify' ((shift (-n) node) :)
-- This variable is later adjusted to the correct index in `adjustLetBoundVars`
return $ End (mkVar' (-k - 1))
| otherwise =
return $ Recur node
where
boundVars = filter (< n + bindersNum) $ Info.getFreeVars node
boundVars = filter (< n) $ Info.getFreeVars node
n = length bl - initialBindersNum

adjustLetBoundVars :: Node -> Node
adjustLetBoundVars = umapN adjust
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ optimize' :: CoreOptions -> Module -> Module
optimize' opts@CoreOptions {..} md =
filterUnreachable
. compose
(6 * _optOptimizationLevel)
(4 * _optOptimizationLevel)
( doConstantFolding
. doSimplification 1
. specializeArgs
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
import Juvix.Compiler.Core.Transformation.Optimize.Inlining
import Juvix.Compiler.Core.Transformation.Optimize.LambdaFolding
import Juvix.Compiler.Core.Transformation.Optimize.LetFolding
import Juvix.Compiler.Core.Transformation.Optimize.LoopHoisting

optimize :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
optimize =
withOptimizationLevel 1 $
optimize md = do
CoreOptions {..} <- ask
withOptimizationLevel' md 1 $
return
. loopHoisting
. letFolding
. lambdaFolding
. letFolding
. caseFolding
. compose
2
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. inlining' _optInliningDepth nonRecSyms
)
. letFolding
where
nonRecSyms = nonRecursiveIdents md

0 comments on commit ed93261

Please sign in to comment.