The purpose of this issue
The purpose of this issue is to communicate about the planned change on nvFuser I'd like to make to enable matmul NN support. The entire change would take multiple PRs, and this issue is to provide a big picture and help understand the motivation of these PRs. Another purpose is to review the design early.
The challenge
NN support
The challenge about NN is not "how to schedule an NN matmul", but instead "how to define an NN matmul". Currently, MmaOp
can be considered as a fused multiply-sum, and the entire matmul is defined as broadcasts followed by MmaOp
:
![](https://user-images.githubusercontent.com/1032377/233734940-7e7d2c2f-448c-4f0d-9153-4ea7e497e657.png)
The currently supported layouts are TT, TN, NT:
- For TT, the input shapes are
[M, K]
and [K, N]
, we broadcast inputs into [M, K, 1]
and [1, K, N]
, after multiplication, we have [M, K, N]
, and after reduction, we get [M, N]
.
- For TN, the input shapes are
[M, K]
and [N, K]
, we broadcast inputs into [M, 1, K]
and [1, N, K]
, after multiplication, we have [M, N, K]
, and after reduction, we get [M, N]
.
- For NT, the input shapes are
[K, M]
and [K, N]
, we broadcast inputs into [K, M, 1]
and [K, 1, N]
, after multiplication, we have [K, M, N]
, and after reduction, we get [M, N]
.
But for NN, the input shapes are [K, M]
and [N, K]
, there is no such way to get an output of [M, N]
. It is only possible to get [N, M]
. In order to support NN, there must be changes in fusion definition.
Forward compatibility
Our design should be compatible with our future needs, although these needs might not be our priority today. Below are some examples that I think is related to this topic:
- Matmul-transpose fusion: For example,
(A.T @ B).T
. We should allow users to insert transposes anywhere in the fusion definition, and nvFuser should be able to find the correct layout for matmul with these transposes taken into consideration.
- Transpose support in other schedulers: Currently, transpose is supported in a separate scheduler. This is not optimal. We should handle transpose and generate optimal code for all schedulers.
- View support in other schedulers: Similar to above, currently, we are only supporting specific patterns of view. Any infrastructure change that could potentially be useful to more complete view support is welcome.
Principles
Easiness to define a fusion
From the user(the person who creates a fusion)'s perspective, fusion definitions should be as flexible as possible. Users shouldn't worry about performance when defining a fusion. The only thing a user should be worrying about is mathematical correctness. The matmul scheduler should accept all mathematically equivalent fusions and schedule all of them optimally. It is an unacceptable experience if the user need to think about things like below when defining the fusion:
I want to do an Ampere matmul, Ampere has an ldmatrix.trans instruction that allows transpose in the smem to register load, and the mma op is always TN, so I will transpose the inputs to [M, K] and [N, K] first, and then do mma.
Instead, for example, for the [M, K] @ [K, N] -> [M, N]
matmul, users should be able to define the fusion whatever way is most convenient to the user, options include:
![Screenshot_20230421_184052](https://user-images.githubusercontent.com/1032377/233755265-327d2a41-bb6c-4a84-8e63-32b6aaec7576.png)
Easiness of lowering
For lowering to be easy, the fusion definition should be as close to the hardware as possible. However, different architecture has different matmul flavor, so there is no single "canonical form" that is close to all hardware.
On Volta, matrices are loaded to register in the same layout as inputs, and the mma op has TT, TN, NT, NN variants. So on Volta, the matmul definition closest to hardware is
![image](https://user-images.githubusercontent.com/1032377/234374182-78cf028b-7fde-4b51-bdbb-54854d6b0d6c.png)
On Turing & Ampere, in shared memory, the layout of the matrices is the same as inputs. However, when shared memory matrices is loaded into register (using ldmatrix
/ldmatrix.trans
), the layouts of the matrices always become TN, and there is only one variant of mma which is TN. So on Turing & Ampere, the matmul definition closest to hardware is
![Screenshot_20230421_160350](https://user-images.githubusercontent.com/1032377/233746764-60c903eb-934d-41a3-a1c5-0d5ba4e18325.png)
On hopper, there are two variants of wgmma
: the "rs" variant whose first operand is on register and the second operand is on shared memory, and the "ss" variant whose both operands are on shared memory. For the "rs" variant, the first operand must be [M, K]
, the second variant can be either [N,K]
or [K,N]
. For the "ss" variant, its first operand can be [M, K]
or [K, M]
, and its second variant can be [N,K]
or [K,N]
. So the the matmul definition closest to hardware is:
![Screenshot_20230421_162027](https://user-images.githubusercontent.com/1032377/233748037-412b01f8-7686-4e63-9a10-65d5a87e1fb8.png)
![Screenshot_20230421_162255](https://user-images.githubusercontent.com/1032377/233748221-a788da90-03a7-4142-9bb9-2d0ac0ec2032.png)
Design
As described in the last section, the principles of easiness to define a fusion and easiness of lowering would lead to different fusion definitions. For a given hardware, the "easiness of lowering" form is the canonical form that the scheduler needs to transform the fusion into. In order to be able to define the canonical form and transform to the canonical form, the following concepts will be added/changed for nvFuser:
Implicit transpose/view/broadcast/squeeze
Implicit transpose
is a transpose that happens in the rFactor domain of a tensor whose definition is not a TransposeOp
. Starting from #148, all transposes are implicit (because we no longer have a TransposeOp
), but we are only practicing implicit transpose tensors defined by LoadStoreOp
. In order to define canonical forms of mma op, we would need to make the output tensor of MmaOp
implicitly transposed as well.
I would like to go farther:
- I would like to do not just implicit transpose, but also implicit view, broadcast, and squeeze.
- I would like to enable implicit transpose/view/broadcast/squeeze not only on
MmaOp
's output, but on all tensors.
For example, if I have T1 = sin(T0)
where T0
has shape [I0, 1, I1, I2]
, and T1
has
root domain: [I0, 1, I1, I2]
rFactor domain: [I2, I0*I1, 1]
Then the output T1
's shape will be [I2, I0*I1, 1]
. That is, T1 does an implicit squeeze-view-transpose-broadcast.
The reason for doing so is because:
- When indexing, we only care about how
IterDomain
s are transformed, we don't really care whether the tensor is defined by LoadStoreOp
or UnaryOpType::ReLU
. I believe this new approach can free up more flexibility without adding too much complexity to our existing system. Instead, I think our system will have in total less lines of code because we can just define many of these ops as LoadStoreOp
, instead of each has its separate Expr
subclass.
BroadcastOp
, SqueezeOp
, ViewOp
, TransposeOp
, etc. is essentially just a data copying operation with fancy indexing. In the generated C++ code, I am not a fan of reading things like
T1[i1] = T0[i0];
T2[i2] = sin(T1[i1]);
I prefer to read just T2[i2] = sin(T0[i0])
. So I want to reduce redundant copies in generated C++. For the case of matmul, I would like to have a single ldmatrix.trans
that does transpose+broadcast together, so that I don't have to deal with this extra broadcast.
Forward and backward push of the rFactor domain [Obsolete]
I would like to add two methods to TensorView
: pushRFactorForward
and pushRFactorBackward
.
Assume we have the following fusion: T0 --set--> T1 --set--> T2
where
T0: root = [I0, I1], no rfactor
T1: root = [I0, I1], rfactor = [I1, I0]
T2: root = [I1, I0], no rfactor
Then pushRFactorForward
will transform the fusion as
T0: root = [I0, I1], no rfactor
T1: root = [I0, I1], no rfactor
T2: root = [I0, I1], rfactor = [I1, I0]
And pushRFactorBackward
will transform the fusion as
T0: root = [I0, I1], rfactor = [I1, I0]
T1: root = [I1, I0], no rfactor
T2: root = [I1, I0], no rfactor
It is possible to specify intermediate state to only push part of the transformations. For example, if the fusion has T0--view-->T1--sin-->T2
and
T0: root = [I0, I1, I2], no rfactor
T1: root = [I0, I1, I2], rfactor = [I1*I0/4, 4, I2]
T2: root = [I1*I0/4, 4, I2], no rfactor
Then pushRFactorBackward([I1*I0, I2])
will get
T0: root = [I0, I1, I2], rfactor = [I1*I0, I2]
T1: root = [I1*I0, I2], rfactor = [I1*I0/4, 4, I2]
T2: root = [I1*I0/4, 4, I2], no rfactor
Note that for this case, T0
might have implicit view, that is, a view that happens implicitly at some op that is not a ViewOp
. And pushRFactorForward([I1*I0, I2])
will get
T0: root = [I0, I1, I2], no rfactor
T1: root = [I0, I1, I2], rfactor = [I1*I0, I2]
T2: root = [I1*I0, I2], rfactor = [I1*I0/4, 4, I2]
Note that for this case, T2
will be have implicit view.
If the current tensor has more than one producers, then a backward push will push its rfactor domain to all its producers. Similarly, if the current tensor has multiple consumers, then a forward push will push it to all its consumers. More complicated, if c = op(a, b)
, then a->pushRFactorForward()
will modify both b
and c
. Similar for the case of pushRFactorBackward
to a tensor with multiple uses.
Not all pushes are valid. A push must be compatible with the current schedule. For example, if you have T0->T1
where
T0: root = [I0, I1, I2], no rfactor, leaf = [I0, I1*I2]
T1: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [I0*I1, I2]
Then T1->pushRFactorBackward()
is illegal. However, if
T0: root = [I0, I1, I2], no rfactor, leaf = [(I0*I1)*I2]
T1: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [I0*I1, I2]
Then T1->pushRFactorBackward()
will get
T0: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [(I0*I1)*I2]
T1: root = [I0*I1, I2], no rfactor, leaf = [I0*I1, I2]
There may be other types of invalid push, but discussing this is not in the scope of this issue.
The goal of these two methods is to make the manipulation of rFactor
domains super easy. I believe in the long term, this added functionality would be helpful for view
and resize
scheduling. For the case of matmul, these two methods should make it easy to canonicalize user's input into hardware flavor. For example, the following transformation can be used to schedule an NN matmul for Ampere:
![Screenshot_20230421_190334](https://user-images.githubusercontent.com/1032377/233756191-afad6997-d25f-4568-aa8a-d0bbc8587507.png)
MmaOp vs Mul-Sum
MmaOp
and Mul->Sum
are two mathematically equivalent way to define matmul. I prefer to define the fusion as Mul->Sum
in user-facing API, and let the scheduler to convert it to MmaOp
and fill in informations.
The content you are editing has changed. Please copy your edits and refresh the page.
-
-
-