From 245eebb05889219e8b62fbf91851e7a5e38fd46a Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Fri, 29 Nov 2024 19:43:13 +0100 Subject: [PATCH 1/3] Inline non-recursive functions with a single call site --- src/Juvix/Compiler/Core/Extra/Utils.hs | 29 +++++++++++++++++++ .../Core/Transformation/Optimize/Inlining.hs | 19 +++++++----- .../Transformation/Optimize/Phase/Main.hs | 6 +++- 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index c377000459..ef334ab043 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo go = \case NIdt Ident {..} -> _identSymbol == sym _ -> False + +-- Returns a map from symbols to their number of occurrences in the given node. +getSymbolsMap :: Module -> Node -> HashMap Symbol Int +getSymbolsMap md = gather go mempty + where + go :: HashMap Symbol Int -> Node -> HashMap Symbol Int + go acc = \case + NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc + NIdt Ident {..} -> mapInc _identSymbol acc + NCase Case {..} -> mapInc _caseInductive acc + NCtr Constr {..} + | Just ci <- lookupConstructorInfo' md _constrTag -> + mapInc (ci ^. constructorInductive) acc + _ -> acc + + mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int + mapInc k = HashMap.insertWith (+) k 1 + +getTableSymbolsMap :: InfoTable -> HashMap Symbol Int +getTableSymbolsMap tab = + foldr + (HashMap.unionWith (+)) + mempty + (map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext)) + where + md = emptyModule {_moduleInfoTable = tab} + +getModuleSymbolsMap :: Module -> HashMap Symbol Int +getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 32ef13fee1..37dde9ef84 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -1,5 +1,6 @@ module Juvix.Compiler.Core.Transformation.Optimize.Inlining where +import Data.HashMap.Strict qualified as HashMap import Data.HashSet qualified as HashSet import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Data.IdentDependencyInfo @@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of _ -> False -convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node -convertNode inlineDepth nonRecSyms md = dmapL go +convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node +convertNode inlineDepth nonRecSyms symOcc md = dmapL go where go :: BinderList Binder -> Node -> Node go bl node = case node of @@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go _ | HashSet.member _identSymbol nonRecSyms && length args >= argsNum - && isInlineableLambda inlineDepth md bl def -> + && ( HashMap.lookup _identSymbol symOcc == Just 1 + || isInlineableLambda inlineDepth md bl def + ) -> mkApps def args _ -> node @@ -58,7 +61,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go Just InlineNever -> node _ | HashSet.member _identSymbol nonRecSyms - && isImmediate md def -> + && ( HashMap.lookup _identSymbol symOcc == Just 1 + || isImmediate md def + ) -> def | otherwise -> node @@ -98,10 +103,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go where (lamsNum, body) = unfoldLambdas' node -inlining' :: Int -> HashSet Symbol -> Module -> Module -inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md +inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module +inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module inlining md = do d <- asks (^. optInliningDepth) - return $ inlining' d (nonRecursiveIdents md) md + return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 5e5aed33d7..1dcff35ea2 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -1,6 +1,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where import Juvix.Compiler.Core.Data.IdentDependencyInfo +import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap) import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding @@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md = nonRecsReachable :: HashSet Symbol nonRecsReachable = nonRecursiveReachableIdents' tab + symOcc :: HashMap Symbol Int + symOcc = getTableSymbolsMap tab + doConstantFolding :: Module -> Module doConstantFolding md' = constantFolding' opts nonRecs' tab' md' where @@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md = | otherwise = nonRecsReachable doInlining :: Module -> Module - doInlining md' = inlining' _optInliningDepth nonRecs' md' + doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md' where nonRecs' = if From fc642a0a8f1513970a463b1197f2e0753b84675a Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Sun, 1 Dec 2024 11:18:43 +0100 Subject: [PATCH 2/3] fix compilation after rebase --- .../Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs index 0e73e1a7db..0fb00e8d28 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -1,6 +1,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where import Juvix.Compiler.Core.Data.IdentDependencyInfo +import Juvix.Compiler.Core.Extra.Utils import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding @@ -23,8 +24,9 @@ optimize md = do 2 ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) . lambdaFolding - . inlining' _optInliningDepth nonRecSyms + . inlining' _optInliningDepth nonRecSyms symOcc ) . letFolding where nonRecSyms = nonRecursiveIdents md + symOcc = getModuleSymbolsMap md From fedb895ce8a77961413f2e6bd1bbcbeb92f1191f Mon Sep 17 00:00:00 2001 From: Lukasz Czajka Date: Sun, 1 Dec 2024 12:10:04 +0100 Subject: [PATCH 3/3] inline only fully applied --- src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs | 1 + tests/Compilation/positive/test086.juvix | 1 + 2 files changed, 2 insertions(+) diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 37dde9ef84..097c627f1a 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -61,6 +61,7 @@ convertNode inlineDepth nonRecSyms symOcc md = dmapL go Just InlineNever -> node _ | HashSet.member _identSymbol nonRecSyms + && argsNum == 0 && ( HashMap.lookup _identSymbol symOcc == Just 1 || isImmediate md def ) -> diff --git a/tests/Compilation/positive/test086.juvix b/tests/Compilation/positive/test086.juvix index 63094d7e15..ec261fc5cd 100644 --- a/tests/Compilation/positive/test086.juvix +++ b/tests/Compilation/positive/test086.juvix @@ -1,3 +1,4 @@ +-- Patterns in definitions module test086; import Stdlib.Prelude open;