Comments (7)
Thanks a lot for the report! Could you provide a short script for reproducing the issue?
from jaxopt.
I made many attempts at creating a small script to reproduce the problem and unfortunately failed. I apologize in advanced that this is a mess to look at but the full script is the only place I was able to get the issue to happen (leading me to believe that it is likely something specific about my model but I am not very adept at diagnosing the compilations that occur).
import numpy as np # 1.21.2
import jax.numpy as jnp # jax version = 0.2.19
from jax.experimental import ode
from jax import jit
from jax.config import config
config.update("jax_enable_x64", True)
#############################################################################
'''
All of the following is establishing constants and equations used throughout
'''
#############################################################################
sigma = 5.670374419E-8 # Stefan-Boltzman constant in W m^-2 K^-4
G = 6.67408E-11 #Gravitational constant in m^3 kg^-1 s^-2
#solar constants
Rsun = 696.34E6 #meters
Msun = 1.9891E30 #kg
Lsun = 3.846E26 #Watts
h = 6.62606957e-34 #Planck constant in m^2 kg/s
hbar = 1.05457173e-34 # h/2pi
kb = 1.38064852E-23 #Boltzman constant in m^2 kg / (s^2 K)
me = 9.109E-31 #mass of electron in kg
mu = 1.66053904E-27 #atomic mass unit in kg
mp = 1.67262178e-27 # mass of proton
ep0 = 8.85418782E-12 #vaccum permittivity in kg^-1 m^-3 s^4 A^2
e=1.60217662E-19 # charge of the electron in coulombs
ev_K=1.16045221E4 # conversion from eV to Kelvin (temperature)
c=299792458 #speed of light
keV_to_K=11604525.006165707 #conversion from keV to Kelvin (temperature)
ecgs=4.80320425E-10 #electronic charge in cgs quantities
keV_to_J=1./6.242e+15 #Conversion from keV to Joules
NA=6.02214076E23 #Avagodaro's number
Amol_coff=mu*NA #Coefficient to describe Abar in moles kg/mol
g_ff = 1 #gaunt factor
a_rad = (4.0*sigma)/c # radiation constant
class Constants:
pass
cst = Constants()
cst.sigma = sigma
cst.G = G
cst.Rs = Rsun
cst.Ms = Msun
cst.Ls = Lsun
cst.h = h
cst.kb = kb
cst.me = me
cst.mu = mu
cst.mp = mp
cst.ep0 = ep0
cst.e = e
cst.ev_K = ev_K
cst.c = c
cst.keV_to_K = keV_to_K
cst.ecgs = ecgs
cst.keV_to_J = keV_to_J
cst.NA = NA
cst.Amol_coff = Amol_coff
cst.g_ff = g_ff
cst.a_rad = a_rad
cst.hbar = hbar
r0 = 10
rint = jnp.arange(r0,1.5*cst.Rs,1E4)
comp = (0.73,0.25,.02)
mu = (2*comp[0] + 0.75*comp[1] + 0.5*comp[2])**-1
gam = 5./3.
def Kappa_ff(density,temp,comp):
X,Y,Z = comp
return (1.0e24)*(1.0+X)*(Z + 0.0001)*((density/1.0e3)**0.7)*(temp**(-3.5))
def Kappa_Hminus(density,temp,comp):
X,Y,Z = comp
return (2.5e-32)*(Z / 0.02)*((density/1.0e3)**0.5)*((temp)**9.0)
def Kappa_bf(density,temp,comp):
rho_cgs = density/1000
X,Y,Z = comp
tog_bf = 2.82*(rho_cgs*(1.0e0 + X))**0.2
return 4.34E25/tog_bf*Z*(1.0e0 + X)*rho_cgs/temp**3.5
def Kappa_e(comp):
X,Y,Z = comp
return 0.2*(1.0 + X)
def Kappa(density,temp,comp):
kap = (Kappa_Hminus(density,temp,comp)**-1 + Kappa_e(comp)**-1 +
Kappa_ff(density,temp,comp)**-1 + Kappa_bf(density,temp,comp)**-1)**-1.0
return kap
def epsilon_PP(density,temp,comp):
X,Y,Z = comp
return 1.07e-7*(density/1.0e5)*(X**2.0)*((temp/1.0e6)**4.0)
def epsilon_CNO(density,temp,comp):
X,Y,Z = comp
return 8.24e-26*(density/1.0e5)*0.03*(X**2.0)*((temp / (1.0e6))**19.9)
def epsilon(density,temp,comp):
return epsilon_PP(density,temp,comp) + epsilon_CNO(density,temp,comp)
def P_degen(density,cst):
return (((3.0*(np.pi)**2.0)**(2.0/3.0)*(cst.hbar**2.0)*
(density/cst.mp)**(5.0/3.0))/(5.0*cst.me))
def P_ideal(density,temp,cst,A=1):
return (density*cst.kb*temp)/(A*cst.mp)
def P_idealplus(density,temp,cst,A=1):
cof = (1-np.exp(-(density/1E5)**2))
return (1+cof) * (density*cst.kb*temp)/(A*cst.mp)
def P_rad(temp,cst):
return (cst.a_rad*(temp)**4.0)/(3.0)
def P(rho,T,cst,Abar):
return P_degen(rho,cst) + P_ideal(rho,T,cst,Abar) + P_rad(T,cst)
def dP_drho_degen(density,cst):
return (((3.0*(np.pi)**2.0)**(2.0/3.0)*(cst.hbar**2.0)*
(density/cst.mp)**(2.0/3.0))/(3.0*cst.me*cst.mp))
def dP_drho_ideal(temp,cst,A=1):
return (cst.kb*temp)/(A*cst.mp)
def dP_drho(density,temp,cst,Abar):
return dP_drho_degen(density,cst) + dP_drho_ideal(temp,cst,Abar)
def dP_dT_ideal(density,cst,A=1):
return (density*cst.kb)/(A*cst.mp)
def dP_dT_rad(temp,cst):
return (4.0*cst.a_rad*temp**3.0)/(3.0)
def dP_dT(density,temp,cst,Abar):
return dP_dT_rad(temp,cst) + dP_dT_ideal(density,cst,Abar)
def dT_dr_radiative(mass, density, radius, temp, luminosity, cst,comp):
T_radiative = ((3.0*Kappa(density,temp,comp)*density*luminosity)/
(16.0*np.pi*4.0*cst.sigma*temp**3.0*radius**2.0))
return -T_radiative
def dT_dr_convective(mass, density, radius, temp, luminosity, gamma, cst, Abar):
pressure = P(density, temp,cst, Abar)
T_convective = ((1.0 - (1.0/gamma))*(temp/pressure)*
((cst.G*mass*density)/(radius**2.0)))
return -T_convective
def dT_dr(mass, density, radius, temp, luminosity, gamma, cst, Abar,comp):
return -jnp.min(jnp.abs(jnp.array([dT_dr_radiative(mass,density,
radius,temp,
luminosity,cst,
comp),
dT_dr_convective(mass,density,
radius,temp,
luminosity,gamma,
cst,Abar)])))
def drho_dr(mass, density, radius, temp, luminosity,gam,cst,Abar,comp):
a = (cst.G*mass*density)/(radius**2.0)
b = (dP_dT(density,temp,cst,Abar)*
dT_dr(mass,density,radius,temp,luminosity,gam,cst,Abar,comp))
c = dP_drho(density,temp,cst,Abar)
return -(a+b)/c
def dM_dr(density, radius):
return 4.0 * np.pi * (radius**2.0) * density
def dL_dr(density,radius,temp,comp):
return 4.0 * np.pi * (radius**2.0) * density * epsilon(density,temp,comp)
def dtau_dr(density,temp,comp):
return Kappa(density,temp,comp) * density
#############################################################################
'''
The following are the functions used in the optimizer
'''
#############################################################################
@jit
def dy_dr_jax(y,r):
'''
Vector of coupled first order ODEs to solve for the stellar model.
Parameters
----------
y : array
Values for y where derivative is calculated.
r : float
Value of radius (independent variable) at which to evaluate the
derivatives.
Returns
-------
dy_dr : array
Values of the derivatives for each of the input variables.
'''
rho,T,M,L,tau = y
frho = drho_dr(M,rho,r,T,L,5./3.,cst,mu,comp)
fT = dT_dr(M,rho,r,T,L,5./3.,cst,mu,comp)
fM = dM_dr(rho,r)
fL = dL_dr(rho,r,T,comp)
ftau = dtau_dr(rho,T,comp)
return jnp.array([frho, fT, fM, fL, ftau])
T0 = 10.5E6
@jit
def L_check_jax(rho0in):
M0 = 4./3. * np.pi * rho0in * r0**3
L0 = epsilon(rho0in,T0,comp)*M0 #J/s
tau0 = Kappa(rho0in,T0,comp)*rho0in
y0 = [rho0in,T0,M0,L0,tau0]
rho,T,M,L,tau = ode.odeint(dy_dr_jax,y0,rint,rtol=1E-15,atol=1E-15)
T = jnp.nan_to_num(T,nan = 1E5)
i_surf = jnp.argmin(abs(jnp.diff(T)))
LT = (4.0*jnp.pi*cst.sigma*(rint[i_surf])**2.0*(T[i_surf])**4.0)
return (L[i_surf] - LT)/jnp.sqrt(L[i_surf]*LT)
from scipy.optimize import bisect # 1.7.1
from jaxopt import Bisection # latest pip install version
print(bisect(L_check_jax,10000,1000000))
print(Bisection(L_check_jax,10000,1000000).run(None))
The top portion of the code just establishes a bunch of constants and algebra equations. The crux of the model are the 2 functions dy_dr_jax
and L_check_jax
. The function being optimized calls the ode solver, integrates the differential equations (dy_dr_jax) and then checks for some condition based on the output. I think everything is jaxfriendly within the functions, they are all jitable, and the scipy bisection method runs fine and gives a correct result. The jaxpot.Bisection class gets built fine i.e.
Bisection(L_check_jax,10000,1000000)
>>> Bisection(optimality_fun=<CompiledFunction object at 0x1910daa00>, lower=10000, upper=1000000, increasing=True, maxiter=30, tol=1e-05, check_bracket=True, implicit_diff=True, verbose=False)
but Bisection(L_check_jax,10000,1000000).run(None)
hangs and does not produce any errors.
from jaxopt.
I would guess we should terminate if high - low
is less than some tolerance, in addition (or instead) of checking the error? e.g., like in
https://github.com/google-research/neural-structural-optimization/blob/1c11b8c6ef50274802a84cf1a244735c3ed9394d/neural_structural_optimization/autograd_lib.py#L213
Otherwise the loop can fail to terminate due to issues with floating point precision.
from jaxopt.
Quick remark before I look into the problem more closely, Bisection
has a maxiter
option (defaulting to 30) so normally it should never hang even in case of numerical issue https://github.com/google/jaxopt/blob/main/jaxopt/_src/bisection.py#L163
from jaxopt.
The issue was that your L_check_jax
function doesn't like integer inputs. So L_check_jax(90000)
hangs while L_check_jax(90000.0)
works. This means that SciPy was converting the bracket interval from integer to floats under the hood. I did the same together with some other improvements in #19. In particular, you no longer need to pass None
to run
.
So after #19 is merged, you should be able to do print(Bisection(L_check_jax,10000,1000000).run().params)
.
In the meantime, you need to do print(Bisection(L_check_jax,10000.0,1000000.0).run(None).params)
.
from jaxopt.
@mblondel Oh wow, that's a silly mistake on my part (which I should've caught). Thanks for looking into this!
from jaxopt.
Fixed in #19.
from jaxopt.
Related Issues (20)
- Type precision issue in BoxOSQP HOT 10
- Garbage collection issues HOT 1
- Wrong failure diagnostic print outs from `ZoomLineSearch` under `vmap` HOT 3
- Attempted boolean conversion of traced array - for hager-zhang HOT 2
- Expression tree API like CVXPY HOT 3
- Unnecessary recompilation of _while_loop_lax HOT 8
- Add type annotations
- Consider switching to pyproject.toml
- OSQP crashing on unexpected params HOT 3
- `verbose=False` is not working as expected for `NonlinearCG` HOT 1
- Stochastic L-BFGS algorithm implementation
- Stopping condition 'madsen-nielsen' incorrect
- unit test failures on aarch64 linux with scipy 1.12
- `LevenbergMarquardt` implementation does not accept PyTree parameters
- diag(JTJ) can be more efficient
- JAXOPT Projected Gradient
- "invalid escape sequence" warning in `BoxOSQP` docstring
- Error when taking gradient wrt parameters in BoxOSQP
- Disable warnings in vmap or print to stderr HOT 2
- Constraint violation causes L-BGFS-B to fail HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jaxopt.