Git Product home page Git Product logo

marcovirgolin / cogs Goto Github PK

View Code? Open in Web Editor NEW
4.0 3.0 1.0 412 KB

A baseline genetic algorithm for the discovery of counterfactuals, implemented in Python for ease of use and heavily leveraging NumPy for speed.

License: GNU General Public License v3.0

Python 58.09% Jupyter Notebook 41.91%
counterfactual-explanations counterfactuals genetic-algorithm explainable-ai explainable-ml explainable-machine-learning interpretable-ai interpretable-ml interpretable-machine-learning

cogs's Introduction

cogs

CoGS is intended to be a baseline search algorithm for the discovery of counterfactuals. As the name suggests, CoGS is a genetic algorithm: it employs a population of candidate counterfactuals and does not require the machine learning model for which counterfactuals are sought to expose gradients. CoGS is implemented in Python for ease of use, and heavily relies on NumPy for speed.

Colab example: https://colab.research.google.com/drive/1HQ4wcViJ5YV6w648yUtmiCoa2fGj4ftE

Reference

Please consider citing the paper for which CoGS has been developed:

@article{virgolin2022on,
  author = {Marco Virgolin and Saverio Fracaros},
  title = {On the Robustness of Sparse Counterfactual Explanations to Adverse Perturbations},
  journal = {Artificial Intelligence},
  pages = {103840},
  year = {2022},
  issn = {0004-3702},
  doi = {https://doi.org/10.1016/j.artint.2022.103840},
  url = {https://www.sciencedirect.com/science/article/pii/S0004370222001801},
}

Installation

Clone this repository, then pip install . from within it.

Usage

CoGS is relatively simple to setup and run. Here's an example:

cogs = Evolution(
        """ hyper-parameters of the problem (required!) """
        x=x,  # the starting point, for which the black-box model gives an undesired class prediction
        fitness_function=gower_fitness_function,  # a classic fitness function for counterfactual explanations
        fitness_function_kwargs={'blackbox':bbm,'desired_class': desired_class},  # bbm is the black-box model, these params are necessary
        feature_intervals=feature_intervals,  # intervals within which the search operates
        indices_categorical_features=indices_categorical_features,  # the indices of the features that are categorical
        plausibility_constraints=pcs, # can be None if no constraints need to be set
        """ hyper-parameters of the optimization (all optional) """
        evolution_type='classic', # the type of evolution, classic works well in general and is relatively fast to execute
        population_size=1000,   # how many candidate counterfactual examples to evolve simultaneously
        n_generations=100,  # number of iterations for the evolution
        selection_name='tournament_4', # selection pressure
        init_temperature=0.8, # how "far" from x we initialize
        num_features_mutation_strength=0.25, # strength of random mutations for numerical features
        num_features_mutation_strength_decay=0.5, # decay for the hyper-param. above
        num_features_mutation_strength_decay_generations=[50,75,90], # when to apply the decay
        """ other optional hyper-parameters """
        verbose=True  # logs progress at every generation 
)
cogs.run()
result = cogs.elite   # closest-found point to 'x' for which 'bbm.predict' returns 'desired_class'

The black-box model (bbm) can be anything, as long as it exposes a predict function, just like scikit-learn models do. There's a full example in our notebook example.ipynb.

Customization

Fitness function

The quality of a candidate counterfactual z for the starting point x is called fitness (to be maximized). The fitness implemented in CoGS is:

-1*{0.5*gower_distance(z,x) + 0.5*L0(z,x) + int(bbm.predict(z)!=desired_class)}

which takes values between (-inf,0) (the closer to 0, the better). You can change the fitness function in cogs/fitness.py to any you like, as long as maximization is pursued. It is strongly recommended to evaluate the entire population of candidate counterfactuals using NumPy operations to maintain high efficiency.

Crossover and mutation

Currently, CoGS implements the following operators to generate offspring counterfactuals:

  • Crossover: Generates two offspring counterfactuals from two parent counterfactuals by swapping the feature values (called genes) of latter, uniformly at random (good under L0 requirements).
  • Linear crossover: Works like the previous crossover for categorical features; for numerical features, the feature values of the offspring are random linear combinations of those of the parents (implemented but not used by default).
  • Mutation: Generates an offspring counterfactual out of each parent. For categorical features, a new random category is sampled; for numerical features, the feature value of the parent is altered by a magnitude that depends on the interval of variation for that feature times the num_features_mutation_strength hyper-parameter. Since mutation can result in feature values that are out-of-bounds for the data set at hand, corrections are implemented.

You can create your own crossover or mutation operator in cogs/variation.py. If your operator can generate values outside the intervals within which features take values, you should consider implement a correction mechanism (as currently present in variation.generate_plausible_mutations), or use the option apply_fixes in fitness.gower_fitness_function.

cogs's People

Contributors

marcovirgolin avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

kc2fresh

cogs's Issues

Setting None to plausibility constrains

Evolution initialization has comment that plausibility constraints can be set as "None", but it actually requires list of None-s, where length of list is the number of features.
Otherwise if cogs = Evolution(..,plausibility_constraints="None",..), than cogs.run() returns the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In [49], line 1
----> 1 cogs.run()

File ~/TRUST_AI/robust/CoGS/cogs/evolution.py:140, in Evolution.run(self)
    138 def run(self):
--> 140   self.population.initialize(self.x, self.feature_intervals, self.indices_categorical_features,
    141     self.plausibility_constraints, self.init_temperature)
    143   self.population.fitnesses = self.fitness_function(genes=self.population.genes, x=self.x, 
    144     feature_intervals=self.feature_intervals, indices_categorical_features=self.indices_categorical_features,
    145     plausibility_constraints=self.plausibility_constraints,
    146     **self.fitness_function_kwargs)
    148   best_fitness_idx = np.argmax(self.population.fitnesses)

File ~/TRUST_AI/robust/CoGS/cogs/population.py:26, in Population.initialize(self, x, feature_intervals, indices_categorical_features, plausibility_constraints, temperature)
     24     init_feat_i = x[i]
     25   else:
---> 26     raise ValueError("Unrecognized plausibility constraint",plausibility_constraints[i],"for feature",i)
     27 else:
     28   init_feat_i = np.random.choice(feature_intervals[i], size=n)

ValueError: ('Unrecognized plausibility constraint', 'N', 'for feature', 0)`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.