-
Notifications
You must be signed in to change notification settings - Fork 58
Relax VM Design
Authors(alphabetical): @altanh, @jroesch, @tqchen, @YuchenJin, @ZihengJiang
- A flexible register-based VM to execute relax programs with dynamic shape and control flow.
- Minimal instruction set:
- Call packed function as its core instruction.
- Builtin packed function library (e.g.,
shape_of(tensor)
is one of the builtin packed functions to be invoked with the Call instruction to get the shape of a tensor).
- Do shape calculations via shape heap (NDArray) manipulation.
- Suppose Tensor A's shape is (m, n) at compile time, and in the relax program we want to compute (j, k) = (m+1, n+1). At runtime, A's shape will be stored in index 0 and index 1 of the shape heap via calling the vm builtin function
store_shape(A.shape)
. m+1 and n+1 will be computed by a TIR Primfunc generated in the shape lowering pass, and j and k will be stored at index 2 and 3 of the shape heap. Please refer to the shape lowering pass in Relax Compilation MVP Design Doc for more details.
- Suppose Tensor A's shape is (m, n) at compile time, and in the relax program we want to compute (j, k) = (m+1, n+1). At runtime, A's shape will be stored in index 0 and index 1 of the shape heap via calling the vm builtin function
The Relax VM has only 4 instructions (Call, Ret, If, Goto).
Call: call a packed function with arguments and write the output to the return register optionally.
call <packed_func> [arg0, ...] dst: optional<reg>
The argument can be register, int immediate, or a constant.
Ret: return the value in the result register.
ret <res_reg>
If: if the cond
register is true, continue to execute the next instruction(true branch), otherwise increase the pc(program counter) by false_offset
to go to the false branch.
If <cond> false_offset
Goto: increase the pc(program counter) by pc_offset
.
Goto pc_offset
ExecBuilder is what we use to emit instructions and build executable for the virtual machine. It does not maintain any data members itself, but provides methods to manipulate the content of the inner Executable. Here is an example to show how builder works:
@tvm.register_func("test.vm.move")
def move(src):
return src
@tvm.register_func("vm.add")
def add(a, b):
ret = a.asnumpy() + b.asnumpy()
return tvm.nd.array(ret)
@tvm.register_func("vm.mul")
def mul(a, b):
ret = a.asnumpy() * b.asnumpy()
return tvm.nd.array(ret)
from tvm import relax as rx
ib = rx.ExecBuilder()
with ib.function("main", num_inputs=1):
ib.emit_call("vm.move", [ib.c(0)], dst=ib.r(1))
ib.emit_call("vm.add", [ib.r(0), ib.imm(10)], dst=ib.r(2))
ib.emit_call("vm.mul", [ib.r(2), ib.r(1)], dst=ib.r(3))
ib.emit_ret(ib.r(3))
executable = ib.get()
Here ib.r(x)
means the register x
in the current frame; ib.c(i)
means the i-th constant in the constant pool, currently the constant pool supports NDArray and DLDataType; ib.imm(val)
means an immediate(inline constant) with value equals val
.
We use ib.function(func_name, num_inputs)
to annotate the scope of a function in the emitted code. The number of inputs, say k, must be provided so the first k registers will be used for storing function's inputs. In the case above, register r(0) stores the function input.
Due to the function convention mentioned above, it would be great to check whether user uses the registers correctly. So we will verify several things for the emitted instruction at the exit of the ib.function
:
- Does user use an unexpected register as input. For example,
ib.r(3)
is used as an input while the number of inputs is only 2. We will raise an error in this situation. - Does user miss any input register. For example, the number of inputs is 3, but user only uses
ib.r(0)
andib.r(2)
as input. We will raise warning in this situation.
Checking is done in CheckExecutable
. Also, it is not good for register allocation if user uses an arbitrary register index. Say, ib.r(10000)
, although this is correct, we do not want to allocate 10000 registers during execution. So we will have a formalize pass to rename those registers in the order of use, which is implemented in ExecBuilderNode::Formalize
.
Executable stores the content that is required for the virtual machine, includes bytecode, function table, constant pool, etc. It have to support serialization and deserialization so that we can pass it to another device and load it into vm.
We designed Executable's data structure carefully so that it can be serialized/deserialized easily.
struct VMFunction {
std::string name;
Index start_instr;
Index num_args;
Index register_file_size;
};
class ExecutableNode : public Object {
/* the global function informations */
std::vector<VMFunction> global_funcs;
/* the constant pool */
std::vector<TVMRetValue> constants;
/* packed function names, corresponding to the
func_idx in Call instruction */
std::vector<std::string> func_names;
/* the emitted byte code of instrucitons */
std::vector<ExecWord> instr_data;
/* since the instruction's length is variable,
we need to store the offset for indexing */
std::vector<Index> instr_offset;
...
};
With such data structure, when we emit a call instruction like ib.emit_call("add", [ib.r(0), ib.r(1)], dst=ib.r(2))
, we will:
- Add packed function name
"add"
into the packed_func_names if it is not in before; - Push
instr_data
's current size intoinstr_offset
:instr_offset.push_back(instr_data.size())
, for indexing the instruction; - Push instruction's opcode, func index, arguments, destination as byte code into
instr_data
Get instruction is simple, we just use instr_offset
to index the instruction in the instr_data
:
Instruction ExecutableNode::GetInstruction(Index i) const {
size_t offset = instr_offset[i];
Opcode op = static_cast<Opcode>(instr_data[offset]);
switch (op) {
// ...
// dispatch according to the op code
}
}
The serialization and deserialization are also easier in such case. Since we don't need to translate Instruction from in-memory data structure to byte code, we just need to serialize the instr_data
and instr_offset
, which is byte code already.
We support dumping the executable's content in text format so that user can inspect the code. For example:
ib = rx.rx.ExecBuilder()
with ib.function("func0", num_inputs=2):
ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
ib.emit_call("vm.builtin.print", args=[ib.r(3)])
ib.emit_ret(ib.r(3))
exec0 = ib.get()
print(exec0.stats())
print(exec0.astext())
Output:
Relax VM executable statistics:
Constant shapes (# 0): []
Globals (#1): [func0]
Packed functions (#3): [vm.op.add, vm.builtin.move, vm.builtin.print]
@func0:
call vm.op.add in: %0, %1 dst: %2
call vm.builtin.move in: %2 dst: %3
call vm.builtin.print in: %3 dst: void
ret ret %3
We also support dumping it back to python code so that user can hack it around easily:
# output of exec.aspython()
ib = rx.ExecBuilder()
with ib.function("func0", num_inputs=2):
ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_call("time", dst=ib.r(100))
ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
ib.emit_call("time", dst=ib.r(101))
ib.emit_call("vm.builtin.print", args=[ib.r(3)])
ib.emit_ret(ib.r(3))
The VM can load into an executable and a runtime module which contains generated code, execute specific functions with given inputs, by interpreting the bytecode.
# building
ib = relax.ExecBuilder()
with ib.function("func0", num_inputs=2):
ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
with ib.function("func1", num_inputs=2):
ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
ex = ib.get()
# execution
vm = relax.VirtualMachine(ex)
a = tvm.nd.array(np.random.rand(4,))
b = tvm.nd.array(np.random.rand(4,))
mul_res = vm["func1"](a, b)
add_res = vm["func0"](a, b)
np.testing.assert_allclose(add_res.asnumpy(), a.asnumpy() + b.asnumpy())
np.testing.assert_allclose(mul_res.asnumpy(), a.asnumpy() * b.asnumpy())
Builtin functions (e.g., alloc_storage) may need to access VM's internal states (e.g., allocators). The VMState is stored in a special register.
struct VMState {
/*! \brief The memory allocators. */
std::vector<Allocator*> allocators;
};
static constexpr RegName kVMStateRegister = 0x008D14FA4379015C; // magic number
// ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ...])
};
The VM uses the shape heap(NDArray) to do shape calculations. (ShapeTuple is also a runtime data type.)
ib = relax.ExecBuilder()
shape = (32, 16)
x = tvm.nd.array(np.random.rand(*shape))
with ib.function("main", num_inputs=0):
# alloc a shape heap of size 2, and store it in r(0)
ib.emit_call("vm.builtin.alloc_shape_heap", args=[ib.imm(2)], dst=ib.r(0))
# get the shape of tensor x, and store the ShapeTuple to r(1)
ib.emit_call("vm.builtin.shape_of", args=[x], dst=ib.r(1))
# store the shape of x to index 0 and index 1 of the shape heap
# shape_heap[0] = 32, shape_heap[1] = 16
ib.emit_call("vm.builtin.store_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)])
# construct a ShapeTuple from the values in index 0 and index 1 of the shape heap
ib.emit_call("vm.builtin.load_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2))
ib.emit_ret(ib.r(2))
The VM runtime library is a collection of packed functions, which can be called using the Call
instruction (so that the relax vm has a minimum instruction set).
https://github.com/octoml/relax/blob/relax/src/relax/vm/builtin.cc