Git Product home page Git Product logo

Comments (7)

lxuechen avatar lxuechen commented on June 27, 2024

The example seems interesting. I will make v1 a default for now. If you could provide some more detail on how z is passed in, then I might be able to improve the usage of v2. Seems related, I remember running into something similar where I had to "contextualize" the SDE based on a representation produced by GRUs back doing latent SDEs.

Incidentally I'm also concerned (but have not tested) that multiplying the batch dimension will put a peak in our memory usage.

This could be true when the Brownian motion dimension is large. Though, if we use adjoints, then the issue might not be as prominent if we could fit models without this term with backprop through solver.

from torchsde.

patrick-kidger avatar patrick-kidger commented on June 27, 2024

The context is that z is some additional static (not time evolving) information that is passed as additional information to the drift and diffusion.

The way I'm doing this is a bit ugly:

classs SDE(torch.nn.Module):
    sde_type = ...
    noise_type = ...

    def set_data(self, z):
        self._z = z

     def f(self, t, y):
        # use both y and _z

    ...

def somefunction(sde: SDE):
    sde.set_data(z)
    torchsde.sdeint(sde, ...)

I'm aware that z could be included in the state with zero drift/diffusion but that's even uglier IMO. (+inefficient)

Thinking about it, we could perhaps include an additional argument to sdeint, sdeint_adjoint corresponding to such static information? This would neaten the above code a lot. (And allow for v2 if we do want it over v1.)
Additionally, the above code can't reset z after calling sdeint because it still needs to be there for the backward pass; if we instead capture it as an argument then that's another wart removed.

Obviously that is departing a little further from our basic duties of solving an SDE, but I'd be happy to offer a PR on that if you're interested.

from torchsde.

lxuechen avatar lxuechen commented on June 27, 2024

Thinking about it, we could perhaps include an additional argument to sdeint, sdeint_adjoint corresponding to such static information? This would neaten the above code a lot. (And allow for v2 if we do want it over v1.)

Now that I'm starting to remember the hairy issues with latent SDE contextualization, this really makes sense. Consider especially when using adjoints, the example you presented poses an additional challenge: The grads w.r.t. z won't be recorded at all. Back in the days, I hacked the solver to make this work.

Off the top of my head, a potential modification to fix this would be to allow sdeint and sdeint_adjoint to take in additional_ys and additional_params. More explicitly, something like

sde = ... 
additional_ys = ...
additional_params = ...
ys = sdeint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)
ys_from_adjoint = sdeint_adjoint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)

The only thing that I'm feeling not too certain about is the format of additional_ys. Having it be a tuple of tensors of size (batch_size, d') makes sense. Though, it would be more useful if it could take in tensors of size (T, batch_size, d') (or (T - 1, batch_size, d)).

from torchsde.

patrick-kidger avatar patrick-kidger commented on June 27, 2024

You're thinking that additional_ys represents this additional static state, and whilst we're at it we could add additional_params to augment SDE.parameters() for the adjoint?

If so I'd note that additional_params would only be needed in the adjoint case. We could follow torchdiffeq for consistency on this - there we called it adjoint_params, and if passed then it is used instead of the parameters of the vector field, rather than as well.

On the format of additional_ys: I'm quite keen to avoid explicitly encoding a single batch dimension.
I'd suggest essentially following what autograd.Function does on this: accept a tuple of Python objects; and if they're gradient-requiring tensors then compute gradients wrt them. Allow tensors to be of any shape.
This does mean that we can't really use v2, as we don't expect to have access to a batch dimension, but I think this kind of batch dimension hacking is quite fragile to the variety of things a user can throw at it anyway.

For speeding up v1, there is this: pytorch/pytorch#42368 which mentions the possibility of a torch.vmap, in particular with a view to batch-vjps. I don't know the state of it but it might be interesting to us.

from torchsde.

patrick-kidger avatar patrick-kidger commented on June 27, 2024

Actually thinking about - with the above proposal we wouldn't need an adjoint_params. Whatever extra tensors that we need to compute gradients with can just be included in additional_ys and ignored in the drift/diffusion.

from torchsde.

lxuechen avatar lxuechen commented on June 27, 2024

Taking a step back, I think having sdeint take in additional_ys is likely going to overcomplicate the solver code. I'm not too inclined to do this at the moment.

I do feel a need to support back-propagating gradients backward towards non-parameters nodes with adjoints. I am fully aware of adjoint_params of torchdiffeq, and I can send in a PR on this.

from torchsde.

lxuechen avatar lxuechen commented on June 27, 2024

Re: torch.vmap

I'm not entirely sure this will make our lives easier. Given that there's not much documentation on what's going on there, much of this discussion seems rather like speculation in my opinion.

from torchsde.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.