Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline non-recursive functions with only one call site #3204

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/Juvix/Compiler/Core/Extra/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo
go = \case
NIdt Ident {..} -> _identSymbol == sym
_ -> False

-- Returns a map from symbols to their number of occurrences in the given node.
getSymbolsMap :: Module -> Node -> HashMap Symbol Int
getSymbolsMap md = gather go mempty
where
go :: HashMap Symbol Int -> Node -> HashMap Symbol Int
go acc = \case
NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc
NIdt Ident {..} -> mapInc _identSymbol acc
NCase Case {..} -> mapInc _caseInductive acc
NCtr Constr {..}
| Just ci <- lookupConstructorInfo' md _constrTag ->
mapInc (ci ^. constructorInductive) acc
_ -> acc

mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int
mapInc k = HashMap.insertWith (+) k 1

getTableSymbolsMap :: InfoTable -> HashMap Symbol Int
getTableSymbolsMap tab =
foldr
(HashMap.unionWith (+))
mempty
(map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext))
where
md = emptyModule {_moduleInfoTable = tab}

getModuleSymbolsMap :: Module -> HashMap Symbol Int
getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable
20 changes: 13 additions & 7 deletions src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Juvix.Compiler.Core.Transformation.Optimize.Inlining where

import Data.HashMap.Strict qualified as HashMap
import Data.HashSet qualified as HashSet
import Juvix.Compiler.Core.Data.BinderList qualified as BL
import Juvix.Compiler.Core.Data.IdentDependencyInfo
Expand All @@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of
_ ->
False

convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms md = dmapL go
convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node
convertNode inlineDepth nonRecSyms symOcc md = dmapL go
where
go :: BinderList Binder -> Node -> Node
go bl node = case node of
Expand All @@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go
_
| HashSet.member _identSymbol nonRecSyms
&& length args >= argsNum
&& isInlineableLambda inlineDepth md bl def ->
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isInlineableLambda inlineDepth md bl def
) ->
mkApps def args
_ ->
node
Expand All @@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
Just InlineNever -> node
_
| HashSet.member _identSymbol nonRecSyms
&& isImmediate md def ->
&& argsNum == 0
&& ( HashMap.lookup _identSymbol symOcc == Just 1
|| isImmediate md def
) ->
def
| otherwise ->
node
Expand Down Expand Up @@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go
where
(lamsNum, body) = unfoldLambdas' node

inlining' :: Int -> HashSet Symbol -> Module -> Module
inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md
inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module
inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md

inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module
inlining md = do
d <- asks (^. optInliningDepth)
return $ inlining' d (nonRecursiveIdents md) md
return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap)
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
Expand Down Expand Up @@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md =
nonRecsReachable :: HashSet Symbol
nonRecsReachable = nonRecursiveReachableIdents' tab

symOcc :: HashMap Symbol Int
symOcc = getTableSymbolsMap tab

doConstantFolding :: Module -> Module
doConstantFolding md' = constantFolding' opts nonRecs' tab' md'
where
Expand All @@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md =
| otherwise = nonRecsReachable

doInlining :: Module -> Module
doInlining md' = inlining' _optInliningDepth nonRecs' md'
doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md'
where
nonRecs' =
if
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where

import Juvix.Compiler.Core.Data.IdentDependencyInfo
import Juvix.Compiler.Core.Extra.Utils
import Juvix.Compiler.Core.Options
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding
Expand All @@ -23,8 +24,9 @@ optimize md = do
2
( compose 2 (letFolding' (isInlineableLambda _optInliningDepth))
. lambdaFolding
. inlining' _optInliningDepth nonRecSyms
. inlining' _optInliningDepth nonRecSyms symOcc
)
. letFolding
where
nonRecSyms = nonRecursiveIdents md
symOcc = getModuleSymbolsMap md
1 change: 1 addition & 0 deletions tests/Compilation/positive/test086.juvix
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-- Patterns in definitions
module test086;

import Stdlib.Prelude open;
Expand Down
Loading