Git Product home page Git Product logo

stan4bart's Introduction

This package is an implementation of a C++ sampler that uses BART for non-parametric mean components and Stan for multilevel/parametric ones.

Installation

  1. Install the developer tools for your platform (Mac OS, Windows). Mac OS users will need the (linked) gfortan for their respective platforms.
    • If on an ARM processor, install gfortran from here
    • If on an Intel processor, install gfortran from here
  2. Execute:
if (length(find.package("remotes", quiet = TRUE)) == 0L)
  install.packages("remotes")
remotes::install_github("vdorie/dbarts")
remotes::install_github("vdorie/stan4bart")

Use

The package utilizes the flexible, expressive lme4 syntax for specifying group-level structures. See the package documentation ?stan4bart and ?stan4bart::stan4bart-generics for more information.

Formulas

  • y ~ bart(x_nuisance_1 + x_nuisance_1) + x_inference - wrap the formula for any variables that you want to be fit non-parametrically in a call to bart(); +s within bart() are interpretted figuratively as indicating the names of the variables to be included
  • y ~ bart(x_fixef) + (1 + x_ranef | g) - supports all lme4 varying intercept and slope formula constructs
  • y ~ x_1 + (1 + x_2 | g) - if there's no BART component, use rstanarm or lme4 instead

Main Event

The main function is stan4bart.

Results

Results are retrieved using the extract, fitted, and predict generics. See ?"stan4bart-generics" for more information.

Known Issues

The name and definition of extract conflict with rstan. The rstan package is not needed to use stan4bart and does not need to be loaded. If a name-collision occurs, the stan4bart extract can be referenced as in:

stan4bart:::extract.stan4bartFit(stan4bart_fit)

Example Code

library(stan4bart)

# Load a test-data function
source(system.file("common", "friedmanData.R", package = "stan4bart"), local = TRUE)

# Relatively low n for illustrative purposes
testData <- generateFriedmanData(n = 100, ranef = TRUE, causal = TRUE, binary = FALSE)

# First level model is:
#   y ~ f(x_1, x_2) + a * x_3^2 + b * x_4 + c * x_5 + z
# Random intercepts are added for g.1 and g.2, and a random slope is placed on x4
# x_6 through x_10 are pure noise
df <- with(testData, data.frame(x, g.1, g.2, y, z))

# Causal inference example
fit <- stan4bart(y ~ bart(. - g.1 - g.2 - X4 - z) + X4 + z + (1 + X4 | g.1) + (1 | g.2), df,
                 treatment = z,
                 cores = 1, seed = 0,
                 verbose = 1)

samples.mu.train <- extract(fit)
samples.mu.test  <- extract(fit, sample = "test")

# Individual conditional treatment effects
samples.icate <- (samples.mu.train - samples.mu.test) * (2 * testData$z - 1)
# Conditional average treatment effect
samples.cate <- apply(samples.icate, 2, mean)
cate <- mean(samples.cate)
cate.int <- c(cate - 1.96 * sd(samples.cate), cate + 1.96 * sd(samples.cate))

# Samples of the posterior predictive distribution are used in calculating
# the counterfactuals for SATE and for calculating the response under
# the observed treatment condition when estimating PATE.
samples.ppd.test <- extract(fit, type = "ppd", sample = "test")

# Individual sample treatment effects
samples.ite <- (testData$y - samples.ppd.test) * (2 * testData$z - 1)
# Sample average treatment effect
samples.sate <- apply(samples.ite, 2, mean)
sate <- mean(samples.sate)
sate.int <- c(sate - 1.96 * sd(samples.sate), sate + 1.96 * sd(samples.sate))

# Population average treatment effect
samples.ppd.train <- extract(fit, type = "ppd", sample = "train")
samples.pate <- apply((samples.ppd.train - samples.ppd.test) * (2 * testData$z - 1), 2, mean)
pate <- mean(samples.pate)
pate.int <- c(pate - 1.96 * sd(samples.pate), pate + 1.96 * sd(samples.pate))


fitted.mu.train <- fitted(fit)
# equal to: apply(samples.mu.train, 1, mean)
fitted.mu.test  <- fitted(fit, sample = "test")
# equal to: apply(samples.mu.test,  1, mean)

# Observed and conterfactual MSE
mse.train <- with(testData, mean((fitted.mu.train - mu.1 * z - mu.0 * (1 - z))^2))
mse.test  <- with(testData, mean((fitted.mu.test  - mu.1 * (1 - z) - mu.0 * z)^2))

stan4bart's People

Contributors

bgoodri avatar vdorie avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.