Git Product home page Git Product logo

Comments (7)

mblondel avatar mblondel commented on May 5, 2024

Thanks a lot for the report! Could you provide a short script for reproducing the issue?

from jaxopt.

jjruby09 avatar jjruby09 commented on May 5, 2024

I made many attempts at creating a small script to reproduce the problem and unfortunately failed. I apologize in advanced that this is a mess to look at but the full script is the only place I was able to get the issue to happen (leading me to believe that it is likely something specific about my model but I am not very adept at diagnosing the compilations that occur).

import numpy as np # 1.21.2
import jax.numpy as jnp # jax version = 0.2.19
from jax.experimental import ode
from jax import jit
from jax.config import config
config.update("jax_enable_x64", True)

#############################################################################
'''
All of the following is establishing constants and equations used throughout
'''
#############################################################################
sigma = 5.670374419E-8 # Stefan-Boltzman constant in W m^-2 K^-4
G = 6.67408E-11 #Gravitational constant in m^3 kg^-1 s^-2

#solar constants
Rsun = 696.34E6 #meters
Msun = 1.9891E30 #kg
Lsun = 3.846E26 #Watts

h = 6.62606957e-34 #Planck constant in  m^2 kg/s
hbar = 1.05457173e-34      # h/2pi
kb = 1.38064852E-23 #Boltzman constant in  m^2 kg / (s^2 K)
me = 9.109E-31 #mass of electron in  kg
mu = 1.66053904E-27 #atomic mass unit in kg
mp = 1.67262178e-27 # mass of proton

ep0 = 8.85418782E-12 #vaccum permittivity in kg^-1 m^-3 s^4 A^2
e=1.60217662E-19 # charge of the electron in coulombs
ev_K=1.16045221E4 # conversion from eV to Kelvin (temperature)
c=299792458 #speed of light
keV_to_K=11604525.006165707 #conversion from keV to Kelvin (temperature)
ecgs=4.80320425E-10 #electronic charge in cgs quantities

keV_to_J=1./6.242e+15 #Conversion from keV to Joules
NA=6.02214076E23 #Avagodaro's number
Amol_coff=mu*NA #Coefficient to describe Abar in moles kg/mol
g_ff = 1 #gaunt factor

a_rad = (4.0*sigma)/c     # radiation constant

class Constants:
     pass
cst = Constants()
cst.sigma = sigma
cst.G = G
cst.Rs = Rsun
cst.Ms = Msun
cst.Ls = Lsun
cst.h = h
cst.kb = kb
cst.me = me
cst.mu = mu
cst.mp = mp
cst.ep0 = ep0
cst.e = e
cst.ev_K = ev_K
cst.c = c
cst.keV_to_K = keV_to_K
cst.ecgs = ecgs
cst.keV_to_J = keV_to_J
cst.NA = NA
cst.Amol_coff = Amol_coff
cst.g_ff = g_ff
cst.a_rad = a_rad
cst.hbar = hbar

r0 = 10
rint = jnp.arange(r0,1.5*cst.Rs,1E4)

comp = (0.73,0.25,.02)
mu = (2*comp[0] + 0.75*comp[1] + 0.5*comp[2])**-1
gam = 5./3.

def Kappa_ff(density,temp,comp):
    X,Y,Z = comp
    return (1.0e24)*(1.0+X)*(Z + 0.0001)*((density/1.0e3)**0.7)*(temp**(-3.5))

def Kappa_Hminus(density,temp,comp):
    X,Y,Z = comp
    return (2.5e-32)*(Z / 0.02)*((density/1.0e3)**0.5)*((temp)**9.0)
    
def Kappa_bf(density,temp,comp):
    rho_cgs = density/1000
    X,Y,Z = comp
    tog_bf = 2.82*(rho_cgs*(1.0e0 + X))**0.2
    return 4.34E25/tog_bf*Z*(1.0e0 + X)*rho_cgs/temp**3.5

def Kappa_e(comp):
    X,Y,Z = comp
    return 0.2*(1.0 + X)

def Kappa(density,temp,comp):
    kap = (Kappa_Hminus(density,temp,comp)**-1 + Kappa_e(comp)**-1 + 
           Kappa_ff(density,temp,comp)**-1 + Kappa_bf(density,temp,comp)**-1)**-1.0
    return kap

def epsilon_PP(density,temp,comp):
    X,Y,Z = comp
    return 1.07e-7*(density/1.0e5)*(X**2.0)*((temp/1.0e6)**4.0)

def epsilon_CNO(density,temp,comp):
    X,Y,Z = comp
    return 8.24e-26*(density/1.0e5)*0.03*(X**2.0)*((temp / (1.0e6))**19.9)

def epsilon(density,temp,comp):
    return epsilon_PP(density,temp,comp) + epsilon_CNO(density,temp,comp)

def P_degen(density,cst):
    return (((3.0*(np.pi)**2.0)**(2.0/3.0)*(cst.hbar**2.0)*
             (density/cst.mp)**(5.0/3.0))/(5.0*cst.me))

def P_ideal(density,temp,cst,A=1):
    return (density*cst.kb*temp)/(A*cst.mp)

def P_idealplus(density,temp,cst,A=1):
    cof = (1-np.exp(-(density/1E5)**2))
    return (1+cof) * (density*cst.kb*temp)/(A*cst.mp)

def P_rad(temp,cst):
    return (cst.a_rad*(temp)**4.0)/(3.0)

def P(rho,T,cst,Abar):
    return P_degen(rho,cst) + P_ideal(rho,T,cst,Abar) + P_rad(T,cst)

def dP_drho_degen(density,cst):
    return (((3.0*(np.pi)**2.0)**(2.0/3.0)*(cst.hbar**2.0)*
             (density/cst.mp)**(2.0/3.0))/(3.0*cst.me*cst.mp))

def dP_drho_ideal(temp,cst,A=1):
    return (cst.kb*temp)/(A*cst.mp)

def dP_drho(density,temp,cst,Abar):
    return dP_drho_degen(density,cst) + dP_drho_ideal(temp,cst,Abar)

def dP_dT_ideal(density,cst,A=1):
    return (density*cst.kb)/(A*cst.mp)

def dP_dT_rad(temp,cst):
    return (4.0*cst.a_rad*temp**3.0)/(3.0)

def dP_dT(density,temp,cst,Abar):
    return dP_dT_rad(temp,cst) + dP_dT_ideal(density,cst,Abar)

def dT_dr_radiative(mass, density, radius, temp, luminosity, cst,comp):
    T_radiative = ((3.0*Kappa(density,temp,comp)*density*luminosity)/
                   (16.0*np.pi*4.0*cst.sigma*temp**3.0*radius**2.0))
    return -T_radiative
    
def dT_dr_convective(mass, density, radius, temp, luminosity, gamma, cst, Abar):
    pressure = P(density, temp,cst, Abar)
    T_convective = ((1.0 - (1.0/gamma))*(temp/pressure)* 
                    ((cst.G*mass*density)/(radius**2.0)))
    return -T_convective

def dT_dr(mass, density, radius, temp, luminosity, gamma, cst, Abar,comp):
    return -jnp.min(jnp.abs(jnp.array([dT_dr_radiative(mass,density,
                                                       radius,temp,
                                                       luminosity,cst,
                                                       comp),
                                       dT_dr_convective(mass,density,
                                                        radius,temp,
                                                        luminosity,gamma,
                                                        cst,Abar)])))

def drho_dr(mass, density, radius, temp, luminosity,gam,cst,Abar,comp):
    a = (cst.G*mass*density)/(radius**2.0)
    b = (dP_dT(density,temp,cst,Abar)*
         dT_dr(mass,density,radius,temp,luminosity,gam,cst,Abar,comp))
    c = dP_drho(density,temp,cst,Abar)
    return -(a+b)/c

def dM_dr(density, radius):
    return 4.0 * np.pi * (radius**2.0) * density

def dL_dr(density,radius,temp,comp):
    return 4.0 * np.pi * (radius**2.0) * density * epsilon(density,temp,comp)

def dtau_dr(density,temp,comp):
    return Kappa(density,temp,comp) * density

#############################################################################
'''
The following are the functions used in the optimizer
'''
#############################################################################
@jit
def dy_dr_jax(y,r):
    '''
    Vector of coupled first order ODEs to solve for the stellar model. 
    
    Parameters
    ----------
    y : array
        Values for y where derivative is calculated. 
        
    r : float
        Value of radius (independent variable) at which to evaluate the 
        derivatives. 
    
    Returns
    -------
    dy_dr : array
        Values of the derivatives for each of the input variables. 
        
    '''
    rho,T,M,L,tau = y
    frho = drho_dr(M,rho,r,T,L,5./3.,cst,mu,comp)
    fT = dT_dr(M,rho,r,T,L,5./3.,cst,mu,comp)
    fM = dM_dr(rho,r)
    fL = dL_dr(rho,r,T,comp)
    ftau = dtau_dr(rho,T,comp)
    return jnp.array([frho, fT, fM, fL, ftau])


T0 = 10.5E6
    
@jit
def L_check_jax(rho0in):
    M0 = 4./3. * np.pi * rho0in * r0**3
    L0 = epsilon(rho0in,T0,comp)*M0 #J/s
    tau0 = Kappa(rho0in,T0,comp)*rho0in
    y0 = [rho0in,T0,M0,L0,tau0]
    rho,T,M,L,tau = ode.odeint(dy_dr_jax,y0,rint,rtol=1E-15,atol=1E-15)
    T = jnp.nan_to_num(T,nan = 1E5)
    i_surf = jnp.argmin(abs(jnp.diff(T)))
    LT = (4.0*jnp.pi*cst.sigma*(rint[i_surf])**2.0*(T[i_surf])**4.0)
    return (L[i_surf] - LT)/jnp.sqrt(L[i_surf]*LT)



from scipy.optimize import bisect # 1.7.1
from jaxopt import Bisection # latest pip install version

print(bisect(L_check_jax,10000,1000000))
print(Bisection(L_check_jax,10000,1000000).run(None))

The top portion of the code just establishes a bunch of constants and algebra equations. The crux of the model are the 2 functions dy_dr_jax and L_check_jax. The function being optimized calls the ode solver, integrates the differential equations (dy_dr_jax) and then checks for some condition based on the output. I think everything is jaxfriendly within the functions, they are all jitable, and the scipy bisection method runs fine and gives a correct result. The jaxpot.Bisection class gets built fine i.e.

Bisection(L_check_jax,10000,1000000)
>>> Bisection(optimality_fun=<CompiledFunction object at 0x1910daa00>, lower=10000, upper=1000000, increasing=True, maxiter=30, tol=1e-05, check_bracket=True, implicit_diff=True, verbose=False)

but Bisection(L_check_jax,10000,1000000).run(None) hangs and does not produce any errors.

from jaxopt.

shoyer avatar shoyer commented on May 5, 2024

I would guess we should terminate if high - low is less than some tolerance, in addition (or instead) of checking the error? e.g., like in
https://github.com/google-research/neural-structural-optimization/blob/1c11b8c6ef50274802a84cf1a244735c3ed9394d/neural_structural_optimization/autograd_lib.py#L213

Otherwise the loop can fail to terminate due to issues with floating point precision.

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

Quick remark before I look into the problem more closely, Bisection has a maxiter option (defaulting to 30) so normally it should never hang even in case of numerical issue https://github.com/google/jaxopt/blob/main/jaxopt/_src/bisection.py#L163

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

The issue was that your L_check_jax function doesn't like integer inputs. So L_check_jax(90000) hangs while L_check_jax(90000.0) works. This means that SciPy was converting the bracket interval from integer to floats under the hood. I did the same together with some other improvements in #19. In particular, you no longer need to pass None to run.
So after #19 is merged, you should be able to do print(Bisection(L_check_jax,10000,1000000).run().params).
In the meantime, you need to do print(Bisection(L_check_jax,10000.0,1000000.0).run(None).params).

from jaxopt.

jjruby09 avatar jjruby09 commented on May 5, 2024

@mblondel Oh wow, that's a silly mistake on my part (which I should've caught). Thanks for looking into this!

from jaxopt.

mblondel avatar mblondel commented on May 5, 2024

Fixed in #19.

from jaxopt.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.