Git Product home page Git Product logo

Comments (9)

boris-il-forte avatar boris-il-forte commented on July 19, 2024

we are working into it right now. In the next few days, the support will be added in version 1.7.0.

from mushroom-rl.

boris-il-forte avatar boris-il-forte commented on July 19, 2024

I've implemented the feature in the above-mentioned commit.
It still experimental, and it's available in the dev branch.
Please feel free to report bugs or propose enhancement.

from mushroom-rl.

davidenitti avatar davidenitti commented on July 19, 2024

is there an example of this? I would like to get the loss value to plot it in tensorboard.
thanks!

from mushroom-rl.

boris-il-forte avatar boris-il-forte commented on July 19, 2024

Yes! TorchApproximator has a property called loss_fit, that returns the average value of the loss of the last fit call.

Unfortunately, we don't log the value of the loss of every epoch in the fit, if multiple epochs are required in the fit call.

However, if you are working with algorithms such as DQN, you just need to write a callback that does the following:

  1. take the q function approximator
  2. use the "model" property to get access to TorchApproximator
  3. use the "loss_fit" property to get the loss of the last fit call (in DQN is just one step on a single minibatch)
  4. log the data as you prefer

Maybe we would add more functionalities to log the loss in the future, but right now we have other priorities: we are porting to the new version of Mujoco and soon the new version of gym, so our time for other improvements is limited.

from mushroom-rl.

boris-il-forte avatar boris-il-forte commented on July 19, 2024

@davidenitti I've also updated the documentation, which was not up to date. Probably we need a tutorial, but I guess you cannot expect it soon due to time reasons explained above.

from mushroom-rl.

davidenitti avatar davidenitti commented on July 19, 2024

@boris-il-forte thanks, I added a writer (from tensorboard in core and added this after the if fit_condition():

if self.writer is not None and self.agent.approximator.model.loss_fit is not None:
    self.writer.add_scalar("train/loss", self.agent.approximator.model.loss_fit, self._total_steps_counter)

if there are multiple losses would loss_fit contain them all? is there another way to do it? maybe in the regressor?
thanks!

from mushroom-rl.

boris-il-forte avatar boris-il-forte commented on July 19, 2024

You don't need to modify the core, you could just use a fit callback (passing it to core). It can also be an object.

Unfortunately, we didn't test multiple losses, but I guess it will not work. However, if you are passing a class loss to the DQN fit, that contains any loss you want, you could simply log your loss there manually.

from mushroom-rl.

davidenitti avatar davidenitti commented on July 19, 2024

I don't know if it's possible to access self.agent.approximator.model.loss_fit from the callback_fit which gets the dataset only as input. (unless I modify the callbacks_fit to get other variables as input so I can get loss_fit

for c in self.callbacks_fit:
    c(dataset)

from mushroom-rl.

boris-il-forte avatar boris-il-forte commented on July 19, 2024

the callback can be an object. Just pass .model attribute to the callback constructor, and then you have access to it.

Anyways, this is for sure another topic that needs improvements for the logging as is more relevant nowadays. I'll try to look at it in the near future

from mushroom-rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.