-
Notifications
You must be signed in to change notification settings - Fork 58
Relax Architecture Overview
Authors(alphabetical): @altanh, @electriclilies, @hypercubestart, @jroesch, @junrushao1994, @mbs-octoml, @mikepapadim, @tkonolige, @tqchen, @YuchenJin, @ZihengJiang
This doc is meant to serve as a high-level design overview of key elements in the relax design. Relax is the codename for us to evolve the design of high-level IR. The main goal of the doc is to provide a concise but clear summary of key motivations and design points without getting into low-level architecture dependent details and advanced topics.
It is intended to be used as a first doc to understand relax from an architecture level. Please also refer to other design docs for details on specific aspects.
Relax have three key goals motivated by our past lessons in ML acceleration.
Specifically, we need to support dynamic shape workloads that are currently out of reach with today’s intermediate representations and optimize performance for both new and old.
(Partially) dynamic shape models are ubiquitous in today's machine learning workloads. The dynamism could come from the variable input size, or simply missing information in the program.
Most of the machine learning engineers are familiar with "computational graph" and its optimizations under the assumption that every operation in the graph has no side effect. While such optimization is useful for a majority of the programs. As we start to work with random numbers, states and weight updates, we also need to be able to represent programs that contain more complex semantics, such as control, inplace updates and side effects.
Additionally, some advanced optimizations could require us to work with mutations, such as inplace updates for scatter-gather operations.
We need to find a way to enable most people to write computational graph optimizations, while still being able to represent these advanced semantics.
Right now TVM contains a clear boundary between abstractions. Relay to TIR lowering is done in a single shot translation fashion. However, we start to see a strong need of performing optimizations across the layers. For example, ideally the automation decisions in TensorIR should inform fusion and layout decisions at the high-level. This need comes up in our applications of TensorCore auto-scheduling as well as NPU related workloads.
import tvm.script
from tvm.script import tir as T, relax as R
@tvm.script.ir_module
class MyIRModule:
@T.prim_func
def tir_exp_func(x: T.handle, y: T.handle): ## <= D2
X = T.match_buffer(x, (n,), "float32")
Y = T.match_buffer(y, (n,), "float32")
with T.grid(n) as i:
Y[i] = T.exp(X[i])
@R.function
def relax_func(x: R.Tensor[(n, k), "f32"], w: R.Tensor[_, "f32"]):
# n, k above are implicitly defined by the signature
# so we will be able to refer to n, k in the later part of the program
with R.dataflow(): ### <= D0
lv0 = R.match_shape(w, (k, m)) ## <= D1
lv1: R.Tensor[(n, m), "f32"] = R.dot(x, lv0)
lv2: R.Tensor[(n * m,), "f32"] = R.flatten(lv1) ## <= D1
lv3: R.Shape = (n * m,) ## <= D1
gv0: R.Tensor[lv2, "f32"] = R.call_tir(lv2, tir_exp_func, [lv3]) ## <= D2
R.outputs(gv0)
R.call_packed("custom_inplace_update", gv0) ## <= D0, D2
return gv0
We can use the above code-snippet to demonstrate the key design points of relax. Note that the script syntax is still evolving and can be subject to change.
Majority of the relax_func
code are encapsulated in a with R.dataflow()
construct. All the operations under the dataflow block is side-effect-free and does not contain advanced control flows(such as if-then-else) or nested scopes.
A dataflow block can effectively be viewed as a computational graph embedded in the program. Note that most of the bindings variables(lv0, lv1, lv2, lv3) within the dataflow block is "local", which means they are only visible within the block. These variables can be viewed as "internal nodes" of the computational graph. We can mark a variable as output(gv0), in which case the variable will be visible in later part of the program. These output variables can be viewed as output nodes in the computational graph.
Note that R.call_packed("custom_inplace_update", gv0)
is outside of the dataflow block. Everything that is outside of a dataflow block can have side effects. So we cannot perform optimizations such as reordering these bindings according to topological order unless we do more careful analysis We expect most of the optimizations will happen at the dataflow block level. These optimizations can be done by ML engineers who are familiar with the computational graph concept. The ability to isolate and represent effectful components also provides opportunities for more advanced optimizations for the places that need them.
Shape deduction is essential to dynamic model workloads. Under a dynamic shape setting, we usually need to compute the shapes of the intermediate tensors before running the computation. Additionally, we also need to handle cases where the shape itself is data-dependent (e.g. unique). Finally, most dynamic shape workloads still contain a lot of (partially) static shapes, ideally we want to take benefit of these static shape information for optimization.
from tvm.script import relax as R
@R.function
def shape_example(x: R.Tensor[(n, 2, 2), "f32"]):
with R.dataflow():
# symbolic and static shape deduction
lv0: R.Tensor[(n, 4), "f32"] = R.reshape(x, (n, 4))
lv1: R.Tensor[(n * 4,), "f32"] = R.flatten(lv0)
lv2: R.Shape = (n * 4,)
# external opaque shape function
lv3: R.Shape = R.call_packed("myshape_func", lv2)
lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1])
# data dependent case
lv5: R.Tensor[_, "f32"] = R.unique(lv4)
# re-match shape
lv6: R.Tensor[(m,), "f32"] = R.match_shape(lv5, (m,))
gv0: R.Tensor[(m,), "f32"] = R.exp(lv6)
R.outputs(gv0)
return gv0
The above program covers typical scenarios in shape deduction(marked in comments). Importantly, shape is now part of the computation along with Tensor values. This reflects the fact that computation of shapes can happen in runtime.
While the text format type annotation lv0: R.Tensor[(n, 4), "f32"]
shows the shape of each value. This is only a syntactic sugar, from the IR's point of view the shape field (n, 4)
is not part of the lv0.checked_type
. The type of lv0 is DynTensor(rank=2, dtype="f32")
, the shape is a special value field that is attached to each Expr. We made this explicit choice to simplify the type inference so that we do not need to get into the full dependent type land.
There a two key constructs related to symbolic shape computation:
D1a: match_shape
value = match_shape(lhs, pattern)
The match shape construct takes a lhs value and a pattern(of symbolic integer expressions). It have two overloaded semantics:
- When lhs is a Tensor, it will match
lhs.shape
to the pattern, populate the corresponding symbolic integer variable if it occurs in the pattern for the first time, and then return a new Tensor that is the same as lhs but the shape field is updated to pattern. - lhs can also be a Shape that directly matches the pattern. This is useful when we want to isolate out shape functions that do not correspond to any Tensor value.
Examples
from tvm.script import relax as R
@R.function
def shape_example(x: R.Tensor[_, "f32"], y: R.Tensor[_, "f32"]):
with R.dataflow():
# the match shape defines n, m because it appears for the first time
lv0: R.Tensor[(n, m)] = R.match_shape(x, (n, m))
# the second occurance of n, m will translate into an assertion
# that y's shape equals (n, m)
lv1: R.Tensor[(n, m)] = R.match_shape(y, (n, m))
# we can also call match_shape on shape expressions
lv2: Shape = R.match_shape(R.shape_of(y), (n, m))
D1b: shape construction from tuple of symbolic integers
After we obtained the symbolic integers such as n and m. We can recompose them together to form an Expr. Any tuple of symbolic integer expressions can be recognized as a Shape value in relax. As a result (n, m)
is a shape value.
Ways to do shape propagation
Importantly, because shape is now part of the value happens during computation. Compile time shape inference can be viewed as compile-time constant folding (or partial evaluation) on operations that happens with regard to shape. There are a few ways for the program to express shape computations:
- W1: Symbolic shape propagation. A shape can be destructed into symbolic integers (n or m in the above program) and we can then use expression of symbolic integers(
n*4
) to represent shape calculation. Notably, static shape is a special case of (constant symbolic) integers. The symbolic integer can then recompose to form a shape value(e.g.(n* 4, )
). - W2: Opaque shape function calls. We can also implement opaque shape functions (myshape_func). These opaque shape functions are useful fallbacks to quickly hack up a runtime shape function.
- W3: For data-dependent shape(unique), we will simply defer to a runtime call
f(inputs)->output
that takes the input Tensor, allocates and return the output tensor. We can then fetch the shape oflv5
from the Tensor value bymatch_shape
construct.
Implications for pass writing
Many of the optimization passes will need to look into the shape information. Now that many of the shape can be symbolic (n, 4)
, the most ideal optimization passes will need to generalize a bit to leverage the symbolic information. For example, in the above programs, we know that all the n
corresponds to the same value. This kind of constraint is super useful. Additionally, thanks to the symbolic integer in the arith module, we can reuse the mechanism of proves to check equivalence and deduction of symbolic expressions(e.g. prove(n4 == n4)).
Because symbolic integer(tir.PrimExpr
) eagerly constant fold, when the input is static shape, the result of computations should be folded eagerly to constant integer as well, preserving the properties we need for static shape dependent optimizations.
Because we can now represent a mixed symbolic static shape in a tuple (n, 4)
, we can try to take benefit of the static information for additional optimizations.
The final key design decision we made is to allow the high-level IR to be able to directly interact and call into lower-level TensorIR and PackedFunc. The TensorIR functions and many external libraries adopt a destination passing convention(we need to explicitly allocate the output and pass in as an argument to the function). We use dps(destination passing) to denote this convention. dps is very important in low-level ML optimizations as it allows us to globally allocate the intermediate storage in a single shot if possible, and executes the computation without active memory allocation.
Calling a dps function means after the call, the result is passed back via the function arguments (e.g., result in the example below) instead of the return value of a function.
// not destination passing
int func(int x) {
return 1;
}
// destination passing
void func(int x, int *result) {
*result = 1;
}
dps style means mutation(of output) in nature. We need a way to bridge the calls into the high-level (pure) dataflow land, so that we can perform computational graph style rewriting on a sequence of tir calls.
D2a: call_tir
call_tir
is an intrinsic that bridges the gap. The name means "calling a TIR convention".
def call_tir(output_shape: Shape, lowlevel_func: Expr, inputs: Tuple[Expr]) -> Expr:
"""Example code to demonstrate the semantics of call tir"""
out_tensor = alloc_tensor(output_shape, current_expr.dtype)
lowlevel_func(*inputs, out_tensor)
return out_tensor
call_tir takes in the output shape, lowlevel_func(can be packed func, tir PrimFunc) and a tuple of inputs. The semantics of call_tir can be demonstrated by the above code. Notably, when we lower call_tir
, we do not need to choose separate output tensor allocations. The compiler can choose to create a memory plan of the intermdiate tensors and tie things together for effective reuse.
Notably, the output_shape
parameter to call_tir intrinsic can be an opaque shape value, a symbolic integer tuple or a constant shape.
The lowlevel_func
can be any function with the signature
fn(input0, input1,... out0, out1...)
The two most common cases include: (1) A TIR function (2) An opaque packed func
Implementation note
CallTIR can be implemented as a special intrinsic(Op) to minimize the impact to the IR changes(instead of a standalone IR node). From the AST point of view, this becomes:
Call(op=Op::Get("relax.call_tir"), shape, lowlevel_func, inputs)
This would also allow future iterations of call_tir
without changing the IR itself, which might be needed at certain time point:
- Enable sequence of multiple mutations on the same array(in the case of concat related ops)
- Enable passing symbolic shape hints to a fused op.
Implications for Integration
D2 enables us to directly embed lower level abstractions into the high level abstractions(R.function). This unlocks a lot of opportunities, including, but not limited to:
- Incrementally lower different parts of the program using different strategies.
- Allow automation to take a call_tir to tir, perform optimization and rewrite into multiple call_tir note that informs layout rewriting decisions to the high-level.
- Bring BYOC flow as a natural part of transformation(by transforming part of the graph into call of opaque packed functions).
D2b: Packed function calls
We use R.call_packed
to indicate a call to a packed function. From the ast's point of view we do not need to introduce an additional call node, instead we can introduce an ExternFunc construct that represents a packedfunc where we can call into.
Call(op=ExternFunc("my_packed_func"), *args)
R.call_packed
only served as a syntax sugar to represent the above AST node. This allows us to unify all the calls. Notably, it also allows us to mix packed function and call_tir when necessary.
lv4: R.Tensor[lv3, "f32"] = R.call_tir(lv3, "custom_func", [lv1])
corresponds to the following AST.
Call(op=Op::Get("relax.call_tir"), shape, ExternFunc("my_packed_func"), [lv1])
CallTIR on external packed functions can be useful when we want to directly integrate low level libraries(such as cudnn) into the high level without invoking memory allocation.
This section covers additional design considerations that are not directly covered by the three key design points(D0, D1, D2).
In some cases we might need to go out of the original strongly typed world of Tensor, Shape and introduce a generic object type(that corresponds to tvm.runtime.Object). This is usually due to two common needs:
- Ability to support runtime objects that are not part of the type system(yet), for example, storage allocator, vm state.
- Ability to express more flexible programs(like those in TorchScript) that do not have type information.
We still encourage most of the code to follow the strongly typed version, and will require explicit type casting to convert an object to Tensor before running operations(they turns into a runtime assertion and cast). Thanks to tvm's object system, we can easily support this feature in runtime.
from tvm.script import relax as R
@R.function
def fn(x: R.Tensor[(n, m), "f32"]):
y: R.Tensor[(n, m), "f32"] = x + 1
return y
The above syntax may not pass pylint check because n and m are undefined variables. One possible way to tradeoff this is to allow string in the signature(same as python's type checking mechanism) to represent shape.
from tvm.script import relax as R
@R.function
def fn(x: R.Tensor["(n, m)", "f32"]):
n, m = R.shape_vars("n", "m")
y: R.Tensor[(n, m), "f32"] = x + 1
return y
We can also declare n, m in the global scope(however that pollutes the global scope and may not be desirable. Similarly, we might need to use "_"
in the place of _
for unknown shapes if we do not declare it and want pylint check to pass.