Git Product home page Git Product logo

data-centric-deep-learning's Introduction

Variational Item Response Theory (VIBO)

This repository contains PyTorch and Pyro code for "Variational Item Response Theory: Fast, Accurate, and Expressive" (https://arxiv.org/abs/2002.00276). The Pyro code contains fewer functional features than the PyTorch version but can be easily extended by the motivated reader.

In this repository you can also find code for IRT with Hamiltonian Monte Carlo (HMC) and two purely deep neural baselines: MLE and Deep IRT (i.e. DKVMN-IRT). As in our experiments we also compared VIBO with Maximum Marginal Likelihood approaches (e.g. Expectation-Maximization); for this, we leverage the MIRT package.

NOTE: we are unable to release the Gradescope dataset publically. Please contact us if you require that data urgently. Otherwise, all other datasets used in the paper are supported here.

NOTE: we will be releasing a pip package supporting HMC and VIBO for IRT in Python. Please stay tuned!

Abstract

Item Response Theory (IRT) is a ubiquitous model for understanding humans based on their responses to questions, used in fields as diverse as education, medicine and psychology. Large modern datasets offer opportunities to capture more nuances in human behavior, potentially improving test scoring and better informing public policy. Yet larger datasets pose a difficult speed / accuracy challenge to contemporary algorithms for fitting IRT models. We introduce a variational Bayesian inference algorithm for IRT, and show that it is fast and scaleable without sacrificing accuracy. Using this inference approach we then extend classic IRT with expressive Bayesian models of responses. Applying this method to five large-scale item response datasets from cognitive science and education yields higher log likelihoods and improvements in imputing missing data. The algorithm implementation is open-source, and easily usable.

Setup Instructions

We use Python 3 and a Conda environment. Please follow the instructions below.

conda create -n vibo python=3 anaconda
conda activate vibo
conda install pytorch torchvision -c pytorch
pip install pyro-ppl
pip install tqdm nltk dotmap sklearn

The config.py file contains several useful global variables. Please change the paths there to be suitable to your own use cases.

Downloading Data

I have included the real world data (with exception of Gradescope) in the public Google drive folder: https://drive.google.com/drive/folders/1ja9P5yzeUDyzzm748p5JObAEs_Evysgc?usp=sharing. Please unzip the folders in the DATA_DIR as specified by the config.

How to Use

We will walk through a few commands for data processing and training models. First, this repository is setup as a package. Thus, for every fresh terminal, we need to run

source init_env.sh

in order to add the correct paths.

Create Simulation Data

To generate simulated data from an IRT model:

python src/simulate.py --irt-model 2pl --num-person 10000 --num-item 100 --ability-dim 1 

The generated data will be saved in a new folder inside DATA_DIR. We recommend using 1pl or 2pl unless the dataset size is rather large.

Fitting MLE

To fit a Maximum Likelihood model, use the following command:

python src/torch_core/mle.py --irt-model 2pl --dataset 2pl_simulation --gpu-device 0 --cuda --num-person 10000 --num-item 100

This script has many command line arguments which the user should inspect carefully. If you do not have a CUDA device, remove the --cuda flag.

If the dataset is a simulated one, it will read the --num-person and --num-item flags to know which sub-directory in DATA_DIR to load the data from. If the dataset is not simulated, those two flags are meaningless.

If you wish to test missing data imputation, add the --artificial-missing-perc flag as we need to artificially hide some entries.

Fitting VIBO

The commands for VIBO are quite similar to MLE. To run the un-amortized version, do:

python src/torch_core/vi.py --irt-model 2pl --dataset 2pl_simulation --gpu-device 0 --cuda --num-person 10000 --num-item 100

To run the amortized version, do:

python src/torch_core/vibo.py --irt-model 2pl --dataset 2pl_simulation --gpu-device 0 --cuda --num-person 10000 --num-item 100

Here, we highlight a few command line flags. First, --conditional-posterior, if present, conditions the approximate posterior over ability on the sampled items. Second, --n-norm-flows adds "Normalizing Flows" to the approximate posterior such that the resulting distribution need no longer be Gaussian but still reparameterizable (this is not mentioned in the main text but may be useful).

Several scripts are included to evaluate models. To get inferred latent variables, use src/torch_core/infer.py; to compute log marginal likelihoods, use src/torch_core/marginal.py; to analyze posterior predictives, use src/torch_core/predictives.py.

Fitting VIBO in Pyro

If you are more comfortable using a probabilistic programming language, we also include an implementation of VIBO in Pyro (which has been confirmed to return similar results to the PyTorch implementation). Do:

python src/pyro_core/vibo.py --irt-model 2pl --dataset 2pl_simulation --gpu-device 0 --cuda --num-person 10000 --num-item 100

Fitting MCMC in Pyro

Pyro additionally makes it very easy to do inference with MCMC or HMC. We thus leverage it to compare VIBO to traditional methods of approximate Bayesian inference:

python src/pyro_core/hmc.py --irt-model 2pl --dataset 2pl_simulation --cuda --num-person 10000 --num-item 100

We emphasize that --num-samples and --num-warmup are important to getting good posterior samples. If you have several GPUs/CPUs, consider increasing --num-chains to greater than one.

Fitting Deep IRT

Deep IRT is not a true inference model; rather it can only make predictions. Thus --artificial-missing-perc should be greater than 0. Use the command:

python src/dkvmn_irt/train.py --artificial-missing-perc 0.2 --num-person 10000 --num-item 100 --gpu-device 0 --cuda

Tuning Parameters

In some settings, especially with small datasets, VIBO adds a KL regularization term that may impose too strong of a regularization. In practive, we find that adding a weight (less than 1) on the KL regularization terms helps to circumvent to this problem.

data-centric-deep-learning's People

Contributors

blowoffvalve avatar kevinsbarnard avatar mhw32 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

data-centric-deep-learning's Issues

Requirements.txt not being installed properly in gitpod

Below is the log of pip install -r requirements.txt, with the relevant error:

Collecting scikit-learn==1.0.2 (from -r requirements.txt (line 17))
  Using cached scikit-learn-1.0.2.tar.gz (6.7 MB)
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'error'
  error: subprocess-exited-with-error
  
  × Preparing metadata (pyproject.toml) did not run successfully.
  │ exit code: 1
  ╰─> [261 lines of output]
      Partial import of sklearn during the build process.
      setup.py:128: DeprecationWarning:
      
        `numpy.distutils` is deprecated since NumPy 1.23.0, as a result
        of the deprecation of `distutils` itself. It will be removed for
        Python >= 3.12. For older Python versions it will remain present.
        It is recommended to use `setuptools < 60.0` for those Python versions.
        For more details, see:
          https://numpy.org/devdocs/reference/distutils_status_migration.html
      
      
        from numpy.distutils.command.build_ext import build_ext  # noqa
      INFO: C compiler: gcc -DNDEBUG -g -fwrapv -O3 -Wall -march=x86-64 -mtune=generic -O3 -pipe -fno-plt -fexceptions -Wp,-D_FORTIFY_SOURCE=2 -Wformat -Werror=format-security -fstack-clash-protection -fcf-protection -g -ffile-prefix-map=/build/python/src=/usr/src/debug/python -flto=auto -ffat-lto-objects -march=x86-64 -mtune=generic -O3 -pipe -fno-plt -fexceptions -Wp,-D_FORTIFY_SOURCE=2 -Wformat -Werror=format-security -fstack-clash-protection -fcf-protection -g -ffile-prefix-map=/build/python/src=/usr/src/debug/python -flto=auto -march=x86-64 -mtune=generic -O3 -pipe -fno-plt -fexceptions -Wp,-D_FORTIFY_SOURCE=2 -Wformat -Werror=format-security -fstack-clash-protection -fcf-protection -g -ffile-prefix-map=/build/python/src=/usr/src/debug/python -flto=auto -fPIC
      
      INFO: compile options: '-c'
      
      .
      .
      .
        note: This error originates from a subprocess, and is likely not a problem with pip.
error: metadata-generation-failed

× Encountered error while generating package metadata.
╰─> See above for output.

note: This is an issue with the package mentioned above, not pip.
hint: See above for details.

Seems like going back to 3.10.11 should solve this. (currently 3.11.4)

drf_yasg errors when you attempt to start label_studio for all versions> 1.4.1 and < 1.6.0

Attempting to run label-studio with the version information available in the dockerfile and requirements.txt causes the error below
=> Database and media directory: /home/gitpod/.local/share/label-studio => Static URL is set to: /static/ Traceback (most recent call last): File "/home/gitpod/.pyenv/versions/3.8.16/bin/label-studio", line 8, in <module> sys.exit(main()) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/label_studio/server.py", line 282, in main _setup_env() File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/label_studio/server.py", line 40, in _setup_env application = get_wsgi_application() File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/django/core/wsgi.py", line 12, in get_wsgi_application django.setup(set_prefix=False) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/django/__init__.py", line 19, in setup configure_logging(settings.LOGGING_CONFIG, settings.LOGGING) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/django/conf/__init__.py", line 82, in __getattr__ self._setup(name) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/django/conf/__init__.py", line 69, in _setup self._wrapped = Settings(settings_module) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/django/conf/__init__.py", line 170, in __init__ mod = importlib.import_module(self.SETTINGS_MODULE) File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/importlib/__init__.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "<frozen importlib._bootstrap>", line 1014, in _gcd_import File "<frozen importlib._bootstrap>", line 991, in _find_and_load File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked File "<frozen importlib._bootstrap>", line 671, in _load_unlocked File "<frozen importlib._bootstrap_external>", line 843, in exec_module File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/label_studio/core/settings/label_studio.py", line 43, in <module> from label_studio.core.utils.common import collect_versions File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/label_studio/core/utils/common.py", line 36, in <module> Starting new HTTPS connection (1): o227124.ingest.sentry.io:443 from drf_yasg.inspectors import CoreAPICompatInspector, NotHandled File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/drf_yasg/inspectors/__init__.py", line 5, in <module> from .field import ( File "/home/gitpod/.pyenv/versions/3.8.16/lib/python3.8/site-packages/drf_yasg/inspectors/field.py", line 406, in <module> (serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), AttributeError: module 'rest_framework.serializers' has no attribute 'NullBooleanField' Sentry is attempting to send 2 pending error messages Waiting up to 2 seconds Press Ctrl-C to quit [https://o227124.ingest.sentry.io:443](https://o227124.ingest.sentry.io/) "POST /api/5820521/store/ HTTP/1.1" 200 41

Libgl1 is not installed in order to run clean.py

When running clean.py for the integration tests the following error is produced:
Traceback (most recent call last): File "src/tests/clean.py", line 10, in <module> import cv2 File "/home/gitpod/.pyenv/versions/3.8.13/lib/python3.8/site-packages/cv2/__init__.py", line 8, in <module> from .cv2 import * ImportError: libGL.so.1: cannot open shared object file: No such file or directory

This can be resolved by installing the libgl1 package manually:
sudo apt-get update && sudo apt-get install libgl1

Or better still, by including this in the dockerfile:
RUN sudo apt-get update && sudo apt-get install -y libgl1

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.