From 287e272517738d7d36fcec64d69934124cc251f1 Mon Sep 17 00:00:00 2001 From: Alexey Radul Date: Wed, 4 Jan 2023 11:59:52 -0500 Subject: [PATCH] Test for and fix agreement between compile-time and run-time numerical casting. There were two problems: - We used to emit signed LLVM casting instructions regardless of the Dex type of the castee, which meant that casting (170 :: Word8) to Int32 (or any other upcast) would produce a negative number instead of 170. - GHC between versions 7.8.3 and 9.2.2 (exclusive) rounds toward zero instead of to nearest when casting integral to floating-point types (https://gitlab.haskell.org/ghc/ghc/-/issues/17231). This was disagreeing with LLVM and widely adopted convention, so I patched it at compile time. --- dex.cabal | 6 +- src/lib/ImpToLLVM.hs | 52 +++++---- src/lib/Optimize.hs | 171 +++++++++++++++++------------- src/lib/TopLevel.hs | 4 +- stack-llvm-head.yaml | 1 + stack-macos.yaml | 1 + stack.yaml | 1 + tests/unit/ConstantCastingSpec.hs | 105 ++++++++++++++++++ 8 files changed, 248 insertions(+), 93 deletions(-) create mode 100644 tests/unit/ConstantCastingSpec.hs diff --git a/dex.cabal b/dex.cabal index 0010902f5..708ea234a 100644 --- a/dex.cabal +++ b/dex.cabal @@ -129,6 +129,8 @@ library -- Serialization , aeson , store + -- Floating-point pedanticness (correcting for GHC < 9.2.2) + , floating-bits if flag(live) build-depends: binary , blaze-html @@ -305,6 +307,7 @@ test-suite spec main-is: Spec.hs hs-source-dirs: tests/unit ghc-options: -Wall + -Wno-unticked-promoted-constructors build-depends: base , containers , hspec @@ -312,7 +315,8 @@ test-suite spec , QuickCheck , text , dex - other-modules: OccAnalysisSpec + other-modules: ConstantCastingSpec + , OccAnalysisSpec , OccurrenceSpec , RawNameSpec default-language: Haskell2010 diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index 7dece5a43..501657b68 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -449,28 +449,44 @@ compileInstr instr = case instr of let sdt = case dt of L.VectorType _ sbt -> sbt; _ -> dt case (sxt, sidt) of -- if upcasting to unsigned int, use zext instruction - (L.IntegerType _, Scalar Word64Type) -> x `zeroExtendTo` dt + (L.IntegerType bits, Scalar Word64Type) | bits < 64 -> x `zeroExtendTo` dt (L.IntegerType bits, Scalar Word32Type) | bits < 32 -> x `zeroExtendTo` dt - _ -> case (sxt, sdt) of - (L.IntegerType _, L.IntegerType _) -> x `asIntWidth` dt - (L.FloatingPointType fpt, L.FloatingPointType fpt') -> case compare fpt fpt' of - LT -> emitInstr dt $ L.FPExt x dt [] - EQ -> return x - GT -> emitInstr dt $ L.FPTrunc x dt [] - (L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ L.FPToSI x dt [] - (L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ L.SIToFP x dt [] + _ -> case (getIType ix, sdt) of + -- if upcasting from unsigned int, use zext instruction + (Scalar Word32Type, L.IntegerType bits) | bits > 32 -> x `zeroExtendTo` dt + (Scalar Word8Type, L.IntegerType bits) | bits > 8 -> x `zeroExtendTo` dt + _ -> case (sxt, sdt) of + (L.IntegerType _, L.IntegerType _) -> x `asIntWidth` dt + (L.FloatingPointType fpt, L.FloatingPointType fpt') -> case compare fpt fpt' of + LT -> emitInstr dt $ L.FPExt x dt [] + EQ -> return x + GT -> emitInstr dt $ L.FPTrunc x dt [] + (L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ float_to_int x dt [] + (L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ int_to_float x dt [] #if MIN_VERSION_llvm_hs(15,0,0) - -- Pointee casts become no-ops, because LLVM uses opaque pointers - (L.PointerType a , L.PointerType a') | a == a' -> return x - (L.IntegerType 64, ptrTy@(L.PointerType _)) -> emitInstr ptrTy $ L.IntToPtr x ptrTy [] - (L.PointerType _ , L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] + -- Pointee casts become no-ops, because LLVM uses opaque pointers + (L.PointerType a , L.PointerType a') | a == a' -> return x + (L.IntegerType 64, ptrTy@(L.PointerType _)) -> emitInstr ptrTy $ L.IntToPtr x ptrTy [] + (L.PointerType _ , L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] #else - (L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x - (L.IntegerType 64 , ptrTy@(L.PointerType _ _)) -> - emitInstr ptrTy $ L.IntToPtr x ptrTy [] - (L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] + (L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x + (L.IntegerType 64 , ptrTy@(L.PointerType _ _)) -> + emitInstr ptrTy $ L.IntToPtr x ptrTy [] + (L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 [] #endif - _ -> error $ "Unsupported cast" + _ -> error $ "Unsupported cast" + where signed ty = case ty of + Scalar Int64Type -> True + Scalar Int32Type -> True + Scalar Word8Type -> False + Scalar Word32Type -> False + Scalar Word64Type -> False + Scalar Float64Type -> True + Scalar Float32Type -> True + Vector _ ty' -> signed (Scalar ty') + PtrType _ -> False + int_to_float = if signed (getIType ix) then L.SIToFP else L.UIToFP + float_to_int = if signed idt then L.FPToSI else L.FPToUI IBitcastOp idt ix -> (:[]) <$> do x <- compileExpr ix let dt = scalarTy idt diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 6c98ec073..6a8fc2ef2 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -9,11 +9,14 @@ module Optimize ( earlyOptimize, optimize , peepholeOp, hoistLoopInvariantDest, dceDestBlock + , foldCast ) where import Data.Functor import Data.Word import Data.Bits +import Data.Bits.Floating +import Data.List import Data.List.NonEmpty qualified as NE import Control.Monad import Control.Monad.State.Strict @@ -86,78 +89,9 @@ unrollTrivialLoops b = liftM fst $ liftGenericTraverserM UTLS $ traverseGenericE peepholeOp :: PrimOp (Atom SimpIR o) -> EnvReaderM o (Either (SAtom o) (PrimOp (Atom SimpIR o))) peepholeOp op = case op of - MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case sTy of - -- TODO: Check that the casts relating to floating-point agree with the - -- runtime behavior. The runtime is given by the `ICastOp` case in - -- ImpToLLVM.hs. We should make sure that the Haskell functions here - -- produce bitwise identical results to those instructions, by adjusting - -- either this or that as called for. - -- TODO: Also implement casts that may have unrepresentable results, i.e., - -- casting floating-point numbers to smaller floating-point numbers or to - -- fixed-point. Both of these necessarily have a much smaller dynamic range. - Int32Type -> case l of - Int32Lit _ -> lit l - Int64Lit i -> lit $ Int32Lit $ fromIntegral i - Word8Lit i -> lit $ Int32Lit $ fromIntegral i - Word32Lit i -> lit $ Int32Lit $ fromIntegral i - Word64Lit i -> lit $ Int32Lit $ fromIntegral i - Float32Lit _ -> noop - Float64Lit _ -> noop - PtrLit _ _ -> noop - Int64Type -> case l of - Int32Lit i -> lit $ Int64Lit $ fromIntegral i - Int64Lit _ -> lit l - Word8Lit i -> lit $ Int64Lit $ fromIntegral i - Word32Lit i -> lit $ Int64Lit $ fromIntegral i - Word64Lit i -> lit $ Int64Lit $ fromIntegral i - Float32Lit _ -> noop - Float64Lit _ -> noop - PtrLit _ _ -> noop - Word8Type -> case l of - Int32Lit i -> lit $ Word8Lit $ fromIntegral i - Int64Lit i -> lit $ Word8Lit $ fromIntegral i - Word8Lit _ -> lit l - Word32Lit i -> lit $ Word8Lit $ fromIntegral i - Word64Lit i -> lit $ Word8Lit $ fromIntegral i - Float32Lit _ -> noop - Float64Lit _ -> noop - PtrLit _ _ -> noop - Word32Type -> case l of - Int32Lit i -> lit $ Word32Lit $ fromIntegral i - Int64Lit i -> lit $ Word32Lit $ fromIntegral i - Word8Lit i -> lit $ Word32Lit $ fromIntegral i - Word32Lit _ -> lit l - Word64Lit i -> lit $ Word32Lit $ fromIntegral i - Float32Lit _ -> noop - Float64Lit _ -> noop - PtrLit _ _ -> noop - Word64Type -> case l of - Int32Lit i -> lit $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) - Int64Lit i -> lit $ Word64Lit $ fromIntegral i - Word8Lit i -> lit $ Word64Lit $ fromIntegral i - Word32Lit i -> lit $ Word64Lit $ fromIntegral i - Word64Lit _ -> lit l - Float32Lit _ -> noop - Float64Lit _ -> noop - PtrLit _ _ -> noop - Float32Type -> case l of - Int32Lit i -> lit $ Float32Lit $ fromIntegral (fromIntegral i :: Word32) - Int64Lit i -> lit $ Float32Lit $ fromIntegral i - Word8Lit i -> lit $ Float32Lit $ fromIntegral i - Word32Lit i -> lit $ Float32Lit $ fromIntegral i - Word64Lit i -> lit $ Float32Lit $ fromIntegral i - Float32Lit _ -> lit l - Float64Lit _ -> noop - PtrLit _ _ -> noop - Float64Type -> case l of - Int32Lit i -> lit $ Float64Lit $ fromIntegral (fromIntegral i :: Word32) - Int64Lit i -> lit $ Float64Lit $ fromIntegral i - Word8Lit i -> lit $ Float64Lit $ fromIntegral i - Word32Lit i -> lit $ Float64Lit $ fromIntegral i - Word64Lit i -> lit $ Float64Lit $ fromIntegral i - Float32Lit f -> lit $ Float64Lit $ float2Double f - Float64Lit _ -> lit l - PtrLit _ _ -> noop + MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case foldCast sTy l of + Just l' -> lit l' + Nothing -> noop -- TODO: Support more unary and binary ops. BinOp IAdd l r -> return $ case (l, r) of -- TODO: Shortcut when either side is zero. @@ -203,6 +137,99 @@ peepholeOp op = case op of LessEqual -> (<=) GreaterEqual -> (>=) +foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal +foldCast sTy l = case sTy of + -- TODO: Check that the casts relating to floating-point agree with the + -- runtime behavior. The runtime is given by the `ICastOp` case in + -- ImpToLLVM.hs. We should make sure that the Haskell functions here + -- produce bitwise identical results to those instructions, by adjusting + -- either this or that as called for. + -- TODO: Also implement casts that may have unrepresentable results, i.e., + -- casting floating-point numbers to smaller floating-point numbers or to + -- fixed-point. Both of these necessarily have a much smaller dynamic range. + Int32Type -> case l of + Int32Lit _ -> Just l + Int64Lit i -> Just $ Int32Lit $ fromIntegral i + Word8Lit i -> Just $ Int32Lit $ fromIntegral i + Word32Lit i -> Just $ Int32Lit $ fromIntegral i + Word64Lit i -> Just $ Int32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Int64Type -> case l of + Int32Lit i -> Just $ Int64Lit $ fromIntegral i + Int64Lit _ -> Just l + Word8Lit i -> Just $ Int64Lit $ fromIntegral i + Word32Lit i -> Just $ Int64Lit $ fromIntegral i + Word64Lit i -> Just $ Int64Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word8Type -> case l of + Int32Lit i -> Just $ Word8Lit $ fromIntegral i + Int64Lit i -> Just $ Word8Lit $ fromIntegral i + Word8Lit _ -> Just l + Word32Lit i -> Just $ Word8Lit $ fromIntegral i + Word64Lit i -> Just $ Word8Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word32Type -> case l of + Int32Lit i -> Just $ Word32Lit $ fromIntegral i + Int64Lit i -> Just $ Word32Lit $ fromIntegral i + Word8Lit i -> Just $ Word32Lit $ fromIntegral i + Word32Lit _ -> Just l + Word64Lit i -> Just $ Word32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word64Type -> case l of + Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) + Int64Lit i -> Just $ Word64Lit $ fromIntegral i + Word8Lit i -> Just $ Word64Lit $ fromIntegral i + Word32Lit i -> Just $ Word64Lit $ fromIntegral i + Word64Lit _ -> Just l + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float32Type -> case l of + Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float32Lit $ fromIntegral i + Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Float32Lit _ -> Just l + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float64Type -> case l of + Int32Lit i -> Just $ Float64Lit $ fromIntegral i + Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float64Lit $ fromIntegral i + Word32Lit i -> Just $ Float64Lit $ fromIntegral i + Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Float32Lit f -> Just $ Float64Lit $ float2Double f + Float64Lit _ -> Just l + PtrLit _ _ -> Nothing + where + -- When casting an integer type to a floating-point type of lower precision + -- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds + -- toward zero, instead of rounding to nearest even like everybody else. + -- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231. + -- + -- We patch this by manually checking the two adjacent floats to the + -- candidate answer, and using one of those if the reverse cast is closer + -- to the original input. + -- + -- This rounds to nearest. Empirically (see test suite), it also seems to + -- break ties the same way LLVM does, but I don't have a proof of that. + -- LLVM's tie-breaking may be system-specific? + fixUlp orig candidate = closest [candidate, candidatem1, candidatep1] where + candidatem1 = nextDown candidate + candidatep1 = nextUp candidate + closest items = minimumBy (\ca cb -> err ca `compare` err cb) items + err cand = absdiff orig (round cand) + absdiff a b = if a >= b then a - b else b - a + peepholeExpr :: SExpr o -> EnvReaderM o (Either (SAtom o) (SExpr o)) peepholeExpr expr = case expr of PrimOp op -> fmap PrimOp <$> peepholeOp op diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index 33387c7cb..30ba0abac 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -10,9 +10,9 @@ module TopLevel ( EvalConfig (..), Topper, TopperM, runTopperM, evalSourceBlock, evalSourceBlockRepl, OptLevel (..), evalSourceText, TopStateEx (..), LibPath (..), - evalSourceBlockIO, loadCache, storeCache, clearCache, + evalSourceBlockIO, initTopState, loadCache, storeCache, clearCache, ensureModuleLoaded, importModule, - loadObject, toCFunction) where + loadObject, toCFunction, evalLLVM) where import Data.Foldable (toList) import Data.Functor diff --git a/stack-llvm-head.yaml b/stack-llvm-head.yaml index 8dbb033f9..b89173b4e 100644 --- a/stack-llvm-head.yaml +++ b/stack-llvm-head.yaml @@ -22,4 +22,5 @@ extra-deps: - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 + - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 diff --git a/stack-macos.yaml b/stack-macos.yaml index a038b0da3..14d9b29f5 100644 --- a/stack-macos.yaml +++ b/stack-macos.yaml @@ -20,6 +20,7 @@ extra-deps: - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 + - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 flags: llvm-hs: diff --git a/stack.yaml b/stack.yaml index 3a485f768..bf9cce344 100644 --- a/stack.yaml +++ b/stack.yaml @@ -20,6 +20,7 @@ extra-deps: - store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001 - store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430 - th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869 + - floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442 nix: enable: false diff --git a/tests/unit/ConstantCastingSpec.hs b/tests/unit/ConstantCastingSpec.hs new file mode 100644 index 000000000..688401481 --- /dev/null +++ b/tests/unit/ConstantCastingSpec.hs @@ -0,0 +1,105 @@ +-- Copyright 2022 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# OPTIONS_GHC -Wno-orphans #-} + +module ConstantCastingSpec (spec) where + +import Control.Monad +import Test.Hspec +import Test.Hspec.QuickCheck +import Test.QuickCheck + +import Builder +import Core +import Lower +import Name +import Optimize +import TopLevel +import Types.Core +import Types.Imp +import Types.Primitives +import Util + +castOp :: ScalarBaseType -> LitVal -> PrimOp (SAtom VoidS) +castOp ty x = MiscOp $ CastOp (BaseTy (Scalar ty)) (Con (Lit x)) + +exprToBlock :: EnvReader m => Expr r n -> m n (Block r n) +exprToBlock expr = do + liftBuilder $ buildBlock $ do + v <- emit $ sink expr + return $ Var v + +evalBlock :: (Topper m, Mut n) => SBlock n -> m n (SAtom n) +evalBlock block = lowerFullySequential block >>= evalLLVM + +evalClosedExpr :: SExpr VoidS -> IO LitVal +evalClosedExpr expr = do + let cfg = EvalConfig LLVM [LibBuiltinPath] Nothing Nothing Nothing NoOptimize + env <- initTopState + fst <$> runTopperM cfg env do + block <- exprToBlock $ unsafeCoerceE expr + (Var name) <- evalBlock block + (AtomNameBinding (TopDataBound (RepVal _ (Leaf (ILit ans))))) <- lookupEnv name + return ans + +instance Arbitrary LitVal where + arbitrary = oneof + [ Int64Lit <$> arbitrary + , Int32Lit <$> arbitrary + , Word8Lit <$> arbitrary + , Word32Lit <$> arbitrary + , Word64Lit <$> arbitrary + , Float64Lit <$> arbitrary + , Float32Lit <$> arbitrary + ] + +spec :: Spec +spec = do + describe "constant-folding casts" do + let constant_folding_and_runtime_casts_agree ty = + \(x::LitVal) -> case foldCast ty x of + Nothing -> return () + Just folded -> do + let op = castOp ty x + evaled <- evalClosedExpr $ PrimOp op + -- The failure message will list `evaled` as "expected" and `folded` as "got" + folded `shouldBe` evaled + forM_ [Int64Type, Int32Type, Word8Type, Word32Type, Word64Type, Float64Type, Float32Type] \ty -> + -- TODO: We'd really like to run 10,000 or 1,000,000 examples here, but + -- status quo is that each one runs through the LLVM compile-and-run + -- pipeline separately, and is thus incredibly slow. Taking 50 as a + -- compromise between test suite speed and test coverage. + -- I ran this offline with 10,000 before checking in, and it passed. + modifyMaxSuccess (const 50) $ prop ("agrees with runtime when casting to " ++ show ty) $ constant_folding_and_runtime_casts_agree ty + it "agrees with runtime on rounding mode" $ + -- These values are chosen to tickle the difference between different + -- rounding modes when rounding to float32, and specifically between + -- breaking ties to even, to zero, or to +infinity when rounding to + -- nearest. + -- + -- Specifically, these are 32-bit integers that are large enough not to be + -- exactly representable in 32-bit floating-point, whose low-order bits go + -- through every configuration that rounding behavior is sensitive to + -- (i.e., nearest float is larger, nearest float is smaller, exactly + -- between two floats with the previous bit being even, or odd). + forM_ [ Word32Lit 0x3000000 + , Word32Lit 0x3000001 + , Word32Lit 0x3000002 + , Word32Lit 0x3000003 + , Word32Lit 0x3000004 + , Word32Lit 0x3000005 + , Word32Lit 0x3000006 + , Word32Lit 0x3000007 + , Word32Lit 0x3000008 + , Word32Lit 0x3000009 + , Word32Lit 0x300000A + , Word32Lit 0x300000B + , Word32Lit 0x300000C + , Word32Lit 0x300000D + , Word32Lit 0x300000E + ] \val -> + constant_folding_and_runtime_casts_agree Float32Type val