Git Product home page Git Product logo

Comments (4)

lihuizhao avatar lihuizhao commented on May 24, 2024 1

Summary

torch.autograd.grad()接口与oneflow.autograd.grad()接口对比

Code to reproduce

import torch as torch_original
import oneflow as flow

def exp_reducer(x):
    return x.exp().sum(dim=1)

print("torch result: ")
inputs = torch_original.rand(2,2,requires_grad=True)
outputs = exp_reducer(inputs)


torch_inputs = (inputs,)
torch_outputs = (outputs,)
torch_grad_outputs = (torch_original.eye(2),)
torch_result = torch_original.autograd.grad(
                    torch_outputs,
                    torch_inputs,
                    torch_grad_outputs,
                    allow_unused=True,
                    create_graph=False,
                    retain_graph=None,
                    is_grads_batched=True,
                )
print(torch_result)

print("oneflow result: ")
flow_inputs = (flow.tensor(inputs.detach().numpy(),requires_grad=True),)
flow_outputs = ( exp_reducer(flow_inputs[0]), )
flow_grad_outputs = (flow.eye(2),)
flow_result = flow.autograd.grad(
        flow_outputs,
        flow_inputs,
        flow_grad_outputs,
        allow_unused=True,
        create_graph=False,
        retain_graph=None,
    )
print(flow_result)

Run result

torch result: 
(tensor([[[1.0448, 2.2668],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [2.5238, 1.5629]]]),)
oneflow result: 
Traceback (most recent call last):
  File "../../test/test.py", line 107, in <module>
    flow_result = flow.autograd.grad(
  File "/workspace/software/oneflow/python/oneflow/autograd/autograd.py", line 63, in grad
    in_grads = grad_api(
oneflow._oneflow_internal.exception.Exception: out_grad's shape must be same as output's ((2,) vs (2,2))
  File "oneflow/api/python/autograd/autograd.cpp", line 113, in Grad
    CheckAndInitOutGrads(outputs, out_grads)
  File "oneflow/api/python/autograd/autograd.cpp", line 73, in CheckAndInitOutGrads
    CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape()))
Error Type: oneflow.ErrorProto.check_failed_error

from oneflow.

wyg1997 avatar wyg1997 commented on May 24, 2024

是否可以直接构造一个调用 autograd.grad 接口的例子呢?我们把这个接口对齐一下,再来验证 jacobian 接口。

from oneflow.

wyg1997 avatar wyg1997 commented on May 24, 2024

明白了,是 is_grads_batched 参数支持的这个功能,主要的作用是把 grad 打包只用走一次 AutogradEngine 就可以完成多次后向计算。我后面可以来支持下。

如果着急实现功能的话,这里有一个绕过的方案:既然这里的作用是把 grad 打包,这里就定义一个 batched_autograd_grad 函数,分批次单独计算每个 grad(注意前 n-1 次要把 retain_graph=True ),最后 stack 一下就行。

from oneflow.

wyg1997 avatar wyg1997 commented on May 24, 2024

接口已支持,可以跑一下试试 @lihuizhao

from oneflow.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.