Git Product home page Git Product logo

diffeqbayes.jl's Introduction

DiffEqBayes.jl

Build Status Coverage Status codecov.io

This repository is a set of extension functionality for estimating the parameters of differential equations using Bayesian methods. It allows the choice of using [CmdStan.jl]((https://github.com/StanJulia/CmdStan.jl), Turing.jl, DynamicHMC.jl and ApproxBayes.jl to perform a Bayesian estimation of a differential equation problem specified via the DifferentialEquations.jl interface.

To begin you first need to add this repository using the following command.

Pkg.add("DiffEqBayes")
using DiffEqBayes

stan_inference

stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing;
               alg=:rk45, num_samples=1000, num_warmup=1000, 
               reltol=1e-3, abstol=1e-6, maxiter=Int(1e5),likelihood=Normal,
               vars=(StanODEData(),InverseGamma(3,3)),nchains=1, sample_u0 = false, 
               save_idxs = nothing, diffeq_string = nothing, printsummary = true)

stan_inference uses CmdStan.jl to perform the Bayesian inference. The Stan installation process is required to use this function. The first argument is a DEProblem, t is the array of time and data is the array where the first dimension (columns) corresponds to the array of system values. priors is an array of prior distributions for each parameter, specified via a Distributions.jl type. alg is a choice between :rk45 and :bdf, the two internal integrators of Stan. num_samples is the number of samples to take per chain, and num_warmup is the number of MCMC warmup steps. abstol and reltol are the keyword arguments for the internal integrator. likelihood is the likelihood distribution to use with the arguments from vars, and vars is a tuple of priors for the distributions of the likelihood hyperparameters. The special value StanODEData() in this tuple denotes the position that the ODE solution takes in the likelihood's parameter list. With the diffeq_string kwarg you can pass in a complex ODE specification if the need arises.

turing_inference

turing_inference(prob::DiffEqBase.DEProblem,alg,t,data,priors;
                    likelihood_dist_priors = [InverseGamma(2, 3)], 
                    likelihood = (u,p,t,σ) -> MvNormal(u, σ[1]*ones(length(u))),
                    num_samples=1000, sampler = Turing.NUTS(0.65),
                    syms = [Turing.@varname(theta[i]) for i in 1:length(priors)],
                    sample_u0 = false, save_idxs = nothing, progress = false, kwargs...)

turing_inference uses Turing.jl to perform its parameter inference. prob can be any DEProblem with a corresponding alg choice. t is the array of time points and data is the set of observations for the differential equation system at time point t[i] (or higher dimensional). priors is an array of prior distributions for each parameter, specified via a Distributions.jl type. num_samples is the number of samples per MCMC chain. The extra kwargs are given to the internal differential equation solver.

dynamichmc_inference

dynamichmc_inference(problem::DiffEqBase.DEProblem, algorithm, t, data,parameter_priors, 
                    parameter_transformations=as(Vector, asℝ₊, length(parameter_priors));
                    σ_priors = fill(Normal(0, 5), size(data, 1)),
                    rng = Random.GLOBAL_RNG, num_samples = 1000,
                    AD_gradient_kind = Val(:ForwardDiff),solve_kwargs = (), 
                    mcmc_kwargs = (initialization = (q = zeros(length(parameter_priors) + 2),),), sample_u0 = false)

dynamichmc_inference uses DynamicHMC.jl to perform the bayesian parameter estimation. prob can be any DEProblem, data is the set of observations for our model which is to be used in the Bayesian Inference process. priors represent the choice of prior distributions for the parameters to be determined, passed as an array of Distributions.jl distributions. t is the array of time points. parameter_transformations is an array of Tranformations imposed for constraining the parameter values to specific domains. rng is the random number generator used for MCMC. Defaults to the global one. num_samples is the number of MCMC draws (default: 1000) AD_gradient_kind is passed on to LogDensityProblems.ADgradient, make sure to importthe corresponding library. solve_kwargs is passed on to solve mcmc_kwargs are passed on as keyword arguments to DynamicHMC.mcmc_with_warmup

abc_inference

abc_inference(prob::DEProblem, alg, t, data, priors; ϵ=0.001,
     distancefunction = euclidean, ABCalgorithm = ABCSMC, progress = false,
     num_samples = 500, maxiterations = 10^5, kwargs...)

abc_inference uses ApproxBayes.jl which uses Approximate Bayesian Computation (ABC) to perform its parameter inference. prob can be any DEProblem with a corresponding alg choice. t is the array of time points and data[:,i] is the set of observations for the differential equation system at time point t[i] (or higher dimensional). priors is an array of prior distributions for each parameter, specified via a Distributions.jl type. num_samples is the number of posterior samples. ϵ is the target distance between the data and simulated data. distancefunction is a distance metric specified from the Distances.jl package, the default is euclidean. ABCalgorithm is the ABC algorithm to use, options are ABCSMC or ABCRejection from ApproxBayes.jl, the default is the former which is more efficient. maxiterations is the maximum number of iterations before the algorithm terminates. The extra kwargs are given to the internal differential equation solver.

Example

using ParameterizedFunctions, OrdinaryDiffEq, RecursiveArrayTools, Distributions
f1 = @ode_def LotkaVolterra begin
 dx = a*x - x*y
 dy = -3*y + x*y
end a

p = [1.5]
u0 = [1.0,1.0]
tspan = (0.0,10.0)
prob1 = ODEProblem(f1,u0,tspan,p)

σ = 0.01                         # noise, fixed for now
t = collect(1.:10.)   # observation times
sol = solve(prob1,Tsit5())
priors = [Normal(1.5, 1)]
randomized = VectorOfArray([(sol(t[i]) + σ * randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)

using CmdStan #required for using the Stan backend
bayesian_result_stan = stan_inference(prob1,t,data,priors)

bayesian_result_turing = turing_inference(prob1,Tsit5(),t,data,priors)

using DynamicHMC #required for DynamicHMC backend
bayesian_result_hmc = dynamichmc_inference(prob1, Tsit5(), t, data, priors)

bayesian_result_abc = abc_inference(prob1, Tsit5(), t, data, priors)

Using save_idxs to declare observables

You don't always have data for all of the variables of the model. In case of certain latent variables you can utilise the save_idxs kwarg to declare the oberved variables and run the inference using any of the backends as shown below.

 sol = solve(prob1,Tsit5(),save_idxs=[1])
 randomized = VectorOfArray([(sol(t[i]) + σ * randn(1)) for i in 1:length(t)])
 data = convert(Array,randomized)

 using CmdStan #required for using the Stan backend
 bayesian_result_stan = stan_inference(prob1,t,data,priors,save_idxs=[1])

 bayesian_result_turing = turing_inference(prob1,Tsit5(),t,data,priors,save_idxs=[1])
 
 using DynamicHMC #required for DynamicHMC backend
 bayesian_result_hmc = dynamichmc_inference(prob1,Tsit5(),t,data,priors,save_idxs = [1])

 bayesian_result_abc = abc_inference(prob1,Tsit5(),t,data,priors,save_idxs=[1])

diffeqbayes.jl's People

Contributors

abhigupta768 avatar asinghvi17 avatar astupidbear avatar ayush-iitkgp avatar chrisrackauckas avatar devmotion avatar github-actions[bot] avatar juliatagbot avatar mauro3 avatar mohamed82008 avatar sebastianm-c avatar vaibhavdixit02 avatar xukai92 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.