Git Product home page Git Product logo

neuralpde.jl's Introduction

NeuralPDE

Join the chat at https://gitter.im/JuliaDiffEq/Lobby Build Status Build status codecov.io Stable Dev

NeuralPDE.jl is a solver package which consists of neural network solvers for partial differential equations using scientific machine learning (SciML) techniques such as physics-informed neural networks (PINNs) and deep BSDE solvers. This package utilizes deep neural networks and neural stochastic differential equations to solve high-dimensional PDEs at a greatly reduced cost and greatly increased generality compared with classical methods.

Installation

Assuming that you already have Julia correctly installed, it suffices to install NeuralPDE.jl in the standard way, that is, by typing ] add NeuralPDE. Note: to exit the Pkg REPL-mode, just press Backspace or Ctrl + C.

Tutorials and Documentation

For information on using the package, see the stable documentation. Use the in-development documentation for the version of the documentation, which contains the unreleased features.

Features

  • Physics-Informed Neural Networks for automated PDE solving.
  • Forward-Backwards Stochastic Differential Equation (FBSDE) methods for parabolic PDEs.
  • Deep-learning-based solvers for optimal stopping time and Kolmogorov backwards equations.

Example: Solving 2D Poisson Equation via Physics-Informed Neural Networks

using NeuralPDE, Flux, ModelingToolkit, GalacticOptim, DiffEqFlux
using Quadrature, Cubature
import ModelingToolkit: Interval, infimum, supremum

@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2

# 2D PDE
eq  = Dxx(u(x,y)) + Dyy(u(x,y)) ~ -sin(pi*x)*sin(pi*y)

# Boundary conditions
bcs = [u(0,y) ~ 0.0, u(1,y) ~ -sin(pi*1)*sin(pi*y),
       u(x,0) ~ 0.0, u(x,1) ~ -sin(pi*x)*sin(pi*1)]
# Space and time domains
domains = [x  Interval(0.0,1.0),
           y  Interval(0.0,1.0)]
# Discretization
dx = 0.1

# Neural network
dim = 2 # number of dimensions
chain = FastChain(FastDense(dim,16,Flux.σ),FastDense(16,16,Flux.σ),FastDense(16,1))

# Initial parameters of Neural network
initθ = Float64.(DiffEqFlux.initial_params(chain))

discretization = PhysicsInformedNN(chain, QuadratureTraining(),init_params =initθ)

@named pde_system = PDESystem(eq,bcs,domains,[x,y],[u(x, y)])
prob = discretize(pde_system,discretization)

cb = function (p,l)
    println("Current loss is: $l")
    return false
end

res = GalacticOptim.solve(prob, ADAM(0.1); cb = cb, maxiters=4000)
prob = remake(prob,u0=res.minimizer)
res = GalacticOptim.solve(prob, ADAM(0.01); cb = cb, maxiters=2000)
phi = discretization.phi

And some analysis:

xs,ys = [infimum(d.domain):dx/10:supremum(d.domain) for d in domains]
analytic_sol_func(x,y) = (sin(pi*x)*sin(pi*y))/(2pi^2)

u_predict = reshape([first(phi([x,y],res.minimizer)) for x in xs for y in ys],(length(xs),length(ys)))
u_real = reshape([analytic_sol_func(x,y) for x in xs for y in ys], (length(xs),length(ys)))
diff_u = abs.(u_predict .- u_real)

using Plots
p1 = plot(xs, ys, u_real, linetype=:contourf,title = "analytic");
p2 = plot(xs, ys, u_predict, linetype=:contourf,title = "predict");
p3 = plot(xs, ys, diff_u,linetype=:contourf,title = "error");
plot(p1,p2,p3)

image

Example: Solving a 100-Dimensional Hamilton-Jacobi-Bellman Equation

using NeuralPDE
using Flux
using DifferentialEquations
using LinearAlgebra
d = 100 # number of dimensions
X0 = fill(0.0f0, d) # initial value of stochastic control process
tspan = (0.0f0, 1.0f0)
λ = 1.0f0

g(X) = log(0.5f0 + 0.5f0 * sum(X.^2))
f(X,u,σᵀ∇u,p,t) = -λ * sum(σᵀ∇u.^2)
μ_f(X,p,t) = zero(X)  # Vector d x 1 λ
σ_f(X,p,t) = Diagonal(sqrt(2.0f0) * ones(Float32, d)) # Matrix d x d
prob = TerminalPDEProblem(g, f, μ_f, σ_f, X0, tspan)
hls = 10 + d # hidden layer size
opt = Flux.ADAM(0.01)  # optimizer
# sub-neural network approximating solutions at the desired point
u0 = Flux.Chain(Dense(d, hls, relu),
                Dense(hls, hls, relu),
                Dense(hls, 1))
# sub-neural network approximating the spatial gradients at time point
σᵀ∇u = Flux.Chain(Dense(d + 1, hls, relu),
                  Dense(hls, hls, relu),
                  Dense(hls, hls, relu),
                  Dense(hls, d))
pdealg = NNPDENS(u0, σᵀ∇u, opt=opt)
@time ans = solve(prob, pdealg, verbose=true, maxiters=100, trajectories=100,
                            alg=EM(), dt=1.2, pabstol=1f-2)

Citation

If you use NeuralPDE.jl in your research, please cite this paper:

@article{zubov2021neuralpde,
  title={NeuralPDE: Automating Physics-Informed Neural Networks (PINNs) with Error Approximations},
  author={Zubov, Kirill and McCarthy, Zoe and Ma, Yingbo and Calisto, Francesco and Pagliarino, Valerio and Azeglio, Simone and Bottero, Luca and Luj{\'a}n, Emmanuel and Sulzer, Valentin and Bharambe, Ashutosh and others},
  journal={arXiv preprint arXiv:2107.09443},
  year={2021}
}

neuralpde.jl's People

Contributors

kirillzubov avatar chrisrackauckas avatar ashutosh-b-b avatar killah-t-cell avatar github-actions[bot] avatar zzj0402 avatar akaysh avatar rohitrathore1 avatar christopher-dg avatar kanav99 avatar mkg33 avatar yingboma avatar paniash avatar shashi avatar vaibhavdixit02 avatar maleadt avatar vavrines avatar scottpjones avatar neuralpde avatar navidcy avatar matthieugomez avatar mdmurbach avatar juliatagbot avatar asinghvi17 avatar anandijain avatar akashkgarg avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.