Git Product home page Git Product logo

amx's People

Contributors

corsix avatar geohot avatar johshoff avatar pthariensflame avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

amx's Issues

Possibility of adding support for Linux for Apple AMX1 and AMX2

Hi all,

I'm in the process of researching Apple AMX as a potential way of speeding up IEEE FP BLAS kernels in OpenBLAS.

On the MacOS side, it seems that between this repository and other resources, I have all I need to be able to write the kernels.

The issue as of now is Linux. Speaking with the folks supporting/developing Asahi Linux (see mastodon thread here: https://mast.hpc.social/@fclc/109914828822965657) discussion came up that Asahi has no plans to support the EL0 CPU state required for AMX.

I'm of the opinion that it may be possible to implement a Linux kernel module to allow for the usage of AMX on M1, M2 and the various SKUs based on those SOCs.

This would probably require fairly tight understanding of AMX and its underlying operations.

I was hoping for insight from any of the folks working on this present project.

Some tests failing in the M3 Max

This is the output of make test in an Apple M3 Max:

Testing AMX_LDX... Failed on iteration 0.0 (operand 0xe7ee80da3d4b9e09)
Testing AMX_LDY... Failed on iteration 0.0 (operand 0xe7ee80da3d4b9e09)
Testing AMX_LDZ... OK   
Testing AMX_LDZI... OK   
Testing AMX_STX... OK   
Testing AMX_STY... OK   
Testing AMX_STZ... OK   
Testing AMX_STZI... OK   
Testing AMX_EXTRX... OK   
Testing AMX_EXTRY... OK   
Testing AMX_MAC16... OK   
Testing AMX_FMA16... OK   
Testing AMX_FMA32... OK   
Testing AMX_FMA64... OK   
Testing AMX_FMS16... OK   
Testing AMX_FMS32... OK   
Testing AMX_FMS64... OK   
Testing AMX_VECINT... OK   
Testing AMX_VECFP... OK   
Testing AMX_MATINT... Failed on iteration 1.575 (operand 0xd060b15ad5995f37)
Testing AMX_MATFP... OK   
Testing AMX_GENLUT... OK

This suggests either Apple broke compatibility with the previous versions, or there are new features using some of the previously-ignored bit in the parameters to these instructions.

I think the former is unlikely, as I have been writing lots of AMX code lately, with excellent test coverage, and I'm yet to see any unexplained failures in my software tests, e.g. something that behaves differently from the M3 than the M1 I also have here (using your documented M1 features, and also some of the documented M2 features, which work as expected on the M3). So hopefully there are new features in the M3.

I investigated this a bit by changing the random values, and I see that for AMX_LDX and AMX_LDY, out of the previously ignored bits (63, 61 and 59), only bit 61 is always set in case of a test error; for 63 and 59, they are sometimes set and sometimes not (indeed, I've seen an error for which bits 63 and 59 were not set, only 61 was).

So I wrote a small program to investigate this, and found that bit 61 represents a strided load: when loading pairs, the stride is 4 (that is, if you start at X0, it loads to X0 and X4), whereas when loading 4 at a time, the stride is 2 (e.g. X0, X2, X4, X6). I will attach a test program and its output on my M3. For AMX_LDY, results are identical.

As for AMX_MATINT, I collected a bunch of values where the tests fail:

0xd060b15ad5995f37
0xd26ab256885620e0
0xd060b15ad5995f37
0x0e61b375b73c8104
0xbc7870a58e4864bc
0xce6b3046d4af6812
0xa069335db08b4b0e
0x5a71b34bf47fe485
0xee7172c6ce0a04ec
0xa87ef14a1baca54d
0xb662b045bc40cdb0
0xc4697074e454ab6f

ANDing these together, the common theme is bits 44, 45, 53 and 54 set. I see that having bits 53 and 54 set means an indexed load in ALU mode 8. For that mode, there are two lane width modes (i.e. bits 45:42): 10 or any other value. However, having bits 44 and 45 set would correspond to 12.

If you'd like to investigate, but don't have access to an M3, I can run any tests you need; just let me know.

ldx_m3_src.txt
ldx_m3_out.txt

M2 Compatibility

Hi,

On my M2 (2022 MacBook Air) I'm getting the following:

AMX_LDX: fail
AMX_LDY: fail
AMX_LDZ: pass
AMX_LDZI: pass
AMX_STX: pass
AMX_STY: pass
AMX_STZ: pass
AMX_STZI: pass
AMX_EXTRX: fail
AMX_EXTRY: fail
AMX_MAC16: pass
AMX_FMA16: pass
AMX_FMA32: pass
AMX_FMA64: pass
AMX_FMS16: pass
AMX_FMS32: pass
AMX_FMS64: pass
AMX_VECINT: fail
AMX_VECFP: fail
AMX_MATINT: pass
AMX_MATFP: fail
AMX_GENLUT: fail

This is with clang 14.0.0 on macOS 13.1. I could be doing something wrong here, or there might be a minor architectural difference between AMX on M1 and M2.

I'm going to investigate further to see if I can get everything to pass on M2, but first I was wondering if there has been any existing work done on M2 yet?

Thanks.

Updated and new publications on using AMX for cryptography

This is a followup on #9.

Our paper mentioned there was published in the the first edition of the new IACR Communications in Cryptology journal, so I'd like to ask if you'd be so kind as to update the link to the actual journal paper, I'd be grateful. We note that the repository associated with this paper is now online, so you might be interested in linking to that as well.

Finally, we have another paper out on implementing two other post-quantum cryptosystems on AMX. You may also be interested in adding a link to it.

Thanks.

A preprint on cryptography using AMX

This is to let you know that I and some colleagues published a preprint for a paper on using AMX for cryptographic applications:

https://eprint.iacr.org/2024/2

Your repository was an invaluable resource for accomplishing our results. Perhaps you'd like to add a link to your references page. We think it might a good resource to help others.

matint/matfp vs mac16/fma*/fms*, and how to do multiply-only (without accumulating/subtracting) using matint

This is a question rather than an issue.

First of all, thanks for the huge effort spent on reverse engineering AMX and documenting it. It is really appreciated. I've been able to try out AMX for an application of mine and that would have been impossible without your extensive documentation.

I just wanted to confirm a few things, and I was hoping with the experience acquired with your reverse engineering effort, you could confirm a few things. Perhaps even add some pointers to enhance your already excellent documentation.

  1. Am I understanding correctly that there is some redundancy between matint/matfp and mac16/fm[as][16,32,64] when you only need to do outer products? Essentially, by choosing ALU mode 0 or 1 in matint/matfp, it appears to me you can emulate almost all functionality of mac16/fm[as][16,32,64]. Is that right or do you see any particular advantage to mac16/fm[as][16,32,64] vs matint/matfp? I'm just trying to understand the thought process of the AMX designers, and why would they have separate instructions when they could just merge them together.
  2. With that said, one functionality I was unable to find in matint/matfp was the ability to do a multiply-only operation x*y rather than accumulating/subtracting with z. With mac16/fm[as][16,32,64] this can be done by setting bit 27 to 1. This is important e.g. for the first loop iteration in a matrix multiply, when z may already have non-zero values. Is there a combination of bits I can use to do an outer product without accumulating/subtracting from z?

register 31 - SP or XZR?

dougallj claims that 31 means XZR, but do any AMX operations make sense with an all-zero value? OTOH, using SP as GPR can be tricky/dangerous...

Using AMX for non-ML optimisation

Hi again,

Thanks for your excellent research. I've been attempting to optimise OpenJPH, a JPEG2000 implementation, with AMX. Just starting off with one of the wavelet transform functions AMX is coming out significantly slower than the NEON instructions generated by the compiler.

First I wanted to ask about AMX_SET() / AMX_CLR(). I timed them in a tight loop and the average came to 7.24 nanoseconds for an AMX_SET() / AMX_CLR() pair, which sounds reasonable. I'm not sure what's happening though because when I go and actually put them around my test function its average time increases by about 2 milliseconds!

This is a rough sketch of how the test goes:

void compress_an_image() {
...
// Calls this 1000s of times
gen_irrev_horz_wvlt_fwd_tx(...)
...
}

// Call this 30 times and take the average
void trial () {
AMX_SET();
compress_an_image();
AMX_CLR();
}

My times are:

Original code, NEON generated by compiler:
Avg run time (ms): 49.935734
Avg run time (ms): 49.773602
Avg run time (ms): 49.785702
Avg run time (ms): 50.063000

Original code, with AMX_SET() / AMX_CLR() around trial(), AMX doing nothing useful:
Avg run time (ms): 52.630169
Avg run time (ms): 52.561501
Avg run time (ms): 52.559669
Avg run time (ms): 52.419498

Converting one stage of one of the wavelet transforms to AMX:
Avg run time (ms): 60.170498
Avg run time (ms): 60.380032
Avg run time (ms): 60.384933
Avg run time (ms): 60.366333

And it just gets worse the more code I convert to AMX.

Regardless of the actual AMX code I wrote, there's some weirdness around AMX_SET() / AMX_CLR(). If I put them around gen_irrev_horz_wvlt_fwd_tx() which gets called 1000s of times during image compression, it's much slower, around 70 ms.

I was wondering if you have any insights, I know there's lots going on that could be slowing it down like the kernel having to save and load the AMX state when context switching, or caching, or how the CPU cores share the AMX blocks (still trying to get my head around that).

I also wanted to show you some of the wavelet transform code if you'd like to have a look. There isn't much AMX code out there so I wrote in a way that I thought would give a reasonable speed increase, but so far no go.

The following only shows the first stage of the transform:

Original code: https://github.com/aous72/OpenJPH/blob/master/src/core/transform/ojph_transform.cpp#L357

void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
    float *src = line_src->f32;
    float *ldst = line_ldst->f32, *hdst = line_hdst->f32;

    const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
    const ui32 H_width = (width + (even ? 0 : 1)) >> 1;

    //extension
    src[-1] = src[1];
    src[width] = src[width-2];
    // predict
    float factor = LIFTING_FACTORS::steps[0];
    const float* sp = src + (even ? 1 : 0);
    float *dph = hdst;
    for (ui32 i = H_width; i > 0; --i, sp+=2)
      *dph++ = sp[0] + factor * (sp[-1] + sp[1]);
}

My AMX (de)optimisation:

#define _amx_ldx(srcBuf, destIndex, flags) AMX_LDX(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_ldy(srcBuf, destIndex, flags) AMX_LDY(((uint64_t)&*(srcBuf)) | ((uint64_t)(destIndex)<<56) | (flags))
#define _amx_stz(destBuf, srcIndex, flags) AMX_STZ(((uint64_t)&*(destBuf)) | ((uint64_t)(srcIndex)<<56) | (flags))

#define VECFP_MULTIPLE_2 0
#define VECFP_MULTIPLE_4 (1ull << 25)

// *M2 only* Multiple mode (bit 31=1), regular load (bit 53=0)
#define _amx_vecfp_multiple(xOffset, yOffset, zRow, xShuffle, yShuffle, bMode, laneWidthMode, aluMode, flags) \
  AMX_VECFP( \
    ((uint64_t)(xOffset) << 10) | \
    ((uint64_t)(yOffset)) | \
    ((uint64_t)(zRow) << 20) | \
    ((uint64_t)(xShuffle) << 29) | \
    ((uint64_t)(yShuffle) << 27) | \
    (1ull << 31) | \
    ((uint64_t)(bMode) << 32) | \
    ((uint64_t)(laneWidthMode) << 42) | \
    ((uint64_t)(aluMode) << 47) | \
    (flags) \
  )

void gen_irrev_horz_wvlt_fwd_tx(line_buf* line_src, line_buf *line_ldst, line_buf *line_hdst, ui32 width, bool even) {
  static float amxScratch[32] __attribute__((aligned(128))) = { 0.0f };

  // src, ldst, and hdst are aligned to 128 bytes
  float *ldst = line_ldst->f32, *hdst = line_hdst->f32, *src = line_src->f32;

  // even is always true
  // 240 < width < 1920

  const ui32 L_width = (width + (even ? 1 : 0)) >> 1;
  const ui32 H_width = (width + (even ? 0 : 1)) >> 1;

  //extension
  src[-1] = src[1];
  src[width] = src[width-2];
  // predict
  const float* sp = src + (even ? 1 : 0);
  float *dph = hdst;

  amxScratch[0] = LIFTING_FACTORS::steps[0];
  _amx_ldy(&amxScratch[0], 0, LDST_MODE_SINGLE);
  // Do ceil(H_width / 32) iterations
  for (ui32 i = 0; i < (H_width + 31) >> 5; i++) {
    // Process 64 floats from sp down to 32 floats to dph
    _amx_ldx(&sp[-1], 0, LDST_MODE_QUAD);
    // Extension
    _amx_ldx(&sp[63], 4, LDST_MODE_SINGLE);

    _amx_vecfp_multiple(0, 0, 0, 3, 0, 7, 4, 10, VECFP_MULTIPLE_4); // Z =  S3(X) * Y
    _amx_vecfp_multiple(8, 0, 0, 3, 0, 7, 4, 0, VECFP_MULTIPLE_4); // Z += (S3(X<<2)) * Y
    _amx_vecfp_multiple(4, 0, 0, 3, 0, 0, 4, 11, VECFP_MULTIPLE_4); // Z += (S3(X<<1))

    _amx_stz(&dph[0], 0, LDST_MODE_SINGLE);
    _amx_stz(&dph[8], 16, LDST_MODE_SINGLE);
    _amx_stz(&dph[16], 32, LDST_MODE_SINGLE);
    _amx_stz(&dph[24], 48, LDST_MODE_SINGLE);

    dph += 32;
    sp += 64;
  }
}

Disassembly:

                             **************************************************************
                             * ojph::local::gen_irrev_horz_wvlt_fwd_tx(ojph::line_buf*... *
                             **************************************************************
                             undefined __cdecl gen_irrev_horz_wvlt_fwd_tx(line_buf * 
             undefined         w0:1           <RETURN>
             line_buf *        x0:8           param_1
             line_buf *        x1:8           param_2
             line_buf *        x2:8           param_3
             uint              w3:4           param_4
             bool              w4:1           param_5
                             __ZN4ojph5local26gen_irrev_horz_wvlt_fwd_txEPN  XREF[2]:     Entry Point(*), 
                             ojph::local::gen_irrev_horz_wvlt_fwd_tx                      init_wavelet_transform_functions
        0001ed6c 0a 08 40 f9     ldr        x10,[param_1, #0x10]
        0001ed70 48 08 40 f9     ldr        x8,[param_3, #0x10]
        0001ed74 89 00 00 52     eor        w9,param_5,#0x1
        0001ed78 29 01 03 0b     add        w9,w9,param_4
        0001ed7c 40 05 40 bd     ldr        s0,[x10, #0x4]
        0001ed80 40 c1 1f bc     stur       s0,[x10, #-0x4]
        0001ed84 6b 08 00 51     sub        w11,param_4,#0x2
        0001ed88 40 59 6b bc     ldr        s0,[x10,w11, uxtw #2]
        0001ed8c 40 59 23 bc     str        s0,[x10,param_4, uxtw #2]
        0001ed90 8b 57 0c 10     adr        x11,0x37880
        0001ed94 1f 20 03 d5     nop
        0001ed98 6c ce 80 52     mov        w12,#0x673
        0001ed9c 6c f9 b7 72     movk       w12,#0xbfcb, LSL #16
        0001eda0 6c 01 00 b9     str        w12,[x11]=>ojph::local::gen_irrev_horz_wvlt_fw
        0001eda4 2b 10 20 00     __amx_ldy  x11
        0001eda8 3f 09 00 71     cmp        w9,#0x2
        0001edac c3 04 00 54     b.cc       LAB_0001ee44
        0001edb0 29 7d 01 53     lsr        w9,w9,#0x1
        0001edb4 29 7d 00 11     add        w9,w9,#0x1f
        0001edb8 e9 17 49 4b     neg        w9,w9, LSR #0x5
        0001edbc 4a 49 24 8b     add        x10,x10,param_5, UXTW  #0x2
        0001edc0 4a 11 00 d1     sub        x10,x10,#0x4
        0001edc4 0b 00 ea d2     mov        x11,#0x5000000000000000
        0001edc8 0c 40 bc d2     mov        x12,#0xe2000000
        0001edcc ec 00 c2 f2     movk       x12,#0x1007, LSL #32
        0001edd0 ac 00 e0 f2     movk       x12,#0x5, LSL #48
        0001edd4 0d 00 84 d2     mov        x13,#0x2000
        0001edd8 0d 40 bc f2     movk       x13,#0xe200, LSL #16
        0001eddc ed 00 c2 f2     movk       x13,#0x1007, LSL #32
        0001ede0 0e 00 82 d2     mov        x14,#0x1000
        0001ede4 0e 40 bc f2     movk       x14,#0xe200, LSL #16
        0001ede8 0e 00 d2 f2     movk       x14,#0x9000, LSL #32
        0001edec ae 00 e0 f2     movk       x14,#0x5, LSL #48
                             LAB_0001edf0                                    XREF[1]:     0001ee40(j)  
        0001edf0 4f 01 0b aa     orr        x15,x10,x11
        0001edf4 0f 10 20 00     __amx_ldx  x15
        0001edf8 4a 01 04 91     add        x10,x10,#0x100
        0001edfc 4f 01 46 b2     orr        x15,x10,#0x400000000000000
        0001ee00 0f 10 20 00     __amx_ldx  x15
        0001ee04 6c 12 20 00     __amx_ve   x12
        0001ee08 6d 12 20 00     __amx_ve   x13
        0001ee0c 6e 12 20 00     __amx_ve   x14
        0001ee10 a8 10 20 00     __amx_stz  x8
        0001ee14 0f 81 00 91     add        x15,x8,#0x20
        0001ee18 ef 01 44 b2     orr        x15,x15,#0x1000000000000000
        0001ee1c af 10 20 00     __amx_stz  x15
        0001ee20 0f 01 01 91     add        x15,x8,#0x40
        0001ee24 ef 01 43 b2     orr        x15,x15,#0x2000000000000000
        0001ee28 af 10 20 00     __amx_stz  x15
        0001ee2c 0f 81 01 91     add        x15,x8,#0x60
        0001ee30 ef 05 44 b2     orr        x15,x15,#0x3000000000000000
        0001ee34 af 10 20 00     __amx_stz  x15
        0001ee38 08 01 02 91     add        x8,x8,#0x80
        0001ee3c 29 05 00 31     adds       w9,w9,#0x1
        0001ee40 83 fd ff 54     b.cc       LAB_0001edf0
                             LAB_0001ee44                                    XREF[1]:     0001edac(j)  
        0001ee44 c0 03 5f d6     ret

I'm not too good with understanding assembly but it looks like the compiler did an OK job taking all those mov movk instructions that setup the operands for the AMX instructions outside of the loop to speed things up. For NEON it should be outputting 4 floats per loop while my AMX code stores 32 floats each loop (well technically 40 but it overlaps), so theoretically it should be a lot faster even if my implementation wasn't ideal.

If you've got any ideas or info on cycle counts / pipelining etc they would be greatly appreciated!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.