Skip to content

Commit

Permalink
Inline non-recursive functions with only one call site (#3204)
Browse files Browse the repository at this point in the history
* Closes #3198
  • Loading branch information
lukaszcz authored Dec 4, 2024
1 parent af9679d commit c79f5e3
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 9 deletions.
29 changes: 29 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
go = \case
NIdt Ident {..} -> _identSymbol == sym
_ -> False

-- Returns a map from symbols to their number of occurrences in the given node.
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
getSymbolsMap md = gather go mempty
where
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
go acc = \case
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
NIdt Ident {..} -> mapInc _identSymbol acc
NCase Case {..} -> mapInc _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' md _constrTag ->
mapInc (ci ^. constructorInductive) acc
_ -> acc

mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
mapInc k = HashMap.insertWith (+) k 1

getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
getTableSymbolsMap tab =
foldr
(HashMap.unionWith (+))
mempty
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
where
md = emptyModule {_moduleInfoTable = tab}

getModuleSymbolsMap :: Module -> HashMap Symbol Int
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable
20 changes: 13 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo
Expand All @@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
_ ->
False

convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms md = dmapL go
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
Expand All @@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
_
| HashSet.member _identSymbol nonRecSyms
&& length args >= argsNum
&& isInlineableLambda inlineDepth md bl def ->
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isInlineableLambda inlineDepth md bl def
) ->
mkApps def args
_ ->
node
Expand All @@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
Just InlineNever -> node
_
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
&& argsNum == 0
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isImmediate md def
) ->
def
| otherwise ->
node
Expand Down Expand Up @@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
where
(lamsNum, body) = unfoldLambdas' node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md

inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (nonRecursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
Expand Down Expand Up @@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab

symOcc :: HashMap Symbol Int
symOcc = getTableSymbolsMap tab

doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where
Expand All @@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecsReachable

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
where
nonRecs' =
if
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
Expand All @@ -23,8 +24,9 @@ optimize md = do
2
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. inlining' _optInliningDepth nonRecSyms
. inlining' _optInliningDepth nonRecSyms symOcc
)
. letFolding
where
nonRecSyms = nonRecursiveIdents md
symOcc = getModuleSymbolsMap md
1 change: 1 addition & 0 deletions tests/Compilation/positive/test086.juvix
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-- Patterns in definitions
module test086;

import Stdlib.Prelude open;
Expand Down

0 comments on commit c79f5e3

Please sign in to comment.