Skip to content

Commit

Permalink
Implement constant folding for casts to floating-point types.
Browse files Browse the repository at this point in the history
Caveat: Did not check that they agree bit-for-bit with the runtime
behavior, but any discrepancies should be small.

Benefit:

- This ends up constant-folding the cast from the literal zero to the
  relevant floating-point type in the `Add` instance for floats.

- Therefore, that zero ends up being an `Atom` and not an expression
  when we build the `BaseMonoid` for adding floats, and so is inlined
  therein.

- Therefore, that zero is, in effect, rematerialized when doing AD.

- Therefore, an n by m by k tensor of those zeros is not stored by
  linearize of matmul, so the jvp-matmul benchmark runs some 30%
  faster.  Of course, it is still storing two other tensors it
  shouldn't store, but that's for another time.
  • Loading branch information
axch committed Jan 4, 2023
1 parent 92f313c commit ce13b1c
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
29 changes: 27 additions & 2 deletions src/lib/Optimize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Data.Bits
import Data.List.NonEmpty qualified as NE
import Control.Monad
import Control.Monad.State.Strict
import GHC.Float

import Types.Core
import Types.Primitives
Expand Down Expand Up @@ -86,7 +87,14 @@ unrollTrivialLoops b = liftM fst $ liftGenericTraverserM UTLS $ traverseGenericE
peepholeOp :: PrimOp (Atom SimpIR o) -> EnvReaderM o (Either (SAtom o) (PrimOp (Atom SimpIR o)))
peepholeOp op = case op of
MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case sTy of
-- TODO: Support all casts.
-- TODO: Check that the casts relating to floating-point agree with the
-- runtime behavior. The runtime is given by the `ICastOp` case in
-- ImpToLLVM.hs. We should make sure that the Haskell functions here
-- produce bitwise identical results to those instructions, by adjusting
-- either this or that as called for.
-- TODO: Also implement casts that may have unrepresentable results, i.e.,
-- casting floating-point numbers to smaller floating-point numbers or to
-- fixed-point. Both of these necessarily have a much smaller dynamic range.
Int32Type -> case l of
Int32Lit _ -> lit l
Int64Lit i -> lit $ Int32Lit $ fromIntegral i
Expand Down Expand Up @@ -132,7 +140,24 @@ peepholeOp op = case op of
Float32Lit _ -> noop
Float64Lit _ -> noop
PtrLit _ _ -> noop
_ -> noop
Float32Type -> case l of
Int32Lit i -> lit $ Float32Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> lit $ Float32Lit $ fromIntegral i
Word8Lit i -> lit $ Float32Lit $ fromIntegral i
Word32Lit i -> lit $ Float32Lit $ fromIntegral i
Word64Lit i -> lit $ Float32Lit $ fromIntegral i
Float32Lit _ -> lit l
Float64Lit _ -> noop
PtrLit _ _ -> noop
Float64Type -> case l of
Int32Lit i -> lit $ Float64Lit $ fromIntegral (fromIntegral i :: Word32)
Int64Lit i -> lit $ Float64Lit $ fromIntegral i
Word8Lit i -> lit $ Float64Lit $ fromIntegral i
Word32Lit i -> lit $ Float64Lit $ fromIntegral i
Word64Lit i -> lit $ Float64Lit $ fromIntegral i
Float32Lit f -> lit $ Float64Lit $ float2Double f
Float64Lit _ -> lit l
PtrLit _ _ -> noop
-- TODO: Support more unary and binary ops.
BinOp IAdd l r -> return $ case (l, r) of
-- TODO: Shortcut when either side is zero.
Expand Down
6 changes: 2 additions & 4 deletions tests/opt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,18 @@ _ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1

%passes vect
_ = for i:(Fin 256). (n_to_f32 (ordinal i)) + 1
-- CHECK: [[one:v#[0-9]+]]:Float32 = %cast Float32 0x1
-- CHECK: seq {{.*}} {{v#[0-9]+}}:(Fin 16 &
-- CHECK: [[i0:v#[0-9]+]]:<16xFloat32> = vbroadcast
-- CHECK: [[i1:v#[0-9]+]]:<16xFloat32> = viota
-- CHECK: [[i2:v#[0-9]+]]:<16xFloat32> = %fadd [[i0]] [[i1]]
-- CHECK: [[ones:v#[0-9]+]]:<16xFloat32> = vbroadcast [[one]]
-- CHECK: [[ones:v#[0-9]+]]:<16xFloat32> = vbroadcast 1.
-- CHECK: %fadd [[i2]] [[ones]]

%passes vect
_ = for i:(Fin 256). (n_to_f64 (ordinal i)) + 1
-- CHECK: [[one:v#[0-9]+]]:Float64 = %cast Float64 0x1
-- CHECK: seq {{.*}} {{v#[0-9]+}}:(Fin 32 &
-- CHECK: [[i0:v#[0-9]+]]:<8xFloat64> = vbroadcast
-- CHECK: [[i1:v#[0-9]+]]:<8xFloat64> = viota
-- CHECK: [[i2:v#[0-9]+]]:<8xFloat64> = %fadd [[i0]] [[i1]]
-- CHECK: [[ones:v#[0-9]+]]:<8xFloat64> = vbroadcast [[one]]
-- CHECK: [[ones:v#[0-9]+]]:<8xFloat64> = vbroadcast 1.
-- CHECK: %fadd [[i2]] [[ones]]

0 comments on commit ce13b1c

Please sign in to comment.