nvidia / fuser Goto Github PK
View Code? Open in Web Editor NEWA Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
License: Other
A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
License: Other
The following code results in a segfault as of yesterday (e.g. commit 1a5db86)
import torch
import torch.nn.functional as F
import nvfuser
def test_embedding(
vocab_size=4096,
embedding_dim=512,
sentence_len=1024,
batch_size=128,
):
terms = torch.randint(0, vocab_size, [batch_size], dtype=torch.long, device='cuda')
embedding_table = torch.randn([vocab_size, embedding_dim], dtype=torch.float32, device='cuda')
torch_embedded = F.embedding(terms, embedding_table) # successful
with nvfuser.FusionDefinition() as fd:
t = fd.from_pytorch(terms)
e = fd.from_pytorch(embedding_table)
# look up each term's embedding and output the embedded terms
tb = fd.ops.broadcast(t, [False, True, True]) # bcast in term and embedding dims
eb = fd.ops.broadcast(e, [True, False, False]) # bcast in batch dimension
embedded = fd.ops.gather(eb, tb, dim=1)
fd.add_output(embedded)
nvf_out, _ = fd.execute([terms, embedding_table])
embedded = nvf_out[0].squeeze()
test_embedding()
A partial backtrace is here:
#0 0x00007fff6c003818 in std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> >::size (this=0x48)
at /usr/include/c++/9/bits/stl_vector.h:916
#1 0x00007fff6c595e76 in nvfuser::ComputeAtRootDomainMapBuilder::initializeBcastMap (this=0x7fffffffb110, tv=0x55555c0e6890,
id=0x55555c168100) at /opt/pytorch/nvfuser/csrc/root_domain_map.cpp:792
#2 0x00007fff6c598eb0 in nvfuser::ComputeAtRootDomainMapBuilder::handle (this=0x7fffffffb110, tv=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/root_domain_map.cpp:1195
#3 0x00007fff6c079a9c in nvfuser::Val::dispatch<nvfuser::OptOutDispatch*> (handler=0x7fffffffb110, val=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/dispatch.cpp:92
#4 0x00007fff6c0768a1 in nvfuser::OptOutDispatch::handle (this=0x7fffffffb110, v=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/dispatch.cpp:762
#5 0x00007fff6c331747 in nvfuser::BackwardVisitor::handle (this=0x7fffffffb110, val=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/iter_visitor.cpp:396
#6 0x00007fff6c079165 in nvfuser::Statement::dispatch<nvfuser::OptOutDispatch*> (handler=0x7fffffffb110, stmt=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/dispatch.cpp:330
#7 0x00007fff6c07684d in nvfuser::OptOutDispatch::handle (this=0x7fffffffb110, s=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/dispatch.cpp:754
#8 0x00007fff6c3316f3 in nvfuser::BackwardVisitor::handle (this=0x7fffffffb110, stmt=0x55555c0e6890)
at /opt/pytorch/nvfuser/csrc/iter_visitor.cpp:388
#9 0x00007fff6c331d79 in nvfuser::BackwardVisitor::traverseTo (this=0x7fffffffb110, fusion=0x55555a10de90,
from=std::vector of length 1, capacity 1 = {...}, traverseAllPaths=false) at /opt/pytorch/nvfuser/csrc/iter_visitor.cpp:462
#10 0x00007fff6c595930 in nvfuser::ComputeAtRootDomainMapBuilder::ComputeAtRootDomainMapBuilder (this=0x7fffffffb110, root_map=...,
map_through_reduction=false) at /opt/pytorch/nvfuser/csrc/root_domain_map.cpp:758
#11 0x00007fff6c593b13 in nvfuser::ComputeAtRootDomainMap::build (this=0x7fffffffb3f0, map_through_reduction=false)
at /opt/pytorch/nvfuser/csrc/root_domain_map.cpp:483
#12 0x00007fff6bfe3d0e in nvfuser::MaxPosCalculator::buildUnmappableDims (this=0x7fffffffb650, compute_at_only=false)
at /opt/pytorch/nvfuser/csrc/inlining.cpp:31
#13 0x00007fff6bfe3c79 in nvfuser::MaxPosCalculator::MaxPosCalculator (this=0x7fffffffb650,
uninlinable_ids=std::unordered_set with 0 elements, compute_at_only=false) at /opt/pytorch/nvfuser/csrc/inlining.cpp:21
#14 0x00007fff6bfe55d5 in nvfuser::inlineAllAt (reference_tv=0x55555c0fe5a0, reference_pos=2, best_effort=true,
uninlinable_ids=std::unordered_set with 0 elements) at /opt/pytorch/nvfuser/csrc/inlining.cpp:290
#15 0x00007fff6c5b8c54 in nvfuser::schedulePointwise (fusion=0x55555a10de90, params=...)
at /opt/pytorch/nvfuser/csrc/scheduler/pointwise.cpp:789
#16 0x00007fff6c60757d in nvfuser::(anonymous namespace)::PointWiseScheduler::schedule (this=0x55555bdb4860, fusion=0x55555a10de90)
at /opt/pytorch/nvfuser/csrc/scheduler/registry.cpp:1598
#17 0x00007fff6c357abe in nvfuser::FusionKernelRuntime::runKernelWithInput (this=0x55555aacc470, args=..., sg=0x55555a9b2800)
at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:365
#18 0x00007fff6c359f55 in nvfuser::FusionKernelRuntime::runWithInput (this=0x55555aacc470, args=...)
at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:659
The problem comes from dereferencing tv->definition()
without checking for nullptr
at
https://github.com/NVIDIA/Fuser/blob/main/csrc/root_domain_map.cpp#L794. Changing this line to
(tv->definition() && tv->definition()->outputs().size() > 1) ||
raises an informative (but uncaught) exception and the python script exits with a RuntimeError
.
In a use case from nanoGPT where the activations from the Input Linears of multihead attention are split, they should generate a horizontal fusion with 3 parallel sequences of slice
+reshape
+permute
. The resulting fusion from nvFuser gets segmented into 6 kernels which is not great.
Repro:
import torch
from nvfuser import FusionDefinition, DataType
inputs = [
torch.randn(16, 128, 3072, device='cuda'),
]
def nvfuser_fusion(fd : FusionDefinition) -> None :
T0 = fd.from_pytorch(inputs[0])
T0_slice1 = fd.ops.slice(T0, [0, 0, 0], [16, 128, 1024], [1, 1, 1])
T0_slice2 = fd.ops.slice(T0, [0, 0, 1024], [16, 128, 2048], [1, 1, 1])
T0_slice3 = fd.ops.slice(T0, [0, 0, 2048], [16, 128, 3072], [1, 1, 1])
T1_slice1 = fd.ops.reshape(T0_slice1, [16, 128, 1024], [16, 128, 16, 64])
T1_slice2 = fd.ops.reshape(T0_slice2, [16, 128, 1024], [16, 128, 16, 64])
T1_slice3 = fd.ops.reshape(T0_slice3, [16, 128, 1024], [16, 128, 16, 64])
T2_slice1 = fd.ops.permute(T1_slice1, [0, 2, 1, 3])
T2_slice2 = fd.ops.permute(T1_slice2, [0, 2, 1, 3])
T2_slice3 = fd.ops.permute(T1_slice3, [0, 2, 1, 3])
fd.add_output(T2_slice1)
fd.add_output(T2_slice2)
fd.add_output(T2_slice3)
with FusionDefinition() as fd:
nvfuser_fusion(fd)
out = fd.execute(inputs)
Nsys cmd:
nsys nvprof --print-gpu-trace python test.py
Nsys output:
Start (ns) Duration (ns) CorrId GrdX GrdY GrdZ BlkX BlkY BlkZ Reg/Trd StcSMem (MB) DymSMem (MB) Bytes (MB) Throughput (MBps) SrcMemKd DstMemKd Device Ctx Strm Name
---------- ------------- ------ ---- ---- ---- ---- ---- ---- ------- ------------ ------------ ---------- ----------------- -------- -------- -------------------- --- ---- ----------------------------------------------------------------------------------------------------
1307711678 23104 146 912 1 1 256 1 1 40 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
1499928416 6560 270 256 16 1 128 1 1 20 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel1(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)3>)
1648504226 6048 311 256 16 1 128 1 1 20 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel2(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)3>)
1796967171 7936 356 256 16 1 128 1 1 20 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel3(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)3>)
1949040639 11680 397 16 256 1 128 1 1 16 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel4(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)4>)
2101463421 11713 442 16 256 1 128 1 1 16 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel5(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)4>)
2253434746 11744 483 16 256 1 128 1 1 16 0.000 0.000 NVIDIA H100 PCIe (0) 1 7 CudaCodeGen::kernel6(CudaCodeGen::Tensor<float, (int)3>, CudaCodeGen::Tensor<float, (int)4>)
Analogous to XLA's pad operation:
(1) easy fix for setupDivMaxSoftmaxDropoutForward
, just change contiguity({true, false, false, true}
to contiguity({true, c10::nullopt, c10::nullopt, true}
(2) failed in BiasDropoutAddLayernormBwd1_fp32
in assert:
TORCH_INTERNAL_ASSERT(
val.has_value(),
"Tried to evaluate the extent, ",
extent->toInlineString(),
" for the ptype: ",
p_type,
" to set launch bounds but could not.");
works fine if skip this assert by:
if(!val.has_value()) {
continue;
}
reshape/view in nvfuser doesn't imply memory alias, so we'll be referring to this as reshape in this issue to keep the conversation simple and accurate.
nvfuser reshape is implemented via translating to a series of keep, merge and split:
Lines 20 to 63 in 86d5dd3
Currently we rely on some runtime checks to ensure that the reshape parsing, i.e. ViewOp
in the fusion, is still semantically correct. This works fine for our TorchScript integration, where we can rely on a guard
operator that queries the backend API
Fuser/csrc/register_interface.cpp
Lines 430 to 431 in 86d5dd3
This workflow is harder to do with our python integration though. There're a few reasons:
reshape
ops.reshape
node in FusionRecord would be lowered to different fusion based on input shapes, that's some nasty patching to the design. cc'ing @kevinstephano @jacobhinkle for reference.IIUC, we are moving forward with more plumbing to support our reshape
logic in python API, a few on-going items (cc'ing @csarofeen @naoyam for reference):
reshape
ops.nvfuser::analyzeViewConstraint
to our cache system, so that we can map the inferred shape to pick the right fusion object in order to pick up the right fusion.This is a lot of refactor that needs to happen in order for the new workflow to work. It feels like we are doing quite a lot plumbing on the codegen as well as the python API side in order to mimic a reshape
op in the codegen.
But in the end, we are not doing anything more than just a decomposition. A decomposition should be much easier performed and validated at the program acquisition time. IIUC, the missing piece now that stops us from doing that is just shape inference in our integration.
I know this is mostly just a design decision and we are pushing to expose nvfuser expression evaluation to client facing APIs. I'm not sure if we could really expect our expression evaluation to replace a shape inference mechanism on our integration, merely due to the fact that nvfuser op coverage is limited, and the awkward program flow where expression evaluation is only available after we have a fusion IR.
In segmented fusions, tensors passed from one segment to next segment are allocated are allocated here:
https://github.com/NVIDIA/Fuser/blob/main/csrc/executor.cpp#L660-L710
Make sure these tensors are deallocated immediately after their uses are done.
For the example in FusionAmpereSwizzle_CUDA
, the generated code contains trivial predicates:
#pragma unroll
for(nvfuser_index_t i653 = 0; i653 < 4; ++i653) {
int i10749;
i10749 = 32 * i653;
#pragma unroll
for(nvfuser_index_t i654 = 0; i654 < 8; ++i654) {
if (((nvfuser_index_t)blockIdx.x) < ((ceilDiv(T1.size[1], 128)) * 4)) {
Ampere::M16N8K16TN<16>(
reinterpret_cast<Array<float,4,4>*>(&T5[(i10749 + (2 * i654))]),
&(reinterpret_cast<Array<__half,8,8>*>(&T2)[i653]),
&(reinterpret_cast<Array<__half,4,4>*>(&T3)[i654]));
}
}
}
where ((nvfuser_index_t)blockIdx.x) < ((ceilDiv(T1.size[1], 128)) * 4)
is trivial because the rhs of <
is identical to gridDim.x
. We should simplify this trivial predicate.
On RTX 3090, the perf with and without that trivial predicate is 20.8374 ms
vs 16.1956 ms
Have no idea why, but I am seeing the following failure non-deterministically after we become our own repo.
[ RUN ] LoopRotationTest.MultipleDoubleBuffer_CUDA
unknown file: Failure
C++ exception with description "aten_output_tensor.allclose( fusion_output_tensor.to(aten_output_tensor.dtype()), tolerance_values.second, tolerance_values.first, true) INTERNAL ASSERT FAILED at "/home/gaoxiang/Fuser/test/test_gpu_validator.h":400, please report a bug to PyTorch.
Validation error in output 0 on line 712 in file /home/gaoxiang/Fuser/test/test_loop_rotation.cpp.
Detected abs error of: 2.55172
absolute tolerance was set to 1.68222e-06
and relative tolerance set to 2.23704e-06
Exception raised from testValidate at /home/gaoxiang/Fuser/test/test_gpu_validator.h:400 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x5c (0x7f33003bccdc in /home/gaoxiang/pytorch-viable/build/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x64 (0x7f33003865f6 in /home/gaoxiang/pytorch-viable/build/lib/libc10.so)
frame #2: c10::detail::torchInternalAssertFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x4f (0x7f33003bacaf in /home/gaoxiang/pytorch-viable/build/lib/libc10.so)
frame #3: <unknown function> + 0x37c27d (0x560a19f3b27d in ./build/bin/nvfuser_tests)
frame #4: <unknown function> + 0x3813af (0x560a19f403af in ./build/bin/nvfuser_tests)
frame #5: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x87 (0x560a1a096097 in ./build/bin/nvfuser_tests)
frame #6: testing::Test::Run() + 0xf6 (0x560a1a08a6b6 in ./build/bin/nvfuser_tests)
frame #7: <unknown function> + 0x4cb8b5 (0x560a1a08a8b5 in ./build/bin/nvfuser_tests)
frame #8: <unknown function> + 0x4cbfba (0x560a1a08afba in ./build/bin/nvfuser_tests)
frame #9: testing::internal::UnitTestImpl::RunAllTests() + 0x754 (0x560a1a08b9f4 in ./build/bin/nvfuser_tests)
frame #10: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x87 (0x560a1a096607 in ./build/bin/nvfuser_tests)
frame #11: testing::UnitTest::Run() + 0x91 (0x560a1a08a9d1 in ./build/bin/nvfuser_tests)
frame #12: <unknown function> + 0x119882 (0x560a19cd8882 in ./build/bin/nvfuser_tests)
frame #13: <unknown function> + 0x23790 (0x7f32cec3c790 in /usr/lib/libc.so.6)
frame #14: __libc_start_main + 0x8a (0x7f32cec3c84a in /usr/lib/libc.so.6)
frame #15: _start + 0x25 (0x560a19d079a5 in ./build/bin/nvfuser_tests)
" thrown in the test body.
[ FAILED ] LoopRotationTest.MultipleDoubleBuffer_CUDA (477 ms)
It was used to indicate domain extents differ between input and output domains, but now that they are correctly modeled as different by PairwiseRootDomainMap
, we may be able to remove it. At least that seems to be the case for torch_gather
On my Titan RTX, the FusionPersistentSoftmaxLocalShared_CUDA test fails:
C++ exception with description "(dynamic_smem_size) < (available_dynamic_smem_without_reconfiguration + additional_dynamic_smem_available_through_reconfiguration) INTERNAL ASSERT FAILED at "/home/nmaruyama/pytorch/debug3/nvfuser/csrc/executor.cpp":910, please report a bug to PyTorch. The total shared memory allocation is larger than available memory. Dynamic size: 66048. Available size: 49136. Configured smem size: 49152. Device limit size: 65536
It seems this issue started to happen at PR #148. Unclear why the PR could affect shared memory usage.
from nvfuser import FusionDefinition, DataType
import torch
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
C0 = fd.define_constant(0, dtype=DataType.Int)
T1 = fd.ops.full([4, 4], C0, dtype=DataType.Double)
fd.add_output(T1)
with FusionDefinition() as fd:
nvfuser_fusion_id3(fd)
out = fd.execute([])
print(fd)
gives:
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
S0 = fd.define_constant(0, dtype=DataType.Int)
T1 = fd.ops.full(S0, shape=[4, 4], dtype=DataType.Double)
fd.add_output(T1)
Above code is invalid because full
's signature is full(size, fill_value, dtype)
.
We do not support parallelization of non-leaf domains. For example, it's an error if a parallelized domain is split:
https://github.com/NVIDIA/Fuser/blob/main/csrc/tensor_view.cpp#L769-L774
However, this isn't enforced in IterDomain::parallelize(ParallelType)
, so it's possible to, e.g., manually parallelize non-leaf domains, or in some complex fusions, it may occur with parallel type propagation. It indeed happened accidentally with the inner-outer scheduler (csarofeen/pytorch#2400)
The purpose of this issue is to communicate about the planned change on nvFuser I'd like to make to enable matmul NN support. The entire change would take multiple PRs, and this issue is to provide a big picture and help understand the motivation of these PRs. Another purpose is to review the design early.
The challenge about NN is not "how to schedule an NN matmul", but instead "how to define an NN matmul". Currently, MmaOp
can be considered as a fused multiply-sum, and the entire matmul is defined as broadcasts followed by MmaOp
:
The currently supported layouts are TT, TN, NT:
[M, K]
and [K, N]
, we broadcast inputs into [M, K, 1]
and [1, K, N]
, after multiplication, we have [M, K, N]
, and after reduction, we get [M, N]
.[M, K]
and [N, K]
, we broadcast inputs into [M, 1, K]
and [1, N, K]
, after multiplication, we have [M, N, K]
, and after reduction, we get [M, N]
.[K, M]
and [K, N]
, we broadcast inputs into [K, M, 1]
and [K, 1, N]
, after multiplication, we have [K, M, N]
, and after reduction, we get [M, N]
.But for NN, the input shapes are [K, M]
and [N, K]
, there is no such way to get an output of [M, N]
. It is only possible to get [N, M]
. In order to support NN, there must be changes in fusion definition.
Our design should be compatible with our future needs, although these needs might not be our priority today. Below are some examples that I think is related to this topic:
(A.T @ B).T
. We should allow users to insert transposes anywhere in the fusion definition, and nvFuser should be able to find the correct layout for matmul with these transposes taken into consideration.From the user(the person who creates a fusion)'s perspective, fusion definitions should be as flexible as possible. Users shouldn't worry about performance when defining a fusion. The only thing a user should be worrying about is mathematical correctness. The matmul scheduler should accept all mathematically equivalent fusions and schedule all of them optimally. It is an unacceptable experience if the user need to think about things like below when defining the fusion:
I want to do an Ampere matmul, Ampere has an ldmatrix.trans instruction that allows transpose in the smem to register load, and the mma op is always TN, so I will transpose the inputs to [M, K] and [N, K] first, and then do mma.
Instead, for example, for the [M, K] @ [K, N] -> [M, N]
matmul, users should be able to define the fusion whatever way is most convenient to the user, options include:
For lowering to be easy, the fusion definition should be as close to the hardware as possible. However, different architecture has different matmul flavor, so there is no single "canonical form" that is close to all hardware.
On Volta, matrices are loaded to register in the same layout as inputs, and the mma op has TT, TN, NT, NN variants. So on Volta, the matmul definition closest to hardware is
On Turing & Ampere, in shared memory, the layout of the matrices is the same as inputs. However, when shared memory matrices is loaded into register (using ldmatrix
/ldmatrix.trans
), the layouts of the matrices always become TN, and there is only one variant of mma which is TN. So on Turing & Ampere, the matmul definition closest to hardware is
On hopper, there are two variants of wgmma
: the "rs" variant whose first operand is on register and the second operand is on shared memory, and the "ss" variant whose both operands are on shared memory. For the "rs" variant, the first operand must be [M, K]
, the second variant can be either [N,K]
or [K,N]
. For the "ss" variant, its first operand can be [M, K]
or [K, M]
, and its second variant can be [N,K]
or [K,N]
. So the the matmul definition closest to hardware is:
As described in the last section, the principles of easiness to define a fusion and easiness of lowering would lead to different fusion definitions. For a given hardware, the "easiness of lowering" form is the canonical form that the scheduler needs to transform the fusion into. In order to be able to define the canonical form and transform to the canonical form, the following concepts will be added/changed for nvFuser:
Implicit transpose
is a transpose that happens in the rFactor domain of a tensor whose definition is not a TransposeOp
. Starting from #148, all transposes are implicit (because we no longer have a TransposeOp
), but we are only practicing implicit transpose tensors defined by LoadStoreOp
. In order to define canonical forms of mma op, we would need to make the output tensor of MmaOp
implicitly transposed as well.
I would like to go farther:
MmaOp
's output, but on all tensors.For example, if I have T1 = sin(T0)
where T0
has shape [I0, 1, I1, I2]
, and T1
has
root domain: [I0, 1, I1, I2]
rFactor domain: [I2, I0*I1, 1]
Then the output T1
's shape will be [I2, I0*I1, 1]
. That is, T1 does an implicit squeeze-view-transpose-broadcast.
The reason for doing so is because:
IterDomain
s are transformed, we don't really care whether the tensor is defined by LoadStoreOp
or UnaryOpType::ReLU
. I believe this new approach can free up more flexibility without adding too much complexity to our existing system. Instead, I think our system will have in total less lines of code because we can just define many of these ops as LoadStoreOp
, instead of each has its separate Expr
subclass.BroadcastOp
, SqueezeOp
, ViewOp
, TransposeOp
, etc. is essentially just a data copying operation with fancy indexing. In the generated C++ code, I am not a fan of reading things like
T1[i1] = T0[i0];
T2[i2] = sin(T1[i1]);
T2[i2] = sin(T0[i0])
. So I want to reduce redundant copies in generated C++. For the case of matmul, I would like to have a single ldmatrix.trans
that does transpose+broadcast together, so that I don't have to deal with this extra broadcast.I would like to add two methods to TensorView
: pushRFactorForward
and pushRFactorBackward
.
Assume we have the following fusion: T0 --set--> T1 --set--> T2
where
T0: root = [I0, I1], no rfactor
T1: root = [I0, I1], rfactor = [I1, I0]
T2: root = [I1, I0], no rfactor
Then pushRFactorForward
will transform the fusion as
T0: root = [I0, I1], no rfactor
T1: root = [I0, I1], no rfactor
T2: root = [I0, I1], rfactor = [I1, I0]
And pushRFactorBackward
will transform the fusion as
T0: root = [I0, I1], rfactor = [I1, I0]
T1: root = [I1, I0], no rfactor
T2: root = [I1, I0], no rfactor
It is possible to specify intermediate state to only push part of the transformations. For example, if the fusion has T0--view-->T1--sin-->T2
and
T0: root = [I0, I1, I2], no rfactor
T1: root = [I0, I1, I2], rfactor = [I1*I0/4, 4, I2]
T2: root = [I1*I0/4, 4, I2], no rfactor
Then pushRFactorBackward([I1*I0, I2])
will get
T0: root = [I0, I1, I2], rfactor = [I1*I0, I2]
T1: root = [I1*I0, I2], rfactor = [I1*I0/4, 4, I2]
T2: root = [I1*I0/4, 4, I2], no rfactor
Note that for this case, T0
might have implicit view, that is, a view that happens implicitly at some op that is not a ViewOp
. And pushRFactorForward([I1*I0, I2])
will get
T0: root = [I0, I1, I2], no rfactor
T1: root = [I0, I1, I2], rfactor = [I1*I0, I2]
T2: root = [I1*I0, I2], rfactor = [I1*I0/4, 4, I2]
Note that for this case, T2
will be have implicit view.
If the current tensor has more than one producers, then a backward push will push its rfactor domain to all its producers. Similarly, if the current tensor has multiple consumers, then a forward push will push it to all its consumers. More complicated, if c = op(a, b)
, then a->pushRFactorForward()
will modify both b
and c
. Similar for the case of pushRFactorBackward
to a tensor with multiple uses.
Not all pushes are valid. A push must be compatible with the current schedule. For example, if you have T0->T1
where
T0: root = [I0, I1, I2], no rfactor, leaf = [I0, I1*I2]
T1: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [I0*I1, I2]
Then T1->pushRFactorBackward()
is illegal. However, if
T0: root = [I0, I1, I2], no rfactor, leaf = [(I0*I1)*I2]
T1: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [I0*I1, I2]
Then T1->pushRFactorBackward()
will get
T0: root = [I0, I1, I2], rfactor = [I0*I1, I2], leaf = [(I0*I1)*I2]
T1: root = [I0*I1, I2], no rfactor, leaf = [I0*I1, I2]
There may be other types of invalid push, but discussing this is not in the scope of this issue.
The goal of these two methods is to make the manipulation of rFactor
domains super easy. I believe in the long term, this added functionality would be helpful for view
and resize
scheduling. For the case of matmul, these two methods should make it easy to canonicalize user's input into hardware flavor. For example, the following transformation can be used to schedule an NN matmul for Ampere:
MmaOp
and Mul->Sum
are two mathematically equivalent way to define matmul. I prefer to define the fusion as Mul->Sum
in user-facing API, and let the scheduler to convert it to MmaOp
and fill in informations.
Currently we have remaining things to handle:
Masaki runs into build failure where a bare metal build doesn't have access.
-- Up-to-date: /usr/local/include
CMake Error at third_party/googletest/googlemock/cmake_install.cmake:46 (file):
file INSTALL cannot set permissions on "/usr/local/include": Operation not
permitted.
Call Stack (most recent call first):
third_party/googletest/cmake_install.cmake:47 (include)
cmake_install.cmake:47 (include)
pip build seems to also be writing to global folders. reported by @t-vi #220 (comment)
switch to setuptools_scm for versioning instead of using home brewed methods. Requested by @zasdfgbnm #161 (comment)
switch to lazyNVRTC on pytorch for driver&nvrtc calls, instead of linking directly against driver & nvrtc library. This may or may not be a necessary thing or good idea, if we want to ship nvfuser separate from pytorch later. For torch GPU build running on cpu only environment, we might be able to guard the prim import of nvfuser to disable it on those systems.
^^^ note that we are not doing this. Instead, we stick to linking against libnvrtc and driver directly. We are currently relying on pip installed libnvrtc on runtime system. #176
Recompile building CXX object CMakeFiles/nvfuser_codegen.dir/csrc/executor_utils.cpp.o
whenever you run python setup.py build
For tensors, we do a dependency analysis to find expressions that are indeed used to producer fusion outputs. For scalars, that's not the case, but Val::uses()
can contain expressions that are outside of the given fusion. This hasn't been an issue, likely because they are just scalars, but it's not consistent either as for tensors they are truly uses in the fusion.
I this this can be fixed by using traverse_members
and traverse_attributes
options in this getExprs
:
https://github.com/NVIDIA/Fuser/blob/main/csrc/fusion.cpp#L586
BUILD_NVFUSER=0 python setup.py develop
$ cd /path/to/nvidia/fuser
$ pip install -v -e .
Using pip 22.3.1 from /usr/local/lib/python3.10/site-packages/pip (python 3.10)
Obtaining file:///opt/pytorch/nvfuser
Preparing metadata (setup.py): started
Running command python setup.py egg_info
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
...
running egg_info
creating /tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info
writing /tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/PKG-INFO
writing dependency_links to /tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/dependency_links.txt
writing entry points to /tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/entry_points.txt
writing top-level names to /tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/top_level.txt
writing manifest file '/tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/SOURCES.txt'
reading manifest file '/tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/SOURCES.txt'
adding license file 'LICENSE'
writing manifest file '/tmp/pip-pip-egg-info-8qqasdji/nvfuser.egg-info/SOURCES.txt'
Traceback (most recent call last):
File "<string>", line 2, in <module>
File "<pip-setuptools-caller>", line 34, in <module>
File "/opt/pytorch/nvfuser/setup.py", line 340, in <module>
main()
File "/opt/pytorch/nvfuser/setup.py", line 336, in main
subprocess.check_call(["patch-nvfuser"])
File "/usr/local/lib/python3.10/subprocess.py", line 364, in check_call
retcode = call(*popenargs, **kwargs)
File "/usr/local/lib/python3.10/subprocess.py", line 345, in call
with Popen(*popenargs, **kwargs) as p:
File "/usr/local/lib/python3.10/subprocess.py", line 971, in __init__
self._execute_child(args, executable, preexec_fn, close_fds,
File "/usr/local/lib/python3.10/subprocess.py", line 1847, in _execute_child
raise child_exception_type(errno_num, err_msg, err_filename)
FileNotFoundError: [Errno 2] No such file or directory: 'patch-nvfuser'
error: subprocess-exited-with-error
× python setup.py egg_info did not run successfully.
│ exit code: 1
╰─> See above for output.
...
The entire log is: https://gist.github.com/crcrpar/35d308ad57e6bbb789d34a7463d48688.
There's a crash with nan
constants for some reason, float("inf")
works!
from nvfuser import FusionDefinition, DataType
import torch
def nvfuser_fusion_id(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Double, is_cpu=False)
c0 = fd.define_constant(float("nan"))
T1 = fd.ops.add(T0, c0)
fd.add_output(T1)
with FusionDefinition() as fd:
nvfuser_fusion_id(fd)
a = torch.randn(8, dtype=torch.double, device='cuda')
fd.execute([a])
I am seeing
Traceback (most recent call last):
File "/home/gaoxiang/Fuser/setup.py", line 340, in <module>
main()
File "/home/gaoxiang/Fuser/setup.py", line 336, in main
subprocess.check_call(["patch-nvfuser"])
File "/usr/lib/python3.10/subprocess.py", line 364, in check_call
retcode = call(*popenargs, **kwargs)
File "/usr/lib/python3.10/subprocess.py", line 345, in call
with Popen(*popenargs, **kwargs) as p:
File "/usr/lib/python3.10/subprocess.py", line 971, in __init__
self._execute_child(args, executable, preexec_fn, close_fds,
File "/usr/lib/python3.10/subprocess.py", line 1847, in _execute_child
raise child_exception_type(errno_num, err_msg, err_filename)
FileNotFoundError: [Errno 2] No such file or directory: 'patch-nvfuser'
at the end of build
also at the beginning, I am seeing:
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
==== importing nvfuser failed ====
try run `patch-nvfuser` if https://github.com/NVIDIA/Fuser is installed via pip package
Should we use the patch-nvfuser
in the repo, instead of the one in user's path? Because users might not have it in its path.
[ 89%] Linking CXX executable nvfuser_tests
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/csrc/kernel_db/test/test_nvfuser_kernel_db_open.cpp.o: in function `nvfuser::NVFuserTest_KernelDb_Open_CUDA_Test::TestBody()':
test_nvfuser_kernel_db_open.cpp:(.text+0x384c): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_nvfuser_kernel_db_open.cpp:(.text+0x3a02): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_nvfuser_kernel_db_open.cpp:(.text+0x3b5f): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_nvfuser_kernel_db_open.cpp:(.text+0x3c8a): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_nvfuser_kernel_db_open.cpp:(.text+0x3e31): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/csrc/kernel_db/test/test_nvfuser_kernel_db_open.cpp.o:test_nvfuser_kernel_db_open.cpp:(.text+0x3f82): more undefined references to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)' follow
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/main.cpp.o: in function `add_negative_flag(std::string const&)':
main.cpp:(.text+0x2b): undefined reference to `testing::FLAGS_gtest_filter'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/main.cpp.o: in function `main':
main.cpp:(.text+0x1c3): undefined reference to `testing::FLAGS_gtest_filter'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu1.cpp.o: in function `testing::internal::PrintTo(std::string const&, std::ostream*)':
test_gpu1.cpp:(.text._ZN7testing8internal7PrintToERKSsPSo[_ZN7testing8internal7PrintToERKSsPSo]+0x23): undefined reference to `testing::internal::PrintStringTo(std::string const&, std::ostream*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu1.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<std::string, std::string>(char const*, char const*, std::string const&, std::string const&)':
test_gpu1.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureISsSsEENS_15AssertionResultEPKcS4_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureISsSsEENS_15AssertionResultEPKcS4_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu2.cpp.o: in function `nvfuser::NVFuserTest_FusionParallelDimensionMap3_CUDA_Test::TestBody()':
test_gpu2.cpp:(.text+0x9315b): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_gpu2.cpp:(.text+0x933b2): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu2.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<c10::optional<long>, int>(char const*, char const*, c10::optional<long> const&, int const&)':
test_gpu2.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureIN3c108optionalIlEEiEENS_15AssertionResultEPKcS7_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureIN3c108optionalIlEEiEENS_15AssertionResultEPKcS7_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `nvfuser::NVFuserMultithreadedTest_SingleFunction_CUDA_Test::TestBody()::{lambda()#1}::operator()() const':
test_gpu3.cpp:(.text+0x38f88): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `nvfuser::NVFuserMultithreadedTest_MultipleFunctions_CUDA_Test::TestBody()::{lambda()#1}::operator()() const':
test_gpu3.cpp:(.text+0x39b9e): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `nvfuser::NVFuserTest_FusionIssue2074_CUDA_Test::TestBody()':
test_gpu3.cpp:(.text+0x781f9): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `nvfuser::NVFuserTest_FusionIssue2077_CUDA_Test::TestBody()':
test_gpu3.cpp:(.text+0x78c2d): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `nvfuser::NVFuserTest_FusionIssue2372_CUDA_Test::TestBody()':
test_gpu3.cpp:(.text+0x79b4f): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o:test_gpu3.cpp:(.text+0x79cb2): more undefined references to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)' follow
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `testing::AssertionResult::AppendMessage(testing::Message const&)':
test_gpu3.cpp:(.text._ZN7testing15AssertionResult13AppendMessageERKNS_7MessageE[_ZN7testing15AssertionResult13AppendMessageERKNS_7MessageE]+0x8a): undefined reference to `testing::Message::GetString() const'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<c10::ScalarType, c10::ScalarType>(char const*, char const*, c10::ScalarType const&, c10::ScalarType const&)':
test_gpu3.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureIN3c1010ScalarTypeES3_EENS_15AssertionResultEPKcS6_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureIN3c1010ScalarTypeES3_EENS_15AssertionResultEPKcS6_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> >, std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> > >(char const*, char const*, std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> > const&, std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> > const&)':
test_gpu3.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureISt6vectorIPN7nvfuser3ValESaIS5_EES7_EENS_15AssertionResultEPKcSA_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureISt6vectorIPN7nvfuser3ValESaIS5_EES7_EENS_15AssertionResultEPKcSA_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<nvfuser::Val*, nvfuser::Val*>(char const*, char const*, nvfuser::Val* const&, nvfuser::Val* const&)':
test_gpu3.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureIPN7nvfuser3ValES4_EENS_15AssertionResultEPKcS7_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureIPN7nvfuser3ValES4_EENS_15AssertionResultEPKcS7_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu3.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<unsigned long, long>(char const*, char const*, unsigned long const&, long const&)':
test_gpu3.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureImlEENS_15AssertionResultEPKcS4_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureImlEENS_15AssertionResultEPKcS4_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_expr_simplifier.cpp.o: in function `nvfuser::ExprSimplifierTest_Compare_CUDA_Test::TestBody()':
test_expr_simplifier.cpp:(.text+0xe8c0): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_expr_simplifier.cpp:(.text+0xe9fe): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_expr_simplifier.cpp:(.text+0xeb3c): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_expr_simplifier.cpp:(.text+0xec74): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: test_expr_simplifier.cpp:(.text+0xedb3): undefined reference to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_expr_simplifier.cpp.o:test_expr_simplifier.cpp:(.text+0xeeeb): more undefined references to `testing::internal::GetBoolAssertionFailureMessage(testing::AssertionResult const&, char const*, char const*, char const*)' follow
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_expr_simplifier.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<std::string, char [98]>(char const*, char const*, std::string const&, char const (&) [98])':
test_expr_simplifier.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureISsA98_cEENS_15AssertionResultEPKcS5_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureISsA98_cEENS_15AssertionResultEPKcS5_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_expr_simplifier.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<std::string, char [30]>(char const*, char const*, std::string const&, char const (&) [30])':
test_expr_simplifier.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureISsA30_cEENS_15AssertionResultEPKcS5_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureISsA30_cEENS_15AssertionResultEPKcS5_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_expr_simplifier.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<c10::optional<bool>, bool>(char const*, char const*, c10::optional<bool> const&, bool const&)':
test_expr_simplifier.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureIN3c108optionalIbEEbEENS_15AssertionResultEPKcS7_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureIN3c108optionalIbEEbEENS_15AssertionResultEPKcS7_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
/usr/bin/ld: CMakeFiles/nvfuser_tests.dir/test/test_gpu_swizzle.cpp.o: in function `testing::AssertionResult testing::internal::CmpHelperEQFailure<std::vector<int, std::allocator<int> >, std::vector<int, std::allocator<int> > >(char const*, char const*, std::vector<int, std::allocator<int> > const&, std::vector<int, std::allocator<int> > const&)':
test_gpu_swizzle.cpp:(.text._ZN7testing8internal18CmpHelperEQFailureISt6vectorIiSaIiEES4_EENS_15AssertionResultEPKcS7_RKT_RKT0_[_ZN7testing8internal18CmpHelperEQFailureISt6vectorIiSaIiEES4_EENS_15AssertionResultEPKcS7_RKT_RKT0_]+0x7f): undefined reference to `testing::internal::EqFailure(char const*, char const*, std::string const&, std::string const&, bool)'
collect2: error: ld returned 1 exit status
make[2]: *** [CMakeFiles/nvfuser_tests.dir/build.make:614: nvfuser_tests] Error 1
make[1]: *** [CMakeFiles/Makefile2:407: CMakeFiles/nvfuser_tests.dir/all] Error 2
make: *** [Makefile:136: all] Error 2
Traceback (most recent call last):
File "setup.py", line 273, in <module>
main()
File "setup.py", line 234, in main
cmake()
File "setup.py", line 229, in cmake
subprocess.check_call(cmd_str)
File "/usr/lib/python3.8/subprocess.py", line 364, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['cmake', '--build', 'build', '--target', 'install', '--', '-j', '64']' returned non-zero exit status 2.
I saw two failures, one in lintrunner
, the other in clang-format
when one of my file was not formatted. I don't think it is a problem as clang-format is pretty fast. Just feel a bit weird about having this duplication. We may probably consider removing
Fuser/.github/workflows/lint.yml
Lines 110 to 128 in 8c9bb39
Everytime when I run the entire test suite, I am seeing the following error:
[ RUN ] NVFuserTest.FusionResizePad6_CUDA
unknown file: Failure
C++ exception with description "CUDA driver error: out of memory
Exception raised from runFusion at /home/gaoxiang/Fuser/csrc/executor.cpp:1358 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x5c (0x7ff1e37bcacc in /home/gaoxiang/pytorch-viable/build/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x64 (0x7ff1e3785644 in /home/gaoxiang/pytorch-viable/build/lib/libc10.so)
frame #2: nvfuser::FusionExecutor::runFusion(nvfuser::KernelArgumentHolder&, nvfuser::LaunchParams const&, nvfuser::CompileParams, std::vector<at::Tensor, std::allocator<at::Tensor> > const&) + 0x39ac (0x7ff217be919c in /home/gaoxiang/pytorch-viable/build/lib/libnvfuser_codegen.so)
frame #3: <unknown function> + 0x28fa9b (0x55a8fe871a9b in ./build/bin/nvfuser_tests)
frame #4: <unknown function> + 0x28fbf0 (0x55a8fe871bf0 in ./build/bin/nvfuser_tests)
frame #5: <unknown function> + 0x4a2f90 (0x55a8fea84f90 in ./build/bin/nvfuser_tests)
frame #6: void testing::internal::HandleExceptionsInMethodIfSupported<testing::Test, void>(testing::Test*, void (testing::Test::*)(), char const*) + 0x87 (0x55a8feba2ea7 in ./build/bin/nvfuser_tests)
frame #7: <unknown function> + 0x5b5826 (0x55a8feb97826 in ./build/bin/nvfuser_tests)
frame #8: <unknown function> + 0x5b5a62 (0x55a8feb97a62 in ./build/bin/nvfuser_tests)
frame #9: <unknown function> + 0x5b615a (0x55a8feb9815a in ./build/bin/nvfuser_tests)
frame #10: testing::internal::UnitTestImpl::RunAllTests() + 0x6d4 (0x55a8feb98b14 in ./build/bin/nvfuser_tests)
frame #11: bool testing::internal::HandleExceptionsInMethodIfSupported<testing::internal::UnitTestImpl, bool>(testing::internal::UnitTestImpl*, bool (testing::internal::UnitTestImpl::*)(), char const*) + 0x87 (0x55a8feba3417 in ./build/bin/nvfuser_tests)
frame #12: testing::UnitTest::Run() + 0x91 (0x55a8feb97b81 in ./build/bin/nvfuser_tests)
frame #13: <unknown function> + 0x1b8572 (0x55a8fe79a572 in ./build/bin/nvfuser_tests)
frame #14: <unknown function> + 0x23790 (0x7ff1e2e3c790 in /usr/lib/libc.so.6)
frame #15: __libc_start_main + 0x8a (0x7ff1e2e3c84a in /usr/lib/libc.so.6)
frame #16: _start + 0x25 (0x55a8fe7c8375 in ./build/bin/nvfuser_tests)
" thrown in the test body.
[ FAILED ] NVFuserTest.FusionResizePad6_CUDA (15925 ms)
This failure is pretty reproducible; it fails on my system almost always. I checked FusionResizePad6_CUDA
and didn't see anything that uses large memory, so I think FusionResizePad6_CUDA
is not to the test blame.
On Hopper, efficient gemm requires warp-specialization, which is not currently supported by nvFuser. This doc is to extend nvFuser in order to support such optimization. I believe this will not only benefit matmul, but also benefit other cases like optimal cat/stack scheduling, horizontal fusion, etc., see "Potential applications" section for more detail.
Notation: I will mostly use the term "task parallelism" for the new thing being added to nvFuser. "Warp-specialization" is a special case of "vertical task parallelism" (described below) on thread index.
In order to use task parallelism, we need first partition the DAG into tasks. Tasks are non-overlapping and dense, that is, every Val
in the fusion definition except fusion inputs belongs to a task (fusion inputs are special because they are given instead of computed), and one Val
can only belong to one task. Initially, all Val
s belong to task 0. Example partition:
Tasks are further grouped into task groups. Task groups form a hierarchical structure.
A task group can be parallelized, for example
group1->parallelize(ParallelType::TIDy);
group3->parallelize(ParallelType::TIDz);
Not all task groups can be parallelized. A parallelizable group is either a "horizontal group" or a "vertical group". A "horizontal group" is a group whose members have no data dependency with each other. For example, group 1 is a horizontal group. A "vertical group" is a group whose members are connected, for example group 3 is a vertical group.
Below is an example where group 4 is neither a horizontal group nor a vertical group:
However, you can make it a horizontal group by grouping group 2 and group 3 together:
Expression sorting must be task and task group aware. For the above example, the sorted expressions can be
group 3
which zoom into
group 1
group 2
which zoom into
task 1
task 2
task 3
task 4
task 5
// the following order is also valid, and expr sort is free to choose one from all valid orders
task 3
task 4
task 1
task 2
task 5
// There are more valid orders...
// In later context, we assume the first order is picked
which zoom into
tv1 = sin(tv0);
tv5 = cos(tv1);
tv2 = tan(tv0);
tv6 = relu(tv2);
tv3 = exp(tv0);
tv7 = sigmoid(tv3);
tv4 = log(tv0);
tv8 = tanh(tv4);
tv9 = cat([tv5, tv6, tv7, tv8]);
tv10 = neg(tv9);
When generating loop nest, for an unparallelized group, it just generate its members one after another. For parallelized groups, it will generate kir::IfThenElse
s to dispatch between its members.
Assuming tv0
-tv8
has [BIDx, TIDy{size0}, TIDx]
, tv9
and tv10
has [BIDx, TIDy{size0*4}, TIDx]
(the cat dim is 1, I am assuming the cat dim is untouched by the scheduler in this example), and inline most.
Then the generated loop nest structure will be
FOR blockIdx.x in BIDx:
IF threadIdx.z == 0:
IF threadIdx.y >= 0 && threadIdx.y < size0:
FOR threadIdx.y in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv1 = sin(tv0);
tv5 = cos(tv1);
ELSE IF threadIdx.y - size0 >= 0 && threadIdx.y - size0 < size0:
FOR threadIdx.y - size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv2 = tan(tv0);
tv6 = relu(tv2);
ELSE IF threadIdx.y - 2*size0 >= 0 && threadIdx.y - 2*size0 < size0:
FOR threadIdx.y - 2*size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv3 = exp(tv0);
tv7 = sigmoid(tv3);
ELSE IF threadIdx.y - 3*size0 >= 0 && threadIdx.y - 3*size0 < size0:
FOR threadIdx.y - 3*size0 in TIDy{size0}:
FOR threadIdx.x in TIDx:
tv4 = log(tv0);
tv8 = tanh(tv4);
ELSE IF threadIdx.z == 1:
FOR threadIdx.y in TIDy{4*size0}:
FOR threadIdx.x in TIDx:
tv9 = cat([tv5, tv6, tv7, tv8]);
tv10 = neg(tv9);
For the parallelization of horizontal task groups, synchronization must happen before and after the dispatch. Depending on the parallel type, block sync or grid sync might be needed. For the parallelization of vertical task groups (a.k.a. warp specialization), parallelization boundary (in this case tv5
-tv8
) must be double/circular buffered, and arrive-wait barrier is used for sync.
Warp specialization is used, and we are doing load and mma+store in different warps.
For example, we have multiple separate fusions, independently scheduled. If all fusions only uses BIDx
but not BIDy
and BIDz
, then we can trivially horizontally fuse these fusions by partitioning each fusion as a task in the combined fusion and horizontal parallelize these tasks on BIDx
.
For cat/stack, the size of the output tensor is naturally the sum of the size of inputs, we could parallelize the computation of inputs in a way like the parallelization of group 1 in the above example.
This is a pretty involved topic. The original question rises from @jacobhinkle 's comment on the lack of support on padding broadcast dimensions from codegen. #10 (comment)
I had an offline discussion with @naoyam on this and want to open up the issue to track the problem.
def fn0():
t0 = torch.randn(1) # how are we define t0 IterDomain?
t1 = torch.randn(5)
o1 = torch._C._nn.pad(t0, (2, 1))
o2 = t0 + t1
t0
's IterDomain needs to be marked as broadcast, since we need to resolve output shape of o2
from t0 + t1
, where t1
has a non-broadcast IterDomain.
Meanwhile, t0
is being padded, so we can't really work around it without proper backend support.
Lines 382 to 384 in 86d5dd3
I doubt that we actually can support it.
def fn1(pad_left, pad_right):
t0 = torch.empty([1])
t1 = torch._C._nn.pad(t0, (pad_left, pad_right)) # how are we define t1 IterDomain?
Given the example above, t0
has size-1 rank-1, which means it comes with a broadcast IterDomain.
So the tricky part here comes to, what do we do with t1
? Does t1
has an broadcast or non-broadcast IterDomain? Well, that depends on pad_left/pad_right value!
Imagine if you are given 0 pad at runtime, then the output t1
needs to have a broadcast IterDomain, while with non-zero padding, we'll have a non-broadcast IterDomain (Note: no it can't be a broadcast with extend, since we are padding a certain value here!).
>>> x = torch.randn(5)
>>> o = torch._C._nn.pad(x, (-2, -2))
>>> o
tensor([-0.6133])
A few things we think that is needed:
Note that point 3 here is not a decision that we can make here without looking at the models that we want to support. cc'ing @kevinstephano Do you think for nanogpt, static size padding is sufficient for now?
Tracking things:
Run python setup.py install
with clean build
Check: ls /home/rspring/miniconda3/lib/python3.9/site-packages/nvfuser-0.0.8+gitaa78d0f-py3.9-linux-x86_64.egg/nvfuser
First install:
cmake __init__.py lib nvfuser_version.py __pycache__ pytorch_utils.py version.py
Second install:
_C.cpython-39-x86_64-linux-gnu.so cmake __init__.py lib nvfuser_version.py __pycache__ pytorch_utils.py version.py
Triggers this error when importing nvfuser:
>>> import nvfuser
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/rspring/miniconda3/lib/python3.9/site-packages/nvfuser-0.0.8+gitaa78d0f-py3.9-linux-x86_64.egg/nvfuser/__init__.py", line 15, in <module>
from . import _C
ImportError: cannot import name '_C' from partially initialized module 'nvfuser' (most likely due to a circular import) (/home/rspring/miniconda3/lib/python3.9/site-packages/nvfuser-0.0.8+gitaa78d0f-py3.9-linux-x86_64.egg/nvfuser/__init__.py)
See #143
This is the error being triggered in dot-product attention that includes slice
and softmax
. You can reproduce the issue with the following branch: add_py_slice_api_plus_codegen_fix
Error output:
ERROR: test_nanogpt_slice (__main__.TestNvFuserFrontend)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/opt/pytorch/nvfuser/python_tests/test_python_frontend.py", line 1538, in test_nanogpt_slice
out = fd.execute(inputs)
File "/opt/pytorch/pytorch/nvfuser/__init__.py", line 51, in execute
return self._execute(inputs, override_user_schedule)
RuntimeError: (int64_t)( num_blocks_per_SM * at::cuda::getDeviceProperties(options_.device.index()) ->multiProcessorCount) >= launch_params_.gdimx() * launch_params_.gdimy() * launch_params_.gdimz() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/executor.cpp":1190, please report a bug to PyTorch. Wanted to laun
ch a cooperative kernel, however the number of blocks is greater than what can be resident on the GPU at once. Need: 32768 (32768 * 1 * 1) but limited to 4 * 108
Test case:
import torch
from nvfuser import FusionDefinition, DataType
inputs = [
torch.randn(1, 1, 1024, 1204, device='cuda'),
torch.randn(16, 16, 128, 128, device='cuda'),
]
def nvfuser_fusion(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[1, 1, -1, -1], contiguous=[None, None, True, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
S1 = fd.define_constant(0.125000, dtype=DataType.Double)
T2 = fd.ops.mul(T1, S1)
T0_slice = fd.ops.slice(T0, [0, 0, 0, 0], [1, 1, 128, 128], [1, 1, 1, 1])
S2 = fd.define_constant(0.00000, dtype=DataType.Double)
T3 = fd.ops.eq(S2, T0_slice)
T4 = fd.ops.broadcast_in_dim(T3, output_shape=[16, 16, 128, 128], broadcast_dims=[0, 1, 2, 3])
S5 = fd.define_constant(float("-inf"), dtype=DataType.Double)
T6 = fd.ops.where(T4, S5, T2)
S7 = fd.define_constant(-1, dtype=DataType.Int)
S8 = fd.define_constant(4, dtype=DataType.Int)
S9 = fd.ops.add(S7, S8)
T10 = fd.ops.max(T6, axes=[3], keepdim=False, dtype=DataType.Null)
T11 = fd.ops.broadcast_in_dim(T10, output_shape=[16, 16, 128, 1], broadcast_dims=[0, 1, 2])
T12 = fd.ops.broadcast_in_dim(T11, output_shape=[16, 16, 128, 128], broadcast_dims=[0, 1, 2, 3])
T13 = fd.ops.sub(T6, T12)
T14 = fd.ops.exp(T13)
S15 = fd.define_constant(-1, dtype=DataType.Int)
S16 = fd.define_constant(4, dtype=DataType.Int)
S17 = fd.ops.add(S15, S16)
T18 = fd.ops.sum(T14, axes=[3], keepdim=False, dtype=DataType.Null)
T19 = fd.ops.broadcast_in_dim(T18, output_shape=[16, 16, 128, 1], broadcast_dims=[0, 1, 2])
T20 = fd.ops.broadcast_in_dim(T19, output_shape=[16, 16, 128, 128], broadcast_dims=[0, 1, 2, 3])
T21 = fd.ops.div(T14, T20)
S22 = fd.define_constant(16, dtype=DataType.Int)
S23 = fd.define_constant(16, dtype=DataType.Int)
S24 = fd.define_constant(128, dtype=DataType.Int)
S25 = fd.define_constant(128, dtype=DataType.Int)
S26 = fd.define_constant(0.00000, dtype=DataType.Double)
S27 = fd.define_constant(1.00000, dtype=DataType.Double)
T28 = fd.ops.uniform(S26, S27, shape=[S22, S23, S24, S25], dtype=DataType.Float)
S29 = fd.define_constant(0.900000, dtype=DataType.Double)
T30 = fd.ops.lt(T28, S29)
T31 = fd.ops.cast(T30, dtype=DataType.Float)
T32 = fd.ops.mul(T21, T31)
S33 = fd.define_constant(1.11111, dtype=DataType.Double)
T34 = fd.ops.mul(T32, S33)
fd.add_output(T34)
with FusionDefinition() as fd:
nvfuser_fusion(fd)
out = fd.execute(inputs)
In the example code nvFuser produces a result with one element incorrect on 3080ti
from nvfuser import FusionDefinition, DataType
import torch
torch.manual_seed(1)
def nvfuser_fusion(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, 1], contiguous=[True, True, True, None], dtype=DataType.BFloat16, is_cpu=False)
T3 = fd.ops.cast(T0, dtype=DataType.Float)
T4 = fd.ops.sum(T3, axes=[1, 3], keepdim=False, dtype=DataType.Null)
S5 = fd.define_constant(7.00000, dtype=DataType.Double)
T6 = fd.ops.div(T4, S5)
T7 = fd.ops.cast(T6, dtype=DataType.BFloat16)
fd.add_output(T7)
def mean_reference(a, dim):
a_f32 = a.to(torch.float32)
a_sum = a_f32.sum(dim)
reduced_elements = 7.0
return (a_sum / reduced_elements).to(a.dtype)
a = torch.randn(8, 7, 5, 1, dtype=torch.bfloat16, device="cuda")
with FusionDefinition() as fd:
nvfuser_fusion(fd)
nv_out = fd.execute([a])[0]
ref_out = mean_reference(a, [1, 3])
torch.testing.assert_close(nv_out, ref_out, atol=1e-3, rtol=0.0)
AssertionError: Tensor-likes are not close!
Mismatched elements: 1 / 40 (2.5%)
Greatest absolute difference: 0.00390625 at index (2, 3) (up to 0.001 allowed)
Greatest relative difference: 0.0052490234375 at index (2, 3) (up to 0.0 allowed)
Changing T6 = fd.ops.div(T4, S5)
to T6 = fd.ops.mul(T4, fd.ops.reciprocal(S5))
fixes the problem.
The same problem is with var_mean
:
from nvfuser import FusionDefinition, DataType
import torch
torch.manual_seed(1)
a = torch.randn(8, 7, 5, 1, dtype=torch.bfloat16, device="cuda")
def mean_reference(a, dim):
a_f32 = a.to(torch.float32)
a_sum = a_f32.sum(dim)
reduced_elements = 7.0
return (a_sum / reduced_elements).to(a.dtype)
ref_out = mean_reference(a, [1, 3])
def nvfuser_fusion_var_mean(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, 1], contiguous=[True, True, True, None], dtype=DataType.BFloat16, is_cpu=False)
T3 = fd.ops.cast(T0, dtype=DataType.Float)
_, mean = fd.ops.var_mean(T3, axes=[1, 3], keepdim=False, correction=0)
mean = fd.ops.cast(mean, dtype=DataType.BFloat16)
fd.add_output(mean)
with FusionDefinition() as fd:
nvfuser_fusion_var_mean(fd)
# This also fails
nv_out = fd.execute([a])[0]
torch.testing.assert_close(nv_out, ref_out, atol=1e-3, rtol=0.0)
I am seeing the following error:
[ 43%] Linking CXX shared library libnvfuser_codegen.so
[ 80%] Built target nvfuser_codegen
[ 81%] Linking CXX shared module libnvfuser.so
[ 81%] Linking CXX executable nvfuser_bench
[ 81%] Built target nvfuser
[ 81%] Linking CXX executable nvfuser_tests
[ 89%] Built target nvfuser_bench
[100%] Built target nvfuser_tests
Install the project...
-- Install configuration: ""
-- Up-to-date: /usr/local/include
CMake Error at third_party/googletest/googlemock/cmake_install.cmake:46 (file):
file INSTALL cannot set permissions on "/usr/local/include": Operation not
permitted.
Call Stack (most recent call first):
third_party/googletest/cmake_install.cmake:47 (include)
cmake_install.cmake:47 (include)
make: *** [Makefile:100: install] Error 1
Traceback (most recent call last):
File "/home/gaoxiang/Fuser/setup.py", line 273, in <module>
main()
File "/home/gaoxiang/Fuser/setup.py", line 234, in main
cmake()
File "/home/gaoxiang/Fuser/setup.py", line 229, in cmake
subprocess.check_call(cmd_str)
File "/usr/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['cmake', '--build', 'build', '--target', 'install', '--', '-j', '64']' returned non-zero exit status 2.
Hi, team nvfuser, i am an ai framework engineer. I know that nvfuser was firstly designed for Pytorch, now you pick the corresponding source code from pytorch as a standalone repository, do you plan to maintain it as an project with general support for ai frameworks besides pytorch?
If so, how do you plan to decouple nvfuser from pytorch?
Once #24 is landed, we should be able to replace AnalyzeViewConstraint
with AnalyzeViewResult
I thought it was TorchScript specific issue and not nvFuser but then I tried torch.compile with Inductor and nvprims_nvfuser. Only nvFuser failed.
This is the script I have been playing with:
import locale
import torch
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_nvfuser_enabled(False)
@torch.jit.script
def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
"""Bias-GeLU fused"""
x = inp + bias
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
class Fusion(torch.nn.Module):
def __init__(self):
super(Fusion, self).__init__()
def forward(self, inp, bias):
"""Bias-GeLU fused"""
x = inp + bias
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
model = torch.compile(Fusion(), backend='nvprims_nvfuser')
H = 768
inp = torch.randn(H, H, device="cuda")
bias = torch.randn(H, device="cuda")
assert locale.getpreferredencoding() == "UTF-8", f"Preferred encoding: {locale.getpreferredencoding()}"
out = model(inp, bias)
assert locale.getpreferredencoding() == "UTF-8", f"Preferred encoding: {locale.getpreferredencoding()}"
out = model(inp, bias)
assert locale.getpreferredencoding() == "UTF-8", f"Preferred encoding: {locale.getpreferredencoding()}"
The original ask comes from csarofeen/pytorch#2556
Currently we are trying to support Embedding & CrossEntropyLoss without a fusion segmentation. This feature request is an umbrella item that I'm using to host follow up issues & PRs:
numpy.take
and numpy.take_along_axis
(we may need to clean up & add nvfuser API); @jjsjann123A better size to think about would be [8192, 32768] where you should have lots of waves.
Though we might want to get more for perf tuning?! #278This program generates an illegal memory access error (as expected because contiguity info doesn't match the input):
import os
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # set to 0 for the error to come from the sync after execute
from nvfuser import FusionDefinition, DataType
import torch
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Float, is_cpu=False)
T1 = fd.define_tensor(symbolic_sizes=[-1, -1, -1, -1], contiguous=[True, True, True, True], dtype=DataType.Bool, is_cpu=False)
S0 = fd.define_constant(0.00000, dtype=DataType.Float)
T79 = fd.ops.where(T1, S0, T0)
fd.add_output(T79)
t0 = torch.randn(8, 12, 1024, 1024, dtype=torch.float32, device='cuda')
t2 = torch.rand(1024, 1024, device='cuda').to(dtype=torch.bool).as_strided([8, 12, 1024, 1024], [0, 0, 1024, 1])
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
fusion_out = fd.execute([t0, t2])
Using CUDA_LAUNCH_BLOCKING=1
the program fails at the fd.execute
call with no Python exception (and no stack trace in gdb).
error: CUDA_ERROR_ILLEGAL_ADDRESS failed with error an illegal memory access was encountered
Using CUDA_LAUNCH_BLOCKING=0
and torch.cuda.synchronize()
at the end would result in a Python catchable exception:
try:
torch.cuda.synchronize()
except RuntimeError as e:
print("Caught RuntimeError: " + str(e))
with TORCH_SHOW_CPP_STACKTRACES=1
would output:
Caught RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/iyashchuk/dev/pytorch/master/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x7f15440c0e25 in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f1544086fbe in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3a8 (0x7f154414fcf8 in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10_cuda.so)
frame #3: THCPModule_cudaSynchronize(_object*, _object*) + 0x18 (0x7f15580e2348 in /home/iyashchuk/dev/pytorch/master/torch/lib/libtorch_python.so)
<omitting python frames>
frame #17: __libc_start_main + 0xf3 (0x7f158e82c083 in /lib/x86_64-linux-gnu/libc.so.6)
terminate called after throwing an instance of 'c10::Error'
what(): CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at /home/iyashchuk/dev/pytorch/master/c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x7f15440c0e25 in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xe0 (0x7f1544086fbe in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x3a8 (0x7f154414fcf8 in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x482f86 (0x7f1557a9af86 in /home/iyashchuk/dev/pytorch/master/torch/lib/libtorch_python.so)
frame #4: c10::TensorImpl::~TensorImpl() + 0x9 (0x7f154409e859 in /home/iyashchuk/dev/pytorch/master/torch/lib/libc10.so)
frame #5: <unknown function> + 0x708e18 (0x7f1557d20e18 in /home/iyashchuk/dev/pytorch/master/torch/lib/libtorch_python.so)
frame #6: THPVariable_subclass_dealloc(_object*) + 0x2d6 (0x7f1557d21176 in /home/iyashchuk/dev/pytorch/master/torch/lib/libtorch_python.so)
<omitting python frames>
frame #14: __libc_start_main + 0xf3 (0x7f158e82c083 in /lib/x86_64-linux-gnu/libc.so.6)
TEST_F(SwizzleTest, SwizzleIndexing170_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
auto tv0 = makeConcreteTensor({64, 64});
fusion.addInput(tv0);
auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion.addOutput(tv2);
tv1->setMemoryType(MemoryType::Shared);
tv1->split(1, 8);
tv1->split(1, 4);
tv1->split(0, 8);
tv1->split(0, 4);
// [2 4 8 2 4 8]
tv1->swizzle(Swizzle2DType::XOR, 1, 4);
tv1->merge(0);
tv1->merge(0);
tv1->merge(1);
tv1->merge(1);
for (auto tv : {tv1, tv2}) {
tv->merge(0);
tv->split(0, 256);
tv->axis(1)->parallelize(ParallelType::TIDx);
}
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
at::Tensor t = at::randn({64, 64}, options);
FusionExecutor fe;
fe.compileFusion(&fusion);
auto outputs = fe.runFusion({t});
testValidate(&fusion, outputs, {t}, {t}, __LINE__, __FILE__);
}
CUDA code
__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 2> T2) {
alignas(16) extern __shared__ char array[];
unsigned smem_offset = 0;
NVFUSER_DEFINE_MAGIC_ZERO
int i911;
i911 = ((nvfuser_index_t)threadIdx.x) / 64;
int i936;
i936 = ((nvfuser_index_t)threadIdx.x) % 64;
int i1007;
i1007 = (T0.stride[0] * i911) + (T0.stride[1] * i936);
int i1008;
i1008 = T0.stride[0] * 4;
int i3040;
i3040 = i936 / 8;
int i3045;
i3045 = ((64 * i911) + (32 * (i3040 / 4))) + (i936 % 8);
int i3047;
i3047 = i3040 % 4;
int i3372;
i3372 = ((nvfuser_index_t)threadIdx.x) / 8;
int i3875;
i3875 = ((32 * (i3372 / 4)) + (8 * (0 ^ (i3372 % 4)))) + (((nvfuser_index_t)threadIdx.x) % 8);
bool b5027;
b5027 = i936 < 64;
int i5028;
i5028 = -64 + i911;
int i5126;
i5126 = -4096 + ((nvfuser_index_t)threadIdx.x);
smem_offset = alignBufferSize(smem_offset, 16);
float* T1 = reinterpret_cast<float*>(array + smem_offset);
smem_offset += (4096 * sizeof(float));
#pragma unroll
for(nvfuser_index_t i91 = 0; i91 < 16; ++i91) {
int i125;
i125 = i91 + nvfuser_zero;
if ((b5027 && (i5028 < (-(4 * i125))))) {
T1[((i3045 + (256 * i91)) + (8 * (i3047 ^ (((i911 + (4 * i91)) / 8) % 4))))]
= T0[(i1007 + (i1008 * i125))];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
__syncthreads();
#pragma unroll
for(nvfuser_index_t i92 = 0; i92 < 16; ++i92) {
int i3937;
i3937 = 256 * (i92 + nvfuser_zero);
if ((i5126 < (-i3937))) {
T2[(((nvfuser_index_t)threadIdx.x) + i3937)]
= T1[(i3875 + (256 * i92))];
}
}
NVFUSER_UPDATE_MAGIC_ZERO
}
First, I don't think it makes sense to alias register tensors, regardless of its size. Modern compilers commonly convert user code into SSA. For CUDA, user C++ code is lowered to NVVM IR, which is based on LLVM IR, which is SSA. Aliasing register tensor at best is a no-op, and at worst, it would increase the compilation time of the C++ -> NVVM IR
lowering in nvRTC. So we should only focus on the aliasing of shared memory and global memory.
Currently, reuseMemoryAllocations
can only alias tensors with the same size, this does not work well for applications like matmul. For the case of matmul, in prologue, the shared memory tensors has size cta_tile.M x cta_tile.K
and cta_tile.N x cta_tile.K
. In epilogue, the shared memory tensors has size cta_tile.M x cta_tile.N
, which is typically different from the previous shared memory tensor sizes. We need a smarter algorithm to be able to reuse prologue tensors's memory for epilogue.
Reference implementation: csarofeen/pytorch#1979
Note that matmul is not the only affected use case.
TEST_F(NVFuserTest, FusionBroadcastInt64Indexing_CUDA) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto tv0 = makeSymbolicTensor(1, DataType::Bool);
auto tv1 = makeSymbolicTensor(1, DataType::Bool);
fusion->addInput(tv0);
fusion->addInput(tv1);
auto tv2 = broadcast(tv0, {false, true});
auto tv3 = broadcast(tv1, {true, false});
auto tv4 = bitwise_xor(tv2, tv3);
fusion->addOutput(tv4);
FusionExecutorCache executor_cache(std::move(fusion));
executor_cache.profile(true);
constexpr int size = (1L << 16L);
const auto options = at::TensorOptions().device(at::kCUDA, 0);
auto input = at::randn({size}, options) > 0;
auto cg_outputs = executor_cache.runFusionWithInputs({input, input});
auto expect = input.view({size, 1}) + input.view({1, size});
ASSERT_FALSE(executor_cache.getMostRecentKernelRuntime()->isSegmented());
ASSERT_EQ(
executor_cache.getMostRecentKernelRuntime()
->getMostRecentExecutorLog()
.fusion_executor->kernel()
->indexType(),
PrimDataType::Int);
testValidate(
executor_cache.fusion(),
cg_outputs,
{input, input},
{expect},
__LINE__,
__FILE__);
}
https://github.com/NVIDIA/Fuser/blob/main/csrc/fusion.cpp#L573-L606
This is especially costly when just copying a fusion. We can just update the use info once only after all the expressions are copied.
More fundamentally, this shouldn't be always done but should be lazily done only when it's necessary.
I am getting the following error on python setup.py build
:
-- Installing: /usr/local/bin/nvfuser_tests
CMake Error at cmake_install.cmake:121 (file):
file INSTALL cannot copy file "/home/gaoxiang/Fuser/build/nvfuser_tests" to
"/usr/local/bin/nvfuser_tests": Permission denied.
The results of const int64_t warp_size = std::min(warp_size_based_on_l1, warp_size_based_on_l2);
in innerPersistentHeuristiccan be a odd number such as 15. which seems not right. Prefer 16 or 32. Noticed while working on this PR #68 (review)
When indexing a 3-D tensor with size [4, 4, 4]
with 2
or -1
with merge
, I think we should get the same error. Note, the 2nd error also should occur when the dim is specified as value 1
and -2
.
With absolute indexing, I get the following error:
import torch
from nvfuser import FusionDefinition
inputs = [
torch.randn(4, 4, 4, device="cuda"),
]
class InputError(FusionDefinition):
def definition(self):
self.t0 = self.from_pytorch(inputs[0])
self.t1 = self.ops.sum(self.t0, axis=-1)
self.add_output(self.t1)
def schedule(self):
self.sched.merge(fd.t1, 2),
fd = InputError()
_ = fd.execute(inputs)
Error:
RuntimeError: Invalid merge detected, either one or both axes are outside of TensorView's range.
With relative indexing I get:
import torch
from nvfuser import FusionDefinition
inputs = [
torch.randn(4, 4, 4, device="cuda"),
]
class InputError(FusionDefinition):
def definition(self):
self.t0 = self.from_pytorch(inputs[0])
self.t1 = self.ops.sum(self.t0, axis=-1)
self.add_output(self.t1)
def schedule(self):
self.sched.merge(fd.t1, -1),
fd = InputError()
_ = fd.execute(inputs)
Error:
RuntimeError: Merging IterDomains requires that their iteration types match. Outer: iS3{i0}, Inner: rS5{i2}
The slice
implementation inside nvFuser does not have a runtime check to determine if the slice range is beyond the dimension size it is slice. Pytorch returns a zero-element tensor while nvFuser returns an empty tensor of the size of the slice(s).
Example:
import torch
from nvfuser import FusionDefinition, DataType
acts = [
torch.randn(5, 5, device='cuda'),
]
def legal(fd: FusionDefinition) -> None :
T0 = fd.from_pytorch(acts[0])
T1 = fd.ops.slice(T0, start_indices=[6, 6], end_indices=[8, 8], strides=[1, 1])
fd.add_output(T1)
with FusionDefinition() as fd:
legal(fd)
out = fd.execute(acts)
print(out[0].size(), out[0].stride())
out_eager = acts[0][6:8, 6:8]
print(out_eager.size(), out_eager.stride())
Output:
$ python test.py
tensor([[0., 0.],
[0., 0.]], device='cuda:0')
torch.Size([2, 2]) (2, 1)
tensor([], device='cuda:0', size=(0, 0))
torch.Size([0, 0]) (5, 1)
Naoya's comments on the feasibility of a runtime check:
I think we would need some runtime check. Currently we are ignoring a patten like:
t0: [I0]
t1 = slice(t0, {{0, 1}})
t1: [I1]
Here, the size of I1 is just 1, so it should be marked as a broadcast domain
But at the codegen side, we are not doing this
In this particular case, this should be trivial as we can obviously find the extent is 1, but more generally, whether the extent is 1 or not depends the extent of the slided domain as well as the start and end offsets
So, we need to do some check at the run time using the runtime info.
And that also seems to be the above case
Currently we are able to create reduced-precision Scalar constants via e.g. IrBuilder::create<Double>(3.14, DataType::Half)
. In Python, we can call fd.define_constant(3.14, dtype=DataType.Half)
and similar for fd.define_scalar
. However, if we use them in a fusion, we get compile errors:
auto tv0 = makeSymbolicTensor(1, DataType::Half);
fusion->addInput(tv0);
auto tv2 = full_like(tv0, IrBuilder::create<Double>(1.5, DataType::Half));
fusion->addOutput(tv2);
// C++ exception with description "false INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir_interface_nodes.h":111, please report a bug
// to PyTorch. Invalid data type: __half Exception raised from toString at /opt/pytorch/nvfuser/csrc/ir_interface_nodes.h:111 (most recent call first):
I believe there are just a few places where we need to handle printing constants and scalars to the kernel. We just need to inspect the kernels to ensure no unnecessary casts are inserted.
I am seeing the following error:
[7/64] Building CXX object CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp.o
FAILED: CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp.o
/usr/bin/g++-10 -DUSE_C10D_GLOO -DUSE_C10D_MPI -DUSE_C10D_NCCL -DUSE_DISTRIBUTED -DUSE_GTEST -DUSE_NCCL_WITH_UCC -DUSE_RPC -DUSE_TENSORPIPE -I/home/gaoxiang/Fuser/third_party/benchmark/include -I/home/gaoxiang/Fuser -I/home/gaoxiang/Fuser/csrc -isystem /home/gaoxiang/Fuser/third_party/googletest/googlemock/include -isystem /home/gaoxiang/Fuser/third_party/googletest/googletest/include -isystem /home/gaoxiang/Fuser/third_party/googletest/googletest -isystem /home/gaoxiang/Fuser/third_party/googletest/googlemock -isystem /home/gaoxiang/.local/lib/python3.10/site-packages/torch/include -isystem /home/gaoxiang/.local/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/gaoxiang/cuda-11.8/include -D_GLIBCXX_USE_CXX11_ABI=1 -std=gnu++17 -Wall -Wno-unused-function -D_GLIBCXX_USE_CXX11_ABI=1 -Werror -MD -MT CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp.o -MF CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp.o.d -o CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp.o -c /home/gaoxiang/Fuser/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp
In file included from /home/gaoxiang/Fuser/csrc/python_frontend/fusion_state.h:10,
from /home/gaoxiang/Fuser/csrc/python_frontend/fusion_definition.h:13,
from /home/gaoxiang/Fuser/csrc/python_frontend/fusion_record.h:14,
from /home/gaoxiang/Fuser/csrc/python_frontend/test/test_nvfuser_fusion_record.cpp:13:
/home/gaoxiang/Fuser/csrc/serde/python_fusion_cache_generated.h:6:10: fatal error: flatbuffers/flatbuffers.h: No such file or directory
6 | #include "flatbuffers/flatbuffers.h"
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
[8/64] Building CXX object CMakeFiles/nvfuser_tests.dir/csrc/python_frontend/test/test_nvfuser_fusion_definition.cpp.o
This error can be resolved by installing flatbuffers
on my system. Are we adding flatbuffers
as a mandatory dependency? If so, should we add (Edit: this won't help)flatbuffers
to https://github.com/NVIDIA/Fuser/blob/main/requirements.txt?
nvfuser should support device
argument.
We could have factory methods in fusion that doesn't have a tensor input, so device
argument would be necessary to determine GPU device for kernel generation.
e.g. in cpp tests
Fuser/test/test_gpu_tensor_factories.cpp
Lines 51 to 56 in 44e99fe
Note that these arguments should also be plumbed all the way through to python API.
Probably reduction on a tensor produced by full
is not supported. How difficult is it?
from nvfuser import FusionDefinition, DataType
import torch
def nvfuser_fusion_id3(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(symbolic_sizes=[-1, 1], contiguous=[True, None], dtype=DataType.Double, is_cpu=False)
T1 = fd.define_tensor(symbolic_sizes=[-1, -1], contiguous=[True, True], dtype=DataType.Double, is_cpu=False)
T2 = fd.ops.broadcast_in_dim(T0, output_shape=[4, 4], broadcast_dims=[0, 1])
T3 = fd.ops.ge(T1, T2)
S4 = fd.define_constant(0, dtype=DataType.Int)
T5 = fd.ops.full([4, 4], S4, dtype=DataType.Double)
S6 = fd.define_constant(0, dtype=DataType.Int)
T7 = fd.ops.full([4, 4], S6, dtype=DataType.Double)
T8 = fd.ops.sum(T7, axes=[1], keepdim=False, dtype=DataType.Null)
T9 = fd.ops.broadcast_in_dim(T8, output_shape=[4, 1], broadcast_dims=[0])
fd.add_output(T3)
fd.add_output(T5)
fd.add_output(T9)
with FusionDefinition() as fd:
nvfuser_fusion_id3(fd)
a = torch.randn(4, 1, dtype=torch.double, device='cuda')
b = torch.randn(4, 4, dtype=torch.double, device='cuda')
out = fd.execute([a, b])
RuntimeError: !tv_inps.empty() INTERNAL ASSERT FAILED at "nvfuser/csrc/scheduler/reduction.cpp":926, please report a bug to PyTorch. Tried to schedule a fusion with no tensor inputs, currently not supported.
https://numpy.org/doc/stable/reference/generated/numpy.take_along_axis.html
Similar to torch.gather
but less flexible. Unlike torch.gather
, the input dimensions that are not gathered must have the same sizes as the dimensions of the index tensor, which is not the case with torch.gather
.
take_along_axis
with non fusion inputs (#250)take_along_axis
with softmax rather than the final reductionThis may be the same underlying issue as #93 but it manifests in a different way. The following example results in an error when compiling kernel:
// Test unsqueezing the an index vector before gathering, as is done for
// cross_entropy_loss.
TEST_F(NVFuserTest, FusionTorchGatherUnsqueezed_CUDA) {
int64_t N = 16384;
int64_t num_classes = 60;
std::vector<int64_t> input_dims{N, num_classes};
std::vector<int64_t> index_dims{N};
at::manual_seed(0);
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0);
auto fusion_ptr = std::make_unique<Fusion>();
Fusion& fusion = *fusion_ptr.get();
FusionGuard fg(&fusion);
TensorView* tv_in1 = makeConcreteTensor(input_dims);
TensorView* tv_idx = makeConcreteTensor(index_dims, DataType::Int);
fusion.addInput(tv_in1);
fusion.addInput(tv_idx);
auto tv_idx_unsqueeze = broadcast(tv_idx, {false, true});
auto tv_gather = torch_gather(tv_in1, 1, tv_idx_unsqueeze);
auto tv_gather_squeeze = reshape(tv_gather, {N, 1}, {N});
fusion.addOutput(tv_gather_squeeze);
at::Tensor input_1 = at::randn(input_dims, options);
at::Tensor input_2 = at::randn(index_dims, options);
at::Tensor input_idx =
at::randint(0, num_classes, index_dims, options_i);
auto t_gather = at::gather(input_1, 1, input_idx.unsqueeze(1)).squeeze();
std::vector<c10::IValue> aten_inputs = {input_1, input_idx};
FusionExecutorCache executor_cache(std::move(fusion_ptr));
auto cg_outputs = executor_cache.runFusionWithInputs(aten_inputs);
testValidate(
&fusion, cg_outputs, aten_inputs, {t_gather}, __LINE__, __FILE__);
}
The generated kernel is
__global__ void kernel1(Tensor<float, 2> T0, Tensor<int64_t, 1> T1, Tensor<float, 1> T4) {
NVFUSER_DEFINE_MAGIC_ZERO
int i318;
i318 = ((T1.stride[0] * 2) * ((nvfuser_index_t)threadIdx.x)) + ((T1.stride[0] * 256) * ((nvfuser_index_t)blockIdx.x));
int i614;
i614 = ((T0.stride[0] * 2) * ((nvfuser_index_t)threadIdx.x)) + ((T0.stride[0] * 256) * ((nvfuser_index_t)blockIdx.x));
int i802;
i802 = 2 * ((nvfuser_index_t)threadIdx.x);
int i803;
i803 = 256 * ((nvfuser_index_t)blockIdx.x);
int i804;
i804 = i802 + i803;
bool b1981;
b1981 = ((1 + i802) + i803) < 16384;
if ((((i802 + 1) + i803) < 16384)) {
int64_t T5[2];
#pragma unroll for(nvfuser_index_t i111 = 0; i111 < 2; ++i111) {
T5[i111] = 0;
} NVFUSER_UPDATE_MAGIC_ZERO #pragma unroll for(nvfuser_index_t i111 = 0; i111 < 2; ++i111) {
T5[i111] = T1[(i318 + (T1.stride[0] * (i111 + nvfuser_zero)))];
} NVFUSER_UPDATE_MAGIC_ZERO
Array<float, 2, 2> T6; #pragma unroll
for(nvfuser_index_t i112 = 0; i112 < 2; ++i112) { float T3[1];
T3[0] = 0; T3[0]
= T0[((i614 + (T0.stride[0] * (i112 + nvfuser_zero))) + (T0.stride[1] * T2[0]))]; T6[i112]
= T3[0];
}
NVFUSER_UPDATE_MAGIC_ZERO
loadLocalToGlobal<float, 2, false>( &T4[i804], &T6[0]);
} else {
int64_t T5[2];
#pragma unroll
for(nvfuser_index_t i111 = 0; i111 < 2; ++i111) {
T5[i111] = 0;
}
}
NVFUSER_UPDATE_MAGIC_ZERO
Array<float, 2, 2> T6;
#pragma unroll
for(nvfuser_index_t i112 = 0; i112 < 2; ++i112) {
float T3[1];
T3[0] = 0;
if (b1981) {
T3[0]
= T0[((i614 + (T0.stride[0] * (i112 + nvfuser_zero))) + (T0.stride[1] * T2[0]))];
}
T6[i112]
= T3[0];
}
NVFUSER_UPDATE_MAGIC_ZERO
if (b1981) {
loadLocalToGlobal<float, 2, false>( &T4[i804], &T6[0]);
}
}
}
}
/*
CUDA NVRTC compile error: __tmp_kernel1.cu(9404): error: identifier "T2" is undefined
= T0[((i614 + (T0.stride[0] * (i112 + nvfuser_zero))) + (T0.stride[1] * T2[0]))];
^
__tmp_kernel1.cu(9386): warning #550-D: variable "T5" was set but never used
int64_t T5[2];
^
Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
*/
Removing the reshape does not change the error.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.