Comments (4)
Summary
torch.autograd.grad()接口与oneflow.autograd.grad()接口对比
Code to reproduce
import torch as torch_original
import oneflow as flow
def exp_reducer(x):
return x.exp().sum(dim=1)
print("torch result: ")
inputs = torch_original.rand(2,2,requires_grad=True)
outputs = exp_reducer(inputs)
torch_inputs = (inputs,)
torch_outputs = (outputs,)
torch_grad_outputs = (torch_original.eye(2),)
torch_result = torch_original.autograd.grad(
torch_outputs,
torch_inputs,
torch_grad_outputs,
allow_unused=True,
create_graph=False,
retain_graph=None,
is_grads_batched=True,
)
print(torch_result)
print("oneflow result: ")
flow_inputs = (flow.tensor(inputs.detach().numpy(),requires_grad=True),)
flow_outputs = ( exp_reducer(flow_inputs[0]), )
flow_grad_outputs = (flow.eye(2),)
flow_result = flow.autograd.grad(
flow_outputs,
flow_inputs,
flow_grad_outputs,
allow_unused=True,
create_graph=False,
retain_graph=None,
)
print(flow_result)
Run result
torch result:
(tensor([[[1.0448, 2.2668],
[0.0000, 0.0000]],
[[0.0000, 0.0000],
[2.5238, 1.5629]]]),)
oneflow result:
Traceback (most recent call last):
File "../../test/test.py", line 107, in <module>
flow_result = flow.autograd.grad(
File "/workspace/software/oneflow/python/oneflow/autograd/autograd.py", line 63, in grad
in_grads = grad_api(
oneflow._oneflow_internal.exception.Exception: out_grad's shape must be same as output's ((2,) vs (2,2))
File "oneflow/api/python/autograd/autograd.cpp", line 113, in Grad
CheckAndInitOutGrads(outputs, out_grads)
File "oneflow/api/python/autograd/autograd.cpp", line 73, in CheckAndInitOutGrads
CHECK_OR_RETURN(*(outputs.at(i)->shape()) == *(out_grads.at(i)->shape()))
Error Type: oneflow.ErrorProto.check_failed_error
from oneflow.
是否可以直接构造一个调用 autograd.grad 接口的例子呢?我们把这个接口对齐一下,再来验证 jacobian 接口。
from oneflow.
明白了,是 is_grads_batched 参数支持的这个功能,主要的作用是把 grad 打包只用走一次 AutogradEngine 就可以完成多次后向计算。我后面可以来支持下。
如果着急实现功能的话,这里有一个绕过的方案:既然这里的作用是把 grad 打包,这里就定义一个 batched_autograd_grad
函数,分批次单独计算每个 grad(注意前 n-1 次要把 retain_graph=True
),最后 stack 一下就行。
from oneflow.
接口已支持,可以跑一下试试 @lihuizhao
from oneflow.
Related Issues (20)
- Missing range check for negative index parameter in oneflow.scatter. HOT 1
- For 'oneflow.scatter_nd', when index parameter exceeds the range of shape, core dumped happen without any error message. HOT 1
- Unstable results in sin/arcsin/arccos calls
- When nan is used as the min or max argument to oneflow.clamp, the wrong result is printed HOT 2
- oneflow.log10 has accuracy difference between cpu and cuda
- oneflow.softmax perform differently between cpu and cuda.
- oneflow.softplus has accuracy difference between cpu and cuda
- oneflow.quantile perform differently between cpu and cuda.
- oneflow.argmax perform differently between cpu and cuda. HOT 1
- oneflow.amax perform differently between cpu and cuda when dim=1 HOT 1
- oneflow.amax/amin perform differently from pytorch
- oneflow.max/min perform differently between cpu and cuda. HOT 1
- oneflow.argsort perform differently between cpu and cuda when dim=1 HOT 2
- oneflow.gt/ge/greater/greater_equal perform differently between cpu and cuda.
- oneflow.cast perform differently between cpu and cuda
- core dumped occurs when an empty array is processed with oneflow.tensordot and the dims argument is 2
- core dumped occurs when an empty array is processed with oneflow.dot
- oneflow.matmul doesn't work for this situation
- [Documention Issue]: Incorrect documentation formula for oneflow.square
- oneflow的tensor如何获取其指针 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from oneflow.