Git Product home page Git Product logo

hidden-stratification's Introduction

No Subclass Left Behind: Fine-Grained Robustness in Coarse-Grained Classification Problems

This code implements the "GEORGE" algorithm from the following paper (to appear in NeurIPS 2020!):

Nimit Sohoni, Jared Dunnmon, Geoffrey Angus, Albert Gu, Christopher Ré

No Subclass Left Behind: Fine-Grained Robustness in Coarse-Grained Classification Problems

Abstract

In real-world classification tasks, each class often comprises multiple finer-grained "subclasses." As the subclass labels are frequently unavailable, models trained using only the coarser-grained class labels often exhibit highly variable performance across different subclasses. This phenomenon, known as hidden stratification, has important consequences for models deployed in safety-critical applications such as medicine. We propose GEORGE, a method to both measure and mitigate hidden stratification even when subclass labels are unknown. We first observe that unlabeled subclasses are often separable in the feature space of deep models, and exploit this fact to estimate subclass labels for the training data via clustering techniques. We then use these approximate subclass labels as a form of noisy supervision in a distributionally robust optimization objective. We theoretically characterize the performance of GEORGE in terms of the worst-case generalization error across any subclass. We empirically validate GEORGE on a mix of real-world and benchmark image classification datasets, and show that our approach boosts worst-case subclass accuracy by up to 22 percentage points compared to standard training techniques, without requiring any information about the subclasses.

Setup instructions

Prerequisites: Make sure you have Python>=3.6 and PyTorch>=1.5 installed. Then, install dependencies with:

pip install -r requirements.txt

Next, either add the base directory of the repository to your PYTHONPATH, or run:

pip install -e .

Demo

We provide a simple demo notebook at tutorials/Basic-Tutorial.ipynb. This example can also be run as a script:

python stratification/demo.py configs/demo_config.json

Configuration options

The first argument to the script should be the path to the configuration file. Default configurations for the GEORGE experiments in the paper are in the configs/ directory. The configuration can also be modified by directly set config values using the command-line interface. Use = to demarcate key-value pairs, and use . to access nested dictionaries as specified in the config; for example:

python stratification/run.py configs/mnist_george_config.json exp_dir=checkpoints/new-experiment classification_config.num_epochs=50

The modified config will be saved at config['exp_dir']. For the complete configuration definition, see stratification/utils/schema.py.

hidden-stratification's People

Contributors

nimz avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

hidden-stratification's Issues

Annotations for NIH Chest X-ray

Hi, thanks for sharing the excellent work! I wonder do you have the annotations for the chest drains in NIH Chest X-ray dataset?

Many thanks!

Error on execution

@nimz @stephenbach @henryre

I'm getting an error while executing the demo code

PYTHONPATH=.:$PYTHONPATH python stratification/demo.py configs/demo_config.json

Full trace

Traceback (most recent call last):
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/jsonschema.py", line 85, in _check_type
    val, fpath = parse_value_or_config(val, enable_path=self._enable_path)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/util.py", line 147, in parse_value_or_config
    value = load_value(cfg_path.get_content(), simple_types=simple_types)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/loaders_dumpers.py", line 111, in load_value
    loader = loaders[get_load_value_mode()]
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/loaders_dumpers.py", line 102, in get_load_value_mode
    mode = parent_parser.get().parser_mode
LookupError: <ContextVar name='parent_parser' at 0x7f60013b1950>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "stratification/demo.py", line 50, in <module>
    main()
  File "stratification/demo.py", line 12, in main
    config = get_config()
  File "/home/user/hidden-stratification/stratification/utils/parse_args.py", line 16, in get_config
    args = parser.parse_args(args_list)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/argparse.py", line 1755, in parse_args
    args, argv = self.parse_known_args(args, namespace)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/argparse.py", line 1787, in parse_known_args
    namespace, args = self._parse_known_args(args, namespace)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/argparse.py", line 1996, in _parse_known_args
    stop_index = consume_positionals(start_index)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/argparse.py", line 1952, in consume_positionals
    take_action(action, args)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/argparse.py", line 1861, in take_action
    action(self, namespace, argument_values, option_string)
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/jsonschema.py", line 73, in __call__
    val = self._check_type(args[2])
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/jsonschema.py", line 93, in _check_type
    except (TypeError, ValueError) + get_jsonschema_exceptions() + get_loader_exceptions() as ex:
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/loaders_dumpers.py", line 107, in get_loader_exceptions
    return loader_exceptions[get_load_value_mode()]
  File "/home/user/anaconda_3/envs/george_env/lib/python3.7/site-packages/jsonargparse/loaders_dumpers.py", line 102, in get_load_value_mode
    mode = parent_parser.get().parser_mode
LookupError: <ContextVar name='parent_parser' at 0x7f60013b1950>

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.