Skip to content

Commit

Permalink
Merge pull request google-research#1316 from axch/vectorize-user-inde…
Browse files Browse the repository at this point in the history
…x-sets

Vectorize through user-defined index sets
  • Loading branch information
axch authored Jul 7, 2023
2 parents 3fbcc02 + c8c0ae3 commit 4c47ab0
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 41 deletions.
23 changes: 20 additions & 3 deletions src/lib/ImpToLLVM.hs
Original file line number Diff line number Diff line change
Expand Up @@ -902,16 +902,33 @@ withWidthOfFP x template = case typeOf template of
L.FloatingPointType L.FloatFP -> litVal $ Float32Lit $ realToFrac x
_ -> error $ "Unsupported floating point type: " ++ show (typeOf template)

-- If we are accessing a `L.Type` from a Dex array, what memory alignment (in
-- bytes) can we guarantee? This is probably better expressed in Dex types, but
-- we would need to plumb them to do it that way. 1-byte alignment should
-- always be safe, but we can promise higher-performance alignments for some
-- types.
dexAlignment :: L.Type -> Word32
dexAlignment = \case
L.IntegerType bits | bits `mod` 8 == 0 -> bits `div` 8
L.IntegerType _ -> 1
L.PointerType _ _ -> 4
L.FloatingPointType L.FloatFP -> 4
L.FloatingPointType L.DoubleFP -> 8
L.VectorType _ eltTy -> dexAlignment eltTy
_ -> 1

store :: LLVMBuilder m => Operand -> Operand -> m ()
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing 0 []
store ptr x = addInstr $ L.Do $ L.Store False ptr x Nothing alignment [] where
alignment = dexAlignment $ typeOf x

load :: LLVMBuilder m => L.Type -> Operand -> m Operand
load pointeeTy ptr =
#if MIN_VERSION_llvm_hs(15,0,0)
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing 0 []
emitInstr pointeeTy $ L.Load False pointeeTy ptr Nothing alignment []
#else
emitInstr pointeeTy $ L.Load False ptr Nothing 0 []
emitInstr pointeeTy $ L.Load False ptr Nothing alignment []
#endif
where alignment = dexAlignment pointeeTy

ilt :: LLVMBuilder m => Operand -> Operand -> m Operand
ilt x y = emitInstr i1 $ L.ICmp IP.SLT x y []
Expand Down
94 changes: 59 additions & 35 deletions src/lib/Vectorize.hs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,13 @@ import Util (allM, zipWithZ)
-- TODO: Local vector values? We might want to pack short and pure for loops into vectors,
-- to support things like float3 etc.
data Stability
= Uniform -- constant across vectorized dimension
| Varying -- varying across vectorized dimension
| Contiguous -- varying, but contiguous across vectorized dimension
-- Constant across vectorized dimension, represented as a scalar
= Uniform
-- Varying across vectorized dimension, represented as a vector
| Varying
-- Varying, but contiguous across vectorized dimension; represented as a
-- scalar carrying the first value
| Contiguous
| ProdStability [Stability]
deriving (Eq, Show)

Expand Down Expand Up @@ -168,25 +172,27 @@ vectorizeLoopsExpr expr = do
narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr
let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth
case expr of
PrimOp (DAMOp (Seq effs dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal n))) dest body))
| n `mod` loopWidth == 0 -> (do
safe <- vectorSafeEffect effs
if safe
then (do
Distinct <- getDistinct
let vn = n `div` loopWidth
body' <- vectorizeSeq loopWidth body
dest' <- renameM dest
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
return $ PrimOp $ DAMOp seqOp)
else renameM expr)
`catchErr` \errs -> do
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
ctx = mempty { messageCtx = [msg] }
errs' = prependCtxToErrs ctx errs
modify (<> LiftE errs')
recurSeq expr
PrimOp (DAMOp (Seq _ _ _ _ _)) -> recurSeq expr
PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do
sz <- simplifyIxSize =<< renameM ixty
case sz of
Just n | n `mod` loopWidth == 0 -> (do
safe <- vectorSafeEffect effs
if safe
then (do
Distinct <- getDistinct
let vn = n `div` loopWidth
body' <- vectorizeSeq loopWidth ixty body
dest' <- renameM dest
seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body'
return $ PrimOp $ DAMOp seqOp)
else renameM expr)
`catchErr` \errs -> do
let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr
ctx = mempty { messageCtx = [msg] }
errs' = prependCtxToErrs ctx errs
modify (<> LiftE errs')
recurSeq expr
_ -> recurSeq expr
PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do
item' <- renameM item
itemTy <- return $ getType item'
Expand Down Expand Up @@ -218,6 +224,15 @@ vectorizeLoopsExpr expr = do
return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body'
recurSeq _ = error "Impossible"

simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m)
=> IxType SimpIR n -> m n (Maybe Word32)
simplifyIxSize ixty = do
sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size []
cheapReduce sizeMethod >>= \case
Just (IdxRepVal n) -> return $ Just n
_ -> return Nothing
{-# INLINE simplifyIxSize #-}

-- Really we should check this by seeing whether there is an instance for a
-- `Commutative` class, or something like that, but for now just pattern-match
-- to detect scalar addition as the only monoid we recognize as commutative.
Expand Down Expand Up @@ -300,22 +315,27 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
safe _ = return False

vectorizeSeq :: Word32 -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o)
vectorizeSeq loopWidth (UnaryLamExpr (b:>ty) body) = do
(_, ty') <- case ty of
ProdTy [ixTy, ref] -> do
ixTy' <- renameM ixTy
vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i
-> TopVectorizeM i o (LamExpr SimpIR o)
vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do
newLoopTy <- case ty of
ProdTy [_ixType, ref] -> do
ref' <- renameM ref
return (ixTy', ProdTy [IdxRepTy, ref'])
return $ ProdTy [IdxRepTy, ref']
_ -> error "Unexpected seq binder type"
ixty' <- renameM ixty
liftVectorizeM loopWidth $
buildUnaryLamExpr (getNameHint b) ty' \ci -> do
-- XXX: we're assuming `Fin n` here
buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do
-- The per-tile loop iterates on `Fin`
(viOrd, dest) <- fromPair $ Var ci
iOrd <- imul viOrd $ IdxRepVal loopWidth
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal iOrd dest)) $
-- TODO: It would be nice to cancel this UnsafeFromOrdinal with the
-- Ordinal that will be taken later when indexing, but that should
-- probably be a separate pass.
i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd]
extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $
vectorizeBlock body $> UnitVal
vectorizeSeq _ _ = error "expected a unary lambda expression"
vectorizeSeq _ _ _ = error "expected a unary lambda expression"

newtype VectorizeM i o a =
VectorizeM { runVectorizeM ::
Expand Down Expand Up @@ -467,9 +487,13 @@ vectorizePrimOp op = case op of
BinOp opk arg1 arg2 -> do
sx@(VVal vx x) <- vectorizeAtom arg1
sy@(VVal vy y) <- vectorizeAtom arg2
let v = case (vx, vy) of (Uniform, Uniform) -> Uniform; _ -> Varying
x' <- if vx /= v then ensureVarying sx else return x
y' <- if vy /= v then ensureVarying sy else return y
let v = case (opk, vx, vy) of
(_, Uniform, Uniform) -> Uniform
(IAdd, Uniform, Contiguous) -> Contiguous
(IAdd, Contiguous, Uniform) -> Contiguous
_ -> Varying
x' <- if v == Varying then ensureVarying sx else return x
y' <- if v == Varying then ensureVarying sy else return y
VVal v <$> emitOp (BinOp opk x' y')
MiscOp (CastOp tyArg arg) -> do
ty <- vectorizeType tyArg
Expand Down
41 changes: 38 additions & 3 deletions tests/opt-tests.dx
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,13 @@ _ = for i:(Fin 20) j:(Fin 4). ordinal j
"vectorizing int binary op"
-- CHECK-LABEL: vectorizing int binary op
%passes vect
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) + 1
_ = for i:(Fin 256). (n_to_i32 (ordinal i)) * 2
-- CHECK: seq (RawFin 0x10)
-- CHECK: [[i0:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[i1:v#[0-9]+]]:<16xInt32> = viota
-- CHECK: [[i2:v#[0-9]+]]:<16xInt32> = %iadd [[i0]] [[i1]]
-- CHECK: [[ones:v#[0-9]+]]:<16xInt32> = vbroadcast 1
-- CHECK: %iadd [[i2]] [[ones]]
-- CHECK: [[twos:v#[0-9]+]]:<16xInt32> = vbroadcast 2
-- CHECK: %imul [[i2]] [[twos]]

"vectorizing float binary op"
-- CHECK-LABEL: vectorizing float binary op
Expand Down Expand Up @@ -211,3 +211,38 @@ _ = yield_accum (AddMonoid Int32) \result.
-- CHECK: [[mat1:v#[0-9]+]]:<16xInt32> = vbroadcast
-- CHECK: [[prodj:v#[0-9]+]]:<16xInt32> = %imul [[mat1]] [[mat2j]]
-- CHECK: extend [[refj]] [[prodj]]

"vectorizing through the `tile` combinator and its funny index set"
-- CHECK-LABEL: vectorizing through the `tile` combinator and its funny index set

%passes vect
_ = yield_accum (AddMonoid Int32) \result.
tile((Fin 256), 32) \set.
for_ i:set.
ix = inject(i, to=(Fin 256))
result!ix += xs[ix]
-- CHECK: seq (RawFin 0x8)
-- CHECK: seq (RawFin 0x2)
-- CHECK: [[refix:v#[0-9]+]]:(Ref {{v#[0-9]+}} <16xInt32>) = vrefslice
-- CHECK: [[xsix:v#[0-9]+]]:<16xInt32> =
-- CHECK-NEXT: vslice
-- CHECK: extend [[refix]] [[xsix]]

"Non-aligned"
-- CHECK-LABEL: Non-aligned

-- This is a regression test. We are checking that Dex-side
-- vectorization does not end up assuming that arrays are aligned on
-- the size of the vectors, only on the size of the underlying
-- scalars.

non_aligned = for i:(Fin 7). for j:(Fin 257). +0

%passes llvm
_ = yield_accum (AddMonoid Int32) \result.
tile((Fin 257), 32) \set.
for_ i:set.
ix = inject(i, to=(Fin 257))
result!(6@(Fin 7))!ix += non_aligned[6@_][ix]
-- CHECK: load <16 x i32>, <16 x i32>* %"v#{{[0-9]+}}", align 4
-- CHECK: store <16 x i32> %"v#{{[0-9]+}}", <16 x i32>* %"v#{{[0-9]+}}", align 4

0 comments on commit 4c47ab0

Please sign in to comment.