Comments (21)
@MatheusMRFM, I have managed to modify the A3C LSTM implementation to work for Pong using the openai gym.
I have not observed the -20 reward forever problem that you are seeing. The most common problem I observed was that the agent would train to hit a mean reward of -15 then the gradient would explode and everything would be forgotten. This was mitigated by increasing the size of the LSTM episode buffer from 30 to 100.
I have been training this agent for about 16 hours now and the mean reward is +13. Although I am pleased to be able to get this working, my training time is a lot slower than what was reported in the deepmind's asynchronous methods paper.
If you need some more hints, take a look at my changes.
pong_a3c_train.py
from deeprl-agents.
Thanks for the response!
Actually, I found one tiny detail that made a huge impact in the training process:
- In the current implementation, in the 'train' function, the LSTM state used for updating the global network is rnn_state = self.local_AC.state_init. If you check OpenAI's A3C, they instead use the LSTM state of the last batch of the current worker (unless the episode has ended).
Therefore, I corrected this by inserting the LSTM state into the batch and recovering the last LSTM state before updating the global network. You can see this in my updated repository in github at the 'train' method inside Worker.py.
With this, I managed to get a mean reward of +15 within 17k episodes in Pong. This took me something around one or two days using 8 threads. Still very slow as well.....
from deeprl-agents.
Good to hear that you have made it work, and thanks for sharing your code.
Interesting find. It looks like my code matches what you are doing with the lstm state, although I might be missing something.
I suspect that the slow training times we are seeing is the result of non-optimal hyperparameters, or some sort of inefficiency in the code.
My training progress with pong can be seen in the attached image. 18 hours is about 460 epochs running over 16 cores + 1 gpu.
from deeprl-agents.
Thanks @BenPorebski and @MatheusMRFM for pointing this out. I'm really glad you've both figured this out, as I know it was giving a number of people issues. I've updated the code in the Jupyter notebook to reflect the proper way of remembering the lstm state during training now.
from deeprl-agents.
Thanks for the update and for writing the initial agent @awjuliani!
I spent quite some time trying to implement A3C from scratch without any luck. Your code definitely put me on the right track.
from deeprl-agents.
@awjuliani, I'm just testing out your changes, but it seems that unpacking the sess.run with the self.batch_rnn_state seems to throw an error with other calls to train. Not sure if this is just broken in my code.
v_l,p_l,e_l,g_n,v_n, self.batch_rnn_state,_ = sess.run([self.local_AC.value_loss, self.local_AC.policy_loss, self.local_AC.entropy,
File "tf_a3c_parallel.py", line 319, in
worker_work = lambda: worker.work(max_episode_length, gamma, sess, coord, saver)
File "tf_a3c_parallel.py", line 233, in work
v_l,p_l,e_l,g_n,v_n = self.train(episode_buffer, sess, gamma, v1)
File "tf_a3c_parallel.py", line 175, in train
feed_dict=feed_dict)
ValueError: not enough values to unpack (expected 7, got 6)
I have removed the self.batch_rnn_state from the unpack and it seems to be running, but I'm not sure if it has undone your intended changes.
from deeprl-agents.
@BenPorebski I had that problem because I missed one addition of an input var to that line, you can check the history of the file to see(its a pain cause notebook formatting):
v_l,p_l,e_l,g_n,v_n, self.batch_rnn_state, _ = sess.run([self.local_AC.value_loss,
self.local_AC.policy_loss,
self.local_AC.entropy,
self.local_AC.grad_norms,
self.local_AC.var_norms,
self.local_AC.state_out,
self.local_AC.apply_grads],
feed_dict=feed_dict)
from deeprl-agents.
@DMTSource, oh, of course!! Thank you.
Sorry @awjuliani, ignore my last.
from deeprl-agents.
@BenPorebski is your program applicable to other Atari games?
from deeprl-agents.
@chrisplyn, I've not tested it, but it probably would work for the other Atari games. You might need to double check that the number of actions is set correctly for the new environment.
from deeprl-agents.
@BenPorebski Thanks Ben, I am also trying to leverage your code to play flappybird-v0. I don't understand the lstm part of your code, can you refer me to any tutorial I can read?
from deeprl-agents.
@chrisplyn, it's been a while since I was last playing with this code, so I'm not sure I understand it well enough to explain. But the general idea of using an lstm is for it to function as a short term memory of the last 50 frames, which is set on line 227. I might be wrong, but I believe it makes the computation a bit quicker than feeding in an 848450 array through the entire network.
If you are after a bit more detail than that, I might have some time tomorrow to have a bit of a play. Maybe it will refresh my memory.
from deeprl-agents.
@BenPorebski Thanks a lot Ben, based on your experiment, will A3C works better than other DQN techniques dealing with Atari Games?
from deeprl-agents.
@chrisplyn In my limited experience when playing with RL, I found A3C to be more stable than DQN. It could be that I botched my DQN implementation, but it struggled learning to play pong with a positive score. The A3C implementation that I posted does work, however, I still found replicate training runs to occasionally go out of control and forget everything it has learnt.
I don't fully understand why this happens. I'm not sure if it is something specific to my implementation of the RL algorithms, or if it is a common experience.
from deeprl-agents.
@BenPorebski Hi Ben, do you know how to do separate testing using the trained model?
from deeprl-agents.
@chrisplyn, Hi Chris, do you mean having the trained model play without learning?
from deeprl-agents.
@BenPorebski yeah, I am not quite familiar with tensorflow, can you share you code of doing this?
from deeprl-agents.
@chrisplyn, so I think this worked for playing pong last time I checked https://gist.github.com/BenPorebski/0df9a2da264bdc33aec26f0809685d8a
from deeprl-agents.
@BenPorebski Hi Ben, have you finally found the reason why it happens '' I still found replicate training runs to occasionally go out of control and forget everything it has learnt.'' ?
It also happens to me with another A3C implementation and I would like to know if it is common experience, due to A3C instability or due to the way A3C was implemented.
from deeprl-agents.
@rjgarciap Hi Ricardo, I have not thoroughly explored this any further with my implementation. However, this does seem to be a very common experience in reinforcement learning [1][2].
from deeprl-agents.
@BenPorebski Thank you very much.
from deeprl-agents.
Related Issues (20)
- _ HOT 1
- simple and odd python problem HOT 2
- Double-Dueling-DQN: question about the rate to update target network
- Double-Dueling-DQN stops learning
- Can't see the source code. HOT 2
- checkGoal() in gridworld.py
- apply_gradients need a lock?
- A3C-Doom, is threading can make real parallelism?
- Please add more comments..
- Target network updates / Double-Dueling-DQN.ipynb HOT 1
- A3C Doom : function error
- DRQN plays FlappyBird
- what is the mean of multiply (1./(i+1))?
- Reward Smoothing
- A3C Doom: Why there should be no more workers than there are threads on CPU?
- How to do twice training session for the same buffer
- scipy.misc.imresize is deprecated in Scipy 1.14.3 --> modified code HOT 1
- A garbage code in Model-Network.ipynb
- Issue in DRQN
- Crash and burn in TF 2.0 and alter
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from deeprl-agents.