Git Product home page Git Product logo

pytorch-trpo's People

Contributors

ikostrikov2 avatar kaixhin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-trpo's Issues

what does volatile=True for?

Dear author:
I found your code very helpful. However, I have problems trying to read the following code:

action_means, action_log_stds, action_stds = policy_net(Variable(states, volatile=volatile))

I wonder the usage of volatile flag. I want to when u set volatile to True/False.

Main.py line128

That KL is alwalys a zero Tensor .Is it the problem of torch's version?

The step of t is not necessary in main.py

In your main.py, line 147: for t in range(10000): # Don't infinite loop while learning
But actually, the t ends at 50, because the env is done in 50 steps. so the range(10000) is so big and not necessary.

compute the Fisher-Vector Producy

Hello, I wanna ask that in line 67 in your trpo.py, you will get two terms, and in the TRPO paper, he said the second term vanishes ?, and you add v*damping, I guess its function is to make sure the positive definiteness? , could you explain it in detail? thank you very much!
and in your line 117 in your main.py, could you explain why this can approximate the average KL in detail? thank you very much!

Object oriented

It would be nice if the agent was an object (with methods "get_action" and "remember" or similar) so that it could be more easily reused.

Bootstrapping the value function?

Currently, the target for the value function is the discounted sum of all future rewards. This gives unbiased estimate but will result in higher variance. An alternative is to use bootstrapped estimate, i.e. something like
target[i] = rewards[i] + gamma * prev_values * masks[i]

Bootstrapping is often preferred due to low variance, even though it results in biased gradient estimate.

It seems that the importance sampling code part is wrong.

pytorch-trpo/main.py

Lines 108 to 119 in e200eb8

fixed_log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds).data.clone()
def get_loss(volatile=False):
if volatile:
with torch.no_grad():
action_means, action_log_stds, action_stds = policy_net(Variable(states))
else:
action_means, action_log_stds, action_stds = policy_net(Variable(states))
log_prob = normal_log_density(Variable(actions), action_means, action_log_stds, action_stds)
action_loss = -Variable(advantages) * torch.exp(log_prob - Variable(fixed_log_prob))
return action_loss.mean()

The fixed log prob part of the line and the "get_loss" function part are exactly the same.
The two parts are executed consecutively so that the two values ("fixed_log_prob", "log_prob") ​​are exactly the same.
Is there a reason you wrote the code like this?

Is the get_kl() function correct?

Thanks for your great code!
I notice that in the function get_kl(), you use policy net to generate the mean, log_std and std, then copy these three parameters and calculate the KL divergence between the original parameters and the copied parameters, which is obviously zero all the time. Is this a bug or a intended behavior?

other env

hello, so I notice your code is about mujoco, and I wonder how to modify it to fit other env, I have tried but failed. thx a lot!
ikostrikov, thx very much. I have tried one continuous control game "MountainCarContinuous-v0" in classical control and it succeeds.

What and When to send on the GPU?

I'm new to pytorch and am having a hard time getting used to handling the variables properly on cpu and gpu. As we are calculating our own losses here, I am having trouble understanding what and when to send to the device (gpu). Would really appreciate an explanation of how to go about this. The code is quite well written and easy to understand btw.

Use pytorch 0.2.0?

There is no torch.autograd.grad in 0.1.2, the newest release version.

Idon‘t konw what the “neggdotstepdir” for ,Thanks !!!

Thank you very much for the code you provided!I learn a lot from it . I would like to ask what is the function of these lines of code, is there any mathematical proof or the like, thank you very much!!!these are different from the original paper?Thanks!!!

neggdotstepdir = (-loss_grad * stepdir).sum(0, keepdim=True)
expected_improve = expected_improve_rate * stepfrac
ratio = actual_improve / expected_improve
 if ratio.item() > accept_ratio and actual_improve.item() > 0:

what is shs?

    flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads]).data

    return flat_grad_grad_kl + v * damping

stepdir = conjugate_gradients(Fvp, -loss_grad, 10)

shs = 0.5 * (stepdir * Fvp(stepdir)).sum(0, keepdim=True)

lm = torch.sqrt(shs / max_kl)
fullstep = stepdir / lm[0]

According to the TRPO formular,
$direction=\sqrt{(\frac{2\delta}{g^T F^{-1} g})} F^{-1} g$,
So $shs=g^T F^{-1} g$,
but your coding is different from that, why?

doc?

like eg, imagine I have my own policy, that takes in a state, and outputs an action, or perhaps a distribution over actions; and I have a world that takes an action, and returns a reward and a new state, how would I plug these into this TRPO implementation?

How to modify the code for discrete actions?

Hi, thanks once again for implementing a really interesting algorithm in PyTorch 👍 ,

I was wondering how to modify the code to be able to use it for environments which require discrete actions, (say cartpole as in the other pytorch trpo implementation, or maybe even Atari games)?

What is get_kl() doing in main.py?

Hi. Thanks for publishing implementation of trpo.

I have question about get_kl().

I thought what get_kl() is supposed to do is to calculate the kl divergence of old policy and new policy, but this get_kl() seems always returning 0.

Also,I do not see kl constraining part in the parameters updating process.

Is this code the modification of trpo or do I have some misunderstanding?

Thanks,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.