Git Product home page Git Product logo

Comments (4)

Khrylx avatar Khrylx commented on July 22, 2024

Hi,

I solved this problem by switching to CUDA 9.0 and reinstall PyTorch. Another thing is to use LSTMCell than LSTM.

Best,
Ye

from pytorch-rl.

erschmidt avatar erschmidt commented on July 22, 2024

Thanks for the fast feedback.

Sadly still no luck. I switched my GRU Layer to a GRUCell which only changed the Error to RuntimeError: GRUFused is not differentiable twice.

Since I'm working in an environment where can't easily change the CUDA Version (Currently 7.5) using a different CUDA is no option. Are you sure that this would solve the problem?
The relevant functions of my policy look like this:

def forward(self, inputs):
        x = self.hidden_activation(self.input_layer(inputs))

        # Hidden Layers
        for hidden_layer in self.hidden_layers:
            x = self.hidden_activation(hidden_layer(x))
            
        # GRUCell
        outputs = []
        for seq in range(x.size(1)):
                self.hidden = self.gru(x[:, seq], self.hidden)
                outputs.append(self.hidden)
        x = torch.stack(outputs, 1)
        
        # Output Layer
        action_mean = self.output_layer(x)
        action_log_std = self.a_logstd.expand_as(action_mean)
        action_std = torch.exp(action_log_std)

        return action_mean, action_log_std, action_std

def get_log_prob(self, x, actions):
        self.hidden = self.init_hidden(x.size(0))
        action_mean, action_log_std, action_std = self.forward(x)
        return normal_log_density(actions, action_mean, action_log_std, action_std, is_recurrent=True)

def get_fim(self, x):
        self.hidden = self.init_hidden(x.size(0))
        mean, _, _ = self.forward(x)
        cov_inv = self.a_logstd.data.exp().pow(-2).squeeze(0).repeat(x.size(0))
        param_count = 0
        std_index = 0
        id = 0
        for name, param in self.named_parameters():
            if name == "a_logstd":
                std_id = id
                std_index = param_count
            param_count += param.data.view(-1).shape[0]
            id += 1
        return cov_inv, mean, {'std_id': std_id, 'std_index': std_index}

I'm really hoping to solve this issue since I need to implement this using a RNN policy.

from pytorch-rl.

Khrylx avatar Khrylx commented on July 22, 2024

I’m pretty positive that changing the CUDA version will solve the problem if you are using GRUCell since that was my case and I didn’t change a single line of code. Alternatively, you can use PPO instead of TRPO, which should give you similar performance.

from pytorch-rl.

erschmidt avatar erschmidt commented on July 22, 2024

I did change from TRPO to PPO. I will compare results but it seems to train fine. Thank you !

from pytorch-rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.