Comments (5)
I just found out that the state_dict
contains the gradients. So they should at least be somewhat reset when loading the global state_dict
(with new gradients) to the local nn.
From the pytorch documentation: "torch.nn.Module.load_state_dict
: Loads a model’s parameter dictionary using a deserialized state_dict.".
To me, that sounds like it loads a copy of the global parameters, meaning that the gradients will be added to the previous global gradients.
from pytorch-a3c.
I have a similar question about the gradient.
Acturally, after lnet.load_state_dict(gnet.state_dict())
being excuted, all the parameters in both . That is to say, the lnet
and gnet
are sharedopt.zero_grad()
will set the gradients in lnet
and gnet
to zero! And, the loss.backward()
will make lnet
and gnet
have the same gradient! So after the 1st iteration, gp._grad = lp.grad
is useless because they are already the same! I find another implementation here involving a if-return
criterion (I guess it corresponds to my claim that the grad assignment is useless after the 1st iteration).
# copy from continuous A3C, consider the cases after the 1st iteration
opt.zero_grad() # zero gradient in both lnet and gnet
loss.backward() # parameters in both lnet and gnet have the same gradients
for lp, gp in zip(lnet.parameters(), gnet.parameters()): # the for loop is useless
# if gp.grad is not None:
# return # This "if-return" code are copied from above link
gp._grad = lp.grad
opt.step() # update gnet parameters (parameters in lnet will not change!)
lnet.load_state_dict(gnet.state_dict()) # update lnet parameters
It is confused to me and it might be a (serious) bug. What if worker A is updating gnet by opt.step and worker B just clears/modifies the gradients by opt.zero_grad()/loss.backward() ? However, the code just works (look the episode reward curve and the visualization)!
BTW, the state_dict
does not contain any gradient info! It is an OrderedDict
of weights and biases of parameters.
from pytorch-a3c.
The lnet.load_state_dict()
function shows as below:
def load_state_dict(self, state_dict):
# deepcopy, to be consistent with module API
state_dict = deepcopy(state_dict)
# Validate the state_dict
groups = self.param_groups
saved_groups = state_dict['param_groups']
it uses deepcopy to isolate parameters from the gnet
. So there is no memory share on here.
So after the 1st iteration, gp._grad = lp.grad is useless because they are already the same!
Once local worker has moved to another worker, the gp._grad
is necessary to switch to another worker's grad.
from pytorch-a3c.
Thanks for your reply! @MorvanZhou
- Yes, the
load_state_dict()
will not make parameters shared. I found it and had scratched out the sentence before. - It will be very kind of you to explain if there might be conficts between workers without locking the shared model.
What if worker A is updating gnet by opt.step and worker B just clears/modifies the gradients by opt.zero_grad()/loss.backward() ?
- Note that I made the following comments by step-by-step debug.
# copy from continuous A3C, consider the cases after the 1st iteration
opt.zero_grad() # zero gradient in both lnet and gnet
loss.backward() # parameters in both lnet and gnet have the same gradients
for lp, gp in zip(lnet.parameters(), gnet.parameters()): # the for loop is useless after the 1st iteration ??
# if gp.grad is not None:
# return # This "if-return" code are copied from above link
gp._grad = lp.grad
opt.step() # update gnet parameters (parameters in lnet will not change!)
lnet.load_state_dict(gnet.state_dict()) # update lnet parameters
from pytorch-a3c.
It will be very kind of you to explain if there might be conficts between workers without locking the shared model.
A lock could be applied in this case, but take a look of HOGWILD for the analysis of backprop without locking.
from pytorch-a3c.
Related Issues (20)
- 没看懂为什么push里面要把ba的类型转成np.int64 HOT 1
- How come different performance? HOT 5
- Memory Leak in case of Architecture Modification
- How to save trained model HOT 2
- Unstable Performance HOT 1
- 关于shared_adam的问题
- About total_loss = (a_loss + c_losss).mean() HOT 1
- Error at the end of discrete and contious (plt.plot)
- About the lock in multiprocess of A3C. HOT 1
- Process hangs at res_queue.get() in Linux HOT 2
- The discrete_A3C.py is not working... HOT 1
- 交叉熵正则
- 关于多维动作空间问题
- 多进程、多线程问题 HOT 2
- 请问能否在cuda/gpu上实现share memory?
- About record() function
- Shared local and global parameters
- is it really a3c implementation? not just actor critic? HOT 2
- 无法拓展到其他env HOT 9
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-a3c.