Git Product home page Git Product logo

mateoespinosa / cem Goto Github PK

View Code? Open in Web Editor NEW
39.0 4.0 12.0 16.34 MB

Repository for our NeurIPS 2022 paper "Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off" and our NeurIPS 2023 paper "Learning to Receive Help: Intervention-Aware Concept Embedding Models"

Home Page: https://arxiv.org/abs/2209.09056

License: MIT License

Python 100.00%
artificial-intelligence concepts deep-learning explainability explainable-ai interpretability neural-networks pytorch concept-based-explanations concept-based-models

cem's Introduction

Concept Embedding Models

License: MIT Python 3.7+ CEM Paper IntCEM Paper Poster Slides

This repository contains the official Pytorch implementation of our two papers:

TL;DR

Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off

CEM Architecture

We propose Concept Embedding Models (CEMs), a novel family of concept-based interpretable neural architectures that can achieve task high performance while being capable of producing concept-based explanations for their predictions. These models can be trained using a very limited number of concept annotations in the task of interest and allow effective test-time concept interventions, enabling CEMs to drastically improve their task performance in a human-in-the-loop setting.

Learning to Receive Help: Intervention-Aware Concept Embedding Models

IntCEM Architecture

In this paper we argue that previous concept-based architectures, including CEMs to some extent, do not include implicit training incentives to improve their performance when they are intervened on. Yet, during testing we expect these models to correctly intake expert feedback and improve their performance. Here, we propose Intervention-awqre Concept Embedding Models (IntCEMs) as an alternative training framework for CEMs that exposes CEMs to informative trajectories of interventions during training so that a penalty is incurred by the model when it mispredicts its task after several interventions have been performed compared to when no interventions are performed. Through this process, we concurrently learn a CEM that is more receptive to test-time intervnetions and an intervention policy that suggests which concepts one should intervene on next to lead to the largest decrease in uncertainty in the mdoel.

Abstracts

Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off (NeurIPS 2022)

Deploying AI-powered systems requires trustworthy models supporting effective human interactions, going beyond raw prediction accuracy. Concept bottleneck models promote trustworthiness by conditioning classification tasks on an intermediate level of human-like concepts. This enables human interventions which can correct mispredicted concepts to improve the model's performance. However, existing concept bottleneck models are unable to find optimal compromises between high task accuracy, robust concept-based explanations, and effective interventions on concepts -- particularly in real-world conditions where complete and accurate concept supervisions are scarce. To address this, we propose Concept Embedding Models, a novel family of concept bottleneck models which goes beyond the current accuracy-vs-interpretability trade-off by learning interpretable high-dimensional concept representations. Our experiments demonstrate that Concept Embedding Models (1) attain better or competitive task accuracy w.r.t. standard neural models without concepts, (2) provide concept representations capturing meaningful semantics including and beyond their ground truth labels, (3) support test-time concept interventions whose effect in test accuracy surpasses that in standard concept bottleneck models, and (4) scale to real-world conditions where complete concept supervisions are scarce.

Learning to Receive Help: Intervention-Aware Concept Embedding Models (NeurIPS 2023)

Concept Bottleneck Models (CBMs) tackle the opacity of neural architectures by constructing and explaining their predictions using a set of high-level concepts. A special property of these models is that they permit concept interventions, wherein users can correct mispredicted concepts and thus improve the model's performance. Recent work, however, has shown that intervention efficacy can be highly dependent on the order in which concepts are intervened on and on the model's architecture and training hyperparameters. We argue that this is rooted in a CBM's lack of train-time incentives for the model to be appropriately receptive to concept interventions. To address this, we propose Intervention-aware Concept Embedding models (IntCEMs), a novel CBM-based architecture and training paradigm that improves a model's receptiveness to test-time interventions. Our model learns a concept intervention policy in an end-to-end fashion from where it can sample meaningful intervention trajectories at train-time. This conditions IntCEMs to effectively select and receive concept interventions when deployed at test-time. Our experiments show that IntCEMs significantly outperform state-of-the-art concept-interpretable models when provided with test-time concept interventions, demonstrating the effectiveness of our approach.

Models

Concept Bottleneck Models (CBMs) have recently gained attention as high-performing and interpretable neural architectures that can explain their predictions using a set of human-understandable high-level concepts. Nevertheless, the need for a strict activation bottleneck as part of the architecture, as well as the fact that one requires the set of concept annotations used during training to be fully descriptive of the downstream task of interest, are constraints that force CBMs to trade downstream performance for interpretability purposes. This severely limits their applicability in real-world applications, where data rarely comes with concept annotations that are fully descriptive of any task of interest.

Concept Embedding Models (CEMs) tackle these two big challenges. Our neural architecture expands a CBM's bottleneck and allows the information related to unseen concepts to be flow as part of the model's bottleneck. We achieve this by learning a high-dimensional representation (i.e., a concept embedding) for each concept provided during training. Naively extending the bottleneck, however, may directly impede the use of test-time concept interventions where one can correct a mispredicted concept in order to improve the end model's downstream performance. This is a crucial element motivating the creation of traditional CBMs and therefore is a highly desirable feature. Therefore, in order to use concept embeddings in the bottleneck while still permitting effective test-time interventions, CEM construct each concept's representation as a linear combination of two concept embeddings, where each embedding has fixed semantics. Specifically, we learn an embedding to represent the "active" space of a concept and one to represent the "inactive" state of a concept, allowing us to selecting between these two produced embeddings at test-time to then intervene in a concept and improve downstream performance. Our entire architecture is visualized in the figure above and formally described in our paper.

Installation

You can locally install this package by first cloning this repository:

$ git clone https://github.com/mateoespinosa/cem

We provide an automatic mechanism for this installation using Python's setup process with our standalone setup.py. To install our package, therefore, you only need to move into the cloned directoru (cd cem) and run:

$ python setup.py install

After running this, you should by able to import our package locally using

import cem

Usage

High-level Usage

In this repository, we include a standalone Pytorch implementation of Concept Embedding Models (CEMs), Intervention-aware Concept Embedding Models (IntCEMs), and several variants of Concept Bottleneck Models (CBMs). These models can be easily trained from scratch given a set of samples annotated with a downstream task and a set of binary concepts. In order to use our implementation, however, you first need to install all our code's requirements (listed in requirements.txt) by following the installation instructions above.

After installation has been completed, you should be able to import cem as a package and use it to train a CEM as follows:

import pytorch_lightning as pl
from cem.models.cem import ConceptEmbeddingModel

#####
# Define your pytorch dataset objects
#####

train_dl = ...
val_dl = ...

#####
# Construct the model
#####

cem_model = ConceptEmbeddingModel(
  n_concepts=n_concepts, # Number of training-time concepts
  n_tasks=n_tasks, # Number of output labels
  emb_size=16,
  concept_loss_weight=0.1,
  learning_rate=1e-3,
  optimizer="adam",
  c_extractor_arch=latent_code_generator_model, # Replace this appropriately
  training_intervention_prob=0.25, # RandInt probability
)

#####
# Train it
#####

trainer = pl.Trainer(
    accelerator="gpu",  # or "cpu" if no GPU available
    devices="auto",
    max_epochs=100,
    check_val_every_n_epoch=5,
)
# train_dl and val_dl are datasets previously built...
trainer.fit(cem_model, train_dl, val_dl)

Similarly, you can train an IntCEM as follows:

import pytorch_lightning as pl
from cem.models.intcbm import IntAwareConceptEmbeddingModel

#####
# Define your pytorch dataset objects
#####

train_dl = ...
val_dl = ...

#####
# Construct the model
#####

intcem_model = IntAwareConceptEmbeddingModel(
  n_concepts=n_concepts, # Number of training-time concepts
  n_tasks=n_tasks, # Number of output labels
  emb_size=16,
  concept_loss_weight=0.1,
  learning_rate=1e-3,
  optimizer="adam",
  c_extractor_arch=latent_code_generator_model, # Replace this appropriately
  training_intervention_prob=0.25, # RandInt probability
  intervention_task_discount=1.1, # Penalty factor "gamma" for misprediction after intervntions
  intervention_weight=1, # The weight lambda_roll of the intervention loss in IntCEM's training objective
)

#####
# Train it
#####

trainer = pl.Trainer(
    accelerator="gpu",  # or "cpu" if no GPU available
    devices="auto",
    max_epochs=100,
    check_val_every_n_epoch=5,
)
# train_dl and val_dl are datasets previously built...
trainer.fit(intcem_model, train_dl, val_dl)

For a full example showing how to generate a dataset and configure a CEM for training on your own custom dataset, please see our Dot example notebook for a step-by-step walkthrough on how to set things up for your own work. The same setup used in this notebook can be used for training IntCEMs or CBMs as defined in this library.

Included Models

Besides CEMs and IntCEMs, this repository also includes a PyTorch implementation of Concept Bottleneck Models (CBMs) and some of its variants. These models should be trainable out of the box if one follows the same steps used for training an IntCEM/CEM as indicated above.

You can import CBMs by including

from cem.models.cbm import ConceptBottleneckModel

Class Arguments

Our CEM module takes the following initialization arguments:

  • n_concepts (int): The number of concepts given at training time.

  • n_tasks (int): The number of output classes of the CEM.

  • emb_size (int): The size of each concept embedding. Defaults to 16.

  • training_intervention_prob (float): RandInt probability. Defaults to 0.25.

  • embedding_activation (str): A valid nonlinearity name to use for the generated embeddings. It must be one of [None, "sigmoid", "relu", "leakyrelu"] and defaults to "leakyrelu".

  • shared_prob_gen (Bool): Whether or not weights are shared across all probability generators. Defaults to True.

  • concept_loss_weight (float): Weight to be used for the final loss' component corresponding to the concept classification loss. Default is 0.01.

  • task_loss_weight (float): Weight to be used for the final loss' component corresponding to the output task classification loss. Default is 1.

  • c2y_model (Pytorch.Module): A valid pytorch Module used to map the CEM's bottleneck (with size n_concepts * emb_size) to n_tasks output activations (i.e., the output of the CEM). If not given, then a simple leaky-ReLU MLP, whose hidden layers have sizes c2y_layers, will be used.

  • c2y_layers (List[int]): List of integers defining the size of the hidden layers to be used in the MLP to predict classes from the bottleneck if c2y_model was NOT provided. If not given, then we will use a simple linear layer to map the bottleneck to the output classes.

  • c_extractor_arch (Fun[(int), Pytorch.Module]): A generator function for the latent code generator model that takes as an input the size of the latent code before the concept embedding generators act (using an argument called output_dim) and returns a valid Pytorch Module that maps this CEM's inputs to the latent space of the requested size.

  • optimizer (str): The name of the optimizer to use. Must be one of adam or sgd. Default is adam.

  • momentum (float): Momentum used for optimization. Default is 0.9.

  • learning_rate (float): Learning rate used for optimization. Default is 0.01.

  • weight_decay (float): The weight decay factor used during optimization. Default is 4e-05.

  • weight_loss (List[float]): Either None or a list with n_concepts elements indicating the weights assigned to each predicted concept during the loss computation. Could be used to improve performance/fairness in imbalanced datasets.

  • task_class_weights (List[float]): Either None or a list with n_tasks elements indicating the weights assigned to each output class during the loss computation. Could be used to improve performance/fairness in imbalanced datasets.

  • active_intervention_values (List[float]): A list of n_concepts values to use when positively intervening in a given concept (i.e., setting concept c_i to 1 would imply setting its corresponding predicted concept to active_intervention_values[i]). If not given, then we will assume that we use 1 for all concepts. This parameter is important when intervening in CEMs that do not have sigmoidal concepts, as the intervention thresholds must then be inferred from their empirical training distribution.

  • inactive_intervention_values (List[float]): A list of n_concepts values to use when negatively intervening in a given concept (i.e., setting concept c_i to 0 would imply setting its corresponding predicted concept to inactive_intervention_values[i]). If not given, then we will assume that we use 0 for all concepts.

  • intervention_policy (an instance of InterventionPolicy as defined in cem.interventions.intervention_policy): An optional intervention policy to be used when intervening on a test batch sample x (first argument), with corresponding true concepts c (second argument), and true labels y (third argument). The policy must produce as an output a list of concept indices to intervene (in batch form) or a batch of binary masks indicating which concepts we will intervene on.

  • top_k_accuracy (List[int]): List of top k values to report accuracy for during training/testing when the number of tasks is high.

Notice that the CBM module takes similar arguments, albeit some extra ones and some with slightly different semantics (e.g., x2c_model goes directly from the input to the bottleneck).

Similarly, the IntCEM module takes the same arguments as its CEM counterpart with the following additional arguments:

  • intervention_task_discount (float): Penalty to be applied for mispredicting the task label after some interventions have been performed vs when no interventions have been performed. This is what we call gamma in the paper. Defaults to 1.1.
  • intervention_weight (float): Weight to be used for the intervention policy loss during training. This is what we call lambda_roll in the paper. Defaults to 1.
  • concept_map (Dict[Any, List[int]]): A map between concept group names (e.g. "wing_colour") and a list of concept indices that represent the group. If concpet groups are known, and we want to learn an intervention policy that acts on groups rather than on individual concepts, then this dictionary should be provided before training and at intervention time. Defaults to every concept being a single group.
  • use_concept_groups (bool): Set this to true if concept_map is provided and you want interventions to be done on entire groups of concepts at a time rather than on individual concepts. Defaults to True.
  • num_rollouts (int): The number of Monte Carlo rollouts we will perform when sampling trajectories (i.e., the number of trajectories one will sample per training step). Defaults to 1.
  • max_horizon (int): The end maximum number of interventions to be made on a single training trajectory. Defaults to 6.
  • initial_horizon (int): The initial maximum number of interventions to be made on a single training trajectory. Defaults to 2.
  • horizon_rate (int): How much we increase T_max on every training step. This value will start as initial_horizon and end in max_horizon increasing by a factor of horizon_rate on every training step. Defaults to 1.005.
  • int_model_layers (List[int]): The size of the layers to be used for the MLP forming IntCEM's learnable polilcy psi. Defaults to
  • int_model_use_bn (bool): Whether or not we use batch normalization between each layer of the intervnetion policy model psi. Defaults to True.

Experiment Reproducibility

Downloading the Datasets

In order to be able to properly run our experiments, you will have to download the pre-processed CUB dataset found here to cem/data/CUB200/ and the CelebA dataset found here to cem/data/celeba. You may opt to download them to different locations but their paths will have to be modified in the respective experiment configs (via the root_dir parameter) or via the DATASET_DIR environment variable.

Running Experiments

To reproduce the experiments discussed in our papers, please use the scripts in the experiments directory after installing the cem package as indicated above. For example, to run our experiments on the DOT dataset (see our paper), you can execute the following command:

$ python experiments/run_experiments.py -c experiments/configs/dot.yaml

This should generate a summary of all the results after execution has terminated in the form of a table and should dump all results/trained models/logs into the given output directory (dot_results/ in this case).

Similarly, you can recreate our CUB, CelebA, and MNIST-Add experiments (or those on any other synthetic dataset) by running

$ python experiments/run_experiments.py -c experiments/configs/{cub/celeba/mnist_add}.yaml

These scripts will also run the intervention experiments and generate the test accuracies for all models as one intervenes on an increasing number of concepts.

Once an experiment is over, the script will dump all testing results/statistics summarized over 5 random initializations in a single results.joblib dictionary which you can then analyze. This dictionary has at its top level the seed id (e.g. '0', '1', etc), then this is mapped to a dictionary that maps run names (e.g., 'CEM') to dictonary containing all test and validation metrics computed.

Citation

If you would like to cite this repository, or the accompanying paper, we would appreciate it if you could please use the following citations to include both the CEM and IntCEM papers (although if IntCEMs are not used, only the CEM citation is necessary):

CEM Paper

@article{EspinosaZarlenga2022cem,
  title={Concept Embedding Models: Beyond the Accuracy-Explainability Trade-Off},
  author={
    Espinosa Zarlenga, Mateo and Barbiero, Pietro and Ciravegna, Gabriele and
    Marra, Giuseppe and Giannini, Francesco and Diligenti, Michelangelo and
    Shams, Zohreh and Precioso, Frederic and Melacci, Stefano and
    Weller, Adrian and Lio, Pietro and Jamnik, Mateja
  },
  journal={Advances in Neural Information Processing Systems},
  volume={35},
  year={2022}
}

IntCEM Paper

@article{EspinosaZarlenga2023intcem,
  title={Learning to Receive Help: Intervention-Aware Concept Embedding Models},
  author={
    Espinosa Zarlenga, Mateo and Collins, Katie and Dvijotham,
    Krishnamurthy and Weller, Adrian and Shams, Zohreh and Jamnik, Mateja
  },
  journal={Advances in Neural Information Processing Systems},
  volume={36},
  year={2023}
}

cem's People

Contributors

gabrieleciravegna avatar mateoespinosa avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

cem's Issues

Unexpected results while training CBM on CelebA

Hey Mateo,

Thanks for your work (I really admire your papers on concept-based models 😁) and releasing your code!

I am trying to train Concept Bottleneck Models using this repository on CelebA. Following your approach, I am using 8 concepts, but I don't have any hidden concepts (so, all 8 concepts are used). I tried training independent and sequential concept bottleneck models. This is the config:

trials: 1
results_dir: /mnt/qb/work/bethge/bkr046/CEM/results/CelebA

dataset: celeba
root_dir: /mnt/qb/work/bethge/bkr046/DATASETS/celeba_torchvision #cem/data/
image_size: 64
num_classes: 1000
batch_size: 512
use_imbalance: True
use_binary_vector_class: True
num_concepts: 8
label_binary_width: 1
label_dataset_subsample: 12
#num_hidden_concepts: 0
selected_concepts: False
num_workers: 8
sampling_percent: 1
test_subsampling: 1

intervention_freq: 1
intervention_batch_size: 1024
intervention_policies:
- "group_random_no_prior"

competence_levels: [1, 0]
incompetence_intervention_policies:
- "group_random_no_prior"

skip_repr_evaluation: True

shared_params:
top_k_accuracy: [3, 5, 10]
save_model: True
max_epochs: 200
patience: 15
emb_size: 16
extra_dims: 0
concept_loss_weight: 1
learning_rate: 0.005
weight_decay: 0.000004
weight_loss: False
c_extractor_arch: resnet34
optimizer: sgd
early_stopping_monitor: val_loss
early_stopping_mode: min
early_stopping_delta: 0.0
momentum: 0.9
sigmoidal_prob: False

runs:
- architecture: 'SequentialConceptBottleneckModel'
extra_name: "NoInterventionInTrainingNoHiddenConcepts"
sigmoidal_embedding: False
concat_prob: False
embedding_activation: "leakyrelu"
bool: False
extra_dims: 0
sigmoidal_extra_capacity: False
sigmoidal_prob: True
training_intervention_prob: 0

While doing this, I noticed a few unexpected things:

  1. Even with 100% interventions, the accuracy of the independent model is less than 50%. Since all concepts are visible, and it's an independent model, the c2y component would have seen all possible concept combinations during training, and therefore should be able to perform well when all ground-truth concepts are provided during interventions.

+---------------------------------------------------------------------------+-----------------+------------------+-----------------+-----------------+-----------------+-----------------+-----------------+
| Method | Task Accuracy | Concept Accuracy | Concept AUC | 25% Int Acc | 50% Int Acc | 75% Int Acc | 100% Int Acc |
+---------------------------------------------------------------------------+-----------------+------------------+-----------------+-----------------+-----------------+-----------------+-----------------+
| IndependentConceptBottleneckModel | 0.2743 ± 0.0000 | 0.8237 ± 0.0000 | 0.8163 ± 0.0000 | 0.3009 ± 0.0000 | 0.3229 ± 0.0000 | 0.3457 ± 0.0000 | 0.3711 ± 0.0000 |
| SequentialConceptBottleneckModel | 0.2784 ± 0.0000 | 0.8237 ± 0.0000 | 0.8163 ± 0.0000 | 0.3107 ± 0.0000 | 0.3409 ± 0.0000 | 0.3584 ± 0.0000 | 0.3711 ± 0.0000 |
+---------------------------------------------------------------------------+-----------------+------------------+-----------------+-----------------+-----------------+-----------------+-----------------+

  1. The training accuracy of the independent model is also pretty low, which is concerning. This is from the 49th epoch (which looks like the last epoch to me):

Epoch 49: 86%|████████▌ | 24/28 [00:01<00:00, 14.65it/s, loss=3.13, y_accuracy=0.381, y_top_3_accuracy=0.548, y_top_5_accuracy=0.619, y_top_10_accuracy=0.714, val_y_accuracy=0.384, val_y_top_3_accuracy=0.557, val_y_top_5_accuracy=0.629, val_y_top_10_accuracy=0.697]

  1. The number of classes in the dataset is 230 instead of 256. I guess that is because some of the 256 possible classes never appear in the dataset. Did you see this as well?

Do you think these observations are expected? If yes, it would be really helpful if you could provide some intuition as to why.

Computing CAS for non-concept models

Thanks for the great paper and repo!

I have a question on how to compute the concept alignment score (CAS) for non-concept-based models.
In your paper, you report CAS for "No concepts" models, but it is not clear on which embeddings the score should be computed.
Should it be computed on the latent code h ?

TypeError: __init__() got an unexpected keyword argument 'task_class_weights'

When executing the command "python experiments/run_experiments.py dot -o dot_results/",there is a error.

Error Info:
File "D:\0develop\miniconda3\envs\CEM_NeurIPS\lib\site-packages\cem-1.0.0-py3.8.egg\cem\train\training.py", line 151, in construct_sequential_models
TypeError: init() got an unexpected keyword argument 'task_class_weights'

This is after I successfully executed "$ python setup.py install" because cmd showed "Finished processing dependencies for cem==1.0.0".

SO, is there a code problem with the packed "cem-1.0.0-py3.8.egg" file that "$ python setup.py install" creates and how to solve this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.