corsix / amx Goto Github PK
View Code? Open in Web Editor NEWApple AMX Instruction Set
License: MIT License
Apple AMX Instruction Set
License: MIT License
Hi all,
I'm in the process of researching Apple AMX as a potential way of speeding up IEEE FP BLAS kernels in OpenBLAS.
On the MacOS side, it seems that between this repository and other resources, I have all I need to be able to write the kernels.
The issue as of now is Linux. Speaking with the folks supporting/developing Asahi Linux (see mastodon thread here: https://mast.hpc.social/@fclc/109914828822965657) discussion came up that Asahi has no plans to support the EL0 CPU state required for AMX.
I'm of the opinion that it may be possible to implement a Linux kernel module to allow for the usage of AMX on M1, M2 and the various SKUs based on those SOCs.
This would probably require fairly tight understanding of AMX and its underlying operations.
I was hoping for insight from any of the folks working on this present project.
After analyzing the die shots and speculating on performance, I came across a major change to the AMX architecture. Would you mind reading through the README to amx-benchmarks and helping me test the hypothesis? You don't need to rent an M2 from the cloud; I can test on my A15.
This is the output of make test
in an Apple M3 Max:
Testing AMX_LDX... Failed on iteration 0.0 (operand 0xe7ee80da3d4b9e09)
Testing AMX_LDY... Failed on iteration 0.0 (operand 0xe7ee80da3d4b9e09)
Testing AMX_LDZ... OK
Testing AMX_LDZI... OK
Testing AMX_STX... OK
Testing AMX_STY... OK
Testing AMX_STZ... OK
Testing AMX_STZI... OK
Testing AMX_EXTRX... OK
Testing AMX_EXTRY... OK
Testing AMX_MAC16... OK
Testing AMX_FMA16... OK
Testing AMX_FMA32... OK
Testing AMX_FMA64... OK
Testing AMX_FMS16... OK
Testing AMX_FMS32... OK
Testing AMX_FMS64... OK
Testing AMX_VECINT... OK
Testing AMX_VECFP... OK
Testing AMX_MATINT... Failed on iteration 1.575 (operand 0xd060b15ad5995f37)
Testing AMX_MATFP... OK
Testing AMX_GENLUT... OK
This suggests either Apple broke compatibility with the previous versions, or there are new features using some of the previously-ignored bit in the parameters to these instructions.
I think the former is unlikely, as I have been writing lots of AMX code lately, with excellent test coverage, and I'm yet to see any unexplained failures in my software tests, e.g. something that behaves differently from the M3 than the M1 I also have here (using your documented M1 features, and also some of the documented M2 features, which work as expected on the M3). So hopefully there are new features in the M3.
I investigated this a bit by changing the random values, and I see that for AMX_LDX
and AMX_LDY
, out of the previously ignored bits (63, 61 and 59), only bit 61 is always set in case of a test error; for 63 and 59, they are sometimes set and sometimes not (indeed, I've seen an error for which bits 63 and 59 were not set, only 61 was).
So I wrote a small program to investigate this, and found that bit 61 represents a strided load: when loading pairs, the stride is 4 (that is, if you start at X0
, it loads to X0
and X4
), whereas when loading 4 at a time, the stride is 2 (e.g. X0
, X2
, X4
, X6
). I will attach a test program and its output on my M3. For AMX_LDY
, results are identical.
As for AMX_MATINT
, I collected a bunch of values where the tests fail:
0xd060b15ad5995f37
0xd26ab256885620e0
0xd060b15ad5995f37
0x0e61b375b73c8104
0xbc7870a58e4864bc
0xce6b3046d4af6812
0xa069335db08b4b0e
0x5a71b34bf47fe485
0xee7172c6ce0a04ec
0xa87ef14a1baca54d
0xb662b045bc40cdb0
0xc4697074e454ab6f
ANDing these together, the common theme is bits 44, 45, 53 and 54 set. I see that having bits 53 and 54 set means an indexed load in ALU mode 8. For that mode, there are two lane width modes (i.e. bits 45:42): 10 or any other value. However, having bits 44 and 45 set would correspond to 12.
If you'd like to investigate, but don't have access to an M3, I can run any tests you need; just let me know.
Hi,
On my M2 (2022 MacBook Air) I'm getting the following:
AMX_LDX: fail
AMX_LDY: fail
AMX_LDZ: pass
AMX_LDZI: pass
AMX_STX: pass
AMX_STY: pass
AMX_STZ: pass
AMX_STZI: pass
AMX_EXTRX: fail
AMX_EXTRY: fail
AMX_MAC16: pass
AMX_FMA16: pass
AMX_FMA32: pass
AMX_FMA64: pass
AMX_FMS16: pass
AMX_FMS32: pass
AMX_FMS64: pass
AMX_VECINT: fail
AMX_VECFP: fail
AMX_MATINT: pass
AMX_MATFP: fail
AMX_GENLUT: fail
This is with clang 14.0.0 on macOS 13.1. I could be doing something wrong here, or there might be a minor architectural difference between AMX on M1 and M2.
I'm going to investigate further to see if I can get everything to pass on M2, but first I was wondering if there has been any existing work done on M2 yet?
Thanks.
This is a followup on #9.
Our paper mentioned there was published in the the first edition of the new IACR Communications in Cryptology journal, so I'd like to ask if you'd be so kind as to update the link to the actual journal paper, I'd be grateful. We note that the repository associated with this paper is now online, so you might be interested in linking to that as well.
Finally, we have another paper out on implementing two other post-quantum cryptosystems on AMX. You may also be interested in adding a link to it.
Thanks.
This is to let you know that I and some colleagues published a preprint for a paper on using AMX for cryptographic applications:
https://eprint.iacr.org/2024/2
Your repository was an invaluable resource for accomplishing our results. Perhaps you'd like to add a link to your references page. We think it might a good resource to help others.
This is a question rather than an issue.
First of all, thanks for the huge effort spent on reverse engineering AMX and documenting it. It is really appreciated. I've been able to try out AMX for an application of mine and that would have been impossible without your extensive documentation.
I just wanted to confirm a few things, and I was hoping with the experience acquired with your reverse engineering effort, you could confirm a few things. Perhaps even add some pointers to enhance your already excellent documentation.
matint
/matfp
and mac16
/fm[as][16,32,64]
when you only need to do outer products? Essentially, by choosing ALU mode 0 or 1 in matint
/matfp
, it appears to me you can emulate almost all functionality of mac16
/fm[as][16,32,64]
. Is that right or do you see any particular advantage to mac16
/fm[as][16,32,64]
vs matint
/matfp
? I'm just trying to understand the thought process of the AMX designers, and why would they have separate instructions when they could just merge them together.matint
/matfp
was the ability to do a multiply-only operation x*y
rather than accumulating/subtracting with z
. With mac16
/fm[as][16,32,64]
this can be done by setting bit 27 to 1. This is important e.g. for the first loop iteration in a matrix multiply, when z
may already have non-zero values. Is there a combination of bits I can use to do an outer product without accumulating/subtracting from z
?dougallj claims that 31 means XZR, but do any AMX operations make sense with an all-zero value? OTOH, using SP as GPR can be tricky/dangerous...
Hi again,
Thanks for your excellent research. I've been attempting to optimise OpenJPH, a JPEG2000 implementation, with AMX. Just starting off with one of the wavelet transform functions AMX is coming out significantly slower than the NEON instructions generated by the compiler.
First I wanted to ask about AMX_SET()
/ AMX_CLR()
. I timed them in a tight loop and the average came to 7.24 nanoseconds for an AMX_SET()
/ AMX_CLR()
pair, which sounds reasonable. I'm not sure what's happening though because when I go and actually put them around my test function its average time increases by about 2 milliseconds!
This is a rough sketch of how the test goes:
void compress_an_image() {
...
// Calls this 1000s of times
gen_irrev_horz_wvlt_fwd_tx(...)
...
}
// Call this 30 times and take the average
void trial () {
AMX_SET();
compress_an_image();
AMX_CLR();
}
My times are:
Original code, NEON generated by compiler:
Avg run time (ms): 49.935734
Avg run time (ms): 49.773602
Avg run time (ms): 49.785702
Avg run time (ms): 50.063000
Original code, with AMX_SET() / AMX_CLR() around trial(), AMX doing nothing useful:
Avg run time (ms): 52.630169
Avg run time (ms): 52.561501
Avg run time (ms): 52.559669
Avg run time (ms): 52.419498
Converting one stage of one of the wavelet transforms to AMX:
Avg run time (ms): 60.170498
Avg run time (ms): 60.380032
Avg run time (ms): 60.384933
Avg run time (ms): 60.366333
And it just gets worse the more code I convert to AMX.
Regardless of the actual AMX code I wrote, there's some weirdness around AMX_SET()
/ AMX_CLR()
. If I put them around gen_irrev_horz_wvlt_fwd_tx()
which gets called 1000s of times during image compression, it's much slower, around 70 ms.
I was wondering if you have any insights, I know there's lots going on that could be slowing it down like the kernel having to save and load the AMX state when context switching, or caching, or how the CPU cores share the AMX blocks (still trying to get my head around that).
I also wanted to show you some of the wavelet transform code if you'd like to have a look. There isn't much AMX code out there so I wrote in a way that I thought would give a reasonable speed increase, but so far no go.
The following only shows the first stage of the transform:
Original code: https://github.com/aous72/OpenJPH/blob/master/src/core/transform/ojph_transform.cpp#L357
void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
float *src = line_src->f32;
float *ldst = line_ldst->f32, *hdst = line_hdst->f32;
const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
const ui32 H_width = (width + (even ? 0 : 1)) >> 1;
//extension
src[-1] = src[1];
src[width] = src[width-2];
// predict
float factor = LIFTING_FACTORS::steps[0];
const float* sp = src + (even ? 1 : 0);
float *dph = hdst;
for (ui32 i = H_width; i > 0; --i, sp+=2)
*dph++ = sp[0] + factor * (sp[-1] + sp[1]);
}
My AMX (de)optimisation:
#define _amx_ldx(srcBuf, destIndex, flags) AMX_LDX(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_ldy(srcBuf, destIndex, flags) AMX_LDY(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_stz(destBuf, srcIndex, flags) AMX_STZ(((uint64_t)&*(destBuf)) | ((uint64_t)(srcIndex)<<56) | (flags))
#define VECFP_MULTIPLE_2 0
#define VECFP_MULTIPLE_4 (1ull << 25)
// *M2 only* Multiple mode (bit 31=1), regular load (bit 53=0)
#define _amx_vecfp_multiple(xOffset, yOffset, zRow, xShuffle, yShuffle, bMode, laneWidthMode, aluMode, flags) \
AMX_VECFP( \
((uint64_t)(xOffset) << 10) | \
((uint64_t)(yOffset)) | \
((uint64_t)(zRow) << 20) | \
((uint64_t)(xShuffle) << 29) | \
((uint64_t)(yShuffle) << 27) | \
(1ull << 31) | \
((uint64_t)(bMode) << 32) | \
((uint64_t)(laneWidthMode) << 42) | \
((uint64_t)(aluMode) << 47) | \
(flags) \
)
void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
static float amxScratch[32] __attribute__((aligned(128))) = { 0.0f };
// src, ldst, and hdst are aligned to 128 bytes
float *ldst = line_ldst->f32, *hdst = line_hdst->f32, *src = line_src->f32;
// even is always true
// 240 < width < 1920
const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
const ui32 H_width = (width + (even ? 0 : 1)) >> 1;
//extension
src[-1] = src[1];
src[width] = src[width-2];
// predict
const float* sp = src + (even ? 1 : 0);
float *dph = hdst;
amxScratch[0] = LIFTING_FACTORS::steps[0];
_amx_ldy(&amxScratch[0], 0, LDST_MODE_SINGLE);
// Do ceil(H_width / 32) iterations
for (ui32 i = 0; i < (H_width + 31) >> 5; i++) {
// Process 64 floats from sp down to 32 floats to dph
_amx_ldx(&sp[-1], 0, LDST_MODE_QUAD);
// Extension
_amx_ldx(&sp[63], 4, LDST_MODE_SINGLE);
_amx_vecfp_multiple(0, 0, 0, 3, 0, 7, 4, 10, VECFP_MULTIPLE_4); // Z = S3(X) * Y
_amx_vecfp_multiple(8, 0, 0, 3, 0, 7, 4, 0, VECFP_MULTIPLE_4); // Z += (S3(X<<2)) * Y
_amx_vecfp_multiple(4, 0, 0, 3, 0, 0, 4, 11, VECFP_MULTIPLE_4); // Z += (S3(X<<1))
_amx_stz(&dph[0], 0, LDST_MODE_SINGLE);
_amx_stz(&dph[8], 16, LDST_MODE_SINGLE);
_amx_stz(&dph[16], 32, LDST_MODE_SINGLE);
_amx_stz(&dph[24], 48, LDST_MODE_SINGLE);
dph += 32;
sp += 64;
}
}
Disassembly:
**************************************************************
* ojph::local::gen_irrev_horz_wvlt_fwd_tx(ojph::line_buf*... *
**************************************************************
undefined __cdecl gen_irrev_horz_wvlt_fwd_tx(line_buf *
undefined w0:1 <RETURN>
line_buf * x0:8 param_1
line_buf * x1:8 param_2
line_buf * x2:8 param_3
uint w3:4 param_4
bool w4:1 param_5
__ZN4ojph5local26gen_irrev_horz_wvlt_fwd_txEPN XREF[2]: Entry Point(*),
ojph::local::gen_irrev_horz_wvlt_fwd_tx init_wavelet_transform_functions
0001ed6c 0a 08 40 f9 ldr x10,[param_1, #0x10]
0001ed70 48 08 40 f9 ldr x8,[param_3, #0x10]
0001ed74 89 00 00 52 eor w9,param_5,#0x1
0001ed78 29 01 03 0b add w9,w9,param_4
0001ed7c 40 05 40 bd ldr s0,[x10, #0x4]
0001ed80 40 c1 1f bc stur s0,[x10, #-0x4]
0001ed84 6b 08 00 51 sub w11,param_4,#0x2
0001ed88 40 59 6b bc ldr s0,[x10,w11, uxtw #2]
0001ed8c 40 59 23 bc str s0,[x10,param_4, uxtw #2]
0001ed90 8b 57 0c 10 adr x11,0x37880
0001ed94 1f 20 03 d5 nop
0001ed98 6c ce 80 52 mov w12,#0x673
0001ed9c 6c f9 b7 72 movk w12,#0xbfcb, LSL #16
0001eda0 6c 01 00 b9 str w12,[x11]=>ojph::local::gen_irrev_horz_wvlt_fw
0001eda4 2b 10 20 00 __amx_ldy x11
0001eda8 3f 09 00 71 cmp w9,#0x2
0001edac c3 04 00 54 b.cc LAB_0001ee44
0001edb0 29 7d 01 53 lsr w9,w9,#0x1
0001edb4 29 7d 00 11 add w9,w9,#0x1f
0001edb8 e9 17 49 4b neg w9,w9, LSR #0x5
0001edbc 4a 49 24 8b add x10,x10,param_5, UXTW #0x2
0001edc0 4a 11 00 d1 sub x10,x10,#0x4
0001edc4 0b 00 ea d2 mov x11,#0x5000000000000000
0001edc8 0c 40 bc d2 mov x12,#0xe2000000
0001edcc ec 00 c2 f2 movk x12,#0x1007, LSL #32
0001edd0 ac 00 e0 f2 movk x12,#0x5, LSL #48
0001edd4 0d 00 84 d2 mov x13,#0x2000
0001edd8 0d 40 bc f2 movk x13,#0xe200, LSL #16
0001eddc ed 00 c2 f2 movk x13,#0x1007, LSL #32
0001ede0 0e 00 82 d2 mov x14,#0x1000
0001ede4 0e 40 bc f2 movk x14,#0xe200, LSL #16
0001ede8 0e 00 d2 f2 movk x14,#0x9000, LSL #32
0001edec ae 00 e0 f2 movk x14,#0x5, LSL #48
LAB_0001edf0 XREF[1]: 0001ee40(j)
0001edf0 4f 01 0b aa orr x15,x10,x11
0001edf4 0f 10 20 00 __amx_ldx x15
0001edf8 4a 01 04 91 add x10,x10,#0x100
0001edfc 4f 01 46 b2 orr x15,x10,#0x400000000000000
0001ee00 0f 10 20 00 __amx_ldx x15
0001ee04 6c 12 20 00 __amx_ve x12
0001ee08 6d 12 20 00 __amx_ve x13
0001ee0c 6e 12 20 00 __amx_ve x14
0001ee10 a8 10 20 00 __amx_stz x8
0001ee14 0f 81 00 91 add x15,x8,#0x20
0001ee18 ef 01 44 b2 orr x15,x15,#0x1000000000000000
0001ee1c af 10 20 00 __amx_stz x15
0001ee20 0f 01 01 91 add x15,x8,#0x40
0001ee24 ef 01 43 b2 orr x15,x15,#0x2000000000000000
0001ee28 af 10 20 00 __amx_stz x15
0001ee2c 0f 81 01 91 add x15,x8,#0x60
0001ee30 ef 05 44 b2 orr x15,x15,#0x3000000000000000
0001ee34 af 10 20 00 __amx_stz x15
0001ee38 08 01 02 91 add x8,x8,#0x80
0001ee3c 29 05 00 31 adds w9,w9,#0x1
0001ee40 83 fd ff 54 b.cc LAB_0001edf0
LAB_0001ee44 XREF[1]: 0001edac(j)
0001ee44 c0 03 5f d6 ret
I'm not too good with understanding assembly but it looks like the compiler did an OK job taking all those mov
movk
instructions that setup the operands for the AMX instructions outside of the loop to speed things up. For NEON it should be outputting 4 floats per loop while my AMX code stores 32 floats each loop (well technically 40 but it overlaps), so theoretically it should be a lot faster even if my implementation wasn't ideal.
If you've got any ideas or info on cycle counts / pipelining etc they would be greatly appreciated!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.