Git Product home page Git Product logo

oralcancer_classification's Introduction

Binary Classification Model Training and Evaluation

This script trains and evaluates machine learning models for binary classification tasks. It supports multiple classification models, including Random Forest, KNN, SVM, Logistic Regression, and XGBoost.

Features

  • Train and evaluate multiple models for binary classification.
  • Plot ROC curves and calculate AUC scores for model comparison.
  • Calculate and plot feature importance using XGBoost models.
  • Generate SHAP summary plots for model interpretability.

Requirements

To run this script, you will need the following packages:

  • pandas
  • scikit-learn
  • xgboost
  • matplotlib
  • joblib
  • shap
  • numpy

You can install these packages using the following command:

pip install pandas scikit-learn xgboost matplotlib joblib shap numpy

Usage

To use this script, you need to provide a CSV data file and specify the feature columns. The data file should contain a column named "Label" for the target labels and other columns for the features.

Example command:

python binary_classification.py data.csv 1-5

In this example, data.csv is the path to the CSV data file, and 1-5 specifies that columns 1 to 5 (inclusive) are the feature columns.

Input

  • data_file: Path to the CSV data file.
  • feature_columns: Feature column range (e.g., 1,3,6 or 1-5 or 3).

Output

  • Models trained for each class pair are saved in the models directory.
  • ROC curves are saved in the plots directory.
  • Feature importance plots are saved in the plots directory.
  • SHAP summary plots are saved in the plots directory.
  • A CSV file named model_performance_results.csv containing the performance metrics of the trained models.

Structure

The main components of the script are:

  • create_binary_labels(): Creates binary labels for a given class pair.
  • split_data(): Splits the data into training and testing sets.
  • plot_auc_combined(): Plots the ROC curves for all models.
  • train_and_evaluate_all(): Trains and evaluates all models for a given class pair.
  • get_feature_importance(): Retrieves the feature importance from an XGBoost model.
  • plot_shap_values(): Plots the SHAP values for an XGBoost model.
  • main(): The main function that orchestrates the training and evaluation process.

Note

  • The script assumes that the target column in the data file is named "Label".

oralcancer_classification's People

Contributors

kxenak avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.