Skip to content

Commit

Permalink
Test for and fix agreement between compile-time and run-time numerica…
Browse files Browse the repository at this point in the history
…l casting.

There were two problems:

- We used to emit signed LLVM casting instructions regardless of the
  Dex type of the castee, which meant that casting (170 :: Word8) to
  Int32 (or any other upcast) would produce a negative number instead
  of 170.

- GHC between versions 7.8.3 and 9.2.2 (exclusive) rounds toward zero
  instead of to nearest when casting integral to floating-point types
  (https://gitlab.haskell.org/ghc/ghc/-/issues/17231).  This was
  disagreeing with LLVM and widely adopted convention, so I patched it
  at compile time.
  • Loading branch information
axch committed Jan 4, 2023
1 parent ce13b1c commit 287e272
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 93 deletions.
6 changes: 5 additions & 1 deletion dex.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ library
-- Serialization
, aeson
, store
-- Floating-point pedanticness (correcting for GHC < 9.2.2)
, floating-bits
if flag(live)
build-depends: binary
, blaze-html
Expand Down Expand Up @@ -305,14 +307,16 @@ test-suite spec
main-is: Spec.hs
hs-source-dirs: tests/unit
ghc-options: -Wall
-Wno-unticked-promoted-constructors
build-depends: base
, containers
, hspec
, mtl
, QuickCheck
, text
, dex
other-modules: OccAnalysisSpec
other-modules: ConstantCastingSpec
, OccAnalysisSpec
, OccurrenceSpec
, RawNameSpec
default-language: Haskell2010
Expand Down
52 changes: 34 additions & 18 deletions src/lib/ImpToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -449,28 +449,44 @@ compileInstr instr = case instr of
let sdt = case dt of L.VectorType _ sbt -> sbt; _ -> dt
case (sxt, sidt) of
-- if upcasting to unsigned int, use zext instruction
(L.IntegerType _, Scalar Word64Type) -> x `zeroExtendTo` dt
(L.IntegerType bits, Scalar Word64Type) | bits < 64 -> x `zeroExtendTo` dt
(L.IntegerType bits, Scalar Word32Type) | bits < 32 -> x `zeroExtendTo` dt
_ -> case (sxt, sdt) of
(L.IntegerType _, L.IntegerType _) -> x `asIntWidth` dt
(L.FloatingPointType fpt, L.FloatingPointType fpt') -> case compare fpt fpt' of
LT -> emitInstr dt $ L.FPExt x dt []
EQ -> return x
GT -> emitInstr dt $ L.FPTrunc x dt []
(L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ L.FPToSI x dt []
(L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ L.SIToFP x dt []
_ -> case (getIType ix, sdt) of
-- if upcasting from unsigned int, use zext instruction
(Scalar Word32Type, L.IntegerType bits) | bits > 32 -> x `zeroExtendTo` dt
(Scalar Word8Type, L.IntegerType bits) | bits > 8 -> x `zeroExtendTo` dt
_ -> case (sxt, sdt) of
(L.IntegerType _, L.IntegerType _) -> x `asIntWidth` dt
(L.FloatingPointType fpt, L.FloatingPointType fpt') -> case compare fpt fpt' of
LT -> emitInstr dt $ L.FPExt x dt []
EQ -> return x
GT -> emitInstr dt $ L.FPTrunc x dt []
(L.FloatingPointType _, L.IntegerType _) -> emitInstr dt $ float_to_int x dt []
(L.IntegerType _, L.FloatingPointType _) -> emitInstr dt $ int_to_float x dt []
#if MIN_VERSION_llvm_hs(15,0,0)
-- Pointee casts become no-ops, because LLVM uses opaque pointers
(L.PointerType a , L.PointerType a') | a == a' -> return x
(L.IntegerType 64, ptrTy@(L.PointerType _)) -> emitInstr ptrTy $ L.IntToPtr x ptrTy []
(L.PointerType _ , L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 []
-- Pointee casts become no-ops, because LLVM uses opaque pointers
(L.PointerType a , L.PointerType a') | a == a' -> return x
(L.IntegerType 64, ptrTy@(L.PointerType _)) -> emitInstr ptrTy $ L.IntToPtr x ptrTy []
(L.PointerType _ , L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 []
#else
(L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x
(L.IntegerType 64 , ptrTy@(L.PointerType _ _)) ->
emitInstr ptrTy $ L.IntToPtr x ptrTy []
(L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 []
(L.PointerType _ _, L.PointerType eltTy _) -> castLPtr eltTy x
(L.IntegerType 64 , ptrTy@(L.PointerType _ _)) ->
emitInstr ptrTy $ L.IntToPtr x ptrTy []
(L.PointerType _ _, L.IntegerType 64) -> emitInstr i64 $ L.PtrToInt x i64 []
#endif
_ -> error $ "Unsupported cast"
_ -> error $ "Unsupported cast"
where signed ty = case ty of
Scalar Int64Type -> True
Scalar Int32Type -> True
Scalar Word8Type -> False
Scalar Word32Type -> False
Scalar Word64Type -> False
Scalar Float64Type -> True
Scalar Float32Type -> True
Vector _ ty' -> signed (Scalar ty')
PtrType _ -> False
int_to_float = if signed (getIType ix) then L.SIToFP else L.UIToFP
float_to_int = if signed idt then L.FPToSI else L.FPToUI
IBitcastOp idt ix -> (:[]) <$> do
x <- compileExpr ix
let dt = scalarTy idt
Expand Down
171 changes: 99 additions & 72 deletions src/lib/Optimize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,14 @@
module Optimize
( earlyOptimize, optimize
, peepholeOp, hoistLoopInvariantDest, dceDestBlock
, foldCast
) where

import Data.Functor
import Data.Word
import Data.Bits
import Data.Bits.Floating
import Data.List
import Data.List.NonEmpty qualified as NE
import Control.Monad
import Control.Monad.State.Strict
Expand Down Expand Up @@ -86,78 +89,9 @@ unrollTrivialLoops b = liftM fst $ liftGenericTraverserM UTLS $ traverseGenericE

peepholeOp :: PrimOp (Atom SimpIR o) -> EnvReaderM o (Either (SAtom o) (PrimOp (Atom SimpIR o)))
peepholeOp op = case op of
MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case sTy of
-- TODO: Check that the casts relating to floating-point agree with the
-- runtime behavior. The runtime is given by the `ICastOp` case in
-- ImpToLLVM.hs. We should make sure that the Haskell functions here
-- produce bitwise identical results to those instructions, by adjusting
-- either this or that as called for.
-- TODO: Also implement casts that may have unrepresentable results, i.e.,
-- casting floating-point numbers to smaller floating-point numbers or to
-- fixed-point. Both of these necessarily have a much smaller dynamic range.
Int32Type -> case l of
Int32Lit _ -> lit l
Int64Lit i -> lit $ Int32Lit $ fromIntegral i
Word8Lit i -> lit $ Int32Lit $ fromIntegral i
Word32Lit i -> lit $ Int32Lit $ fromIntegral i
Word64Lit i -> lit $ Int32Lit $ fromIntegral i
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
Int64Type -> case l of
Int32Lit i -> lit $ Int64Lit $ fromIntegral i
Int64Lit _ -> lit l
Word8Lit i -> lit $ Int64Lit $ fromIntegral i
Word32Lit i -> lit $ Int64Lit $ fromIntegral i
Word64Lit i -> lit $ Int64Lit $ fromIntegral i
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
Word8Type -> case l of
Int32Lit i -> lit $ Word8Lit $ fromIntegral i
Int64Lit i -> lit $ Word8Lit $ fromIntegral i
Word8Lit _ -> lit l
Word32Lit i -> lit $ Word8Lit $ fromIntegral i
Word64Lit i -> lit $ Word8Lit $ fromIntegral i
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
Word32Type -> case l of
Int32Lit i -> lit $ Word32Lit $ fromIntegral i
Int64Lit i -> lit $ Word32Lit $ fromIntegral i
Word8Lit i -> lit $ Word32Lit $ fromIntegral i
Word32Lit _ -> lit l
Word64Lit i -> lit $ Word32Lit $ fromIntegral i
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
Word64Type -> case l of
Int32Lit i -> lit $ Word64Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> lit $ Word64Lit $ fromIntegral i
Word8Lit i -> lit $ Word64Lit $ fromIntegral i
Word32Lit i -> lit $ Word64Lit $ fromIntegral i
Word64Lit _ -> lit l
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
Float32Type -> case l of
Int32Lit i -> lit $ Float32Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> lit $ Float32Lit $ fromIntegral i
Word8Lit i -> lit $ Float32Lit $ fromIntegral i
Word32Lit i -> lit $ Float32Lit $ fromIntegral i
Word64Lit i -> lit $ Float32Lit $ fromIntegral i
Float32Lit _ -> lit l
Float64Lit _ -> noop
PtrLit _ _ -> noop
Float64Type -> case l of
Int32Lit i -> lit $ Float64Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> lit $ Float64Lit $ fromIntegral i
Word8Lit i -> lit $ Float64Lit $ fromIntegral i
Word32Lit i -> lit $ Float64Lit $ fromIntegral i
Word64Lit i -> lit $ Float64Lit $ fromIntegral i
Float32Lit f -> lit $ Float64Lit $ float2Double f
Float64Lit _ -> lit l
PtrLit _ _ -> noop
MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case foldCast sTy l of
Just l' -> lit l'
Nothing -> noop
-- TODO: Support more unary and binary ops.
BinOp IAdd l r -> return $ case (l, r) of
-- TODO: Shortcut when either side is zero.
Expand Down Expand Up @@ -203,6 +137,99 @@ peepholeOp op = case op of
LessEqual -> (<=)
GreaterEqual -> (>=)

foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal
foldCast sTy l = case sTy of
-- TODO: Check that the casts relating to floating-point agree with the
-- runtime behavior. The runtime is given by the `ICastOp` case in
-- ImpToLLVM.hs. We should make sure that the Haskell functions here
-- produce bitwise identical results to those instructions, by adjusting
-- either this or that as called for.
-- TODO: Also implement casts that may have unrepresentable results, i.e.,
-- casting floating-point numbers to smaller floating-point numbers or to
-- fixed-point. Both of these necessarily have a much smaller dynamic range.
Int32Type -> case l of
Int32Lit _ -> Just l
Int64Lit i -> Just $ Int32Lit $ fromIntegral i
Word8Lit i -> Just $ Int32Lit $ fromIntegral i
Word32Lit i -> Just $ Int32Lit $ fromIntegral i
Word64Lit i -> Just $ Int32Lit $ fromIntegral i
Float32Lit _ -> Nothing
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Int64Type -> case l of
Int32Lit i -> Just $ Int64Lit $ fromIntegral i
Int64Lit _ -> Just l
Word8Lit i -> Just $ Int64Lit $ fromIntegral i
Word32Lit i -> Just $ Int64Lit $ fromIntegral i
Word64Lit i -> Just $ Int64Lit $ fromIntegral i
Float32Lit _ -> Nothing
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Word8Type -> case l of
Int32Lit i -> Just $ Word8Lit $ fromIntegral i
Int64Lit i -> Just $ Word8Lit $ fromIntegral i
Word8Lit _ -> Just l
Word32Lit i -> Just $ Word8Lit $ fromIntegral i
Word64Lit i -> Just $ Word8Lit $ fromIntegral i
Float32Lit _ -> Nothing
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Word32Type -> case l of
Int32Lit i -> Just $ Word32Lit $ fromIntegral i
Int64Lit i -> Just $ Word32Lit $ fromIntegral i
Word8Lit i -> Just $ Word32Lit $ fromIntegral i
Word32Lit _ -> Just l
Word64Lit i -> Just $ Word32Lit $ fromIntegral i
Float32Lit _ -> Nothing
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Word64Type -> case l of
Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> Just $ Word64Lit $ fromIntegral i
Word8Lit i -> Just $ Word64Lit $ fromIntegral i
Word32Lit i -> Just $ Word64Lit $ fromIntegral i
Word64Lit _ -> Just l
Float32Lit _ -> Nothing
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Float32Type -> case l of
Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i
Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i
Word8Lit i -> Just $ Float32Lit $ fromIntegral i
Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i
Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i
Float32Lit _ -> Just l
Float64Lit _ -> Nothing
PtrLit _ _ -> Nothing
Float64Type -> case l of
Int32Lit i -> Just $ Float64Lit $ fromIntegral i
Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i
Word8Lit i -> Just $ Float64Lit $ fromIntegral i
Word32Lit i -> Just $ Float64Lit $ fromIntegral i
Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i
Float32Lit f -> Just $ Float64Lit $ float2Double f
Float64Lit _ -> Just l
PtrLit _ _ -> Nothing
where
-- When casting an integer type to a floating-point type of lower precision
-- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds
-- toward zero, instead of rounding to nearest even like everybody else.
-- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231.
--
-- We patch this by manually checking the two adjacent floats to the
-- candidate answer, and using one of those if the reverse cast is closer
-- to the original input.
--
-- This rounds to nearest. Empirically (see test suite), it also seems to
-- break ties the same way LLVM does, but I don't have a proof of that.
-- LLVM's tie-breaking may be system-specific?
fixUlp orig candidate = closest [candidate, candidatem1, candidatep1] where
candidatem1 = nextDown candidate
candidatep1 = nextUp candidate
closest items = minimumBy (\ca cb -> err ca `compare` err cb) items
err cand = absdiff orig (round cand)
absdiff a b = if a >= b then a - b else b - a

peepholeExpr :: SExpr o -> EnvReaderM o (Either (SAtom o) (SExpr o))
peepholeExpr expr = case expr of
PrimOp op -> fmap PrimOp <$> peepholeOp op
Expand Down
4 changes: 2 additions & 2 deletions src/lib/TopLevel.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ module TopLevel (
EvalConfig (..), Topper, TopperM, runTopperM,
evalSourceBlock, evalSourceBlockRepl, OptLevel (..),
evalSourceText, TopStateEx (..), LibPath (..),
evalSourceBlockIO, loadCache, storeCache, clearCache,
evalSourceBlockIO, initTopState, loadCache, storeCache, clearCache,
ensureModuleLoaded, importModule,
loadObject, toCFunction) where
loadObject, toCFunction, evalLLVM) where

import Data.Foldable (toList)
import Data.Functor
Expand Down
1 change: 1 addition & 0 deletions stack-llvm-head.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ extra-deps:
- store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
- store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
- th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
- floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442

1 change: 1 addition & 0 deletions stack-macos.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extra-deps:
- store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
- store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
- th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
- floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442

flags:
llvm-hs:
Expand Down
1 change: 1 addition & 0 deletions stack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extra-deps:
- store-0.7.8@sha256:0b604101fd5053b6d7d56a4ef4c2addf97f4e08fe8cd06b87ef86f958afef3ae,8001
- store-core-0.4.4.4@sha256:a19098ca8419ea4f6f387790e942a7a5d0acf62fe1beff7662f098cfb611334c,1430
- th-utilities-0.2.4.1@sha256:b37d23c8bdabd678aee5a36dd4373049d4179e9a85f34eb437e9cd3f04f435ca,1869
- floating-bits-0.3.0.0@sha256:742bcfcbc21b8daffc995990ee2399ab49550e8f4dd0dff1732d18f57a064c83,2442

nix:
enable: false
Expand Down
Loading

0 comments on commit 287e272

Please sign in to comment.