Comments (7)
The example seems interesting. I will make v1
a default for now. If you could provide some more detail on how z
is passed in, then I might be able to improve the usage of v2
. Seems related, I remember running into something similar where I had to "contextualize" the SDE based on a representation produced by GRUs back doing latent SDEs.
Incidentally I'm also concerned (but have not tested) that multiplying the batch dimension will put a peak in our memory usage.
This could be true when the Brownian motion dimension is large. Though, if we use adjoints, then the issue might not be as prominent if we could fit models without this term with backprop through solver.
from torchsde.
The context is that z
is some additional static (not time evolving) information that is passed as additional information to the drift and diffusion.
The way I'm doing this is a bit ugly:
classs SDE(torch.nn.Module):
sde_type = ...
noise_type = ...
def set_data(self, z):
self._z = z
def f(self, t, y):
# use both y and _z
...
def somefunction(sde: SDE):
sde.set_data(z)
torchsde.sdeint(sde, ...)
I'm aware that z
could be included in the state with zero drift/diffusion but that's even uglier IMO. (+inefficient)
Thinking about it, we could perhaps include an additional argument to sdeint
, sdeint_adjoint
corresponding to such static information? This would neaten the above code a lot. (And allow for v2
if we do want it over v1
.)
Additionally, the above code can't reset z
after calling sdeint
because it still needs to be there for the backward pass; if we instead capture it as an argument then that's another wart removed.
Obviously that is departing a little further from our basic duties of solving an SDE, but I'd be happy to offer a PR on that if you're interested.
from torchsde.
Thinking about it, we could perhaps include an additional argument to sdeint, sdeint_adjoint corresponding to such static information? This would neaten the above code a lot. (And allow for v2 if we do want it over v1.)
Now that I'm starting to remember the hairy issues with latent SDE contextualization, this really makes sense. Consider especially when using adjoints, the example you presented poses an additional challenge: The grads w.r.t. z
won't be recorded at all. Back in the days, I hacked the solver to make this work.
Off the top of my head, a potential modification to fix this would be to allow sdeint
and sdeint_adjoint
to take in additional_ys
and additional_params
. More explicitly, something like
sde = ...
additional_ys = ...
additional_params = ...
ys = sdeint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)
ys_from_adjoint = sdeint_adjoint(sde, y0, ts, bm, additional_y=additional_y, additional_params=additional_params)
The only thing that I'm feeling not too certain about is the format of additional_ys
. Having it be a tuple of tensors of size (batch_size, d')
makes sense. Though, it would be more useful if it could take in tensors of size (T, batch_size, d')
(or (T - 1, batch_size, d)
).
from torchsde.
You're thinking that additional_ys
represents this additional static state, and whilst we're at it we could add additional_params
to augment SDE.parameters()
for the adjoint?
If so I'd note that additional_params
would only be needed in the adjoint case. We could follow torchdiffeq
for consistency on this - there we called it adjoint_params
, and if passed then it is used instead of the parameters of the vector field, rather than as well.
On the format of additional_ys
: I'm quite keen to avoid explicitly encoding a single batch dimension.
I'd suggest essentially following what autograd.Function
does on this: accept a tuple of Python objects; and if they're gradient-requiring tensors then compute gradients wrt them. Allow tensors to be of any shape.
This does mean that we can't really use v2
, as we don't expect to have access to a batch dimension, but I think this kind of batch dimension hacking is quite fragile to the variety of things a user can throw at it anyway.
For speeding up v1
, there is this: pytorch/pytorch#42368 which mentions the possibility of a torch.vmap
, in particular with a view to batch-vjps. I don't know the state of it but it might be interesting to us.
from torchsde.
Actually thinking about - with the above proposal we wouldn't need an adjoint_params
. Whatever extra tensors that we need to compute gradients with can just be included in additional_ys
and ignored in the drift/diffusion.
from torchsde.
Taking a step back, I think having sdeint
take in additional_ys
is likely going to overcomplicate the solver code. I'm not too inclined to do this at the moment.
I do feel a need to support back-propagating gradients backward towards non-parameters nodes with adjoints. I am fully aware of adjoint_params
of torchdiffeq, and I can send in a PR on this.
from torchsde.
Re: torch.vmap
I'm not entirely sure this will make our lives easier. Given that there's not much documentation on what's going on there, much of this discussion seems rather like speculation in my opinion.
from torchsde.
Related Issues (20)
- Defining noise type for vector-valued homogeneous SDE HOT 1
- vector-valued SDE cumbersome workflow HOT 1
- Zero drift and zero diffusion matrices lead to non-zero changes of variable
- Irregular data and sampling posterior in latent_sde_lorenz.py
- Low CPU and GPU usage in training sde_gan, Seeking Help to Improve Performance. HOT 2
- Different `t` for data in a minibatch HOT 5
- torchsde pypi package is misformated HOT 26
- Deprecation torchsde version HOT 5
- Something went wrong Expecting value: line 1 column 1 (char 0) HOT 1
- Incorporating real stock time series data
- SDE-LSTM structure for time series forecasting
- Performance Improvement Inquiry: Experiencing Slow Execution with TorchSDE HOT 2
- Latent SDE failed to generate longer samples
- issue with my text to image ai Device type privateuseone is not supported for torch.Generator() api. HOT 1
- 我的 Mac上 只有torchsde-0.2.6.dist-info,怎样才能安装insightface
- 如果 torchsde 当前没有解决这个问题的新版本,你可以联系该项目的维护者或作者
- sdeint_adjoint for f_and_g_prod type SDE, how to set parameter?
- failure to install pip install qiskit-aer to run qsvm on ibm simulator
- learning the generative model of periodical process HOT 1
- extra_solver_state Documentation HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torchsde.