Git Product home page Git Product logo

deeprl-chinese's People

Contributors

liber145 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

deeprl-chinese's Issues

15_mac_a2c.py目标网络未起作用

第15章意图用目标网络计算TD目标,但是下述代码认为奖励已经充分包括预测值,那目标网络的定义就没有意义了

        # 注意到simple_spread_v2中,reward是根据当前状态到目标位置的距离而计算的奖励。因此,直接使用reward作为td目标值更合适。
        # with torch.no_grad():
        #     td_value = self.target_value(bns).squeeze()
        #     td_value = br + self.gamma * td_value * (1 - bd)

另外有一个问题,就是P127页第8章的更新目标网络采用超参数 r 调整更新目标网络。想请问,如果使用设置同步频率,即n个周期同步一次目标网络代替文中的每个周期按照比例更新,在训练的效果上面有什么区别~

15_mac_a2c.py中Categorical的冗余操作

class MAC(nn.Module):
    def policy(self, observation, agent):
        # 参考https://pytorch.org/docs/stable/distributions.html#score-function
        log_prob_action = self.agent2policy[agent].policy(observation)
        m = Categorical(logits=log_prob_action)  # 应该用prob传参
        action = m.sample()
        log_prob_a = m.log_prob(action) 
        return action.item(), log_prob_a

上文定义的策略函数返回的是归一化概率和归一化对数概率,所以创建Categorical对象时候应该传入的参数名是prob,而不是logits

m = Categorical(prob=log_prob_action)

训练loss大幅震荡且没有下降趋势

老师您好,我按照目前的代码用默认参数训练完DQN后,发现过了warm-up阶段后训练loss猛增且一直在1e+0和5e+2的高位间剧烈震荡(见下方图),在100_000步内无法收敛,且根本看不出其在增加训练量后会有下降的趋势;另外,尽管随epsilon衰减episode_reward和episode_length平均水平有所提升, 但同loss一样,两者也是剧烈震荡。礼貌询问应该试着往哪些方向调哪些超参?多谢!

loss随训练步数的变化:
loss_2023-3-23_0-11-48

episode_reward随训练步数的变化:
episode_length_2023-3-23_0-11-48

episode_length随训练步数的变化:
episode_reward_2023-3-23_0-11-48

04_dqn.py has some bug

When i run the file use the command: python .\04_dqn.py --do_train

Traceback (most recent call last):
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 242, in <module>
    main()
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 235, in main
    train(args, env, agent)
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 129, in train
    action = agent.get_action(torch.from_numpy(state))
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 43, in get_action
    qvals = self.Q(state)
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\code\github\wangshusen\DeepRL-Chinese\04_dqn.py", line 29, in forward
    x = F.relu(self.fc1(state))
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "E:\develop\anaconda3\envs\ray\lib\site-packages\torch\nn\modules\linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

The file should become like this:

action = agent.get_action(torch.from_numpy(state))

action = agent.get_action(torch.from_numpy(state).to(args.device))

DeepRL-Chinese/04_dqn.py

Lines 158 to 162 in 55a9f5e

bs = torch.tensor(bs, dtype=torch.float32)
ba = torch.tensor(ba, dtype=torch.long)
br = torch.tensor(br, dtype=torch.float32)
bd = torch.tensor(bd, dtype=torch.float32)
bns = torch.tensor(bns, dtype=torch.float32)

bs = torch.tensor(bs, dtype=torch.float32, device=args.device)
ba = torch.tensor(ba, dtype=torch.long, device=args.device)
br = torch.tensor(br, dtype=torch.float32, device=args.device)
bd = torch.tensor(bd, dtype=torch.float32, device=args.device)
bns = torch.tensor(bns, dtype=torch.float32, device=args.device)

DDQN 的代码实现不正确

以下是 repo 中 DDQN 的实现,可以看到target NN 在计算下一个状态的 next Q value的时候,使用的 action 并不是用self.model得到的,而是直接用 target NN 在下一个状态时最大的价值的动作,这种实现方式是基本的target network + DQN 而不是真正的 DDQN

class DoubleDQN:
    def __init__(self, dim_obs=None, num_act=None, discount=0.9):
        self.discount = discount
        self.model = QNet(dim_obs, num_act)
        self.target_model = QNet(dim_obs, num_act)
        self.target_model.load_state_dict(self.model.state_dict())

    def get_action(self, obs):
        qvals = self.model(obs)
        return qvals.argmax()

    def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
        # Compute current Q value based on current states and actions.
        qvals = self.model(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
        # next state的value不参与导数计算,避免不收敛。
        next_qvals, _ = self.target_model(next_s_batch).detach().max(dim=1)
        loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
        return loss

真正的 DDQN 应该改写成

    def ddqn_compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch):
        # Compute current Q value based on current states and actions.
        qvals = self.model(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze()
        next_s_action = self.model(next_s_batch).argmax(dim=1)
        next_qvals, _ = self.target_model(next_s_batch).gather(1, next_s_action.unsqueeze(1)).detach().max(dim=1)
        loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals)
        return loss

经过原始代码测试,eval 时前者平均 reward 是-142.57142857142858,后者是-138.25
可见确实是 DDQN 的效果更好,但是经过观察,DDQN 训练可能需要更多的时间,如果我把max step 设置成 20W
结果反而是 DDQN 更差,这非常诡异
诡异 2:我重新测试了一次 max step = 10w,结果发现这次原compute loss 的方法直接训崩了,avg reward = -200,DDQN 的 avg reward 高达-136,这得出第二个结论,DQN 的训练太不稳定了

我把 ER 改成 PER,结果依然非常不稳定,有时候 PER 更好,有时候 ER 更好,就离谱,RL 真的太不稳定了,哪怕是我这里的 DDQN 也不一定比前面那个版本更好

09——trpo 运行之后报错,请问怎么修改?

python : Traceback (most recent call last):
所在位置 行:1 字符: 1

  • python -u 09_trpo.py --do_train --output_dir output/trpo 2>&1 | tee o ...
  •   + CategoryInfo          : NotSpecified: (Traceback (most recent call last)::String) [], RemoteException
      + FullyQualifiedErrorId : NativeCommandError
    
    File "09_trpo.py", line 497, in <module>
      main()
    File "09_trpo.py", line 490, in main
      train(args, env, agent)
    File "09_trpo.py", line 400, in train
      action = agent.get_action(torch.from_numpy(state)).item()
    

TypeError: expected np.ndarray (got tuple)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.