Git Product home page Git Product logo

Comments (7)

Samuel-Maddock avatar Samuel-Maddock commented on May 25, 2024

Hi, great question and apologies for the late reply.

You are correct that in multi-class classification the gradient dimension increases to the number of classes K. Unfortunately, the codebase at the moment doesn’t support DP training straightforwardly in a multi-class setting. In standard (non-private) XGBoost, the simplest approach is that the softmax produces a split-score for each class and K trees are trained (one per class) based on these scores. A similar method can be implemented for the private setting but some care is needed to ensure the privacy accounting is correct.

A naive approach that can be used currently, is to convert your dataset into one-vs-all classification and train separate PrivateGBDT models for each class and split the privacy budget evenly between them e.g., $\varepsilon/K$ each. I will think about adding better support for multi-class classification and leave the issue open for now.

from federated-boosted-dp-trees.

njuptlht avatar njuptlht commented on May 25, 2024

Thanks for your response,I see the option ‘num_class’ in line 7 of train_monitor.py, but after I change it to ‘num_class’=3, the training cannot be completed. Is it the reason for DP (If I don't use DP, will it work?) or the code base does not support multi-classification in any case?

from federated-boosted-dp-trees.

Samuel-Maddock avatar Samuel-Maddock commented on May 25, 2024

This is correct - DP or not, the codebase does not support multi-classification via num_class or using softmax at this time.

You can still perform multi-class classification using the method I have described above but it takes a bit of work.

from federated-boosted-dp-trees.

doczqa avatar doczqa commented on May 25, 2024

Thanks for your response,I see the option ‘num_class’ in line 7 of train_monitor.py, but after I change it to ‘num_class’=3, the training cannot be completed. Is it the reason for DP (If I don't use DP, will it work?) or the code base does not support multi-classification in any case?
老哥你现在找到多分类的代码了吗?麻烦回复我一下,可以交流一波

from federated-boosted-dp-trees.

njuptlht avatar njuptlht commented on May 25, 2024

from federated-boosted-dp-trees.

Samuel-Maddock avatar Samuel-Maddock commented on May 25, 2024

It is straightforward to perform multi-class classification in a one-vs-rest manner without needing to modify the PrivateGBDT code itself.

Below is an example on the Connect4 dataset which has C=3 classes. The example compares XGBoost with PrivateGBDT (eps=0) so the results should be equivalent and indeed, the AUC/accuracy are close.

In order to use DP with $\varepsilon > 0$ the privacy budget must be scaled by C since you are training a GBDT for each class.

I will add automatic multi-class support to the PrivateGBDT class within the next few weeks, but in the meantime you can use the method below.

from federated_gbdt.models.gbdt.private_gbdt import PrivateGBDT
from federated_gbdt.core.loss_functions import SoftmaxCrossEntropyLoss
from experiments.experiment_helpers.data_loader import DataLoader
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.preprocessing import OneHotEncoder
import numpy as np

from xgboost import XGBClassifier

# Load connect4 dataset
dataloader = DataLoader()
X_train, X_test, y_train, y_test = dataloader.load_datasets(
    ["connect_4"], return_dict=False
)[0]
onehot_y_test = OneHotEncoder(sparse_output=False).fit_transform(y_test.reshape(-1, 1))

# XGBoost baseline
xgb = XGBClassifier().fit(X_train, y_train)
xgb_probs = xgb.predict_proba(X_test)
xgb_pred = np.argmax(xgb_probs, axis=1)
print(f"XGBoost AUC - {roc_auc_score(onehot_y_test, xgb_probs)}")
print(f"XGBoost Accuracy - {accuracy_score(y_test, xgb_pred)}")
print("\n")

# PrivateGBDT (eps=0, non-private)
C = len(np.unique(y_train))  # C=3 classes for connect4
total_eps = 0
 # scale privacy budget, here we have eps=0 (non-private) so scaling not needed
class_eps = total_eps/C
class_probs = []
for c in range(0, C):
    print(f"Training model... class {c} vs all")
    dp_method = "" if class_eps == 0 else "gaussian_cdp"
    xgb_model = PrivateGBDT(num_trees=100, epsilon=class_eps, dp_method=dp_method)
    y_train_c = (y_train == c).astype(int) # one-vs-all for class c
    xgb_model = xgb_model.fit(X_train, y_train_c)
    class_probs.append(xgb_model.predict_proba(X_test)[:, 1])
y_probs = SoftmaxCrossEntropyLoss().predict(np.array(list(zip(*class_probs))))
y_pred = np.argmax(y_probs, axis=1)
print(f"PrivateGBDT (epsilon={total_eps}) AUC - {roc_auc_score(onehot_y_test, y_probs)}")
print(f"PrivateGBDT (epsilon={total_eps}) Accuracy - {accuracy_score(y_test, y_pred)}")

from federated-boosted-dp-trees.

Samuel-Maddock avatar Samuel-Maddock commented on May 25, 2024

Basic support for multi-class has been added in 91a7d6d using the basic composition approach outlined here w.r.t privacy.

This assumes class labels y are encoded in the range [0, num_classes-1]. An example on Connect4 has also been added.

from federated-boosted-dp-trees.

Related Issues (4)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.