Git Product home page Git Product logo

pytorch-a3c's People

Contributors

jbwasse2 avatar morvanzhou avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-a3c's Issues

Error at the end of discrete and contious (plt.plot)

Be gentle this is my first issue I ever write. :)
First thank you for your work, I am trying to use this for seminar paper and this was the first actually executing repo I found.
To the issue:
I get the following error at the end of the execution in the plt.plot(res) line:

QObject::moveToThread: Current thread (0x557e202d2ac0) is not the object's thread (0x557e1fa88a40).
Cannot move to target thread (0x557e202d2ac0)

qt.qpa.plugin: Could not load the Qt platform plugin "xcb" in "/home/oliverheidmann/.local/lib/python3.8/site-packages/cv2/qt/plugins" even though it was found.
This application failed to start because no Qt platform plugin could be initialized. Reinstalling the application may fix this problem.

Available platform plugins are: xcb, dxcb, xcb, eglfs, linuxfb, minimal, minimalegl, offscreen, vnc, wayland-egl, wayland, wayland-xcomposite-egl, wayland-xcomposite-glx.

zsh: abort (core dumped)  python continuous_A3C.py

After some googeling I found that this seems to be an issue with threading.
My idea is, that 'res' is still bound to a thread and as such the plt.plot(res) throws this error.
But that is just a first guess.

When removing the plotting lines and saving res with np.save("fileName", res) the program executes without an error.
Any ideas?
Best regards
NodmGatall

交叉熵正则

您好,非常感谢您提供的代码,对我有很大的帮助!!!!!!!!
但是我也遇到了困惑的地方,请问连续空间中损失函数有交叉熵正则,但是离散空间的没有,是什么原因呀

Process hangs at res_queue.get() in Linux

In discrete_A3C.py, the res_queue.get() in the main function hangs for a very long time (possibly forever) in Linux, but the entire code works perfectly fine on Windows.

workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
[w.start() for w in workers]
res = []                    
    while True:
        print('Last printed checkpoint. Printed only during first iteration of while loop')
        r = res_queue.get()
        print('This line is never printed.')
        if r is not None:
            res.append(r)
        else:
            break
[w.join() for w in workers]

No errors are thrown, so presumably Pytorch is installed correctly and working. By inserting print() statements at various checkpoints in the code snippet above (and in the Worker class constructor function) reveals that the code never moves past the very first call to res_queue.get(). Is anyone else having this same problem on Linux?

请问能否在cuda/gpu上实现share memory?

您好,您在continuous_A3C代码中使用cpu进行并行的计算,我发现若将gnet放置于cuda上,在process.start()后,gnet初始化的weight会都变成0,这是什么原因呢?是否能在gpu上进行share的操作呢?
在w.start()之前
image
在w.start()之后
image

关于shared_adam的问题

莫烦老师好!请问shared_adam应当如何编写?我看了一些其他人的a3c代码,发现其实大同小异,就是除了share_memory之外或多或少地抄了一边optimizer模块中adam的源代码而已,所以我有些糊涂,编写我的shared_adam时应该怎么“抄”呢?感谢莫烦大佬!!

About the lock in multiprocess of A3C.

Excuse me, I want to know whether the lock is needed in the multiprocess of A3C. I have saw some codes of the implement of A3C, and sometimes they use a lock when they update the gradient of shared model with the single worker of A3C. So is the usage of the lock necessary?

Some question about implement in Net

Hi, Morvan:

In here, why do we need to write self.eval()? what does it do in effect?

In the Pytorch Official document, it says

Sets the module in evaluation mode. This has any effect only on modules such as Dropout or BatchNorm.

The first question is: in this implementation, it doesn't exist Dropout layer so as BatchNorm layer so self.eval() has no effect?

The same situation appears in line 52: self.train()

Another question is about the usage of self.eval() and self.train(). If there exists Dropout layer or BatchNorm layer in our neural network, what does the above code actually do?

Unstable Performance

Hi Morvan, thank you so much for sharing this code. But I find the performance unstable when running the code for both discrete and continuous situation. And the average reward is much smaller than the graph you drew. Is there any reason for that?

Cannot work when network size gets bigger

Hi @MorvanZhou. Thank you for your tutorial. I'm trying to modify the A3C from Cartpole to MsPacman. I found that after I change the network to a CNN, the code will get stuck in the forward function. It could be run on Mac without problems. But It will get stuck when running on Linux. To illustrate the problem, I simply changed the N_S to 10000 in discrete_A3C.py and use a randomly generated numpy vector as a state. It will also stuck in forward function and has no any warning or error information. Do you have any ideas about that?

"""
Reinforcement Learning (A3C) using Pytroch + multiprocessing.
The most simple implementation for continuous action.

View more on my Chinese tutorial page [莫烦Python](https://morvanzhou.github.io/).
"""

import torch
import numpy as np
import torch.nn as nn
from utils import v_wrap, set_init, push_and_pull, record
import torch.nn.functional as F
import torch.multiprocessing as mp
from shared_adam import SharedAdam
import gym
import os
os.environ["OMP_NUM_THREADS"] = "1"

UPDATE_GLOBAL_ITER = 10
GAMMA = 0.9
MAX_EP = 4000

env = gym.make('CartPole-v0')
N_S = 10000
N_A = env.action_space.n


class Net(nn.Module):
    def __init__(self, s_dim, a_dim):
        super(Net, self).__init__()
        self.s_dim = s_dim
        self.a_dim = a_dim
        self.pi1 = nn.Linear(s_dim, 100)
        self.pi2 = nn.Linear(100, a_dim)
        self.v1 = nn.Linear(s_dim, 100)
        self.v2 = nn.Linear(100, 1)
        set_init([self.pi1, self.pi2, self.v1, self.v2])
        self.distribution = torch.distributions.Categorical

    def forward(self, x):
        pi1 = F.relu(self.pi1(x))
        logits = self.pi2(pi1)
        v1 = F.relu(self.v1(x))
        values = self.v2(v1)
        return logits, values

    def choose_action(self, s):
        self.eval()
        logits, _ = self.forward(s)
        prob = F.softmax(logits, dim=1).data
        m = self.distribution(prob)
        return m.sample().numpy()[0]

    def loss_func(self, s, a, v_t):
        self.train()
        logits, values = self.forward(s)
        td = v_t - values
        c_loss = td.pow(2)
        
        probs = F.softmax(logits, dim=1)
        m = self.distribution(probs)
        exp_v = m.log_prob(a) * td.detach().squeeze()
        a_loss = -exp_v
        total_loss = (c_loss + a_loss).mean()
        return total_loss


class Worker(mp.Process):
    def __init__(self, gnet, opt, global_ep, global_ep_r, res_queue, name):
        super(Worker, self).__init__()
        self.name = 'w%i' % name
        self.g_ep, self.g_ep_r, self.res_queue = global_ep, global_ep_r, res_queue
        self.gnet, self.opt = gnet, opt
        self.lnet = Net(N_S, N_A)           # local network
        self.env = gym.make('MsPacman-v0').unwrapped

    def run(self):
        total_step = 1
        while self.g_ep.value < MAX_EP:
            s = self.env.reset()
            s = np.random.rand(N_S)
            buffer_s, buffer_a, buffer_r = [], [], []
            ep_r = 0.
            while True:
                if self.name == 'w0':
                    self.env.render()
                a = self.lnet.choose_action(v_wrap(s[None, :]))
                s_, r, done, _ = self.env.step(a)
                s_ = np.random.rand(N_S)
                if done: r = -1
                ep_r += r
                buffer_a.append(a)
                buffer_s.append(s)
                buffer_r.append(r)

                if total_step % UPDATE_GLOBAL_ITER == 0 or done:  # update global and assign to local net
                    # sync
                    push_and_pull(self.opt, self.lnet, self.gnet, done, s_, buffer_s, buffer_a, buffer_r, GAMMA)
                    buffer_s, buffer_a, buffer_r = [], [], []

                    if done:  # done and print information
                        record(self.g_ep, ep_r, self.res_queue, self.name, 1, 0)
                        break
                s = s_
                total_step += 1
        self.res_queue.put(None)


if __name__ == "__main__":
    gnet = Net(N_S, N_A)        # global network
    gnet.share_memory()         # share the global parameters in multiprocessing
    opt = SharedAdam(gnet.parameters(), lr=0.0001)      # global optimizer
    global_ep, global_ep_r, res_queue = mp.Value('i', 0), mp.Value('d', 0.), mp.Queue()

    # parallel training
    workers = [Worker(gnet, opt, global_ep, global_ep_r, res_queue, i) for i in range(mp.cpu_count())]
    [w.start() for w in workers]
    res = []                    # record episode reward to plot
    while True:
        r = res_queue.get()
        if r is not None:
            res.append(r)
        else:
            break
    [w.join() for w in workers]

    import matplotlib.pyplot as plt
    plt.plot(res)
    plt.ylabel('Moving average ep reward')
    plt.xlabel('Step')
    plt.show()

The discrete_A3C.py is not working...

I got this error:

Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/Users/Petrelli/Desktop/Research/code/pytorch-A3C/continuous_A3C.py", line 88, in run
a = self.lnet.choose_action(v_wrap(s[None, :]))
TypeError: tuple indices must be integers or slices, not tuple

Then I just changed the line 85 to a = self.lnet.choose_action(v_wrap(s[None :]))

However, another error comes...

Traceback (most recent call last):
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/process.py", line 315, in _bootstrap
self.run()
File "/Users/Petrelli/Desktop/Research/code/pytorch-A3C/discrete_A3C.py", line 85, in run
a = self.lnet.choose_action(v_wrap(s[None :]))
File "/Users/Petrelli/Desktop/Research/code/pytorch-A3C/utils.py", line 11, in v_wrap
if np_array.dtype != dtype:
AttributeError: 'tuple' object has no attribute 'dtype'

Not sure if it is the python version's reason (I used python 3.9 and 3.10 to run the code, on Mac).

About total_loss = (a_loss + c_losss).mean()

I suppose in the paper, authors maybe calculate the loss for actor and critics respectively. I was wondering it really works better when you do "total_loss = (a_loss + c_losss).mean()" in the loss_func()?

Thanks

About record() function

Why record episode reward in this way? This makes the reward curve look nice but in fact it is not.
Why not just record the value of ep_r?

if global_ep_r.value == 0.:
      global_ep_r.value = ep_r
else:
      global_ep_r.value = global_ep_r.value * 0.99 + ep_r * 0.01

关于loss function的疑问

loss_function中的

exp_v = m.log_prob(a) * td.detach()

log_prob是[prob1, prob2, prob3]
td 是 [[value1, value2, value3]]
这样直接相乘得到的是一个二维矩阵,但是A2C里面不应该是对应步骤的A与对应的actor_loss相乘吗?
是否该改为

exp_v = m.log_prob(a) * td.detach()[0]

这是我更改loss函数后的reward,似乎具有更好的稳定性?

Error happens read No handlers could be found for logger "werkzeug"

Hi Marvan,
I am trying to implement A3C in another simulator based on your continuous and multiprocess A3C algorithms.
However errors below showed up when running it

No handlers could be found for logger "werkzeug"

and after press ctrl+c it showed

Traceback (most recent call last):
File "continuous_A3C.py", line 147, in
r = res_queue.get()
File "/usr/lib/python2.7/multiprocessing/queues.py", line 117, in get
res = self._recv()
File "/usr/local/lib/python2.7/dist-packages/torch/multiprocessing/queue.py", line 21, in recv
buf = self.recv_bytes()
KeyboardInterrupt
Process w0:
Traceback (most recent call last):
File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in bootstrap
Process w1:
Traceback (most recent call last):
File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in bootstrap
self.run()
File "continuous_A3C.py", line 113, in run
s
, r, done = self.env.forward_action(mov_mag, rot_mag)
File "/home/elizabeth/continunous_marvon/environment_ai2thor.py", line 90, in forward_action
self.run()
File "continuous_A3C.py", line 113, in run
self.controller.step(dict(action='MoveAhead', moveMagnitude=move_mag, snapToGrid=False))
s
, r, done = self.env.forward_action(mov_mag, rot_mag)
File "/usr/local/lib/python2.7/dist-packages/ai2thor/controller.py", line 537, in step
File "/home/elizabeth/continunous_marvon/environment_ai2thor.py", line 90, in forward_action
self.controller.step(dict(action='MoveAhead', moveMagnitude=move_mag, snapToGrid=False))
File "/usr/local/lib/python2.7/dist-packages/ai2thor/controller.py", line 537, in step
self.last_event = queue_get(self.request_queue)
File "/usr/local/lib/python2.7/dist-packages/ai2thor/server.py", line 39, in queue_get
self.last_event = queue_get(self.request_queue)
File "/usr/local/lib/python2.7/dist-packages/ai2thor/server.py", line 39, in queue_get
res = que.get(block=True, timeout=0.5)
File "/usr/lib/python2.7/Queue.py", line 177, in get
res = que.get(block=True, timeout=0.5)
File "/usr/lib/python2.7/Queue.py", line 177, in get
self.not_empty.wait(remaining)
File "/usr/lib/python2.7/threading.py", line 359, in wait
self.not_empty.wait(remaining)
File "/usr/lib/python2.7/threading.py", line 359, in wait
_sleep(delay)
KeyboardInterrupt
_sleep(delay)
KeyboardInterrupt

Did you meet this before?
Is this res = que.get(block=True, timeout=0.5) the reason why it stucked?
Do you have any idea to solve this?
Look forward to your reply. Thanks a lot!

无法拓展到其他env

周老师 我遇到一个很奇怪的问题:我把pytorch的a3c中所有关于s维度的变量都补0,也就是把env的s从原来的4拓宽到2048以后,代码中的“forward函数”不报错,但是会卡在这个函数里面不动。我仔细检查过,网络没有任何维度匹配上的问题,但是自己现在反复搞不定这个问题。你有什么思路吗?

question about push_and_pull funtion

Hi, Morvan,

In your push_and_pull function, you update the global grad without any condition, see https://github.com/MorvanZhou/pytorch-A3C/blob/master/discrete_A3C.py#L95. However, for other implementations of a3c which are quite similar with yours update global grad only when global grad is None. See https://github.com/greydanus/baby-a3c/blob/master/baby-a3c.py#L159 and https://github.com/ikostrikov/pytorch-a3c/blob/master/train.py#L13. I am quite confused about it. Would you mind giving me your insight about it?

多进程、多线程问题

近期在编写pytorch多线程、多进程算法训练。

想问一下如果是使用thread多线程训练的话,还需要进行share memory的操作吗?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.