Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

Relax VM Design

Yuchen Jin edited this page Nov 15, 2021 · 4 revisions

Relax VM Design

Authors(alphabetical): @altanh, @jroesch, @tqchen, @YuchenJin, @ZihengJiang

Key Goals

  • A flexible register-based VM to execute relax programs with dynamic shape and control flow.
  • Minimal instruction set:
    • Call packed function as its core instruction.
    • Builtin standard library (e.g., shape_of(tensor) is one of the builtin packed functions to be invoked with the Call instruction to get the shape of a tensor).
  • Do shape calculations via shape heap (NDArray) manipulation.
    • Suppose Tensor A's shape is (m, n) at compile time, and in the relax program we want to compute (j, k) = (m+1, n+1). At runtime, A's shape will be stored in index 0 and index 1 of the shape heap via calling the vm builtin function store_shape(A.shape). m+1 and n+1 will be computed by a TIR Primfunc generated in the shape lowering pass, and j and k will be stored at index 2 and 3 of the shape heap. Please refer to the shape lowering pass in Relax Compilation MVP Design Doc for more details.

Instruction Set

The Relax VM has only 2 instructions (Call and Ret), control flow instructions such as If will be added later.

Call: call a packed function with arguments and write the output to the return register optionally.

call <packed_func> [arg0, ...] dst: optional<reg>

The argument can be register, int immediate, or a constant.

Ret: return the value in the result register.

ret <res_reg>

Architecture Overview

relax-vm-architecture-diagram

ExecBuilder

ExecBuilder is what we use to emit instructions and build executable for the virtual machine. It does not maintain any data members itself, but provides methods to manipulate the content of the inner Executable. Here is an example to show how builder works:

@tvm.register_func("test.vm.move")
def move(src):
    return src

@tvm.register_func("vm.add")
def add(a, b):
    ret = a.asnumpy() + b.asnumpy()
    return tvm.nd.array(ret)

@tvm.register_func("vm.mul")
def mul(a, b):
    ret = a.asnumpy() * b.asnumpy()
    return tvm.nd.array(ret)

from tvm import relax as rx
ib = rx.ExecBuilder()
with ib.function("main", num_inputs=1):
    ib.emit_call("vm.move", [ib.c(0)], dst=ib.r(1))
    ib.emit_call("vm.add", [ib.r(0), ib.imm(10)], dst=ib.r(2))
    ib.emit_call("vm.mul", [ib.r(2), ib.r(1)], dst=ib.r(3))
    ib.emit_ret(ib.r(3))
executable = ib.get()

Here ib.r(x) means the register x in the current frame; ib.c(i) means the i-th constant in the constant pool, currently the constant pool supports NDArray and DLDataType; ib.imm(val) means an immediate(inline constant) with value equals val.

ib.function and the Convention

We use ib.function(func_name, num_inputs) to annotate the scope of a function in the emitted code. The number of inputs, say k, must be provided so the first k registers will be used for storing function's inputs. In the case above, register r(0) stores the function input.

Check and Formalize

Due to the function convention mentioned above, it would be great to check whether user uses the registers correctly. So we will verify several things for the emitted instruction at the exit of the ib.function:

  • Does user use an unexpected register as input. For example, ib.r(3) is used as an input while the number of inputs is only 2. We will raise an error in this situation.
  • Does user miss any input register. For example, the number of inputs is 3, but user only uses ib.r(0) and ib.r(2) as input. We will raise warning in this situation.

Checking is done in CheckExecutable. Also, it is not good for register allocation if user uses an arbitrary register index. Say, ib.r(10000), although this is correct, we do not want to allocate 10000 registers during execution. So we will have a formalize pass to rename those registers in the order of use, which is implemented in ExecBuilderNode::Formalize.

Executable

Executable stores the content that is required for the virtual machine, includes bytecode, function table, constant pool, etc. It have to support serialization and deserialization so that we can pass it to another device and load it into vm.

Data Structure

We designed Executable's data structure carefully so that it can be serialized/deserialized easily.

struct VMFunction {
  std::string name; 
  Index start_instr;
  Index num_args;
	Index register_file_size;
};

class ExecutableNode : public Object {
  /* the global function informations */
  std::vector<VMFunction> global_funcs;
  /* the constant pool */
  std::vector<TVMRetValue> constants;
  /* packed function names, corresponding to the
     func_idx in Call instruction */
  std::vector<std::string> func_names;
  /* the emitted byte code of instrucitons */
  std::vector<ExecWord> instr_data;
  /* since the instruction's length is variable,
     we need to store the offset for indexing */
  std::vector<Index> instr_offset;
  ...
};

Emit Instruction and Get Instruction

With such data structure, when we emit a call instruction like ib.emit_call("add", [ib.r(0), ib.r(1)], dst=ib.r(2)), we will:

  • Add packed function name "add" into the packed_func_names if it is not in before;
  • Push instr_data's current size into instr_offset: instr_offset.push_back(instr_data.size()), for indexing the instruction;
  • Push instruction's opcode, func index, arguments, destination as byte code into instr_data

Get instruction is simple, we just use instr_offset to index the instruction in the instr_data:

Instruction ExecutableNode::GetInstruction(Index i) const {
  size_t offset = instr_offset[i];
  Opcode op = static_cast<Opcode>(instr_data[offset]);
  switch (op) {
    // ...
    // dispatch according to the op code
  }
}

Serialize and Deserialize

The serialization and deserialization are also easier in such case. Since we don't need to translate Instruction from in-memory data structure to byte code, we just need to serialize the instr_data and instr_offset, which is byte code already.

Text Format

We support dumping the executable's content in text format so that user can inspect the code. For example:

ib = rx.rx.ExecBuilder()

with ib.function("func0", num_inputs=2):
    ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
    ib.emit_call("vm.builtin.print", args=[ib.r(3)])
    ib.emit_ret(ib.r(3))

exec0 = ib.get()
print(exec0.stats())
print(exec0.astext()) 

Output:

Relax VM executable statistics:
  Constant shapes (# 0): []
  Globals (#1): [func0]
  Packed functions (#3): [vm.op.add, vm.builtin.move, vm.builtin.print]

@func0:
  call  vm.op.add        in: %0, %1       dst: %2
  call  vm.builtin.move  in: %2           dst: %3
  call  vm.builtin.print in: %3           dst: void
  ret   ret %3

Dump as Python Code

We also support dumping it back to python code so that user can hack it around easily:

# output of exec.aspython()
ib = rx.ExecBuilder()
with ib.function("func0", num_inputs=2):
    ib.emit_call("vm.op.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_call("time", dst=ib.r(100))
    ib.emit_call("vm.builtin.move", args=[ib.r(2)], dst=ib.r(3))
    ib.emit_call("time", dst=ib.r(101))
    ib.emit_call("vm.builtin.print", args=[ib.r(3)])
    ib.emit_ret(ib.r(3))

Virtual Machine

The VM can load into an executable and a runtime module which contains generated code, execute specific functions with given inputs, by interpreting the bytecode.

# building
ib = relax.ExecBuilder()
with ib.function("func0", num_inputs=2):
    ib.emit_call("test.vm.add", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_ret(ib.r(2))

with ib.function("func1", num_inputs=2):
    ib.emit_call("test.vm.mul", args=[ib.r(0), ib.r(1)], dst=ib.r(2))
    ib.emit_ret(ib.r(2))
ex = ib.get()

# execution
vm = relax.VirtualMachine(ex)
a = tvm.nd.array(np.random.rand(4,))
b = tvm.nd.array(np.random.rand(4,))
mul_res = vm["func1"](a, b)
add_res = vm["func0"](a, b)
np.testing.assert_allclose(add_res.asnumpy(), a.asnumpy() + b.asnumpy())
np.testing.assert_allclose(mul_res.asnumpy(), a.asnumpy() * b.asnumpy())

VM State

Builtin functions (e.g., alloc_storage) may need to access VM's internal states (e.g., allocators). The VMState is stored in a special register.

struct VMState {
  /*! \brief The memory allocators. */
  std::vector<Allocator*> allocators;
};

static constexpr RegName kVMStateRegister = 0x008D14FA4379015C; // magic number

// ib.emit_call("vm.builtin.alloc_storage", args=[ib.vm_state(), ...])
};

Shape calculation

The VM uses the shape heap(NDArray) to do shape calculations. (ShapeTuple is also a runtime data type.)

ib = relax.ExecBuilder()
shape = (32, 16)
x = tvm.nd.array(np.random.rand(*shape))
with ib.function("main", num_inputs=0):
	# alloc a shape heap of size 2, and store it in r(0)
	ib.emit_call("vm.builtin.alloc_shape_heap", args=[ib.imm(2)], dst=ib.r(0))
  # get the shape of tensor x, and store the ShapeTuple to r(1)
	ib.emit_call("vm.builtin.shape_of", args=[x], dst=ib.r(1))
  # store the shape of x to index 0 and index 1 of the shape heap
  # shape_heap[0] = 32, shape_heap[1] = 16
	ib.emit_call("vm.builtin.store_shape", args=[ib.r(1), ib.r(0), ib.imm(0), ib.imm(1)])
  # construct a ShapeTuple from the values in index 0 and index 1 of the shape heap
	ib.emit_call("vm.builtin.load_shape", args=[ib.r(0), ib.imm(0), ib.imm(1)], dst=ib.r(2))
	ib.emit_ret(ib.r(2))

Builtin runtime functions

https://github.com/octoml/relax/blob/relax/src/relax/vm/builtin.cc