Git Product home page Git Product logo

cgl_fairness's People

Contributors

sangwon79 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cgl_fairness's Issues

How to split the experimental data

Hello author, I have some questions about the experimental data processing of this paper. Taking the ADULT dataset as an example, I have two main questions:

  1. May I ask if the test set used for the comparison algorithm mentioned in Figure 1 in the paper is the same? How is the test set divided when the Group label ratio (%) is different? When the Group label ratio (%) is changing, is the algorithm test set also changing?
  2. When using the "scratch" method, is it artificially removing sensitive attributes from the ADULT dataset before training the model?
    If you could answer, I would greatly appreciate it~
    Snipaste_2024-01-15_16-52-32

How to solve this problem?

I want to reproduce the experiment, and I have trained a group classifier:
01
But when I try to find a threshold and save the predictions of group classifier according to the README, I met the problem as follow:
02
How to solve this problem? Thank you!

can not use lbc

a6465ec272f0e89eb632dcddc7aa2a5
I see the tutuial using lbc.
And I tried this command:

python main.py --model resnet18 --method lbc \
    --dataset celeba \
    --version cgl \
    --sv 0.8

But I got this error below:
image

Another question is that the arg --iter doesn't appear in main.py so I wonder the usage of --iter.
image
image

using `main_groupclf.py` train adult, loss contain NaN.

image
The nan loss made the acc a rapid decrease.

And then I use command:

python main_groupclf.py --model mlp --method scratch \
    --dataset adult \
    --mode eval \
    --version groupclf_val \
    --sv 0.8

got this:

WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : test
# of 0 group data :  [10421  1330]
# of 1 group data :  [16763  7663]
preprocessing for ssl...
count the number of data newly!
mode : test
# of 0 group data :  [2084  265]
# of 1 group data :  [3352 1532]
WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : train
# of 0 group data :  [10421  1330]
# of 1 group data :  [16763  7663]
preprocessing for ssl...
count the number of data newly!
mode : train
# of 0 group data :  [6669  852]
# of 1 group data :  [10728  4904]
# of test data : 7233
# of train data : 23153
Dataset loaded.
# of classes, # of groups : 2, 2
Evaluation ----------------
WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : train
# of 0 group data :  [10421  1330]
# of 1 group data :  [16763  7663]
test acc : 0.32476150974699297
val acc : 0.3248143671213953
11751 24426
Traceback (most recent call last):
  File "/data1/cgl_fairness/main_groupclf.py", line 328, in <module>
    main()
  File "/data1/cgl_fairness/main_groupclf.py", line 306, in main
    opt_thres = predict_thres(results['probs'], results['true_idxs'], results['false_idxs'], val_idxs)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/cgl_fairness/main_groupclf.py", line 98, in predict_thres
    r = measure_ood(val_true_maxprob.numpy(), val_false_maxprob.numpy())
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/cgl_fairness/main_groupclf.py", line 71, in measure_ood
    fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/_param_validation.py", line 214, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/metrics/_ranking.py", line 1095, in roc_curve
    fps, tps, thresholds = _binary_clf_curve(
                           ^^^^^^^^^^^^^^^^^^
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/metrics/_ranking.py", line 810, in _binary_clf_curve
    assert_all_finite(y_score)
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 200, in assert_all_finite
    _assert_all_finite(
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 122, in _assert_all_finite
    _assert_all_finite_element_wise(
  File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 171, in _assert_all_finite_element_wise
    raise ValueError(msg_err)
ValueError: Input contains NaN.

So the "nan" in the loss prevents us from proceeding to the next step.

Can not use mfd(missing arg)

I use the command below:

# train a scratch model
!python main.py --model mlp --method scratch --dataset compas
!python main.py --model mlp --method mfd \
    --dataset compas \
    --labelwise \
    --version cgl \
    --sv 0.8 \
    --lamb 100 \
    --teacher-path './trained_models/20200101/compas/scratch/mlp_seed0_epochs50_bs128_lr0.001.pt'

got this error:
image
I have seen the reason below:

!python main.py --model mlp --method mfd \
    --dataset compas \
    --labelwise \
    --version cgl \
    --sv 0.8 \
    --lamb 100 \
    --teacher-type mlp \
    --teacher-path './trained_models/20200101/compas/scratch/mlp_seed0_epochs50_bs128_lr0.001.pt'

This command is OK. So the README might should be updated.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.