Comments (1)
Repro:
import torch
from torch._decomp import register_decomposition
lib = torch.library.Library("fsdp_test", "DEF")
lib.define("chunk_cat_(Tensor(a!) ret, Tensor[] tensors, int dim, int num_chunks) -> ()", tags=torch.Tag.needs_fixed_stride_order)
@torch.library.impl(lib, "chunk_cat_", "Meta")
def chunk_cat_(ret, tensors, dim, num_chunks):
torch._chunk_cat(
tensors, dim, num_chunks, out=ret
)
@torch.library.impl(lib, "chunk_cat_", "CUDA")
def chunk_cat_(ret, tensors, dim, num_chunks):
torch._chunk_cat(tensors, dim, num_chunks, out=ret)
def f(x, y, z):
full_default_3: "f32[2, 524544]" = torch.ops.aten.full.default([2, 524544], 1.0, dtype = torch.float32, layout = torch.strided, device = "cuda", pin_memory = False)
chunk_cat_default_1 = torch.ops.fsdp_test.chunk_cat_.default(full_default_3, [x, y, z], 0, 2)
mul_out = torch.mul(full_default_3, full_default_3)
sum_out = mul_out.sum()
return sum_out
if __name__ == "__main__":
device = "cuda"
x = torch.randn([1024, 512], device=device)
y = torch.randn([512], device=device)
z = torch.randn([1024, 512], device=device)
eager_out = f(x, y, z)
compiled_aot_eager_f = torch.compile(f, backend="aot_eager", fullgraph=True)
compiled_aot_eager_out = compiled_aot_eager_f(x, y, z)
assert torch.allclose(eager_out, compiled_aot_eager_out), f"eager_out: {eager_out}, compiled_aot_eager_out: {compiled_aot_eager_out}" # passes
compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
compiled_inductor_out = compiled_inductor_f(x, y, z)
assert torch.allclose(eager_out, compiled_inductor_out), f"eager_out: {eager_out}, compiled_inductor_out: {compiled_inductor_out}" # fails
from pytorch.
Related Issues (20)
- Inconsistency between `torch.get_device` and `torch.Tensor.get_device` with `__torch_function__` HOT 4
- `__torch_function__` does not work for functions that called within other overrided functions HOT 1
- torch.distributed hangs at first barrier call after upgrading to 2.4 HOT 1
- TORCH_COMPILE_CPROFILE=1 broken (strobelight might always be on internally?) HOT 9
- [DEBUG] Strange behavior observed with PyTorch 2.4.0 + Windows + CPU inference HOT 28
- Distributed tests failing on Amazon2023 AMI
- [FSDP2] root moduel parameters stays unsharded after forward before backward HOT 3
- Layer normalization on Nested Tensor ragged dimension fails when `lengths is not None`
- DISABLED [object Object] HOT 1
- DISABLED test_flash_attention_vs_math_ref_grads_batch_size_1_seq_len_q_1024_seq_len_k_1024_head_dim_8_is_causal_False_dropout_p_0_0_bfloat16_scale_l1_cuda_bfloat16 HOT 1
- backward of adaptive max pool (adaptive_max_pool2d_backward_cuda) doesn't have a deterministic implementation HOT 4
- Missing grad_fn information while torch.compile with customized gradient function HOT 1
- torch.compile should not recompiles when `.requires_grad=True` under `torch.no_grad()` context HOT 2
- DistributedCheckpoint's async_save doesn't work with 0-dimensional tensors under FSDP HOT 1
- Critical Bug: Inconsistent Outputs from torch.nn.Conv2d on GPU for Identical Inputs Across Different Batch Sizes HOT 1
- DTensor does not yet support torch.nn.init.orthogonal_ HOT 5
- Numerical differences between ranks with DTensor and torch.linalg.vector_norm HOT 1
- Pytorch conversion to ONNX inserts ScatterNDUpdate Layers HOT 6
- Bug Report: Distributed Process Group Hangs with NCCL and GLOO Backends HOT 5
- Memory Efficient Attention on ROCm results in image corruption on the diffusers SD3 pipeline HOT 9
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch.