Git Product home page Git Product logo

Comments (1)

zou3519 avatar zou3519 commented on September 26, 2024

Repro:

import torch
from torch._decomp import register_decomposition

lib = torch.library.Library("fsdp_test", "DEF")

lib.define("chunk_cat_(Tensor(a!) ret, Tensor[] tensors, int dim, int num_chunks) -> ()", tags=torch.Tag.needs_fixed_stride_order)

@torch.library.impl(lib, "chunk_cat_", "Meta")
def chunk_cat_(ret, tensors, dim, num_chunks):
    torch._chunk_cat(
        tensors, dim, num_chunks, out=ret
    )

@torch.library.impl(lib, "chunk_cat_", "CUDA")
def chunk_cat_(ret, tensors, dim, num_chunks):
    torch._chunk_cat(tensors, dim, num_chunks, out=ret)


def f(x, y, z):
    full_default_3: "f32[2, 524544]" = torch.ops.aten.full.default([2, 524544], 1.0, dtype = torch.float32, layout = torch.strided, device = "cuda", pin_memory = False)
    chunk_cat_default_1 = torch.ops.fsdp_test.chunk_cat_.default(full_default_3, [x, y, z], 0, 2)
    mul_out = torch.mul(full_default_3, full_default_3)
    sum_out = mul_out.sum()
    return sum_out


if __name__ == "__main__":
    device = "cuda"
    x = torch.randn([1024, 512], device=device)
    y = torch.randn([512], device=device)
    z = torch.randn([1024, 512], device=device)
    eager_out = f(x, y, z)

    compiled_aot_eager_f = torch.compile(f, backend="aot_eager", fullgraph=True)
    compiled_aot_eager_out = compiled_aot_eager_f(x, y, z)
    assert torch.allclose(eager_out, compiled_aot_eager_out), f"eager_out: {eager_out}, compiled_aot_eager_out: {compiled_aot_eager_out}"   # passes

    compiled_inductor_f = torch.compile(f, backend="inductor", fullgraph=True)
    compiled_inductor_out = compiled_inductor_f(x, y, z)
    assert torch.allclose(eager_out, compiled_inductor_out), f"eager_out: {eager_out}, compiled_inductor_out: {compiled_inductor_out}"  # fails

from pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.