naver-ai / cgl_fairness Goto Github PK
View Code? Open in Web Editor NEWLicense: Other
License: Other
Hello author, I have some questions about the experimental data processing of this paper. Taking the ADULT dataset as an example, I have two main questions:
The nan
loss made the acc a rapid decrease.
And then I use command:
python main_groupclf.py --model mlp --method scratch \
--dataset adult \
--mode eval \
--version groupclf_val \
--sv 0.8
got this:
WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : test
# of 0 group data : [10421 1330]
# of 1 group data : [16763 7663]
preprocessing for ssl...
count the number of data newly!
mode : test
# of 0 group data : [2084 265]
# of 1 group data : [3352 1532]
WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : train
# of 0 group data : [10421 1330]
# of 1 group data : [16763 7663]
preprocessing for ssl...
count the number of data newly!
mode : train
# of 0 group data : [6669 852]
# of 1 group data : [10728 4904]
# of test data : 7233
# of train data : 23153
Dataset loaded.
# of classes, # of groups : 2, 2
Evaluation ----------------
WARNING:root:Missing Data: 3620 rows removed from AdultDataset.
mode : train
# of 0 group data : [10421 1330]
# of 1 group data : [16763 7663]
test acc : 0.32476150974699297
val acc : 0.3248143671213953
11751 24426
Traceback (most recent call last):
File "/data1/cgl_fairness/main_groupclf.py", line 328, in <module>
main()
File "/data1/cgl_fairness/main_groupclf.py", line 306, in main
opt_thres = predict_thres(results['probs'], results['true_idxs'], results['false_idxs'], val_idxs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data1/cgl_fairness/main_groupclf.py", line 98, in predict_thres
r = measure_ood(val_true_maxprob.numpy(), val_false_maxprob.numpy())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data1/cgl_fairness/main_groupclf.py", line 71, in measure_ood
fpr, tpr, thresholds = metrics.roc_curve(labels, preds, pos_label=1)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/_param_validation.py", line 214, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/metrics/_ranking.py", line 1095, in roc_curve
fps, tps, thresholds = _binary_clf_curve(
^^^^^^^^^^^^^^^^^^
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/metrics/_ranking.py", line 810, in _binary_clf_curve
assert_all_finite(y_score)
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 200, in assert_all_finite
_assert_all_finite(
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 122, in _assert_all_finite
_assert_all_finite_element_wise(
File "/data1/micromamba/envs/fair/lib/python3.11/site-packages/sklearn/utils/validation.py", line 171, in _assert_all_finite_element_wise
raise ValueError(msg_err)
ValueError: Input contains NaN.
So the "nan" in the loss prevents us from proceeding to the next step.
I use the command below:
# train a scratch model
!python main.py --model mlp --method scratch --dataset compas
!python main.py --model mlp --method mfd \
--dataset compas \
--labelwise \
--version cgl \
--sv 0.8 \
--lamb 100 \
--teacher-path './trained_models/20200101/compas/scratch/mlp_seed0_epochs50_bs128_lr0.001.pt'
got this error:
I have seen the reason below:
!python main.py --model mlp --method mfd \
--dataset compas \
--labelwise \
--version cgl \
--sv 0.8 \
--lamb 100 \
--teacher-type mlp \
--teacher-path './trained_models/20200101/compas/scratch/mlp_seed0_epochs50_bs128_lr0.001.pt'
This command is OK. So the README
might should be updated.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.