Comments (3)
you shoudl use ot.emd2 that returns the ot loss (no need to sum) a value hith proper grads. ot.emd function returns an OT plan that is indeed detached from teh uput since the exact OT plan is not differentiable. Could you please try that and tell us if you have the same error ?
from pot.
I tried with ot.emd2. Same issue.
Like I said, the issue isn't with grad. It also applies to jit, vmap, ... any of jax's code transformation fns.
In the ot backend, numpy is being used where it should be jax.numpy!?
from pot.
jit anf vmap will NOT work for exact ot solver: they use specific C++ solvers and the backend format does not allow us to handle that properly with jax. Grad should work and should be tested. We will look into that after some pressing deadlines. I you have some ideas please help us while we provide a jax backend we are mainly pytorch users and no jax experts.
from pot.
Related Issues (20)
- `gromov_wasserstein` returning the zero array as the optimal coupling HOT 2
- `sinkhorn_lpl1_mm` performs unnecessary computations
- Find correspondences between a set of 3D gaussian distributions (quadrics or ellipsoids) and a set of 2d gaussian distributions (conics or ellipses) HOT 1
- Parallelization problem for 3D tensor HOT 1
- Modernize/Refactor the network_simplex method HOT 2
- Remove jax version constraint when pymanopt use new jax config format
- ot.solve uses GPU even though tensors are on CPU? HOT 3
- How to conduct scenario reduction with this project?
- Too low tolerance in `test_solve_sample_methods`
- Upstream Licencing conflict with GPL: POT and CVXOPT HOT 3
- Issue with Importing POT Library in Apple Silicon Environment with TensorFlow HOT 1
- Numpy 2.0 compatibility HOT 11
- Questions for ot.barycenter HOT 1
- Deprecate distutils in favor of setuptools
- Add task runners for tests and extend contributing docs
- [Bug] Linesearch hidden in scipy 1.14 HOT 1
- Incompatibility with numpy 2.0 HOT 2
- Mean computed without weights in empirical gaussain OT
- UnbalancedSinkhorn Transport fails to transform due to "nx.array_equal"
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pot.