diff --git a/src/Juvix/Compiler/Core/Extra/Utils.hs b/src/Juvix/Compiler/Core/Extra/Utils.hs index c377000459..ef334ab043 100644 --- a/src/Juvix/Compiler/Core/Extra/Utils.hs +++ b/src/Juvix/Compiler/Core/Extra/Utils.hs @@ -640,3 +640,32 @@ isDirectlyRecursive md sym = ufold (\x xs -> or (x : xs)) go (lookupIdentifierNo go = \case NIdt Ident {..} -> _identSymbol == sym _ -> False + +-- Returns a map from symbols to their number of occurrences in the given node. +getSymbolsMap :: Module -> Node -> HashMap Symbol Int +getSymbolsMap md = gather go mempty + where + go :: HashMap Symbol Int -> Node -> HashMap Symbol Int + go acc = \case + NTyp TypeConstr {..} -> mapInc _typeConstrSymbol acc + NIdt Ident {..} -> mapInc _identSymbol acc + NCase Case {..} -> mapInc _caseInductive acc + NCtr Constr {..} + | Just ci <- lookupConstructorInfo' md _constrTag -> + mapInc (ci ^. constructorInductive) acc + _ -> acc + + mapInc :: Symbol -> HashMap Symbol Int -> HashMap Symbol Int + mapInc k = HashMap.insertWith (+) k 1 + +getTableSymbolsMap :: InfoTable -> HashMap Symbol Int +getTableSymbolsMap tab = + foldr + (HashMap.unionWith (+)) + mempty + (map (getSymbolsMap md) (HashMap.elems $ tab ^. identContext)) + where + md = emptyModule {_moduleInfoTable = tab} + +getModuleSymbolsMap :: Module -> HashMap Symbol Int +getModuleSymbolsMap = getTableSymbolsMap . computeCombinedInfoTable diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs index 32ef13fee1..097c627f1a 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Inlining.hs @@ -1,5 +1,6 @@ module Juvix.Compiler.Core.Transformation.Optimize.Inlining where +import Data.HashMap.Strict qualified as HashMap import Data.HashSet qualified as HashSet import Juvix.Compiler.Core.Data.BinderList qualified as BL import Juvix.Compiler.Core.Data.IdentDependencyInfo @@ -16,8 +17,8 @@ isInlineableLambda inlineDepth md bl node = case node of _ -> False -convertNode :: Int -> HashSet Symbol -> Module -> Node -> Node -convertNode inlineDepth nonRecSyms md = dmapL go +convertNode :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Node -> Node +convertNode inlineDepth nonRecSyms symOcc md = dmapL go where go :: BinderList Binder -> Node -> Node go bl node = case node of @@ -39,7 +40,9 @@ convertNode inlineDepth nonRecSyms md = dmapL go _ | HashSet.member _identSymbol nonRecSyms && length args >= argsNum - && isInlineableLambda inlineDepth md bl def -> + && ( HashMap.lookup _identSymbol symOcc == Just 1 + || isInlineableLambda inlineDepth md bl def + ) -> mkApps def args _ -> node @@ -58,7 +61,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go Just InlineNever -> node _ | HashSet.member _identSymbol nonRecSyms - && isImmediate md def -> + && argsNum == 0 + && ( HashMap.lookup _identSymbol symOcc == Just 1 + || isImmediate md def + ) -> def | otherwise -> node @@ -98,10 +104,10 @@ convertNode inlineDepth nonRecSyms md = dmapL go where (lamsNum, body) = unfoldLambdas' node -inlining' :: Int -> HashSet Symbol -> Module -> Module -inlining' inliningDepth nonRecSyms md = mapT (const (convertNode inliningDepth nonRecSyms md)) md +inlining' :: Int -> HashSet Symbol -> HashMap Symbol Int -> Module -> Module +inlining' inliningDepth nonRecSyms symOcc md = mapT (const (convertNode inliningDepth nonRecSyms symOcc md)) md inlining :: (Member (Reader CoreOptions) r) => Module -> Sem r Module inlining md = do d <- asks (^. optInliningDepth) - return $ inlining' d (nonRecursiveIdents md) md + return $ inlining' d (nonRecursiveIdents md) (getModuleSymbolsMap md) md diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs index 5e5aed33d7..1dcff35ea2 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/Main.hs @@ -1,6 +1,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.Main where import Juvix.Compiler.Core.Data.IdentDependencyInfo +import Juvix.Compiler.Core.Extra.Utils (getTableSymbolsMap) import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding @@ -39,6 +40,9 @@ optimize' opts@CoreOptions {..} md = nonRecsReachable :: HashSet Symbol nonRecsReachable = nonRecursiveReachableIdents' tab + symOcc :: HashMap Symbol Int + symOcc = getTableSymbolsMap tab + doConstantFolding :: Module -> Module doConstantFolding md' = constantFolding' opts nonRecs' tab' md' where @@ -48,7 +52,7 @@ optimize' opts@CoreOptions {..} md = | otherwise = nonRecsReachable doInlining :: Module -> Module - doInlining md' = inlining' _optInliningDepth nonRecs' md' + doInlining md' = inlining' _optInliningDepth nonRecs' symOcc md' where nonRecs' = if diff --git a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs index 0e73e1a7db..0fb00e8d28 100644 --- a/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs +++ b/src/Juvix/Compiler/Core/Transformation/Optimize/Phase/PreLifting.hs @@ -1,6 +1,7 @@ module Juvix.Compiler.Core.Transformation.Optimize.Phase.PreLifting where import Juvix.Compiler.Core.Data.IdentDependencyInfo +import Juvix.Compiler.Core.Extra.Utils import Juvix.Compiler.Core.Options import Juvix.Compiler.Core.Transformation.Base import Juvix.Compiler.Core.Transformation.Optimize.CaseFolding @@ -23,8 +24,9 @@ optimize md = do 2 ( compose 2 (letFolding' (isInlineableLambda _optInliningDepth)) . lambdaFolding - . inlining' _optInliningDepth nonRecSyms + . inlining' _optInliningDepth nonRecSyms symOcc ) . letFolding where nonRecSyms = nonRecursiveIdents md + symOcc = getModuleSymbolsMap md diff --git a/tests/Compilation/positive/test086.juvix b/tests/Compilation/positive/test086.juvix index 63094d7e15..ec261fc5cd 100644 --- a/tests/Compilation/positive/test086.juvix +++ b/tests/Compilation/positive/test086.juvix @@ -1,3 +1,4 @@ +-- Patterns in definitions module test086; import Stdlib.Prelude open;