Comments (4)
Hi,
I solved this problem by switching to CUDA 9.0 and reinstall PyTorch. Another thing is to use LSTMCell than LSTM.
Best,
Ye
from pytorch-rl.
Thanks for the fast feedback.
Sadly still no luck. I switched my GRU Layer to a GRUCell which only changed the Error to RuntimeError: GRUFused is not differentiable twice.
Since I'm working in an environment where can't easily change the CUDA Version (Currently 7.5) using a different CUDA is no option. Are you sure that this would solve the problem?
The relevant functions of my policy look like this:
def forward(self, inputs):
x = self.hidden_activation(self.input_layer(inputs))
# Hidden Layers
for hidden_layer in self.hidden_layers:
x = self.hidden_activation(hidden_layer(x))
# GRUCell
outputs = []
for seq in range(x.size(1)):
self.hidden = self.gru(x[:, seq], self.hidden)
outputs.append(self.hidden)
x = torch.stack(outputs, 1)
# Output Layer
action_mean = self.output_layer(x)
action_log_std = self.a_logstd.expand_as(action_mean)
action_std = torch.exp(action_log_std)
return action_mean, action_log_std, action_std
def get_log_prob(self, x, actions):
self.hidden = self.init_hidden(x.size(0))
action_mean, action_log_std, action_std = self.forward(x)
return normal_log_density(actions, action_mean, action_log_std, action_std, is_recurrent=True)
def get_fim(self, x):
self.hidden = self.init_hidden(x.size(0))
mean, _, _ = self.forward(x)
cov_inv = self.a_logstd.data.exp().pow(-2).squeeze(0).repeat(x.size(0))
param_count = 0
std_index = 0
id = 0
for name, param in self.named_parameters():
if name == "a_logstd":
std_id = id
std_index = param_count
param_count += param.data.view(-1).shape[0]
id += 1
return cov_inv, mean, {'std_id': std_id, 'std_index': std_index}
I'm really hoping to solve this issue since I need to implement this using a RNN policy.
from pytorch-rl.
I’m pretty positive that changing the CUDA version will solve the problem if you are using GRUCell since that was my case and I didn’t change a single line of code. Alternatively, you can use PPO instead of TRPO, which should give you similar performance.
from pytorch-rl.
I did change from TRPO to PPO. I will compare results but it seems to train fine. Thank you !
from pytorch-rl.
Related Issues (20)
- Confusion about advantage computation
- question about A2C
- GAIL discriminator loss uses complete expert data in each iteration? HOT 4
- is this an error:num_steps += (t + 1) ? HOT 1
- about the kl HOT 3
- Question on multiprocessing HOT 1
- Doubt regarding the calculation of advantage HOT 2
- Mountain Car
- How are we using rewards in imitation learning? HOT 4
- Various questions? HOT 1
- Implementation problem HOT 6
- Fail to train of GAIL in Ant-v2 environment HOT 3
- Is this repository only work for Gym Environments?
- What's Conjugate gradients and line_search in TROP?
- What's Conjugate gradients and line_search in TRPO? HOT 1
- About computing Hessian*vector
- A question bout PPO implementation
- Is the implelented performance comparable with the results provided in original GAIL paper?
- TRPO,Is fixed_log_probs the same as log_probs HOT 1
- Why does GAIL get lower rewards the more it is trained? HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-rl.