nod-ai / iree-amd-aie Goto Github PK
View Code? Open in Web Editor NEWIREE plugin repository for the AMD AIE accelerator
License: Apache License 2.0
IREE plugin repository for the AMD AIE accelerator
License: Apache License 2.0
This is just a tracker for all the ops that we keep on adding support along with the shapes/element type constraints that is currently being supported.
Operation | Shapes | Element type | Action item |
---|---|---|---|
Matmul | 8 x (8, 16, 32, 64, 2048) x (8, 16, 32, 64) | i32 | #124 (comment) |
Matmul transpose | 8 x (8, 16, 32, 64, 2048) x (8, 16, 32, 64) | i32 | #124 (comment) |
Batch matmul | #96 (comment) |
This describes a numerical issue, and the current fix in: #218
A brief history, and summary of findings:
On CI there are also failures, but much rarer (1 in 1000 failures).
To narrow down the issue, I wanted to reproduce (or not) the error on a different stack. I built mlir-air and managed to run https://github.com/Xilinx/mlir-air/tree/main/test/xrt/08_gemm_extern_vec successfully (thanks @erwei-xilinx for helping me)
I modified the above mlir-air test to use the xclbin and lx6 instructions generated in #208 . Running 2,000 times, I saw no numerical errors. This made me think this is NOT a compiler error (the xclbin seems to be totally fine when run through mlir-air). But rather something specific to our setup in iree-amd-aie. I started digging into the runtime, adding print statements to see when we sync: see branch https://github.com/newling/iree-amd-aie/tree/intermittent_numerics_debugging for the print additions.
I noticed that the mlir-air matmul test has a sync for the 'C' (C = matmul(A, B)) tensor before running the kernel, but that iree-amd-aie does not have such a sync (iree-amd-aie only syncs C after running the kernel). @nirvedhmeshram and I figured out that adding the sync here (see this current PR) makes C sync before the matmul, and resolves the numerical intermittent error.
Consider this test case:
func.func @matmul_64x64_32xi32_(%lhs : tensor<64x16xi32>, %rhs : tensor<16x64xi32>) -> tensor<64x64xi32> {
%init_acc = tensor.empty() : tensor<64x64xi32>
%c0_acc_type = arith.constant 7 : i32
%acc = linalg.fill ins(%c0_acc_type : i32) outs(%init_acc : tensor<64x64xi32>) -> tensor<64x64xi32>
%result = linalg.matmul ins(%lhs, %rhs: tensor<64x16xi32>, tensor<16x64xi32>) outs(%acc: tensor<64x64xi32>) -> tensor<64x64xi32>
return %result: tensor<64x64xi32>
}
i.e. a matmul where the initial buffer is filled with '7' (C = 7 + A @ B). compiling this through iree-amd-aie and then running iree-run-module (before the fix in this PR) as follows:
iree-run-module --module=success.vmfb --device=xrt --input="64x16xi32=1" --input="16x64xi32=2" --expected_output="64x64xi32=39"
produces about 50% successful runs (for me locally). The syncs printed (see branch https://github.com/newling/iree-amd-aie/tree/intermittent_numerics_debugging) are as follows in the case of successful runs:
Syncing instr buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 4096
This buffer is at address 94100515692544
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 4096
This buffer is at address 94100515692544
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 4096
This buffer is at address 94100515700736
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 4096
This buffer is at address 94100515700736
EXEC @matmul_64x64_32xi32_
Binding count: 3
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 16384
This buffer is at address 94100515893248
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 16384
This buffer is at address 94100515893248
[SUCCESS] all function outputs matched their expected values.
And about 50% of the time the run fails. The logs when it fails look like:
Syncing instr buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 4096
This buffer is at address 94816848805888
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 4096
This buffer is at address 94816848805888
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 4096
This buffer is at address 94816848814080
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 4096
This buffer is at address 94816848814080
EXEC @matmul_64x64_32xi32_
Binding count: 3
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
Syncing buffer from device (XCL_BO_SYNC_BO_FROM_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
Syncing buffer to device (XCL_BO_SYNC_BO_TO_DEVICE)
This is a buffer of size 16384
This buffer is at address 94816849006592
[FAILED] result[0]: element at index 4080 (0) does not match the expected (39); expected that the view is equal to contents of a view of 64x64xi32
expected:
64x64xi32=[39 39 39 39 39 39 39 39 39 39 ...
Comparing the syncs printed in the successful and failing cases, one curious difference is present: when it fails, there seem to be 3 syncs of C after the run (Note that C is the buffer of size 16384). When it passes, there is only one one sync.
If we run the command
iree-run-module --module=success.vmfb --device=xrt --input="64x16xi32=1" --input="16x64xi32=2" --output_max_element_count=99999
So that we can see the values when there is an error, we see:
39 39 39 39 39 39 39 39 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
i.e. there are 0's instead of 39's.
This makes us think that IREE is writing 0's to C after the AIE has written the result (39) back from device. There is a PR illustrating this theory over here: #224 In that PR, we see what happens manually in the cases where we
In the case where we 1: do write to XRT and then 2: do not sync, we reproduce the iree-amd-aie bug which this current PR aims to fix.
Working on the assumption that IREE is doing a memset to 0 of size 6464sizeof(int) bytes to the XRT DDR buffer, where is this happening? At this point I have unofortunately hit a brick wall. I have inserted 100's of print statements in the IREE runtime but cannot narrow down where such a memset is happening.
Note that such a memset is not necessary for correctness: The matmul is not an accumulation, and so the initial value of the the C buffer can safely be garbage. Ideally IREE would realize that and not set C to zero.
Supposing that IREE cannot not (that is, IREE must) initialize with zeros. Is there somewhere in IREE where we can call iree_hal_buffer_invalidate_range and/or iree_hal_buffer_flush_range? These are the functions where we call the syncs. Here too I have unfortunately hit a wall trying to decipher the runtime.
This is not urgent but what I want an expert (@benvanik, @antiagainst) opinion on at some point is
Where is C being 0 initialized in the HAL runtime, and can this be avoided?
Where is iree_hal_xrt_direct_command_buffer_dispatch:
being called from in the iree runtime, and would it make sense, for that caller, to call iree_hal_buffer_invalidate_range and/or iree_hal_buffer_flush_range as well?See also the comment: #218 (comment)
Currently the backend bails in MaterializeInterface pass becuase we dont have a ExectuableTargetAttr attached
./tools/iree-compile --iree-hal-target-backends=amd-aie ../SRT/tests/e2e/linalg/large_linalg_matmul.mlir
../SRT/tests/e2e/linalg/large_linalg_matmul.mlir:13:8: error: no executable targets specified for translation
%D = linalg.matmul ins(%lhs, %rhs: tensor<2048x1024xf32>, tensor<1024x512xf32>)
^
Vivek please add details
To reproduce you need PR #63 and mlir-air branch https://github.com/spaceotter/mlir-air/tree/iree-up2. Remember to update the iree checkout and its submodules according to sync_deps.py.
I compile the matmul test like this:
../iree/build/tools/iree-compile --iree-hal-target-backends=amd-aie /proj/xcohdstaff6/erieaton/iree-amd-aie/tests/samples/matmul_fill_static_i32.mlir --iree-codegen-transform-dialect-library=/proj/xcohdstaff6/erieaton/iree-amd-aie/tests/samples/matmul_fill_spec_pad.mlir --iree-amd-aie-peano-install-dir=../install --iree-amd-aie-mlir-aie-install-dir=../mlir-aie/install --iree-amd-aie-vitis-install-dir=/proj/xbuilds/SWIP/2023.2_1013_2256/installs/lin64/Vitis/2023.2 --iree-hal-dump-executable-files-to=/proj/xcohdstaff6/erieaton/test4 --iree-hal-dump-executable-intermediates-to=/proj/xcohdstaff6/erieaton/test4 --iree-amd-aie-show-invoked-commands --mlir-print-ir-after-all --mlir-print-ir-before=air-dma-to-channel
It results in this error:
/proj/xcohdstaff6/erieaton/test4/configured_module_matmul_static_dispatch_0.mlir:19:8: error: failed to legalize operation 'affine.apply' marked as erased
%7 = linalg.matmul ins(%3, %4 : tensor<8x16xi32>, tensor<16x8xi32>) outs(%6 : tensor<8x8xi32>) -> tensor<8x8xi32>
^
/proj/xcohdstaff6/erieaton/test4/configured_module_matmul_static_dispatch_0.mlir:19:8: note: see current operation: %67 = "affine.apply"(%arg6) <{map = affine_map<()[s0] -> (s0 * 8)>}> : (index) -> index
<unknown>:0: note: found live user of result #0: air.execute_terminator %22 : index
<unknown>:0: error: error
Full log here
out.txt
Updating mlir-air c1dc9a and mlir-aie 5164b3 submodules to latest : is leading to the failure of the following lit test for simple-pack
(already part of main
) :-
func.func @matmul_large(%lhs: tensor<2048x512xi32>, %rhs: tensor<512x2048xi32>) -> tensor<2048x2048xi32> {
%empty = tensor.empty() : tensor<2048x2048xi32>
%cst = arith.constant 0 : i32
%fill = linalg.fill ins(%cst : i32) outs(%empty : tensor<2048x2048xi32>) -> tensor<2048x2048xi32>
%res = linalg.matmul ins(%lhs, %rhs: tensor<2048x512xi32>, tensor<512x2048xi32>)
outs(%fill: tensor<2048x2048xi32>) -> tensor<2048x2048xi32>
return %res : tensor<2048x2048xi32>
}
I've added the failure e2e log and the successful e2e log here.
Dispatch Type | Type | Shapes | Link | Running on AIE |
---|---|---|---|---|
Elementwise | i64/f32 | (8,8) | link | No |
Scan | i64 | (1,8) | link | No |
Elementwise w/ tensor.extract | i64 | (1,8,2048) | link | No |
Reduction + Elementwise | f32 | (8,2048) | link | No |
Reduction + Elementwise | f32 | (8,2048) | link | No |
Matmul transpose b | f32 | (8,2048,2048) | link | No |
Elementwise | f32 | (8,32,64) | link | No |
Elementwise | f32 | (8,32,64) | link | No |
Batchmatmul transpose b+Elementwise | f32 | (32,8,8,64) | link | No |
softmax+elemetwise | f32 | (32,8,8) | link | No |
Batchmatmul +Elementwise(transpose) | f32 | (32,8,64,8) | link | No |
Matmul transpose b +Elementwise | f32 | (8,2048,2048) | link | No |
Matmul transpose b + Elementwise | f32 | (8,8192,2048) | link | No |
Matmul transpose b + Elementwise | f32 | (8,2048,8192) | link | No |
Matmul transpose b | f32 | (8,50272,2048) | link | No |
since we are currently using the run_pp mode, I was wondering if this crash is expected when trying to change the kernel.json and the kernel call without changing anything else.
This is the early part of my new kernel.json
ps-kernels": {
"kernels": [
{
"arguments": [
{
"address-qualifier": "GLOBAL",
"name": "opcode",
"offset": "0x00",
"type": "uint64_t"
},
{
"address-qualifier": "GLOBAL",
"memory-connection": "SRAM",
"name": "instr",
"offset": "0x08",
"type": "char *"
},
{
"address-qualifier": "SCALAR",
"name": "ninstr",
"offset": "0x10",
"type": "uint64_t"
},
And at runtime I am doing
xrt::kernel kernel = *kernel_params.kernel;
xrt::bo instr = *kernel_params.instr;
uint32_t num_instr = kernel_params.num_instr;
xrt::run run = xrt::run(kernel);
// Index to push arguments on the kernel.
iree_host_size_t arg_index = 0;
run.set_arg(arg_index++, 0x2);
// First argument is the LX6 instructions.
run.set_arg(arg_index++, instr);
It crashes on the last line with the following dump
terminate called after throwing an instance of 'std::out_of_range'
what(): vector::_M_range_check: __n (which is 65535) >= this->size() (which is 2)
Program received signal SIGABRT, Aborted.
__pthread_kill_implementation (no_tid=0, signo=6, threadid=140737352633344) at ./nptl/pthread_kill.c:44
44 ./nptl/pthread_kill.c: No such file or directory.
(gdb) bt
#0 __pthread_kill_implementation (no_tid=0, signo=6, threadid=140737352633344) at ./nptl/pthread_kill.c:44
#1 __pthread_kill_internal (signo=6, threadid=140737352633344) at ./nptl/pthread_kill.c:78
#2 __GI___pthread_kill (threadid=140737352633344, signo=signo@entry=6) at ./nptl/pthread_kill.c:89
#3 0x00007ffff7442476 in __GI_raise (sig=sig@entry=6) at ../sysdeps/posix/raise.c:26
#4 0x00007ffff74287f3 in __GI_abort () at ./stdlib/abort.c:79
#5 0x00007ffff78a4f26 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#6 0x00007ffff78b6d9c in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#7 0x00007ffff78b6e07 in std::terminate() () from /lib/x86_64-linux-gnu/libstdc++.so.6
#8 0x00007ffff78b7068 in __cxa_throw () from /lib/x86_64-linux-gnu/libstdc++.so.6
#9 0x00007ffff78a82f1 in ?? () from /lib/x86_64-linux-gnu/libstdc++.so.6
#10 0x00007ffff7d8d9fa in (anonymous namespace)::encoded_bitset<64ul>::test(unsigned long) const () from /opt/xilinx/xrt/lib/libxrt_coreutil.so.2
#11 0x00007ffff7d933cc in xrt::run_impl::validate_bo_at_index(unsigned long, xrt::bo const&) () from /opt/xilinx/xrt/lib/libxrt_coreutil.so.2
#12 0x00007ffff7d89c4e in xrt::run::set_arg_at_index(int, xrt::bo const&) () from /opt/xilinx/xrt/lib/libxrt_coreutil.so.2
#13 0x0000555555629146 in xrt::run::set_arg (this=0x7fffffffb7f8, index=1, boh=...) at /opt/xilinx/xrt/include/xrt/xrt_kernel.h:439
#14 iree_hal_xrt_direct_command_buffer_dispatch (base_command_buffer=<optimized out>, executable=<optimized out>, entry_point=<optimized out>, workgroup_x=<optimized out>, workgroup_y=<optimized out>, workgroup_z=<optimized out>)
at /proj/xsjhdstaff4/nmeshram/iree-amd-aie/runtime/src/iree-amd-aie/driver/xrt/direct_command_buffer.cc:335
I wonder if this is because the instruction bo is not right for what I am trying to do?
Hi,
I think it is worth adding that the -DIREE_CMAKE_PLUGIN_PATHS=
flag should point to the path relative to the iree source directory. I was trying to build in a different location, and took me a while to realize that it is not relative to the build directory.
This is because in the iree repo in build_tools/cmake/iree_plugin_register.cmake:116
, ${IREE_SOURCE_DIR}
is used:
cmake_path(ABSOLUTE_PATH _d BASE_DIRECTORY "${IREE_SOURCE_DIR}" NORMALIZE)
Really low priority, and feel free to close right away.
Please refer to this log where the CI captures this issue
https://gist.github.com/nirvedhmeshram/218ce409d5e26d390d053cfc74419fa4
To reproduce, you can run these CI actions locally,
https://github.com/nod-ai/iree-amd-aie/blob/main/.github/workflows/ci.yml#L102-L110
To cause the issue you will need the change done by #176
After the recent IREE bump, running the transform dialect script example https://github.com/nod-ai/iree-amd-aie/blob/main/tests/transform_dialect/matmul_fill_spec_pad_pack.mlir gives the following error:
<stdin>:3:5: error: could not find a nested named sequence with name: __kernel_config
hal.executable.variant public @amdaie_xclbin_fb target(<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "none"}>) {
The example can work with iree-opt --iree-transform-dialect-interpreter
after deleting some ops in the transform dialect script. The example is currently not in the CI, as it doesn't have any checks.
The goal of this RFC is to discuss how peeled matmul IR can be lowered to AIE code. cc @MaheshRavishankar @erwei-xilinx @nirvedhmeshram @yzhang93 @Abhishek-Varma
A few notes before to start:
To start, here is some sample peeled matmul IR:
...
scf.forall (%arg4, %arg5) in (1, 2) {
%5 = affine.apply #map(%arg4)
%6 = affine.apply #map(%arg5)
%subview_7 = memref.subview %alloc_4[%5, 0] [32, 64] [1, 1] : memref<32x64xi32, 1> to memref<32x64xi32, strided<[64, 1], offset: ?>, 1>
%subview_8 = memref.subview %alloc_3[0, %6] [64, 32] [1, 1] : memref<64x64xi32, 1> to memref<64x32xi32, strided<[64, 1], offset: ?>, 1>
%subview_9 = memref.subview %alloc_2[%5, %6] [32, 32] [1, 1] : memref<32x64xi32, 1> to memref<32x32xi32, strided<[64, 1], offset: ?>, 1>
linalg.fill ins(%c0_i32 : i32) outs(%alloc_1 : memref<4x8x4x8xi32, 2>)
iree_linalg_ext.pack %subview_7 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %alloc_0 : (memref<32x64xi32, strided<[64, 1], offset: ?>, 1> memref<8x8x4x8xi32, 2>)
iree_linalg_ext.pack %subview_8 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %alloc : (memref<64x32xi32, strided<[64, 1], offset: ?>, 1> memref<4x8x8x8xi32, 2>)
linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%alloc_0, %alloc : memref<8x8x4x8xi32, 2>, memref<4x8x8x8xi32, 2>) outs(%alloc_1 : memref<4x8x4x8xi32, 2>) {
^bb0(%in: i32, %in_10: i32, %out: i32):
%7 = arith.muli %in, %in_10 : i32
%8 = arith.addi %out, %7 : i32
linalg.yield %8 : i32
}
iree_linalg_ext.unpack %alloc_1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %subview_9 : (memref<4x8x4x8xi32, 2> memref<32x32xi32, strided<[64, 1], offset: ?>, 1>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
linalg.copy ins(%alloc_2 : memref<32x64xi32, 1>) outs(%subview : memref<32x64xi32, strided<[64, 1], offset: ?>>)
scf.for %arg4 = %c64 to %c1024 step %c64 {
%subview_7 = memref.subview %1[0, %arg4] [32, 64] [1, 1] : memref<32x1024xi32, strided<[?, ?], offset: ?>> to memref<32x64xi32, strided<[?, ?], offset: ?>>
%subview_8 = memref.subview %0[%arg4, 0] [64, 64] [1, 1] : memref<1024x64xi32, strided<[?, ?], offset: ?>> to memref<64x64xi32, strided<[?, ?], offset: ?>>
linalg.copy ins(%subview_7 : memref<32x64xi32, strided<[?, ?], offset: ?>>) outs(%alloc_4 : memref<32x64xi32, 1>)
linalg.copy ins(%subview_8 : memref<64x64xi32, strided<[?, ?], offset: ?>>) outs(%alloc_3 : memref<64x64xi32, 1>)
linalg.copy ins(%subview : memref<32x64xi32, strided<[64, 1], offset: ?>>) outs(%alloc_2 : memref<32x64xi32, 1>)
scf.forall (%arg5, %arg6) in (1, 2) {
%5 = affine.apply #map(%arg5)
%6 = affine.apply #map(%arg6)
%subview_9 = memref.subview %alloc_4[%5, 0] [32, 64] [1, 1] : memref<32x64xi32, 1> to memref<32x64xi32, strided<[64, 1], offset: ?>, 1>
%subview_10 = memref.subview %alloc_3[0, %6] [64, 32] [1, 1] : memref<64x64xi32, 1> to memref<64x32xi32, strided<[64, 1], offset: ?>, 1>
%subview_11 = memref.subview %alloc_2[%5, %6] [32, 32] [1, 1] : memref<32x64xi32, 1> to memref<32x32xi32, strided<[64, 1], offset: ?>, 1>
iree_linalg_ext.pack %subview_11 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %alloc_1 : (memref<32x32xi32, strided<[64, 1], offset: ?>, 1> memref<4x8x4x8xi32, 2>)
iree_linalg_ext.pack %subview_9 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %alloc_0 : (memref<32x64xi32, strided<[64, 1], offset: ?>, 1> memref<8x8x4x8xi32, 2>)
iree_linalg_ext.pack %subview_10 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %alloc : (memref<64x32xi32, strided<[64, 1], offset: ?>, 1> memref<4x8x8x8xi32, 2>)
linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%alloc_0, %alloc : memref<8x8x4x8xi32, 2>, memref<4x8x8x8xi32, 2>) outs(%alloc_1 : memref<4x8x4x8xi32, 2>) {
^bb0(%in: i32, %in_12: i32, %out: i32):
%7 = arith.muli %in, %in_12 : i32
%8 = arith.addi %out, %7 : i32
linalg.yield %8 : i32
}
iree_linalg_ext.unpack %alloc_1 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [4, 8] into %subview_11 : (memref<4x8x4x8xi32, 2> memref<32x32xi32, strided<[64, 1], offset: ?>, 1>)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
linalg.copy ins(%alloc_2 : memref<32x64xi32, 1>) outs(%subview : memref<32x64xi32, strided<[64, 1], offset: ?>>)
}
...
Issues with lowering peeled matmul IR to AIE code:
Now, looking at the AIE core side, peeled core code can be accomplished as shown in the snipped below. The goal is to get to something like this on the AIE core side, starting from the higher level peeled matmul IR above.
%core_0_2 = aie.core(%tile_0_2) {
%0 = aie.objectfifo.acquire @outC0(Produce, 1) : !aie.objectfifosubview<memref<32xf32>>
%1 = aie.objectfifo.subview.access %0[0] : !aie.objectfifosubview<memref<32xf32>> -> memref<32xf32>
%2 = aie.objectfifo.acquire @inA0(Consume, 1) : !aie.objectfifosubview<memref<32x32xbf16>>
%3 = aie.objectfifo.subview.access %2[0] : !aie.objectfifosubview<memref<32x32xbf16>> -> memref<32x32xbf16>
%4 = aie.objectfifo.acquire @inB(Consume, 1) : !aie.objectfifosubview<memref<32xbf16>>
%5 = aie.objectfifo.subview.access %4[0] : !aie.objectfifosubview<memref<32xbf16>> -> memref<32xbf16>
func.call @zero_vectorized_f32(%1) : (memref<32xf32>) -> ()
func.call @matvec_vectorized_bf16_f32(%3, %5, %1) : (memref<32x32xbf16>, memref<32xbf16>, memref<32xf32>) -> ()
aie.objectfifo.release @inA0(Consume, 1)
aie.objectfifo.release @inB(Consume, 1)
%c0_0 = arith.constant 0 : index
%c1_1 = arith.constant 1 : index
%c1_2 = arith.constant 1 : index
scf.for %arg1 = %c0_0 to %c1_1 step %c1_2 {
%6 = aie.objectfifo.acquire @inA0(Consume, 1) : !aie.objectfifosubview<memref<32x32xbf16>>
%7 = aie.objectfifo.subview.access %6[0] : !aie.objectfifosubview<memref<32x32xbf16>> -> memref<32x32xbf16>
%8 = aie.objectfifo.acquire @inB(Consume, 1) : !aie.objectfifosubview<memref<32xbf16>>
%9 = aie.objectfifo.subview.access %8[0] : !aie.objectfifosubview<memref<32xbf16>> -> memref<32xbf16>
func.call @matvec_vectorized_bf16_f32(%7, %9, %1) : (memref<32x32xbf16>, memref<32xbf16>, memref<32xf32>) -> ()
aie.objectfifo.release @inA0(Consume, 1)
aie.objectfifo.release @inB(Consume, 1)
}
aie.objectfifo.release @outC0(Produce, 1)
aie.end
} {link_with = "mv.o"}
For now, only option 1 is worked out to be discussed:
This option can be accomplished through following steps:
main
(AieCombineCoreCode)The transformations above warrants a demonstration of conceptual lowering which you can find here: https://gist.github.com/jtuyls/7e6a41619666fa3186b1a8156978eedc
Not worked out for now.
GEMM size (M x N x K x datatype) | HW launch size | M&N scaling required | K scaling required | "Pad" flow support | "Pack" flow support |
---|---|---|---|---|---|
(2048 x 2048 x 512 x i32) | (64 x 64 x 512) | Yes | No | running on AIE | running on AIE |
(2048 x 2048 x 2048 x i32) | (64 x 64 x 512) | Yes | Yes | running on AIE | running on AIE |
Currently, MLIR-AIR's development is focused on delivering discrete compilation pipelines each providing coverage to a corresponding set of new GEMM scenarios. This means that the first working compilation flows for, e.g. big Ns vs small Ns may require different MLIR-AIR passes / parameters. Once coverage is provided, we shall move to find a unified C++ pipeline that generalizes to all.
Currently the iree_linalg_ext.pack operations which map well to the aie dma operations are decomposed to air.dma_memcpy_nd which doesnt seem to model it as well since it doesnt have a concept of wrap or inner tile. We should add a new op in air dialect that maps one to one to iree_linalg_ext.pack-> air.dma_op -> aie.dma_op
While modifying the iree-amd-aie pipeline, which currently looks like:
passManager.addPass(createAMDAIEPackToDmaPass())
...
passManager.addPass(createAMDAIECanonicalizeDmaPass());
passManager.addPass(xilinx::air::createCopyToDmaPass());
I thought it'd make sense to move the pass which runs canonicalization of dmas to be after lowering copy to dma. Because that way we canonicalize the dmas generated in that pass too.
This results in an assertion error later in the pipeline:
iree-compile: iree-amd-aie/third_party/mlir-air/mlir/lib/Conversion/AIRRtToIpuPass.cpp:525: void (anonymous namespace)::tileIllegalWrapDim(airrt::DmaMemcpyNdOp): Assertion `!(const_wrap % (AIE2_WRAP_UPPER_BOUND / 2)) && "Currently do not support remainder tiles"' failed.
See https://github.com/newling/iree-amd-aie/tree/reproducer_odd_dims_assertion
It'd be good if we could run any canonicalization pass anywhere, I think.
We've been noticing an intermittent numerical error with pack-peel pipeline and data type bf16. To reproduce the problem, one can run this test locally for multiple times and the test results and error are as below:
error: the actual and expected result matrices disagree at row 48, column 16.
actual value: 0
expected value: -54
left-hand side (rows 40..55 out of 0..63, columns 0..15 out of 0..127)
0 1 2 -2 1 2 1 2 2 -2 0 -1 2 -1 2 -1
-2 1 1 -1 1 1 -1 1 -2 -2 -2 2 0 2 -1 1
0 -1 1 2 0 2 1 1 -1 -2 1 0 2 2 -1 2
2 -1 1 2 0 2 0 -2 1 2 2 2 1 -2 -2 2
-1 -1 -1 1 0 2 0 1 -2 -1 2 2 -2 1 1 1
-1 0 1 2 -2 2 -2 2 1 -2 -1 0 1 -2 0 1
2 -1 -1 1 0 2 0 1 1 0 -2 -1 2 1 -1 -2
1 2 0 -1 -1 -2 0 2 0 2 2 1 -2 0 -1 2
-2 2 -2 0 0 1 1 -2 1 0 -2 -1 1 -1 -1 2
2 -2 -1 2 1 -2 1 -2 -1 1 1 2 -2 1 2 -1
2 0 -2 0 1 2 -1 0 1 -1 1 1 -2 1 0 -1
-1 0 1 -2 -2 -2 0 -2 0 -1 1 2 -1 -2 -1 2
0 -2 -1 -2 2 2 -2 -1 -2 2 1 -1 -2 0 1 2
2 2 2 -1 1 2 1 -2 -2 -1 -1 -1 1 -2 0 0
-2 2 0 -2 2 -1 -1 -1 2 2 2 -1 0 -2 1 1
1 -2 -2 1 0 -1 -1 0 0 2 0 -2 -2 2 2 0
right-hand side (rows 0..15 out of 0..127, columns 8..23 out of 0..63)
-1 2 -2 1 0 2 -2 -1 0 2 1 0 0 -2 -1 -1
-1 0 2 0 1 2 2 1 -2 -2 -2 2 -1 -2 1 0
1 -2 -2 -2 1 0 2 2 1 -1 -1 2 2 -2 -1 1
2 0 0 -1 -1 -2 0 1 -2 1 2 -1 -1 0 0 -2
-1 0 2 -2 0 0 -1 0 -1 1 -2 1 -1 2 0 -1
-2 -2 0 0 -1 2 -2 -1 -2 -2 -1 2 -2 0 -2 1
0 -1 -2 -2 0 2 0 -2 -2 -2 -2 -1 2 1 2 1
2 2 1 -2 0 -1 2 0 2 -1 -1 -1 2 2 -1 -2
-1 1 -2 0 -1 1 0 -1 -2 1 0 1 -1 -2 1 -1
1 -1 0 0 -1 2 -2 1 -1 1 -2 -2 -1 0 2 -2
-2 -2 2 1 -1 0 2 1 2 -1 1 -1 1 1 1 1
0 0 -2 1 -2 -2 1 -1 -2 0 0 1 1 2 2 -2
-1 -2 2 -1 2 -2 1 0 0 -2 -2 0 0 0 0 -2
1 1 -1 0 2 -2 0 -1 -2 0 0 -2 1 -2 0 1
2 -1 0 0 -1 -2 -2 1 2 1 1 1 -1 -2 0 2
1 2 -1 2 1 2 -2 2 -2 2 2 2 1 1 2 2
expected result (rows 40..55 out of 0..63, columns 8..23 out of 0..63)
-29 -32 -35 -41 -12 16 -8 -32 35 -29 0 18 -23 -37 -23 19
13 9 8 -9 12 -26 -8 -13 8 -3 -4 6 48 26 -29 37
-35 -30 -2 -52 25 5 10 -39 25 3 9 -31 24 -26 -19 28
0 -10 -27 22 11 21 8 15 -19 9 -10 15 -10 14 43 -7
31 12 -3 -1 -25 -33 -19 -2 34 24 32 5 -11 -6 -21 49
53 -36 -2 26 -37 16 -41 14 -4 11 -7 3 2 -11 -23 4
-22 4 -17 -33 -1 -4 55 -28 -14 -31 -39 -40 31 12 -20 -37
15 -8 8 34 -14 40 25 23 -29 5 -10 2 -29 24 -12 -2
-33 -17 -33 38 28 -6 21 8 -54🦄 -18🦄 -46🦄 -10🦄 -28🦄 -35🦄 34🦄 8🦄
4🦄 -10🦄 13🦄 21🦄 -5🦄 -17🦄 27🦄 1🦄 -34🦄 -6🦄 -12🦄 -11🦄 -31🦄 34🦄 35🦄 5🦄
23🦄 7🦄 -16🦄 14🦄 9🦄 -25🦄 -26🦄 -32🦄 10 0 27 0 -36 -5 -28 8
30 -25 -7 66 -5 -2 21 30 19 -21 31 19 27 -1 5 28
0 -41 9 26 -8 32 -55 -9 8 35 5 -17 -1 -10 34 11
-7 -22 32 -2 -8 26 -12 -13 -21 -8 35 43 14 10 -6 -20
-15🦄 2🦄 5🦄 2🦄 -2🦄 25🦄 -14🦄 -43🦄 9 12 -18 14 -15 -33 5 30
39🦄 12🦄 25🦄 -4🦄 6🦄 13🦄 -1🦄 26🦄 -4🦄 13🦄 37🦄 -54🦄 22🦄 14🦄 19🦄 25🦄
actual result (rows 40..55 out of 0..63, columns 8..23 out of 0..63)
-29 -32 -35 -41 -12 16 -8 -32 35 -29 0 18 -23 -37 -23 19
13 9 8 -9 12 -26 -8 -13 8 -3 -4 6 48 26 -29 37
-35 -30 -2 -52 25 5 10 -39 25 3 9 -31 24 -26 -19 28
0 -10 -27 22 11 21 8 15 -19 9 -10 15 -10 14 43 -7
31 12 -3 -1 -25 -33 -19 -2 34 24 32 5 -11 -6 -21 49
53 -36 -2 26 -37 16 -41 14 -4 11 -7 3 2 -11 -23 4
-22 4 -17 -33 -1 -4 55 -28 -14 -31 -39 -40 31 12 -20 -37
15 -8 8 34 -14 40 25 23 -29 5 -10 2 -29 24 -12 -2
-33 -17 -33 38 28 -6 21 8 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞
0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞
0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 10 0 27 0 -36 -5 -28 8
30 -25 -7 66 -5 -2 21 30 19 -21 31 19 27 -1 5 28
0 -41 9 26 -8 32 -55 -9 8 35 5 -17 -1 -10 34 11
-7 -22 32 -2 -8 26 -12 -13 -21 -8 35 43 14 10 -6 -20
0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 9 12 -18 14 -15 -33 5 30
0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞 0🐞
Note:
This issue can be closed when the corresponding test passes
Hi, I encountered an RuntimeError while running the sync_deps.py script
Errors during submodule fetch:
third_party/llvm-project
Traceback (most recent call last):
File "/home/vimal/Edge_ai/iree-amd-aie/sync_deps.py", line 123, in <module>
main()
File "/home/vimal/Edge_ai/iree-amd-aie/sync_deps.py", line 75, in main
run(fetch_args, repo_dir)
File "/home/vimal/Edge_ai/iree-amd-aie/sync_deps.py", line 117, in run
raise RuntimeError(f"Git command failed: {args_text} (from {cwd})"
RuntimeError: Git command failed: git fetch origin 4691fc5bf37d6ba8aeac6a1c5382a595fa900f4c (from /home/vimal/Edge_ai/iree)
Here's a test kernel using bf16:
func.func @matmul_static(%lhs : tensor<8x16xbf16>,
%rhs : tensor<16x8xbf16>) -> tensor<8x8xbf16> {
%empty = tensor.empty() : tensor<8x8xbf16>
%cst = arith.constant 0.0 : bf16
%fill = linalg.fill ins(%cst : bf16) outs(%empty : tensor<8x8xbf16>) -> tensor<8x8xbf16>
%2 = linalg.matmul ins(%lhs, %rhs : tensor<8x16xbf16>, tensor<16x8xbf16>)
outs(%fill : tensor<8x8xbf16>) -> tensor<8x8xbf16>
return %2 : tensor<8x8xbf16>
}
The error you get during compilation right now is in the validator for aiex.ipu.dma_memcpy_nd:
/home/eric/src/test1/configured_module_matmul_static_dispatch_0.mlir:19:8: error: 'aiex.ipu.dma_memcpy_nd' op must be used with memref type i32.
%7 = linalg.matmul ins(%3, %4 : tensor<8x16xbf16>, tensor<16x8xbf16>) outs(%6 : tensor<8x8xbf16>) -> tensor<8x8xbf16>
^
/home/eric/src/test1/configured_module_matmul_static_dispatch_0.mlir:19:8: note: see current operation: "aiex.ipu.dma_memcpy_nd"(%arg0) <{id = 1 : i64, metadata = @airMemcpyId4, operandSegmentSizes = array<i32: 1, 0, 0, 0>, static_offsets = array<i64: 0, 0, 0, 0>, static_sizes = array<i64: 1, 1, 8, 16>, static_strides = array<i64: 0, 0, 16>, x = 0 : i64, y = 0 : i64}> : (memref<8x16xbf16>) -> ()
/
If I bypass the validator, I'd expect maybe some type error in the MIR or some out of bounds runtime error, but instead the assembler gives this mysterious error:
LLVM ERROR: unable to legalize instruction: %422:_(<2 x s64>) = G_ADD %286:_, %82:_ (in function: core_0_2)
The CI should also test the pack pipeline(s), more matmul sizes, more input ops, with and without accumulation, etc.
As now we move data tiling within the scf.for
loop to make large input data fit the memTile, it brings an issue that linalg.fill
is created in global memory. And thus the initialized output needs to be copied from global to shared memory and then from shared to local memory. We'll need to find a way to initialize data directly in local memory.
The current IR can be found below
executable_target_elf = #hal.executable.target<"amd-aie", "elf", {target_arch = "chip-tbd"}>
#map = affine_map<(d0) -> (d0 * 16)>
#map1 = affine_map<(d0) -> (d0 * 64)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [<0, bindings = [<0, storage_buffer, ReadOnly>, <1, storage_buffer, ReadOnly>, <2, storage_buffer>]>]>
#translation = #iree_codegen.translation_info<TransformDialectCodegen codegen_spec = @__transform_main>
#device_target_amd_aie = #hal.device.target<"amd-aie", {executable_targets = [#executable_target_elf], legacy_sync}>
module attributes {hal.device.targets = [#device_target_amd_aie]} {
hal.executable private @matmul_example_dispatch_0 {
hal.executable.variant public @elf target(#executable_target_elf) {
hal.executable.export public @matmul_example_dispatch_0_matmul_4x2048x2048_i8xi8xi32 ordinal(0) layout(#pipeline_layout) attributes {translation_info = #translation} {
^bb0(%arg0: !hal.device):
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
hal.return %c32, %c1, %c1 : index, index, index
}
builtin.module {
func.func @matmul_example_dispatch_0_matmul_4x2048x2048_i8xi8xi32() {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c64 = arith.constant 64 : index
%c2048 = arith.constant 2048 : index
%c0_i8 = arith.constant 0 : i8
%alloc = memref.alloc() : memref<1x1x8x4x4x8xi32, "local">
%alloc_0 = memref.alloc() : memref<1x1x8x8x8x8xi8, "local">
%alloc_1 = memref.alloc() : memref<1x1x8x4x4x8xi8, "local">
%alloc_2 = memref.alloc() : memref<1x1x16x64xi32, "shared">
%alloc_3 = memref.alloc() : memref<1x1x64x64xi8, "shared">
%alloc_4 = memref.alloc() : memref<1x1x16x64xi8, "shared">
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<4x2048xi8, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %0, 64 : memref<4x2048xi8, #hal.descriptor_type<storage_buffer>>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<2048x2048xi8, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %1, 64 : memref<2048x2048xi8, #hal.descriptor_type<storage_buffer>>
%2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) alignment(64) offset(%c0) : memref<4x2048xi32, #hal.descriptor_type<storage_buffer>>
memref.assume_alignment %2, 64 : memref<4x2048xi32, #hal.descriptor_type<storage_buffer>>
scf.forall (%arg0, %arg1) in (1, 32) {
%3 = affine.apply #map(%arg0)
%4 = affine.apply #map1(%arg1)
%subview = memref.subview %2[%3, %4] [4, 64] [1, 1] : memref<4x2048xi32, #hal.descriptor_type<storage_buffer>> to memref<4x64xi32, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
linalg.fill ins(%c0_i32 : i32) outs(%subview : memref<4x64xi32, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
scf.for %arg2 = %c0 to %c2048 step %c64 {
%5 = affine.apply #map(%arg0)
%subview_5 = memref.subview %0[%5, %arg2] [4, 64] [1, 1] : memref<4x2048xi8, #hal.descriptor_type<storage_buffer>> to memref<4x64xi8, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
%6 = affine.apply #map1(%arg1)
%subview_6 = memref.subview %1[%arg2, %6] [64, 64] [1, 1] : memref<2048x2048xi8, #hal.descriptor_type<storage_buffer>> to memref<64x64xi8, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>
iree_linalg_ext.pack %subview_5 padding_value(%c0_i8 : i8) inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %alloc_4 : (memref<4x64xi8, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x16x64xi8, "shared">)
iree_linalg_ext.pack %subview_6 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [64, 64] into %alloc_3 : (memref<64x64xi8, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x64x64xi8, "shared">)
iree_linalg_ext.pack %subview padding_value(%c0_i32 : i32) inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %alloc_2 : (memref<4x64xi32, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>> memref<1x1x16x64xi32, "shared">)
scf.forall (%arg3, %arg4) in (1, 1) {
%subview_7 = memref.subview %alloc_4[%arg3, 0, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : memref<1x1x16x64xi8, "shared"> to memref<1x1x16x64xi8, strided<[1024, 1024, 64, 1], offset: ?>, "shared">
%subview_8 = memref.subview %alloc_3[0, %arg4, 0, 0] [1, 1, 64, 64] [1, 1, 1, 1] : memref<1x1x64x64xi8, "shared"> to memref<1x1x64x64xi8, strided<[4096, 4096, 64, 1], offset: ?>, "shared">
%subview_9 = memref.subview %alloc_2[%arg3, %arg4, 0, 0] [1, 1, 16, 64] [1, 1, 1, 1] : memref<1x1x16x64xi32, "shared"> to memref<1x1x16x64xi32, strided<[1024, 1024, 64, 1], offset: ?>, "shared">
iree_linalg_ext.pack %subview_7 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc_1 : (memref<1x1x16x64xi8, strided<[1024, 1024, 64, 1], offset: ?>, "shared"> memref<1x1x8x4x4x8xi8, "local">)
iree_linalg_ext.pack %subview_8 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [8, 8] into %alloc_0 : (memref<1x1x64x64xi8, strided<[4096, 4096, 64, 1], offset: ?>, "shared"> memref<1x1x8x8x8x8xi8, "local">)
iree_linalg_ext.pack %subview_9 outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %alloc : (memref<1x1x16x64xi32, strided<[1024, 1024, 64, 1], offset: ?>, "shared"> memref<1x1x8x4x4x8xi32, "local">)
linalg.generic {indexing_maps = [#map2, #map3, #map4], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%alloc_1, %alloc_0 : memref<1x1x8x4x4x8xi8, "local">, memref<1x1x8x8x8x8xi8, "local">) outs(%alloc : memref<1x1x8x4x4x8xi32, "local">) {
^bb0(%in: i8, %in_10: i8, %out: i32):
%7 = arith.extsi %in : i8 to i32
%8 = arith.extsi %in_10 : i8 to i32
%9 = arith.muli %7, %8 : i32
%10 = arith.addi %out, %9 : i32
linalg.yield %10 : i32
}
iree_linalg_ext.unpack %alloc outer_dims_perm = [0, 1, 3, 2] inner_dims_pos = [2, 3] inner_tiles = [4, 8] into %subview_9 : (memref<1x1x8x4x4x8xi32, "local"> memref<1x1x16x64xi32, strided<[1024, 1024, 64, 1], offset: ?>, "shared">)
} {mapping = [#gpu.thread<y>, #gpu.thread<x>]}
iree_linalg_ext.unpack %alloc_2 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %subview : (memref<1x1x16x64xi32, "shared"> memref<4x64xi32, strided<[2048, 1], offset: ?>, #hal.descriptor_type<storage_buffer>>)
}
} {mapping = [#gpu.block<y>, #gpu.block<x>]}
memref.dealloc %alloc_4 : memref<1x1x16x64xi8, "shared">
memref.dealloc %alloc_3 : memref<1x1x64x64xi8, "shared">
memref.dealloc %alloc_2 : memref<1x1x16x64xi32, "shared">
memref.dealloc %alloc_1 : memref<1x1x8x4x4x8xi8, "local">
memref.dealloc %alloc_0 : memref<1x1x8x8x8x8xi8, "local">
memref.dealloc %alloc : memref<1x1x8x4x4x8xi32, "local">
return
}
}
}
}
}
Similar to MapForallToBlocks
and MapNestedForallToThreads
transforms used in GPU, we should also consider to map the scf.forall to blocks and cores with block_id/core_id.
scf.forall (%iv0, %iv1) = (%lb0, %lb1) to (%ub0, %ub1) step (%step0, %step1) {
...
scf.forall (%iv2, %iv3) = (%lb2, %lb3) to (%ub2, %ub3) step (%step2, %step3) {
...
} {#gpu.thready, #gpu.threadx}
...
} {#gpu.blocky, #gpu.blockx}
Lets assume that we use a logical block of size (%blocksize_x, %blocksize_y), then the IR after the transformation should look like
air.herd (%core_id_y, %core_id_x) in (%num_cores_y, %num_cores_x) {
%num_blocks_y = ceilDiv(%num_cores_y, %block_size_y)
%num_blocks_x = ceilDiv(%num_cores_x, %block_size_x)
%block_id_y = %core_id_y / %block_size_y
%block_id_x = %core_id_x / %block_size_x
%local_id_y = %core_id_y mod %block_size_y
%local_id_x = %core_id_x mod %block_size_x
scf.for (%iv0, %iv1) = (%lb0 + %block_id_y * %step_0, %lb1 + %block_id_x * %step_1) to (%ub0, %ub1) step(%num_blocks_y * %step0, %num_blocks_x * %step1) {
...
scf.for (%iv2, %iv3) = (%lb2 + %local_id_y * %step_2, %lb3 + %local_id_x * %step_3) to (%ub2, %ub3) step(%block_sizes_y * %step2, %block_size_x * %step3) {
...
}
...
}
I'll beautify this once I get hold of Azure storage.
I have attached gemma_7b.mlir along with gemma weights.
For now, I've uploaded all GEMM dispatches here.
GEMMs in Gemma model appear in two forms :-
linalg.batch_matmul
.linalg.matmul_transpose_b
.Dispatch Type | Type | Shapes | Running on AIE |
---|---|---|---|
linalg.batchmatmul | f32 | 16x1x256xD | No |
linalg.batchmatmul | f32 | 16x1xDx256 | No |
linalg.batchmatmul | f32 | 16xDx256xD | No |
linalg.batchmatmul | f32 | 16xDxDx256 | No |
linalg.batchmatmul | f32 | 1x128x1x1 | No |
linalg.batchmatmul | f32 | 1x128xDx1 | No |
linalg.matmul_transpose_b | f32 | 1x256000x3072 | No |
linalg.matmul_transpose_b | f32 | Dx256000x3072 | No |
NOTE:
I first tried getting the Gemma model compiled for llvm-cpu
and only found batch_mmt4d
- I've added those here.
iree-compile gemma_7b.mlir --iree-input-type=torch \
--iree-hal-target-backends=llvm-cpu \
--iree-hal-dump-executable-sources-to=GEMMA_DISPATCHES \
-o test.vmfb
And when I tried compiling Gemma model for amd-aie
backend, I found the above dispatches.
iree-compile gemma_7b.mlir --iree-input-type=torch \
--iree-hal-target-backends=amd-aie \
--iree-hal-dump-executable-sources-to=GEMMA_DISPATCHES \
-o test.vmfb
@Abhishek-Varma suggested the following command:
iree-compile --mlir-elide-elementsattrs-if-larger=2 \
--iree-hal-target-backends=amd-aie \
--iree-amdaie-use-pipeline=pad-pack \
--iree-amdaie-path-to-ukernels=/mlir-aie/install/aie_kernels/ \
--aie2xclbin-print-ir-before-all \
--iree-amd-aie-enable-chess \
--iree-amdaie-enable-ukernels="all" INPUT.mlir \
--iree-amd-aie-peano-install-dir=/mlir-aie/install \
--iree-amd-aie-mlir-aie-install-dir=/mlir-aie/install \
--iree-amd-aie-vitis-install-dir=/proj/xbuilds/2023.2_released/installs/lin64/Vitis/2023.2 \
--iree-hal-dump-executable-files-to=$PWD \
--iree-amd-aie-show-invoked-commands -o pad_pack.vmfb \
--mlir-print-ir-before-all \
--mlir-disable-threading &> output_ukernel_pad_pack.txt
To compile INPUT.mlir to use a matmul microkernel, compiled with chess.
This assumes that mlir-aie has been built and installed, with 'mm.o' found at mlir-aie/install/aie_kernels/mm.o
This almost works for me, except I get the error: "error: 'aie.core' op Failed to link with xbridge". To workaround this error, I needed to copy mm.o to the directory where I run iree-compile from, and remove the flag
--iree-amdaie-path-to-ukernels=/mlir-aie/install/aie_kernels/ \
This fix was suggested by @erwei-xilinx who has run a similar pipeline through air (without IREE) and has had success when mm.o is in the work directory.
With this change, if INPUT.mlir is
!lhs = tensor<256x256xbf16>
!rhs = tensor<256x256xbf16>
!res = tensor<256x256xf32>
func.func @matmul_small(%lhs : !lhs, %rhs : !rhs) -> !res {
%empty = tensor.empty() : !res
%cst = arith.constant 0.0 : f32
%fill = linalg.fill ins(%cst : f32) outs(%empty : !res) -> !res
%2 = linalg.matmul ins(%lhs, %rhs : !lhs, !rhs)
outs(%fill : !res) -> !res
return %2 : !res
}
I can successfully run the vmfb and get the correct numbers:
iree-run-module --device=xrt --module=pad_pack.vmfb --input="256x256xbf16=1" --input="256x256xbf16=1"
gives
EXEC @matmul_small
result[0]: hal.buffer_view
256x256xf32=[256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256
If the input shape is changed to
!lhs = tensor<308x2432xbf16>
!rhs = tensor<2432x2432xbf16>
!res = tensor<308x2432xf32>
The output is no longer correct:
EXEC @matmul_small
result[0]: hal.buffer_view
308x2432xf32=[912 45600 912 45600 760 68400 760 316920 912 63840 912 212952 912 45600 912 45600 760 68400 760 316920 912 50464 912
For another shape of priority it looks good. Change the shapes, and running
iree-run-module --device=xrt --module=pad_pack.vmfb --input="8192x2432xbf16=1" --input="2432x9728xbf16=1" --expected_output="8192x9728xf32=2432"
EXEC @matmul_small
[SUCCESS] all function outputs matched their expected values.
Since xclbins need to be signed before they are allowed to run, we need a helper program that reads the vmfb flatbuffer, extracts the embedded xclbin files, and signs them automatically. Currently it's only possible to run things on real hardware by dumping the artifacts to get the xclbin files that were made.
Based on recent discussions the pad-pack pipeline has allowed us to make progress but the long term solution seems to be that we want to further build the pack pipeline which is currently not e2e enabled. I used the patch below to test where it is today.
--- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp
+++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp
@@ -283,13 +283,8 @@ void buildAMDAIETransformPassPipeline(OpPassManager &pm) {
pm.addPass(createAMDAIELowerExecutableTargetPass(options));
}
pm.addPass(createAMDAIELowerWorkgroupCountPass());
- if (clUsePipeline == AIEPassPipeline::PadPackPipeline) {
- auto &modulePassManager = pm.nest<ModuleOp>();
- addMLIRAIRAIELoweringPasses(modulePassManager);
- } else if (clUsePipeline != AIEPassPipeline::PackPipeline) {
- auto &modulePassManager = pm.nest<ModuleOp>();
- addMLIRAIRAIELegacyLoweringPasses(modulePassManager);
- }
+ auto &modulePassManager = pm.nest<ModuleOp>();
+ addMLIRAIRAIELoweringPasses(modulePassManager);
For a matmul with shape M.N,K all 64 it seems to be crashing in the air-to-aie lowering. Here is the IR dump with the crash dump in the end. @erwei-xilinx any thoughts why it would crash there?
https://gist.github.com/nirvedhmeshram/25cf5426efea6cb048c06881ee801b60
Tracker task for the shapes we should support with direct codegen
2 compiler errors need resolving. The errors can be reproduced with square matmuls.
The failure is in irrt-to-ipu
The error message is
// iree-compile: iree-amd-aie/third_party/mlir-air/mlir/lib/Conversion/AIRRtToIpuPass.cpp:551: void
(anonymous namespace)::tileIllegalWrapDim(airrt::DmaMemcpyNdOp): Assertion
`!(const_wrap % (AIE2_WRAP_UPPER_BOUND / 2)) && "Currently do not support remainder tiles"' failed.
@erwei-xilinx is aware of this:
https://teams.microsoft.com/l/message/19:meeting_Zjc5ZmZhM2EtZDcxZS00NzYxLTliYmQtNGFlNzY1MDJhNjMy@thread.v2/1713812572221?context=%7B%22contextType%22%3A%22chat%22%7D
Failure in air-label-scf-for-to-ping-pong
Error message is quite generic, needs further understanding:
error: block with no terminator, has %40 = "air.wait_all"(%arg8) : (!air.async.token) -> !air.async.token
note: see current operation: %40 = "air.wait_all"(%arg8) : (!air.async.token) -> !air.async.token
Example input IR:
!lhs = tensor<2432x2432xbf16>
!rhs = tensor<2432x2432xbf16>
!out = tensor<2432x2432xf32>
func.func @matmul_32x32_32xf32_(%lhs : !lhs, %rhs : !rhs) -> !out {
%init_acc = tensor.empty() : !out
%c0_acc_type = arith.constant 0.0 : f32
%acc = linalg.fill ins(%c0_acc_type : f32) outs(%init_acc : !out) -> !out
%result = linalg.matmul ins(%lhs, %rhs: !lhs, !rhs) outs(%acc: !out) -> !out
return %result: !out
}
Here is the crash I am seeing for matmul with accumulate
https://gist.github.com/nirvedhmeshram/32a76d8676b8d4d3ff6e7bbc7a10d05d
related PR: #119
Issue: func.func @<>() in imported torch onnx MLIR
from transformers import LlamaForCausalLM, LlamaTokenizer
import torch
class optModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf")
self.model.eval()
def forward(self, tokens):
attention_mask = torch.ones(tokens.shape, dtype=torch.long)
return self.model.forward(input_ids=tokens, attention_mask=attention_mask)
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
test_input = tokenizer.encode("The Manhattan bridge", return_tensors="pt")
model = optModel()
onnx_program = torch.onnx.export(model, test_input, "llama-7b.onnx")
The last line export ONNX mode, in case you use a different script, stick in last line to your setup.
Examine (head llama-7b.torch-onnx.mlir) you will see that mlir has func.func @<>() and body of the function is missing.
I have tested a superset of OPT GEMM sizes, with M/N/K as [8, 16, 32, 64, 2048, 8192] combinations using the pad-pack pipeline. Most test cases have passed, and the failed cases are those when two of M/N/K dimensions are large, e.g., 8192 x 8192 x k
, m x 8192 x 2048
, m x 8192 x 8192
. The good news is 8192 x 2048 x 8192
compiles without problem.
For the next goal, we should look into those large failure cases, e.g., OPT size 8 x 50272 x 2048
.
Given the following IR :-
scf.for %arg0 = 0 to 100 steps 2{
...
arith.add %cst, %arg0
...
}
We want to peel the prologue (1st iteration) and the epilogue (last iteration) :-
scf.for %arg0 = 0 to 2 steps 2{
...
arith.add %cst, %arg0
...
}
scf.for %arg0 = 2 to 98 steps 2{
...
arith.add %cst, %arg0
...
}
scf.for %arg0 = 98 to 100 steps 2{
...
arith.add %cst, %arg0
...
}
The current pass is already achieving prologue peeling.
I get following error followed by crash:
ElementsAttr does not provide iteration facilities for type mlir::Attribute
, see attribute: dense_resource<_layers.0.weight> : tensor<4x3xf32>
invalid T
for ElementsAttr::getValues
UNREACHABLE executed at /proj/gdba/kumar/nod/iree/third_party/llvm-project/mlir/include/mlir/IR/BuiltinAttributeInterfaces.h:307!
Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace.
Steps to reproduce the issue:
Save following code as mlp.onnx.torch.mlir:
module {
func.func @main_graph(%arg0: !torch.vtensor<[8,3],f32>) -> !torch.vtensor<[8,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.3.0"} {
%0 = torch.vtensor.literal(dense_resource<_layers.0.weight> : tensor<4x3xf32>) : !torch.vtensor<[4,3],f32>
%1 = torch.vtensor.literal(dense_resource<_layers.0.bias> : tensor<4xf32>) : !torch.vtensor<[4],f32>
%2 = torch.vtensor.literal(dense_resource<_layers.2.weight> : tensor<5x4xf32>) : !torch.vtensor<[5,4],f32>
%3 = torch.vtensor.literal(dense_resource<_layers.2.bias> : tensor<5xf32>) : !torch.vtensor<[5],f32>
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
%4 = torch.aten.transpose.int %0, %int0, %int1 : !torch.vtensor<[4,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,4],f32>
%5 = torch.aten.mm %arg0, %4 : !torch.vtensor<[8,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[8,4],f32>
%6 = torch.aten.add.Tensor %5, %1, %int1 : !torch.vtensor<[8,4],f32>, !torch.vtensor<[4],f32>, !torch.int -> !torch.vtensor<[8,4],f32>
%7 = torch.aten.relu %6 : !torch.vtensor<[8,4],f32> -> !torch.vtensor<[8,4],f32>
%int0_0 = torch.constant.int 0
%int1_1 = torch.constant.int 1
%8 = torch.aten.transpose.int %2, %int0_0, %int1_1 : !torch.vtensor<[5,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,5],f32>
%9 = torch.aten.mm %7, %8 : !torch.vtensor<[8,4],f32>, !torch.vtensor<[4,5],f32> -> !torch.vtensor<[8,5],f32>
%10 = torch.aten.add.Tensor %9, %3, %int1_1 : !torch.vtensor<[8,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[8,5],f32>
%11 = torch.aten.relu %10 : !torch.vtensor<[8,5],f32> -> !torch.vtensor<[8,5],f32>
return %11 : !torch.vtensor<[8,5],f32>
}
}
{-#
dialect_resources: {
builtin: {
_layers.0.weight: "0x08000000301A083F90AB963EED1C05BF6842DABD02B53C3E095BD93D48CE6A3E28A1853DC00C5E3D905A05BF6B25483EBF1EF53E",
_layers.0.bias: "0x0800000054A3B53ED6C0B33E55D1A2BDE5D2BB3E",
_layers.2.weight: "0x08000000409F42BD50E864BD98C6783E603FB43C5216D13EE0EBBB3DF499B23E98581F3E10A47EBD40BF023C2884CF3ED06213BE88759B3DCA19E0BEC484253EF87EE5BDCE1DAD3E00DC273B1C028D3EDA38CFBE",
_layers.2.bias: "0x08000000041A503E183382BEFCEEBABE7A3AF0BE0E9DEE3E"
}
}
#-}
And run :
/iree-build/tools/iree-compile --iree-hal-target-backends=llvm-cpu mlp.onnx.torch.mlir > mlp.onnx.vmfb
I get crash stack as below:
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var LLVM_SYMBOLIZER_PATH
to point to it):
0 libIREECompiler.so 0x00007fe6517ebee7 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 39
1 libIREECompiler.so 0x00007fe6517ea100 llvm::sys::RunSignalHandlers() + 80
2 libIREECompiler.so 0x00007fe6517ec59f
3 libpthread.so.0 0x00007fe64b863420
4 libc.so.6 0x00007fe64b40d00b gsignal + 203
5 libc.so.6 0x00007fe64b3ec859 abort + 299
6 libIREECompiler.so 0x00007fe651770b0f
7 libIREECompiler.so 0x00007fe651735402
8 libIREECompiler.so 0x00007fe65172fd31
9 libIREECompiler.so 0x00007fe65429303f mlir::LLVM::detail::getLLVMConstant(llvm::Type*, mlir::Attribute, mlir::Location, mlir::LLVM::ModuleTranslation const&) + 1807
10 libIREECompiler.so 0x00007fe65429731d mlir::LLVM::ModuleTranslation::convertGlobals() + 1741
11 libIREECompiler.so 0x00007fe65429b9bf mlir::translateModuleToLLVMIR(mlir::Operation*, llvm::LLVMContext&, llvm::StringRef) + 1791
12 libIREECompiler.so 0x00007fe6530f8d71
13 libIREECompiler.so 0x00007fe652d74520
14 libIREECompiler.so 0x00007fe651980696 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
15 libIREECompiler.so 0x00007fe651980ea8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
16 libIREECompiler.so 0x00007fe6519856e1
17 libIREECompiler.so 0x00007fe652d752c8
18 libIREECompiler.so 0x00007fe651980696 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
19 libIREECompiler.so 0x00007fe651980ea8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
20 libIREECompiler.so 0x00007fe65198634e
21 libIREECompiler.so 0x00007fe65198242b mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 2299
22 libIREECompiler.so 0x00007fe651980831 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1057
23 libIREECompiler.so 0x00007fe651980ea8 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
24 libIREECompiler.so 0x00007fe651983299 mlir::PassManager::run(mlir::Operation*) + 985
25 libIREECompiler.so 0x00007fe651746dbb ireeCompilerInvocationPipeline + 3099
26 libIREECompiler.so 0x00007fe65194b1a5
27 libIREECompiler.so 0x00007fe65194aa5a
28 libc.so.6 0x00007fe64b3ee083 __libc_start_main + 243
29 iree-compile 0x000055fda4bf172e
Based on the recent addition of Pack based pipeline without changing any config and using pack_pipeline_e2e.mlir as the base input IR structure to test the matmul sizes the current state with Matmul shape = M x N x K x i32
:-
Failure at air-copy-to-dma
- mlir-print-ir-before-all:
- 8 x 16 x (8, 32, 64, 2048, 8192)
- 8 x 32 x (8, 32, 64, 2048, 8192)
- 8 x 64 x (8, 32, 64, 2048, 8192)
- 8 x 2048 x (8, 32, 64, 2048, 8192)
- 8 x 8192 x (8, 32, 64, 2048, 8192)
- 8 x 50272 x (8, 32, 64, 2048, 8192)
Please find the dispatches to target attached here
Dispatch Name | GEMM Name (M,N,K) | Elementwise | No of inputs | No of outputs | gops (2MNK)/1e9 |
---|---|---|---|---|---|
_initializer_46_dispatch_0 | matmul_transpose_b_2x2432x256 | Y | 3 | 1 | 0.00 |
_initializer_48_dispatch_0 | matmul_transpose_b_2x2432x2432 | Y | 3 | 1 | 0.02 |
async_dispatch_3 | matmul_transpose_b_2x2432x2048 | Y | 3 | 1 | 0.02 |
async_dispatch_4 | matmul_transpose_b_154x2432x4096 | Y | 3 | 1 | 3.07 |
async_dispatch_5 | matmul_transpose_b_2x2432x2432 | Y | 4 | 2 | 0.02 |
async_dispatch_6 | matmul_transpose_b_2x14592x2432 | Y | 3 | 1 | 0.14 |
async_dispatch_10 | matmul_transpose_b_154x7296x2432 | N | 2 | 1 | 5.47 |
async_dispatch_15 | matmul_transpose_b_8192x7296x2432 | N | 2 | 1 | 290.72 |
async_dispatch_29 | batch_matmul_2x77x2432x2432 | Y | 5 | 1 | 1.82 |
async_dispatch_32 | matmul_transpose_b_154x9728x2432 | Y | 3 | 1 | 7.29 |
async_dispatch_33 | matmul_transpose_b_154x2432x9728 | N | 2 | 1 | 7.29 |
async_dispatch_36 | batch_matmul_2x4096x2432x2432 | Y | 5 | 1 | 96.91 |
async_dispatch_39 | matmul_transpose_b_8192x9728x2432 | Y | 3 | 1 | 387.62 |
async_dispatch_40 | matmul_transpose_b_8192x2432x9728 | N | 2 | 1 | 387.62 |
async_dispatch_1302 | matmul_transpose_b_2x4864x2432 | Y | 3 | 1 | 0.05 |
async_dispatch_1333 | matmul_transpose_b_8192x64x2432 | N | 2 | 1 | 2.55 |
async_dispatch_10144 | matmul_512x16384x512 | Y | 3 | 1 | 8.59 |
async_dispatch_10149 | matmul_512x16384x512 | Y | 4 | 1 | 8.59 |
async_dispatch_10233 | matmul_256x262144x512 | N | 2 | 1 | 68.72 |
async_dispatch_10266 | matmul_128x1048576x256 | N | 2 | 1 | 68.72 |
Example of a typical snippet of interest, this is from async_dispatch_3
%9 = linalg.matmul_transpose_b ins(%4, %5 : tensor<2x2048xf32>, tensor<2432x2048xf32>) outs(%8 : tensor<2x2432xf32>) -> tensor<2x2432xf32>
%10 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%9, %6 : tensor<2x2432xf32>, tensor<2432xf32>) outs(%7 : tensor<2x2432xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%11 = arith.addf %in, %in_1 : f32
%12 = arith.negf %11 : f32
%13 = math.exp %12 : f32
%14 = arith.addf %13, %cst_0 : f32
%15 = arith.divf %cst_0, %14 : f32
%16 = arith.mulf %15, %11 : f32
linalg.yield %16 : f32
} -> tensor<2x2432xf32>
Here is a general command to compile a model that may produce such dispatches with llvm-cpu
./tools/iree-compile --mlir-elide-elementsattrs-if-larger=2 --iree-hal-target-backends=llvm-cpu ../tresleches.mlir --iree-hal-dump-executable-sources-to=../dispatches_no_dtile2 --iree-opt-data-tiling=0
In the end-to-end matmul test script run_matmul_test.sh I observe the following:
The following test of a single matmul:
run_matmul_test \ --name_prefix "single_matmul" \
--lhs_rhs_type "i32" \
--acc_type "i32" \
--m "8" \
--n "32" \
--k "16"
compiles and runs correctly. As does a test of a single matmul with m=32 n=8 k=16. But when the 2 matmuls are merged into a single test:
run_matmul_test \ --name_prefix "two_matmuls" \
--lhs_rhs_type "i32" \
--acc_type "i32" \
--m "32,8" \
--n "8,32" \
--k "16,16"
The compilation works, but there is a hang at runtime:
+ echo 'Running command: /proj/gdba/jamesn/workspace/builds/iree-clang/tools/iree-e2e-matmul-test --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls.vmfb --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls_calls.vmfb --device=xrt'
Running command: /proj/gdba/jamesn/workspace/builds/iree-clang/tools/iree-e2e-matmul-test --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls.vmfb --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls_calls.vmfb --device=xrt
+ eval '/proj/gdba/jamesn/workspace/builds/iree-clang/tools/iree-e2e-matmul-test --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls.vmfb --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls_calls.vmfb --device=xrt'
++ /proj/gdba/jamesn/workspace/builds/iree-clang/tools/iree-e2e-matmul-test --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls.vmfb --module=/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci/results_dir_tmp/two_matmuls_calls.vmfb --device=xrt
--- TEST[matmul_8x32_16xi32__8_16_32_0] ---
Matmul shape (MxKxN): 8x16x32
The compilation in the 2 matmul case results in the following files being created:
results_dir_tmp/
├── configured_module_matmul_32x8_16xi32__dispatch_0.mlir
├── configured_module_matmul_8x32_16xi32__dispatch_0.mlir
├── module_matmul_32x8_16xi32__dispatch_0_amdaie_xclbin_fb
│ ├── aie_cdo_elfs.bin
│ ├── aie_cdo_enable.bin
│ ├── aie_cdo_error_handling.bin
│ ├── aie_cdo_init.bin
│ ├── aie_partition.json
│ ├── design.bif
│ ├── design.pdi
│ ├── input.ll
│ ├── input.o
│ ├── input.opt.ll
│ ├── kernels.json
│ ├── mem_topology.json
│ ├── module_matmul_32x8_16xi32__dispatch_0_amdaie_xclbin_fb.aiecc.mlir
│ ├── module_matmul_32x8_16xi32__dispatch_0_amdaie_xclbin_fb.ipu.txt
│ ├── module_matmul_32x8_16xi32__dispatch_0_amdaie_xclbin_fb.xclbin
│ ├── segment_0_core_0_2.elf
│ └── segment_0_core_0_2.elf.ld
├── module_matmul_32x8_16xi32__dispatch_0_amdaie_xclbin_fb_benchmark.mlir
├── module_matmul_32x8_16xi32__dispatch_0.mlir
├── module_matmul_8x32_16xi32__dispatch_0_amdaie_xclbin_fb
│ ├── aie_cdo_elfs.bin
│ ├── aie_cdo_enable.bin
│ ├── aie_cdo_error_handling.bin
│ ├── aie_cdo_init.bin
│ ├── aie_partition.json
│ ├── design.bif
│ ├── design.pdi
│ ├── input.ll
│ ├── input.o
│ ├── input.opt.ll
│ ├── kernels.json
│ ├── mem_topology.json
│ ├── module_matmul_8x32_16xi32__dispatch_0_amdaie_xclbin_fb.aiecc.mlir
│ ├── module_matmul_8x32_16xi32__dispatch_0_amdaie_xclbin_fb.ipu.txt
│ ├── module_matmul_8x32_16xi32__dispatch_0_amdaie_xclbin_fb.xclbin
│ ├── segment_0_core_0_2.elf
│ └── segment_0_core_0_2.elf.ld
├── module_matmul_8x32_16xi32__dispatch_0_amdaie_xclbin_fb_benchmark.mlir
├── module_matmul_8x32_16xi32__dispatch_0.mlir
├── two_matmuls_calls.mlir
├── two_matmuls_calls.vmfb
├── two_matmuls_ir.mlir
└── two_matmuls.vmfb
Notice that there are 2 xclbin files above, created when iree-compile calls into AIETargetBackend::serializeExecutable (twice, once for each matmul dispatch) which in turn calls into mlir-aie's aie2xclbin through llvm::sys::ExecuteAndWait
.
Task: track down this hang, and fix it.
Update: In CI, I do not get a hang but another error:
+ echo 'Running command: /home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/iree-install/tools/iree-e2e-matmul-test --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls.vmfb --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls_calls.vmfb --device=xrt'
+ eval '/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/iree-install/tools/iree-e2e-matmul-test --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls.vmfb --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls_calls.vmfb --device=xrt'
++ /home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/iree-install/tools/iree-e2e-matmul-test --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls.vmfb --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls_calls.vmfb --device=xrt
Running command: /home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/iree-install/tools/iree-e2e-matmul-test --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls.vmfb --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls_calls.vmfb --device=xrt
terminate called after throwing an instance of 'xrt_core::system_error'
what(): DRM_IOCTL_AMDXDNA_CREATE_HWCTX IOCTL failed (err=16): Device or resource busy
build_tools/ci/run_matmul_test.sh: line 432: 725456 Aborted (core dumped) /home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/iree-install/tools/iree-e2e-matmul-test --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls.vmfb --module=/home/github/actions-runner/_work/iree-amd-aie/iree-amd-aie/test1/multiple_matmuls_calls.vmfb --device=xrt
(see https://github.com/nod-ai/iree-amd-aie/actions/runs/8714586802/job/23905147045?pr=279)
Update:
Trying to understand why CI and local behaviour is different. Could it be different versions of XRT?
CI version of xrt:
svcnod@sharkbox1:~$ /opt/xilinx/xrt/bin/xclbinutil --version
XRT Build Version: 2.17.0
Build Version Branch: HEAD
Build Version Hash: 41f4221433c6b173316b61cb2e7e3ee5152d8075
Build Version Hash Date: Sun, 11 Feb 2024 14:32:16 +0530
Build Version Date: Wed, 14 Feb 2024 21:28:54 -0800
Local version of xrt:
jnewling@xsjipunuc50:/proj/gdba/jamesn/workspace/iree-amd-aie/build_tools/ci $ /opt/xilinx/xrt/bin/xclbinutil --version
XRT Build Version: 2.17.0
Build Version Branch: HEAD
Build Version Hash: f23d53edd42fea0f0acd08c194b4750ed77127e2
Build Version Hash Date: Sat, 23 Mar 2024 16:53:57 +0530
Build Version Date: Fri, 12 Apr 2024 10:05:46 -0700
So the CI version if XRT is 2 months older.
Update:
The directory where the xclbin files seems to be critical in determining success/fail:
This will pass:
./run_matmul_test.sh /home/jnewling/data/ ../../../builds/iree-clang $MLIR_AIE_INSTALL_DIR $PEANO_INSTALL_DIR /opt/xilinx/xrt $VITIS_INSTALL_PATH 0
This will hang:
./run_matmul_test.sh /proj/gdba/jamesn/data/ ../../../builds/iree-clang $MLIR_AIE_INSTALL_DIR $PEANO_INSTALL_DIR /opt/xilinx/xrt $VITIS_INSTALL_PATH 0
This took @daveliddell and I about 2 hours triaging to uncover. It is highly mysterious.
Most of IREEs codegeneration relies on using tensor-based approaches. To aid that the scf.forall
operation allows you to do parallel spatial decomposition. This is some part of what is modeled by air.launch
and air.herd
. The latter are more general constructs, but for current codegeneration pipeline these constructs might be too general.
This issue is to track the use of tensor based code-generation (and therefore scf.forall
) farther down the stack. Post bufferization, air.launch
and air.herd
might be useful constructs to use anyway for managing the lowering to IPU instructions.
With an IR like this
func.func @matmul_8x32_16xi32_(%lhs: tensor<8x16xi32>, %rhs: tensor<16x32xi32>) -> tensor<8x32xi32> {
%init_acc = tensor.empty() : tensor<8x32xi32>
%c0_acc_type = arith.constant 0: i32
%acc = linalg.fill ins(%c0_acc_type : i32) outs(%init_acc : tensor<8x32xi32>) -> tensor<8x32xi32>
%result = linalg.matmul ins(%lhs, %rhs: tensor<8x16xi32>, tensor<16x32xi32>) outs(%acc: tensor<8x32xi32>) -> tensor<8x32xi32>
return %result: tensor<8x32xi32>
}
func.func @matmul_8x16_16xi32_(%lhs: tensor<8x16xi32>, %rhs: tensor<16x16xi32>) -> tensor<8x16xi32> {
%init_acc = tensor.empty() : tensor<8x16xi32>
%c0_acc_type = arith.constant 0: i32
%acc = linalg.fill ins(%c0_acc_type : i32) outs(%init_acc : tensor<8x16xi32>) -> tensor<8x16xi32>
%result = linalg.matmul ins(%lhs, %rhs: tensor<8x16xi32>, tensor<16x16xi32>) outs(%acc: tensor<8x16xi32>) -> tensor<8x16xi32>
return %result: tensor<8x16xi32>
}
The following compiler command sometimes works, but most of the times crashes and the segfault is not the same each time either
./tools/iree-compile build-matmul/matmul_i32_i32_small_amd-aie_xrt_matmuls.mlir --iree-hal-target-backends=amd-aie --iree-amd-aie-peano-install-dir=<path to peano>l --iree-amd-aie-mlir-aie-install-dir=<path to mlir-aie> --iree-amd-aie-vitis-install-dir=<path to vitis> -o test.vmfb
Here are some of the segfaults
Note that if any of the functions is compiled by itself there are no problems.
As now I have tried to connect passes for packing pipeline in this PR, I noticed some missing pieces that should be set as high priority to solve.
Here is a gist showing the transient issue
https://gist.github.com/nirvedhmeshram/5deeaaa94729053794aa2dd9fbf223e0
git clone https://github.com/nod-ai/SHARK-TestSuite.git
cd e2eshark
pip install -r requirements.txt
Run following command after providing your own values for :
-c : a valid torch MLIR build dir (build https://github.com/llvm/torch-mlir if you do not have a build already) : Note that torch MLIR is huge, hence should be generated.
-i : a valid IREE build dir (build https://github.com/openxla/iree f you do not have a build already)
--hfhome : A directory with more than 10GB of free space as hugging face models weights will be downloaded there
python ./run.py --upto inference --mode onnx -c ../../torch-mlir/build -i ../../iree-build/ --tests pytorch/models/opt-125M --hfhome /proj/gdba/kumar/HF_HOME
then cd to test-run/pytorch/models/opt-125M and you can find torch mlir and run logs. You can work in this directory to iterate on your fixes.
iree-compile: iree/third_party/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From &) [To = mlir::DenseElementsAttr, From = mlir::Attribute]: Assertion isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed. Please report issues to https://github.com/openxla/iree/issues and include the crash backtrace. Stack dump: 0. Program arguments: /proj/gdba/kumar/nod/iree-build/tools/iree-compile --iree-hal-target-backends=llvm-cpu opt-125M.fp32.onnx.torch.mlir Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var
LLVM_SYMBOLIZER_PATH` to point to it):
0 libIREECompiler.so 0x00007fa14a046757 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) + 39
1 libIREECompiler.so 0x00007fa14a044970 llvm::sys::RunSignalHandlers() + 80
2 libIREECompiler.so 0x00007fa14a046e0f
3 libpthread.so.0 0x00007fa14413f420
4 libc.so.6 0x00007fa143ce900b gsignal + 203
5 libc.so.6 0x00007fa143cc8859 abort + 299
6 libc.so.6 0x00007fa143cc8729
7 libc.so.6 0x00007fa143cd9fd6
8 libIREECompiler.so 0x00007fa14ad4042e
9 libIREECompiler.so 0x00007fa14ad2d16b
10 libIREECompiler.so 0x00007fa14ad2cac9
11 libIREECompiler.so 0x00007fa14d3fa9e7
12 libIREECompiler.so 0x00007fa14d3f799e mlir::PatternApplicator::matchAndRewrite(mlir::Operation*, mlir::PatternRewriter&, llvm::function_ref<bool (mlir::Pattern const&)>, llvm::function_ref<void (mlir::Pattern const&)>, llvm::function_ref<mlir::LogicalResult (mlir::Pattern const&)>) + 942
13 libIREECompiler.so 0x00007fa14d3e3389
14 libIREECompiler.so 0x00007fa14d3dfa32 mlir::applyPatternsAndFoldGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) + 1058
15 libIREECompiler.so 0x00007fa14d3a9ddb
16 libIREECompiler.so 0x00007fa14a1dc666 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
17 libIREECompiler.so 0x00007fa14a1dce78 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
18 libIREECompiler.so 0x00007fa14a1e23ae
19 libIREECompiler.so 0x00007fa14a1de3fb mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl(bool) + 2299
20 libIREECompiler.so 0x00007fa14a1dc801 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 1057
21 libIREECompiler.so 0x00007fa14a1dce78 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
22 libIREECompiler.so 0x00007fa14a1e16b1
23 libIREECompiler.so 0x00007fa14b1c1d64
24 libIREECompiler.so 0x00007fa14a1dc666 mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) + 646
25 libIREECompiler.so 0x00007fa14a1dce78 mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) + 344
26 libIREECompiler.so 0x00007fa14a1df269 mlir::PassManager::run(mlir::Operation*) + 985
27 libIREECompiler.so 0x00007fa149fa18b2 ireeCompilerInvocationPipeline + 3714
28 libIREECompiler.so 0x00007fa14a1a6be5
29 libIREECompiler.so 0x00007fa14a1a647a
30 libc.so.6 0x00007fa143cca083 __libc_start_main + 243
31 iree-compile 0x0000564142c0d72e
Aborted (core dumped)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.