tlc-pack / tvm-tensorir Goto Github PK
View Code? Open in Web Editor NEWLicense: Apache License 2.0
License: Apache License 2.0
The main goal of tensorization is to do loop nest matching(or iteration domain+ body). i.e. we want to be able to transform and match a tensor intrin in the following form.
loop1:
loop2:
body
Tensorization because it introduces quite a lot of constraints:
If we manage the search by randomly generate sequence of transforms then finally try to tensorize, it is very hard to find a candidate that meets the constraints. This is why tensorization is hard.
On the other hand, the additional constraints introduced by tensorization intrinsics is also a very good thing.
Because they constrain the search space. e.g. if the tensor intrinsic is 4x4 matmul, it does not make sense to do a tile size=2 split in the innermost.
To develop on this idea further, we can try to develop a gradual matching process to guide the search. The main idea is as follows: we try to identify candidates of tensorization by trying to gradually refine match.
Take the following intrinsic(of C, B, A) as an example:
for i in range(8):
for j in range(8):
C[i,j] = B[i, j] * A[i, j]
In the first step, we only try to match the body(C0), so the pattern abstracts to vc = vb * va
, where va, vb, vc are variables that can match any expressions. By using this body pattern match, we can first find promising bodies for tensorization.
At this stage, the following code can also be matched by pattern vc = vb * va
for i in range(128):
for j in range(-1, 7):
C[i,j+1] = B[i, j+1] * A[i, j+1]
We can also apply communicative pattern so it also matches C[i,j+1] = A[i, j+1] * B[i, j+1]
. But it can certainly rejects things like C[i,j] = B[i,j] + 1
(because body pattern mismatch)
After the first step, we will have a collection of possible matching candidates, the following steps will try to "refine the match" by considering more constraints. For example, we can start to consider C1, the type of each iterator, which will help us to consider the correspondence of loop iterator types. For example, we know that a data parallel loop cannot match to a reduction loop. We also know that certain loops are independent.
After we established the body and iterator mapping. The next set of constraints are iterator space length, and they will be used to guide the tiling(e.g. we know we need to tile to the certain size to meet the constraint).
Hopefully we get the idea here --- we try to decompose the constraints imposed by tensorization by generating a set of search candidates and refine them so that we can reduce the overall search space, notably we do not have to refine to the point where we can fully tensorize. Because we can always to search and verify (the forward roll out steps) as long as we are left with a few promising candidates.
Once we have the promising candidates and the iterator tiles. One thing we could do is to first blockize these blocks(knowing that they are candidate for tensorization). Then apply the AutoTIR normal search to search the schedules around the tensorize.
We can try to combine this idea with the MetaSchedule generation and ansor-style search -- For example, the initial tensorize matching steps could be programmed as a custom rule that generates a sequence of the MetaSchedule program.
Note that this is only an initial high level idea. In summary, instead of only doing the forward rollout exploration, we also try to use the target to refine the set of the constraints we might have in the forward rollouts.
We still need to refine the set of constraints we want to match, and build a prototype.
for i in grid(0, 10):
can be confusing after we start to support multiple grids, e.g.
for i, j in grid(3, 4)
To keep things simple, let us consider only to use the later format, and always assume iterator starts from 0 in the grid syntax. Or allow alternatives like
for i, j in grid(3, range(-1, 2))
@spectrometerHBH @Hzfengsy after looking at the implementation a bit.
Here are some note about BufferFlatten.
Our main goal of this preparation algorthm is to:
The current algorithm first computes the LCA of all buffer references(via construct a parent ref data structure ASTInfo) and then gather the bound after computing LCA.
Here are a few alternative algorithms just for discussion
Here is one alternative to detect the LCA of the buffers without the parent ref data structure.
// The code below is one example of this.
class Visitor {
public:
void VisitBufferUse(Buffer buffer) {
// merge buffer into the current message
MergeMessage(&buffer_in_scope_, Message{buffer, 1});
}
// new scope such as for
void VisitNewScope(For* op) {
if (buffer_in_scope_.size() != 0) {
std::vector<Message> res;
std::swap(buffer_in_scope_, res);
VisitBody(op);
std::swap(buffer_in_scope_, res);
for (auto& entry : res) {
if (entry.count == TotalRefCount(entry.buffer)) {
// mark current scope as LCA of buffer
} else {
// merge entry to parent.
MergeMessage(&buffer_in_scope_, entry);
}
}
} else {
VisitBody(op);
for (auto& entry : res) {
if (entry.count == TotalRefCount(entry.buffer)) {
// mark current scope as LCA of buffer
// mark deletion delete entry from buffer_in_scope_;
entry.buffer = nullptr;
}
}
// try to compact the buffer_in_scope_;
}
}
private:
struct Message {
// buffer of interest
Buffer buffer;
// reference count in the scope.
int use_count;
};
// buffer in the current scope.
std::vector<Message> buffer_in_scope_;
};
std::vector<Message>
.Another possibility is to go one step further, to use buffer and their access set as message.
struct Message {
public:
// buffer of interest
Buffer buffer;
// integer set of each location that is being
std::vector<arith::IntSet> region;
int use_count;
};
Of course the current LCA then gather approach also enjoys good algo complexity and readability.
Right now the Region Gather uses an approach of directly relax and merge back to the final
location pt, it is fine but there are a few cases.
// scope root
for i in range(0, 2):
// scope a
temp[i * 2 + 0] // set0
temp[i * 2 + 1] // set1
for j:
temp[j]
union(relax(set0, i=range(0,2)), relax(set1, i=range(0,2)))
relax(union(set0, set1), i=range(0,2))
.
Because our integer set analysis is relaxed and only keeps interval information. T1 can be more accurate than T0.
We need to start to think about memory and thread hierachy, which is important for GPU.
When traversing up to see which loop_var to relax, LCA is not enough in the case of
GPU memory hierachy and general accelerators.
Think of the following case
// scope root
alloc_buffer global temp[10]
thread_scope blockIdx.x in (0, 10):
// scope a
temp[blockIdx.x] = A[blockIdx.x]
B[blockIdx.x] = temp[blockIdx.x]
Because temp is the global memory, we cannot attach it to scope a(even though it is LCA).
We need to attach to scope root instead.
Here is another example of thread binding and memory scope, which is a bit more complicated:
// scope root
alloc_buffer shared temp[10]
thread_scope blockIdx.x in (0, 10):
// scope block
thread_scope threadIdx.x in (0, 10):
for k:
// note just bind things to threadIdx.x
// cooperative fetching
temp[threadIdx.x] = A[threadIdx.x]
B[threadIdx.x] = temp[threadIdx.x]
Here the LCA of temp is at scope of k.
However, because we need to allocate it at the scope block,
we will need to relax threadIdx.x
but not k.
Note that the memory scope does bring restrictions to scheduling. e.g. The access LCA of a buffer with shared scope cannot be at the global level(the code below is invalid).
// scope root
share temp[10]
thread_scope blockIdx.x in (0, 10):
// cooperative fetching
temp[threadIdx.x] = A[threadIdx.x]
thread_scope blockIdx.x in (0, 10):
B[threadIdx.x] = temp[threadIdx.x]
List of typical points to check during code review, or when writing code.
cc @Hzfengsy @spectrometerHBH let us also improve the list as we find more possible pts
This primitive us very useful for autoscheduling to keep all tiling decisions at one stage.
Example:
for i
for k:
opaque_produce(C[i*4, i*4+4])
for j:
D[j] = C[j]+1
s.elemwise_reverse_compute_at(D, C, i)
for i:
for k:
opaque_produce(C[i*4, i*4+4])
for j in range(4):
D[i*4 + j] = C[i*4 + j]+1
Why, it is useful to make all tiling decisions at the reduction point. Then "reverse inline" the later elementwise stage back to the loop tiles of the original computation. Correctness:
I am not too happy about the name, and would love to see if we have better ideas.
After looking at the rerder #37 i feel that it looks a bit too long.
Seems one of the factor is due to the fact that we are trying to do copy on write instead of the direct mutation. I wonder if there is a way to introduce a few more idioms to make these mutation easier.
For example, one possible idiom would be something like Schedule->PrepareForMutation(List[SRef]), which takes a list of SRef, and makes sure that everything from these SRef to the root contains a single ref count(if not, run COW). After this call, we can freely make use of mutation within the range.
Of course such idiom does not take things like deletion into consideration. Would love to see if it is possible to have idea along the lines
Right now Allocate is printed as
def func():
data = tir.var("handle")
tir.allocate(data, "float32", [1024])
After apache/tvm#6317, we want to make sure that allocate actually uses var that has a pointer type annotation(non runtime type) to get richer information, it would be great to enhance hybrid parser to support this case.
A slight generalization might look like
def func():
data = tir.var(ty.Ptr[ty.float32])
tir.allocate(data, "float32", [1024])
However, it seems to be weird to separate data var declaration. In this case, we can simply do
def func():
# data is a var with type annotation "Ptr[float32["
data = tir.allocate("float32", [1024])
and in the case of scoping
def func():
with tir.allocate("float32", [1024]) as data:
So far we have a text printer for relay. which allows us to print an IRModule into text format. On the TIR side, we still relies on the ReprPrinter.
This is issue is for upgrading the text printer so that we can print an IRModule that include PrimFunc(tir::Function in the upstream) as a text format. This will help us to enhance the demo experience.
Ideally we want to land a version in the mainline in about two to three weeks. @spectrometerHBH please see if it is possible for you and @Hzfengsy to coordinate a format and land a version, then we can pull back to the tensorIR
Key points
Week 1~3: 07/08/2020 - 07/29/2020
Week 4~5: 07/29/2020 - 08/12/2020
Week 6: 08/12/2020 - 08/17/2020
Week 7-8: 08/17/2020 - 09/03/2020
Week 9-13: 09/03/2020 - 10/7/2020
Week 14: 10/8/2020 - 10/15/2020
This is an RFC thread, unlike DISCUSS, RFC means we have a concrete proposal that can be executed
The loops after schedule can contain very deep nest, and it is useful to introduce a sugar(in both printing and parsing) to contain a more concised version of grid.
For example the following code
for xo, xi, y in grid(10, 20, 200):
pass
is equivalent to
for xo in range(10):
for xi in range(20):
for y in range(200):
pass
The alternative names to grid
include:
The following items happens upstream
After TIR RFC
Related people
Sync with the current upstream https://github.com/apache/incubator-tvm,
most of the major changes in the upstream are not completed, so we don't need to sync frequenly after this sync.
Related refactor PRs:
Python is the most popular language for deep learning frameworks due to its great flexibility and rich ecosystem. This RFC plans to utilize a subset of Python AST that can express every TIR node. The new dialect will serve as a way to construct and inspect the IR in Python.
ML compilation is an open research area, while its great value has already leads to quick transfer to production. We believe Hybrid Script will enable more ML scientists and engineers to quickly implement prototypes of new ML algorithms, which will increase the rate of innovation. For developers of tvm compiler stack, having a readable and writable text format of IR will ease the difficulty of implementing and testing data structure transformations and optimizations.
# opt_gemm.py
# after normalize schedule
class Module:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
# body
tir.attr(C_1, "realize_scope", "")
tir.realize(C_1[0:1024, 0:1024])
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C_1[x, y] = tir.float32(0)
for k in tir.range(0, 1024):
C_1[x, y] = (C_1[x, y] + (A_1[x, k]*B_1[k, y]))
# after lower
class Module:
def mmult(A: ty.handle, B: ty.handle, C: ty.handle) -> None:
# function attr dict
tir.func_attr({"global_symbol": "mmult", "tir.noalias": True})
A_1 = tir.buffer_bind(A, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
B_1 = tir.buffer_bind(B, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
C_1 = tir.buffer_bind(C, [1024, 1024], elem_offset=0, align=128, offset_factor=1)
# body
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C_1.data[((x*1024) + y)] = tir.float32(0)
for k in tir.range(0, 1024):
C_1.data[((x*1024) + y)] = (tir.load("float32", C_1.data, ((x*1024) + y)) + (tir.load("float32", A_1.data, ((x*1024) + k))*tir.load("float32", B_1.data, ((k*1024) + y))))
# host module
class Module:
def mmult(args: ty.handle, arg_type_ids: ty.handle, num_args: ty.int32, out_ret_value: ty.handle, out_ret_tcode: ty.handle) -> ty.int32:
# function attr dict
tir.func_attr({"target": meta[Target][0], "tir.noalias": True, "global_symbol": "mmult", "tir.is_entry_func": True, "calling_conv": 1})
# body
assert (num_args == 3), "mmult: num_args should be 3"
arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")
arg0_code: ty.int32 = tir.load("int32", arg_type_ids, 0)
arg1: ty.handle = tir.tvm_struct_get(args, 1, 12, dtype="handle")
arg1_code: ty.int32 = tir.load("int32", arg_type_ids, 1)
arg2: ty.handle = tir.tvm_struct_get(args, 2, 12, dtype="handle")
arg2_code: ty.int32 = tir.load("int32", arg_type_ids, 2)
A: ty.handle = tir.tvm_struct_get(arg0, 0, 1, dtype="handle")
tir.attr(A, "storage_alignment", 128)
arg0_shape: ty.handle = tir.tvm_struct_get(arg0, 0, 2, dtype="handle")
arg0_strides: ty.handle = tir.tvm_struct_get(arg0, 0, 3, dtype="handle")
dev_id: ty.int32 = tir.tvm_struct_get(arg0, 0, 9, dtype="int32")
B: ty.handle = tir.tvm_struct_get(arg1, 0, 1, dtype="handle")
tir.attr(B, "storage_alignment", 128)
arg1_shape: ty.handle = tir.tvm_struct_get(arg1, 0, 2, dtype="handle")
arg1_strides: ty.handle = tir.tvm_struct_get(arg1, 0, 3, dtype="handle")
C: ty.handle = tir.tvm_struct_get(arg2, 0, 1, dtype="handle")
tir.attr(C, "storage_alignment", 128)
arg2_shape: ty.handle = tir.tvm_struct_get(arg2, 0, 2, dtype="handle")
arg2_strides: ty.handle = tir.tvm_struct_get(arg2, 0, 3, dtype="handle")
assert ((((arg0_code == 3) or (arg0_code == 13)) or (arg0_code == 7)) or (arg0_code == 4)), "mmult: Expect arg[0] to be pointer"
assert ((((arg1_code == 3) or (arg1_code == 13)) or (arg1_code == 7)) or (arg1_code == 4)), "mmult: Expect arg[1] to be pointer"
assert ((((arg2_code == 3) or (arg2_code == 13)) or (arg2_code == 7)) or (arg2_code == 4)), "mmult: Expect arg[2] to be pointer"
assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg0, 0, 4, dtype="int32")), "arg0.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg0, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg0, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg0, 0, 7, dtype="uint16") == tir.uint16(1))), "arg0.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 0))), "Argument arg0.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg0_shape, 1))), "Argument arg0.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg0_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg0_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg0_strides, 0)))), "arg0.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg0, 0, 8, dtype="uint64")), "Argument arg0.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg0, 0, 10, dtype="int32")), "Argument arg0.device_type has an unsatisfied constraint"
assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg1, 0, 4, dtype="int32")), "arg1.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg1, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg1, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg1, 0, 7, dtype="uint16") == tir.uint16(1))), "arg1.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 0))), "Argument arg1.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg1_shape, 1))), "Argument arg1.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg1_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg1_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg1_strides, 0)))), "arg1.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg1, 0, 8, dtype="uint64")), "Argument arg1.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg1, 0, 10, dtype="int32")), "Argument arg1.device_type has an unsatisfied constraint"
assert (dev_id == tir.tvm_struct_get(arg1, 0, 9, dtype="int32")), "Argument arg1.device_id has an unsatisfied constraint"
assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
assert (2 == tir.tvm_struct_get(arg2, 0, 4, dtype="int32")), "arg2.ndim is expected to equal 2"
assert (((tir.tvm_struct_get(arg2, 0, 5, dtype="uint8") == tir.uint8(2)) and (tir.tvm_struct_get(arg2, 0, 6, dtype="uint8") == tir.uint8(32))) and (tir.tvm_struct_get(arg2, 0, 7, dtype="uint16") == tir.uint16(1))), "arg2.dtype is expected to be float32"
assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 0))), "Argument arg2.shape[0] has an unsatisfied constraint"
assert (1024 == tir.cast("int32", tir.load("int64", arg2_shape, 1))), "Argument arg2.shape[1] has an unsatisfied constraint"
if not (tir.isnullptr(arg2_strides, dtype="bool")):
assert ((1 == tir.cast("int32", tir.load("int64", arg2_strides, 1))) and (1024 == tir.cast("int32", tir.load("int64", arg2_strides, 0)))), "arg2.strides: expected to be compact array"
tir.evaluate(0)
assert (tir.uint64(0) == tir.tvm_struct_get(arg2, 0, 8, dtype="uint64")), "Argument arg2.byte_offset has an unsatisfied constraint"
assert (1 == tir.tvm_struct_get(arg2, 0, 10, dtype="int32")), "Argument arg2.device_type has an unsatisfied constraint"
assert (dev_id == tir.tvm_struct_get(arg2, 0, 9, dtype="int32")), "Argument arg2.device_id has an unsatisfied constraint"
tir.attr(0, "compute_scope", "mmult_compute_")
for x in tir.range(0, 1024):
for y in tir.range(0, 1024):
C[((x*1024) + y)] = tir.float32(0)
for k in tir.range(0, 1024):
C[((x*1024) + y)] = tir.call_llvm_pure_intrin(tir.uint32(97), tir.uint32(3), tir.load("float32", A, ((x*1024) + k)), tir.load("float32", B, ((k*1024) + y)), tir.load("float32", C, ((x*1024) + y)), dtype="float32")
The basic degsin goal is that printer can print IR as a readable Python script, which parser can parse and construct an Equivalent IR object for any stage of compilation.
The overall design to support all IR variants is to build one-to-one correspodence between scoping structure of IR and tree structure of Python AST.
Ideally, each kind of IRNode corresponds to an AST Node in the Python AST. Generally, for an arbitrary node named xxxNode
, we can use tir.xxx()
in Python script to represent it, and recursively print the function call arguments according to the node's constructor function, since we want no information loss in round trip.
scope handlers : For StmtNode with body, we can use with tir.xxx()/for xxx in tir.yyy()
to represent it in Python script, and we will call these nodes and functions.
intrins: The remaining IRNodes (StmtNodes without body, PrimExprNodes and more) are called intrins.
In principle, we can represent any IR tree using above two rules, like
with tir.let():
with tir.assert(...):
with tir.assert(...):
with tir.attr(...):
tir.store(...)
tir.evaluate(...)
but it will sacrifice readability and writablity, and some information is not suitable to be printed in such a format.
For example, many nodes like AllocateNode
include a Var
as its constructing parameter. Ideally we want to print a Var
only using its name_hint like
with tir.allocate(packedB, "float32", [16, 16, 16], True):
...
But we will loss the dtype information of packedB
in the above example. We'd better use
packedB = tir.var("handle")
with tir.allocate(packedB, "float32", [16, 16, 16], True):
Now, packedB = tir.var("handle")
doesn't actually correspond to an IRNode, we call these statements and functions special stmts. The declaration of Buffer
, Var
and DictAttr of PrimFunc are all handled by special stmts.
If we follow the default scoping rule above naively, we will see many unnecessary indents in the text form.
with tir.attr(0, "compute_scope", "mmult_compute_"):
with tir.attr(packedB, "storage_scope", "global"):
with tir.attr(packedB, "storage_alignment", 128):
if tir.isnullptr(packedB, dtype="bool"):
tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
To provide better readability and writablity for Hybrid Script, we provide a rule of consice scoping for printer and parser.
Printer: For a scope handler node, if it is the last stmt of its parent stmt's scope, then we can print its body without explicit indent.
Parser: If we encounter a stmt corresponding to a scope handler node but not in With
context, the rest of current scope are parsed as its body
tir.attr(0, "compute_scope", "mmult_compute_")
tir.attr(packedB, "storage_scope", "global")
tir.attr(packedB, "storage_alignment", 128)
if tir.isnullptr(packedB, dtype="bool"):
tir.evaluate(tir.tvm_throw_last_error(dtype="int32"))
According to our classification of nodes, nodes within a same category usually follow similar printing and parsing rule. With the evolvement of IR, there will be new nodes in the future, it would be convenient if we can write a few lines of code to register it into our printer and parser.
All the functions (scope handlers, intrins, special stmts) are registered inside class Registry
. Registry
holds diffrent dictionaries mapping function names to actual function entry for diffrent categories of nodes.
The registry mechanism will automatically handle argument parsing and passing with error reporting for missing arguments. The registered function only need to consider how to process the arguments.
According to our classification above,
register_scope_handler(scope_name, concise)
to register@register_scope_handler("with_scope", concise=False)
def let(parser, node, var, value, body):
""" With scope handler function let(var, value, body) """
return tvm.tir.LetStmt(var, value, body)
As said, scope handlers are stmts with body, we further classify them into 2 categories
scope_name="with_scope"
):Since we want to support concise scoping (concise=True
), some of with scope handlers can appear in two formats
with tir.xxx() # With in Python AST
tir.xxx() # Expr(Call) in Python AST
If we ban concise scoping (concise=True
), then the node can only be represented as
with tir.xxx() # With in Python AST
scope_name="for_scope"
):@register_scope_handler("for_scope")
def range(parser, node, begin, end, for_type="serial"):
""" For scope handler function range(begin, end, for_type)"""
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
loop_var_name = node.target.id
loop_var = tvm.te.var(loop_var_name, dtype="int32")
parser.scope_emitter.new_scope()
parser.scope_emitter.update_symbol(loop_var_name, loop_var)
body = get_body(parser, node)
parser.scope_emitter.pop_scope()
for_type_dict = {"serial": 0, "parallel": 1, "vectorized": 2, "Unrolled": 3, }
return tvm.tir.For(loop_var, begin, extent, for_type_dict[for_type], 0, body)
Their common parsing behavior is
Special stmts: use register_special_stmt
to register
Special stmts can appear in 2 formats, They doesn't correspond to nodes in IR directly.
target = tir.xxx()
, like packedB = tir.var("handle")
@register_special_stmt
def var(parser, node, dtype):
return te.var(parser._assign_target, dtype)
tir.xxx()
, like tir.func_attr({"global_symbol": "default_function", "tir.noalias": True})
@register_special_stmt
def func_attr(parser, node, dict_attr):
parser.dict_attr = dict_attr
Intrin: use register_intrin
to register
@register_intrin
def ramp(base, stride, lanes):
lanes = lanes.value if not isinstance(lanes, int) else lanes
return tvm.tir.Ramp(base, stride, lanes)
Intrin can appear as tir.xxx()
As we mentioned before, it might be useful to think about compute-style syntax(like TE) that can concisely write out a block, without necessarily writing out the outside(loops/schedules). This is a discssion thread to get everyone's thought about what the syntaxt should look like. The primary goals are:
When a block has variable that are not binded, the program first execute the loops that are binded, then it executes the natural loop generated by the unbound block variables
# simple compute
with block(domain=(10, 20)) as xo, xi:
A[xo, xi] = B[xo + 1, xi]
# compute A[i, :]
with block(domain=grid(10)) as x:
A[x, 0] = B[x, 0]
for i in range(1, 10):
A[x, i] = B[x, i] + A[x, i - 1]
# matmul
with block(domain=(256, 256, 256)) as i, j, k:
sum.step(C[i, j], A[i, k] * B[j, k])
# binded matmul (trivial schedule)
for xi, xj, xk in grid(256, 256, 256):
with block(domain=(256, 256, 256), binds=[xi, xj, xk]) as i, j, k:
sum.step(C[i, j], A[i, k] * B[j, k])
# 8x8 tiled matmul
for xi0, xj0, xk, xi1, xj1 in grid(32, 32, 256, 8, 8):
with block(domain=(256, 256, 256), binds=[xi0 * 8 + xi1, xj0 * 8 + xj1, xk]) as i, j, k:
sum.step(C[i, j], A[i, k] * B[j, k])
Let us use this thread to discuss python syntax, i put done some random thoughts here.
Goals
NOTE: I didn't think very carefully about the final syntax, so just some fruit for thoughts :)
for i in range(10):
with block(info=meta["Block"][0]):
A[i] = B[i] + 1
Meta can store the full block information, like dependencies, loop info, additional data. The advantage of this that we can store any info in meta and still be able to read things back. so that the printed code is concise(if we skip meta) but we are still be able to get the code back in full form.
Having meta will also ensure future compact(see below)
This is somewhat related to both IR design and syntax. Right now we are doing normal loops, but in the future, we can start to think about more interesting iterators(e.g. for sparse data structure)
for i, j in coordinate(A.coo_coord):
for k in range(10):
with block(constraint_info):
A[i] = B[i] + 1
Different iterators will be posed just as different iterator constraints(info included in the block) and properties of the iterator(whether can be splitted, reorder relation) in the loops. When designing the infra, we want to think about potential future use-cases.
We should aim to design an infrastructure, so that when we extend the IR to support new transformations, we only need to do minimum change to the python syntax and IR infra. e.g. the special scope_handler would be useful to handle this case
python3.6 starts to have type annotations
def f(x: int, y: int) -> int:
pass
In some sense, we could view the iterator constraint as a type, or not:
Example mock syntax
for i in range(30):
with block(vi: Axis[30] = i):
pass
Need to look into numpy's api name for the grid
for i, j in grid(30, 30):
with block((vi, vj) : Grid[(30, 30)] = (i, j)):
pass
Right now all the schedule primitives are in the schedule and we have to use operator->() extensively to get the ScheduleNode pointer. Let us do a refactor to move the schedule primitive functions to the ScheduleNode
Given that master is out-dated, let us backup the master and reset dev to master
Set scope when do compute_at
We should automatically set scope when doing compute_at.
scope = ""
means its scope can be changed accordingly. In the original tvm's schedule,
If a stage is compute at BlockIdx, its scope will be set to "shared" automatically.
If a stage is compute at ThreadIdx, its scope will be set to "local".
If an intermidiate stage is compute at an axis of another stage, its scope will be set to "local"
However, our argument of compute_at is a block, not a stage. So I think we should do
Schedule::raw_realize_region
, Schedule::raw_realize_scope
to Schedule::realize_region
, Schedule::realize_scope
Compare compilation time with original schedule
Compilation time is important for auto-tuning as we will compile a lot of programs during auto-tuning for feature extraction. We have to make sure the efficiency of our implementation.
@tvm.hybrid.script
def matmul(a, b, c):
n, m, z = tir.var_tuple(3, "int32")
match a => {
Buffer((m, n)): {
# m and n are defined
}
}
match {a, b} => {
Buffer((m, n)) as A and Buffer((m, z)) as B: {
# m and n are defined
}
}
A = tir.match_buffer(a, (n, z), "float32")
B = tir.match_buffer(b, (m, z), "float32")
C = tir.match_buffer(c, (n, m), "float32")
with tir.block(n, m, tir.reduce_axis(z)) as i, j, k:
sum.step(C[i, j], A[i, k] * B[j, k])
@tvm.hybrid.script
def matmul(a, b, c):
n, m, z = var_tuple(3, "int32")
A = match_buffer(a, (n, z), "float32")
B = match_buffer(b, (m, z), "float32")
C = match_buffer(b, (n, m), "float32")
for xi, xj, xk in tir.grid(n, m, z):
with block(domain=[n, m, tir.reduce_axis(z)], "B") as i, j, k:
tir.block_vars(
i = xi,
j = xj,
k = xk
)
sum.step(C[i, j], A[i, k] * B[j, k])
for xi, xj, xk0, xk1 in tir.grid(n, m, z // 4, 4):
with block([n, m, tir.reduce_axis(z)], "C") as i, j, k:
block.predicate(xk0 * 4 + xk1 < z)
i = block.bind(xi)
j = block.bind(xj)
k = block.bind(xk0 * 4 + xk1)
sum.step(C[i, j], A[i, k] * B[j, k])
for xi, xj, xk0, xk1 in grid(n, m, z // 4, 4):
if xk0 * 4 + xk < z:
with block([n, m, tir.reduce_axis(z)], "C") as i, j, k:
for xi, xj, xk0, xk1 in grid(n, m, z // 4, 4):
if tir.predicate(xk0 * 4 + xk < z):
with block([n, m, tir.reduce_axis(z)], "C") as i, j, k:
...
for xi, xj, xk0, xk1 in grid(n, m, z // 4, 4):
with block([n, m, tir.reduce_axis(z)], "C") as i, j, k:
tir.predicate(xk0 * 4 + xk < z)
tir.block_if
tir.where
...
for xi, xj, xk0, xk1 in grid(n, m, z // 4, 4):
with block([n, m, tir.reduce_axis(z)], "C") as i, j, k:
tir.realize(
i = xi,
j = xj,
predicate=xk0 * 4 + xk < z
)
..
@tvm.hybrid.script
def add_pipeline(a, d):
n = tir.var("int32")
A = match_buffer(a, (n,), "float32")
D = match_buffer(d, (n,), "float32")
B = alloc_buffer((n,), "float32")
with block(n) as i:
B[i] = A[i] + 1
C = alloc_buffer((n,), "float32")
with block(n) as i:
C[i] = tir.exp(B[i])
with block(n) as i:
D[i] = C[i] + 1
@tvm.hybrid.script
def add_pipeline(a, d):
n = tir.var("int32")
A = match_buffer(a, (n,), "float32")
D = match_buffer(d, (n,), "float32")
B = compute((n), lambda i: A[i] + 1)
C = compute((n), lambda i: tir.exp(B[i]))
with block(n) as i:
D[i] = C[i] + 1
Python hybrid
Unify TIR and Relay
Schedule primitive
Build
Demo
Key points
The python side has been updated to reflect the final form of the namespace organization
Update Guide apache/tvm#4647
cc @Hzfengsy
After discussion with @Hzfengsy, we put a proposal for tensorize
here.
tensorize
Originally, we want to implement blockize
first. As for tensorize
, we can blockize
first and then apply Equal to check the validness of tensorize
. But we found it difficult to do so, here are some points.
blockize
requires extracting block vars and inferring their loop bindings. Theoretically, there exist many diffrent valid block vars and loop bindings, as long as they isolate the inner of the block and produce equivalent code. blockize
can use an arbitrary set of block vars and loop bindings.tensorize
expects a specific set of block var and loop binding to match the computation template provided by user.For example
for ax0_inner in range(0, 16):
for ax1_inner in range(0, 16):
for ax2_inner in range(0, 16):
with block({vi(0, 1024):(((ax0_outer*256) + (ax0_outer_outer*64)) + ((ax0_outer*16) + ax0_inner)),
vj(0, 1024):(((ax1_outer*256) + (ax1_outer_outer*64)) + ((ax1_outer*16) + ax1_inner)),
vk(0, 1024, iter_type="reduce"):((((ax2_outer_outer*4) + ax2_outer_inner)*16) + ax2_inner)},
writes=[C_wmma.accumulator[vi:(vi + 1), vj:(vj + 1)]],
reads=[C_wmma.accumulator[vi:(vi + 1), vj:(vj + 1)],
A_shared_wmma.matrix_a[vi:(vi + 1), vk:(vk + 1)],
B_shared_wmma.matrix_b[vj:(vj + 1), vk:(vk + 1)]],
name="C"):
reducer0.step(C_wmma.accumulator[vi, vj], (A_shared_wmma.matrix_a[vi, vk]*B_shared_wmma.matrix_b[vj, vk]))
If we simply want to blockize
this whole piece of IR, a natural way is that for every loop var outside this subtree, like ax0_outer
,ax1_outer
, we use a corresponding block var to isolate it. Otherwise it is tricky to further infer the iter types(data_par/reduce/serial/opque
) later. And we haven't come up with another clear and practical rule.
However, if we want to tensorize
it to use Tensor Core. And user provide such a computation template
with block({v0, v1, v2}}:
for i in range(0, 16):
for j in range(0, 16):
for k in range(0, 16):
with block({vi(v0, v0 + 16): v0 + i)),
vj(v1, v1 + 16): v1 + j)),
vk(v2, v2 + 16, iter_type="reduce"): v2 + k)},
writes=[C_wmma.accumulator[vi:(vi + 1), vj:(vj + 1)]],
reads=[C_wmma.accumulator[vi:(vi + 1), vj:(vj + 1)],
A_shared_wmma.matrix_a[vi:(vi + 1), vk:(vk + 1)],
B_shared_wmma.matrix_b[vj:(vj + 1), vk:(vk + 1)]],
name="C"):
reducer0.step(C_wmma.accumulator[vi, vj], (A_shared_wmma.matrix_a[vi, vk]*B_shared_wmma.matrix_b[vj, vk]))
then we have to extract v0=(((ax0_outer*256) + (ax0_outer_outer*64)) + ((ax0_outer*16)
, and we actually don't care the iter type of v0
,v1
,v2
.
tensorize
Due to the considerations above, we decide to make tensorize
independent of blockize
.
User provide computation template like above and a function to generate intrinsic block, for example, if we have known v0
,v1
,v2
, how to generate a wmma call.
Then tensorize
expects a loop_sref/block_sref as argument, which is the tensorize point. We try to match the whole subtree with computation template, requiring that
And we can extract values for v0
,v1
,v2
in loop bindings. If we don't encounter contradiction, we can do tensorize
for this template. Then we use IR generated by the function provided by user applying matched values for v0
,v1
,v2
to substitute the tensorize point.
Quite a lot of the places in our IR involves code that introduces a new scope, for example, Allocate, Assert, AttrStmt, LetStmt. The simplest way to print these blocks is simply to print and align. However, this will create a lot of indentation blocks, we should instead use the following special rule for printing(as in the case of printing let blocks in the functional programming.
For example, in the following code, we don't have to indent let because we know that the body of the let is everything that follows it, until the end of the current scope.
with block():
A = tir.allocate((10, 20))
tir.let(x = 1)
tir.let(y = 2)
call(x + y)
Similar rules can apply to other cases
Hardware chips usually have more than one storage and execution hierarchy. As for NVIDIA GPUs, they have GPU blocks(GPU SMs), warp and CUDA cores with global, shared and local memory scope. Each level can access specific memory scope.
On the other hand, TIR with blocks also can be hierarchical. A block can only access the buffers which allocate at the same or outer scope. Hence, I would like to make it a one-to-one map.
I would talk about it through a GPU gemm example.
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 32
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 32
block(block_hierarchy="GPU_SM") {
// attr [iter_var(threadIdx.y, range(min=0, ext=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, range(min=0, ext=8), threadIdx.x)] thread_extent = 8
block(block_hierarchy="GPU_processor") {
// attr [iter_var(vy, range(min=0, ext=2), vthread)] virtual_thread = 2
// attr [iter_var(vx, range(min=0, ext=2), vthread)] virtual_thread = 2
block(init C.local)
for (k.outer, 0, 256) {
block(copy from A to A.shared)
block(copy from B to B.shared)
for (k.inner.outer, 0, 8) {
block(copy from A.shared to A.local)
block(copy from B.shared to B.local)
block(update C.local)
}
}
block(copy from C.local to C.global)
}
}
GPU_SM
blocks and outer GPU_processor
blocksGPU_warp
blocks can be tensorized to TensorCore intrinsicWe still introduce block_hierarchy for some special use(e.g. TensorCore) but allow people to directly bind thread without blockize. The possible problem is that it is hard to do checks during the schedule. However, the good news is we can do checks during BufferFlatten
I remember that current TVM has a special rule to handle shared memory during bound infer. Maybe I miss some details in that case.
One thing that I noticed is that buffer_decl
and direct field reference can be a bit long. Buffer itself is really a special construct in TVM. Currently, we first create the buffer fields as separate vars, then bind them.
Adata = tvm.var("handle")
A = tir.decl_buffer(data=Adata, shape=(3, 4))
tir.some_func(Adata)
In many cases, when the buffer itself is a variable that first get declared, we could instead have the following related syntax
A = tir.decl_new_buffer(shape=(3, 4), with_strides=False)
# Refers to Adata
tir.some_func(A.data)
One potential gain is that it might help us to reduce the number of lines in the hybrid script. Would be nice to see how concise we can make the buffer releated code to be
# A.shape is a var
A = tir.decl_new_buffer(with_strides=False)
B = tir.decl_buffer(shape=A.shape)
tir.some_func(A.data, A.shape[0])
After frequent discussion with @spectrometerHBH, we find there are still some problems with the correctness validation.
Before that, we usually suppose that the block can get correct input as long as the dependency is satisfied (e.g. the consumer block must execute after the producer block). However, it is not always the truth, especially when the producer block will update the output tensor. Here is an example
for i0 = 0 to 16
for j0 = 0 to 16
block(vj=j, ..., name="A")
A[j0] = A[j0] + 1
for i1 = 0 to 16
for j1 = 0 to 16
block(...)
B[i1][j1] = A[j1]
In this case, we can not compute_at block B
under the Loop j0
even though the dependency is satisfied. Here is the wrong code after comput_at.
for i0 = 0 to 16
for j0 = 0 to 16
A[j0] = A[j0] + 1
B[i0][j0] = A[j0]
The problem happens just because the producer block has not produced the final correct output during the loop. The same problem will also happen in reorder with branches. Of course, we are not trying to support the case, but we need to ban it from the legal schedule.
But it is not an easy thing either.
The most radical solution is to forbid the schedule for the IR comes from the hybrid. It will solve every correctness problem since we only support the schedule for te.compute
. It is crazy and we will not use it when we have any other way. It would be our hold card.
The key point is that the producer can not provide validating results. We would like to make sure all the producer blocks are complete. Here is my proposal:
data_par
(not all the block_var)compute_at
;reorder
;!+=
to make the reduction (init update) to become a complete block.Since then, the complete block equals the current tvm stage, even the reduction block. We can forbid every risky operation which is not allowed with the complete restriction.
An insight we can get from the counterexample above is that we haven't come up with a validation mechanism for block and its surrounding loops. Actually, the block A's surrounding loops i0, j0 are not legal. The reasons are stated below.
The spirit of TIR is to express info in the block and generate loops outside the block to satisfy the constraints. Besides range
information, we also have to check the binding function, i.e. (vi, vj, ....) = f(i, j, ...)
. range
only tells us which subset of the instance space of block may be run by the outside loop, but it can't tell how many times a specific running instance will be touched(maybe 0, maybe larger than 1). In the counterexample above, each instance of block "A" will be run many times because of i0. Hence if we pose no more constraints, the block info is not complete.
One reasonable constraint for binding function f
is that it builds a 1-1 correspondence between the space of (i,j,k...)
and the space of (vi,vj,vk...)
, i.e. The loop iteration and binding function will iterate all the instances inside the block var space exactly once. The previous complete block definition + this 1-1 correspondence constraint will make sure it's safe to compute_at.
But it brings difficulty to check this constraint. Since split
and fuse
will make the binding function rather complicated and it's difficult to come up with a simple pattern that is compatible with current binding we have to support and at the same time is provable to be a 1-1 correspondence.
One algorithm I come up with is that we only consider the following bindings to be legal
1. vi=c1*i+b1, vj=c2*j+b2, vk=c3*k+b3 ... (one loop_var binds one block_var linearly)
2. if f is a legal binding and g is the binding after we applying `split` on f, then g is legal
3. if f is a legal binding and g is the binding after we applying `fuse` on f, then g is legal
and given a binding, we try to reverse the split
and fuse
effect to get vi=c1*i+b1, vj=c2*j+b2, vk=c3*k+b3
. If we fail, then it is not legal.
Between those 2 (excluding Proposal 0), we both prefer Proposal 1 for now.
The core idea is to use python AST’s structure to reflect TIR’s structure.
We divide IRNodes into 3 categories.
Intrinsic
In short, intrinsic are PrimExprs and Stmts without body, i.e. PrimExpr, Store, ReduceStep.
All these nodes appear as tir.xxx().
Special Stmts
In principle, we can represent the whole IR using only 1&2. But some node or some part of a node can be printed in a better format.
This is something for fruit of thought as a possible future work. Not necessarily actionable right now.
https://discuss.tvm.ai/t/rfc-bring-in-tensor-expression-autodiff/5987
Discusses how can we introduce tensor expr level AD to the te.compute. It would be interesting to think about how can we generalize to the TIR level. In particular, if we place restrictions, such as making sure all blocks are complete, would we be able to run autodiff on the TIR directly written in hybrid script.
It would be useful to discuss and align possible designs right now so we can prepared for such as change, if it is possible.
block_list
and dep_graph
One goal we have in the hybrid script is to support the roundtrip of any stage of the compilation. I tried to mock up an example to serve as basis and goalpost for our discussions
def lower_phase0(a, c):
n = tir.var()
A = buffer_bind(a, (n,), "float32")
C = buffer_bind(c, (n,), "float32")
B = buffer_allocate((n,), "float32")
with block([n]) as i:
B[i] = A[i] + 1
with block([n]) as i:
C[i] = B[i] * 2
def lower_phase1_complete(a, c):
n = tir.var()
A = buffer_bind(a, (n,), "float32")
C = buffer_bind(c, (n,), "float32")
B = buffer_allocate((n,), "float32")
for xi0 in grid(n):
with block([n]) as i0:
i0 = block_var_bind(xi0)
B[i0] = A[i0] + 1
for xi1 in grid(n):
with block([n]) as i1:
i1 = block_var_bind(xi1)
C[i1] = B[i1] * 2
def lower_phase2_simplified(a, c):
n = tir.var()
A = buffer_bind(a, (n,), "float32")
C = buffer_bind(c, (n,), "float32")
B = buffer_allocate((n,), "float32")
for xi0 in grid(n):
B[xi0] = A[xi0] + 1
for xi1 in grid(n):
C[xi0] = B[xi0] * 2
def lower_phase3_after_flatten(a, c):
n = tir.var()
A = buffer_bind(a, (n,), "float32")
C = buffer_bind(c, (n,), "float32")
Bdata = allocate((n,), "float32")
for xi0 in grid(n):
tir.store(Bdata, xi0, A.data[xi0] + 1)
for xi1 in grid(n):
tir.store(C.data, xi0, Bdata[xi0] * 2)
def lower_phase4_packedapi_lowering(args: handle, type_code: handle, nargs: int):
arg0 = tir.tvm_load_arg_handle(args, 0)
arg1 = tir.tvm_load_arg_handle(args, 1)
Adata = tir.tvm_struct_get(arg0, TVMStructKind.DATA)
Cdata = tir.tvm_struct_get(arg1, TVMStructKind.DATA)
Ashape = tir.tvm_struct_get(arg0, TVMStructKind.SHAPE)
n = tir.load(Ashape, 0, "int32")
Bdata = allocate((n,), "float32")
for xi0 in grid(n):
tir.store(Bdata, xi0, tir.load(Adata, xi0 + 1, "float32"))
for xi1 in grid(n):
tir.store(Cdata, xi0, tir.load(Bdata, xi0 * 2, "float32"))
Thanks to @spectrometerHBH on the first PR0 of hybrid script. Now that it almost lands, we can start to think about how to further enhance the syntax and improve the usecases:
Let us come up with a few cases where the user manually write hybrid script.
Of course many of these cases will only become useful once we also add auto addition of blocks.
But it is useful to be able to parse some of them. The goal is to come up with syntax sugars that make hybrid script as easy as te.compute functions.
@tvm.hybrid.script
def addone(a, c):
A = tir.bind_buffer(a, (100,), "float32")
C = tir.bind_buffer(a, (100,), "float32")
B = tir.alloc_buffer((100, ), "float32")
for i in range(100):
B[i] = A[i] + 1
for i in range(100):
C[i] = tir.exp(B[i])
It might be helpful to look into a few cases of te.hybrid(in topi) and see if we can migrate them to the new hybrid, adding necessary syntax sugars if needed.
Our goal is to eventually remove te.hybrid and use the tvm.hybrid as the only hybrid parser/printer.
Let us to be able to handle errors that arises along the way and make sure error messages are friendly as we try to manually write cases.
As we know that block vars are bridges between the inside block and the outside, which is their basic definition. However, the definition is too general to do operations such as comupte_at. We may need to add more constraints to simplify the operation.
There are different levels of constraint I can imagine:
producer_block(v1, v2, writes: A[v1][v2], B[v1][v1+v2])
producer_block(reads=A[0:4][0:4], B[0:4][0:8]) # v1 = [0:4], v2=[0:4]
Cons: It is still not easy to do tensorize, but it is possible to generate loops during compute_at. Also we can do block auto-generation and blockization
Pros: It might not handle the following compute_at
case
producer_block(v1, v2, writes: A[v1][v2], B[v1][v2+1])
In this case, to satisify the constraint, it will be changed to
producer_block(v1, v2, v3, writes: A[v1][v2], B[v1][v3])
This change will not only lose the information of v3=v2+1
, but also will generate another loop of v3 duing compute_at
Cons: Easy for tensorization, blockization and compute_at
Pros: Same as L1. However, A[i]=B[i + 1]
is not available in this case. It might have problems during padding.
Please have a look and share your ideas about this. @tqchen @spectrometerHBH
NOTE, the upstream structure is now stable, check apache/tvm#4647 for update guide, cc @Hzfengsy @spectrometerHBH
SampleFusibleLoops
and bug fixTest
Direct translation of GPU code
https://github.com/merrymercy/tvm-tensorir/blob/f6825efc70200a39bfcb697f52a187585edcd3ab/tensorir/tests/test_schedule.py#L175-L180
Schedule primitive support : bind
https://github.com/merrymercy/tvm-tensorir/blob/f6825efc70200a39bfcb697f52a187585edcd3ab/tensorir/tests/test_schedule.py#L195-L200
Memory scope, especially GPU Memory scope, is a mature solution in TVM to solve scheduling with hierarchical hardware memory. And now facing two problems in TIR schedule (also in TE schedule): cooperative fetching and tensorize with warp level instruction(tensor core)
During GPU gemm scheduling, we will have such an intermedia step:
for i = 0 to n
for j = 0 to n
A_shared[i, j] = A[i, j]
for i = 0 to n
for j = 0 to n
B_shared[i, j] = B[i, j]
for blockIdx.x = 0 to p
for blockIdx.y = 0 to p
for threadIdx.x = 0 to q
for threadIdx.y = 0 to q
....
Then if we would like to compute_at shared memory copy under threadIdx.y
, we would like to relax threadIdx loops. Here is the target IR:
for blockIdx.x = 0 to p
for blockIdx.y = 0 to p
for threadIdx.x = 0 to q
for threadIdx.y = 0 to q
for i = 0 to n/p
for j = 0 to n/p
A_shared[blockIdx.x*p + i, blockIdx.y * p + j] = A[i, j]
for i = 0 to n/p
for j = 0 to n/p
B_shared[blockIdx.x*p + i, blockIdx.y * p + j] = B[i, j]
....
In TE schedule, TensorCore scheduling does not guarantee absolutely correctness. We need another check to confirm the warp level instructions do execute at warp level.
Each block adds one execution scope: thread, warp, and block for GPU, for example:
block(thread-level)
A_shared[i, j] = A[i, j]
block(warp-level)
wmma.sync()
or
block(warp-level)
for i = 0 to 32 (bind = threadidx.x)
A_local[i] = A_shared[i]
block(block-level)
for i = 0 to n (bind = threadidx.x)
for j = 0 to n (bind = threadidx.y)
A_shared[i, j] = A[i]
Usually, a thread-level block can be defined before schedule, a warp-level block can be defined during tensor intrinsic. However, block-level may be defined by blockize
Now, we have the block execution scope and the block write buffer. We can check if the block is in thread-level scope and it also writes shared buffer, then we can just relax threadIdx during compute_at. We can make a builtin map to record that information.
There are two chance to check the correctness: during schedule and after the schedule
If we do check during the schedule, we can get the failure information as early as possible, but one challenge is that we don't have complete information(e.g. we have bound threadIdx.y but have not bound threadIdx.x)
If we check it after the schedule, we can simply check the correctness. The only disadvantage is the error reporting is a little bit late.
I'd like to check the correctness after the schedule and before BufferFlatten.
Might be a good item after apache/tvm#5372
This PR brings most of the TIR close to what we have in this repo. It might be a good chance to sync after that PR lands. Some noticable highlights include:
Some related changes in this repo:
I've come up with several ways to declare a reduction in hybrid
@tvm.tir.hybrid.script
def add(x, y):
return x + y
@tvm.tir.hybrid.script
def mul(x, y):
return x * y
C[i, j] = update("add", A[i, k] * B[k, j], initial = 0)
C[i, j] = update("mul", A[k], initial = 1)
C[i, j] = update(add(C[i, j], A[i, k] * B[k, j]), initial = 0)
C[i, j] = update(mul(C[i, j], A[k]), initial = 1)
C[i,j] = inc(A[i, k] * B[k, j]) # sugar
C[i,j] = mul(A[k]) # sugar
To represent it in C++ is more troublesome.
Comm_Reduce
node, but it's not an Stmt node, so it can not be directly attached to AST. We can wrap it with an AttrStmt Node. But it seems weird to do so.Comm_Reduce
with Reduce
. But only the combiner attribute of Reduce
is used.cc @tqchen @Hzfengsy if you have any suggestions or preferences.
This is part of cuda pseudocode for gemm using tensor core.
// some code for copying to shared memory
for (int k_step = 0; k_step < CHUNK_K; k_step++) {
wmma::fragment<wmma::matrix_a, M, N, K, half, wmma::row_major>
a[WARP_COL_TILES];
wmma::fragment<wmma::matrix_b, M, N, K, half, wmma::col_major>
b[WARP_ROW_TILES];
for (int i = 0; i < WARP_COL_TILES; i++) {
wmma::load_matrix_sync(a[i], tile_ptr, K * CHUNK_K + SKEW_HALF);
}
for (int j = 0; j < WARP_ROW_TILES; j++) {
wmma::load_matrix_sync(b[j], tile_ptr, K * CHUNK_K + SKEW_HALF);
}
for (int i = 0; i < WARP_COL_TILES; i++) {
for (int j = 0; j < WARP_ROW_TILES; j++) {
wmma::mma_sync(c[i][j], a[i], b[j], c[i][j]);
}
}
}
The key difference between tensor core and other tensorization is shared memory and block thread architecture. Each wmma
operator calculates a 16*16 gemm, while it uses all threads in a warp. So it is necessary to consider a new schedule in cuda schedule.
Help wanted:
wmma::mma_sync(c[i][j], a[i], b[j], c[i][j]);
by tensorization., but it is difficult to declare the wmma::fragment
, split the data and store them into fragments.A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.