Git Product home page Git Product logo

deep-learning-for-causal-inference's Introduction

Deep Learning Models for Causal Inference (under selection on observables)

UPDATE 06/28/2023: I've updated the paper with a new, significantly streamlined draft of the accompanying review that focuses more on core ideas rather than papers since the literature is evolving so rapidly.

UPDATE 02/23/2023: To keep these relevant, I added a new tutorial with an implementation of Dragonnet in Pytorch, TMLE-style confidence intervals, and feature importance interpretation using Integrated Gradients and SHAP.

UPDATE 03/07/2022: I spruced up the tutorials by fixing them at TF 2.8, updating notation, fixing some bugs, and redoing the figures. I also added an extended version of the first tutorial for those with no DL experience!

UPDATE 10/13/2021: Most recent draft of the accompanying review "Deep Learning of Potential Outcomes" is on Arxiv . Check it out!

While there is a lot of interest in using causal inference to improve deep learning, there aren't many examples of how deep learning can be used to estimate causal effects. This repository contains extensive tutorials for building deep learning models to do causal estimation under selection on observables.

I tried to write the tutorials at a very high level so that anybody with a basic understanding of causal inference and machine learning could find them useful. The tutorials assume very little prior knowledge about deep learning and TensorFlow. In addition to featuring relevant models, I hoped that these tutorials could be a gentle introduction for building, tuning, and evaluating your own complex models in Tensorflow 2.

These are a work in a progress. If you have any questions or feedback on how I can improve them, please let me know. The tutorials accompany a review we are currently writing on this literature. Lastly, if you enjoyed these tutorials, feel free to star the repository. Thanks!

Open In Colab 1. Introduction to Deep Learning for Causal Inference on Observables. [SHORT]

(For those already familar with Tensorflow)

This tutorial introduces the idea of representation learning for causal inference. You also build and test a simple conditional average treatment effect (CATE) estimator, TARNet (first introduced in Shalit et al., 2017), using the TF2 functional API.

Open In Colab 1. Introduction to Deep Learning for Causal Inference on Observables. [LONG]

(For those with no prior DL experience)

This tutorial is an "unabridged" version of the above for those who have never done any DL and may find TF overwhelming. It introduces S-learners, and T-learners before TARNet as a way to get familiar with building custom Tensorflow models.

Open In Colab 2. Causal Inference Metrics and Hyperparameter Optimization.

Because we do not observe counterfactual outcomes, it's not obvious how to optimize supervised learning models for causal inference. This tutorial introduces some metrics for evaluating model performance. In the first part, you learn how to assess performance on these metrics in Tensorboard. In the second part, we hack Keras Tuner to do hyperparameter optimization for TARNet, and discuss considerations for training models as estimators rather than predictors.

Open In Colab 3. Semi-parametric extensions to TARNet

This tutorial highlights some semi-parametric extensions to TARNet featured in Shi et al., 2020. We add treatment modeling to our TARNet model and build an augmented inverse propensity score estimator. We then briefly describe the algorithm for Targeted Maximum Likelihood Estimation to introduce and build a TARNet with Shi et al.'s Targeted Regularization.

Open In Colab 4. Uncertainty and Interpretation. [IN PROGRESS]

This tutorial reimplements Dragonnet in Pytorch and shows how to calculate asymptotically-valid confidence intervals for the average treatment effect. We also interpret the features contributing to different heterogeneous CATEs using Integrated Gradients and SHAP scores. This is also a good tutorial if you also just want to learn how to interpret SHAP scores, independent of the context of causal inference.

Open In Colab 5. Using Integral Probability Metrics for Causal Inference [OPTIONAL]

This tutorial features the Counterfactual Regression Network (CFRNet) and propensity-weighted CFRNet featured in Shalit et al., 2017, Johannson et al. 2018, Johannson et al. 2020. This approach relies on Integral Probability Metrics (e.g. the MMD and Wasserstein distance used in GANs) to bound the counterfactual prediction loss and force the treated and control distributions closer together. The weighted variant adds adaptive propensity-based weights that provide a consistency guarantee, relax overlap assumptions, and ideally reduce bias.

deep-learning-for-causal-inference's People

Contributors

kochbj avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.