Git Product home page Git Product logo

data1030-midterm's Introduction

Data1030 Final Project: Injury Prediction for Distance Running Athletes

Project overview:

The goal of this project is to predict whether an athlete will get injured on a given day based on their recent training history. This is a classification problem with the target variable being a binary variable indicating injury or no injury. The motivation behind this project is that sports teams want to prevent injury of their athletes, so it will be useful to know which factors contribute to injury, and signal athletes to take precautions to prevent injury.

I trained 4 different models: Logistic Regression, Random Forest Classifier, SVC and KNN. The most predictive model is Logistic Regression, which gives a mean F_2 score of 0.506 (0.5 standard deviations avobe baseline).

The three biggest challenges I faced while working on this project were working with imbalanced data, dealing with overfitting, and setting up train-val-test sets for a time-series data.

Below is an overview of my procedure and key considerations. Please refer to the final project report, and the code file for more details.

  • EDA:
    • Choosing one athlete to predict: I examined the distributions of feature variables and picked the athlete with the highest fraction of injury out of all events attended to be the athlete my model is trained for. I decided to train a model for only one athlete because not only is my data a time-series, each athlete has a different distribution of training record data so combining these will violate the iid assumption.
  • Feature engineering: The features contain training data for the past 7 days. I added the average and maximum values for the past 7 days as additional features to the dataset. Minimum values over the last 7 days are 0 for all features, so those were not included in the dataset. I also added a lag feature of the target variable, this feature indicates whether the athlete was injured on the previous event day.
  • Splitting: I split the data into training-validation-test sets in chronological order, and performed this splitting in 6 folds so I can test the model performance on different test sets to account for uncertainty. Please see report for more details on splitting.
  • Preprocessing: I used the standard scalar for all the variables because they can all be treated as continuous variables, and most of them do not have a clear upper bound and have a long tail. The lag of the target variable was also transformed using the standard scalar to make it easier to interpret the coefficients of my linear model.
  • Evaluation metric: I chose the F_2 score as my evaluation metric because my dataset is highly imbalanced, and also I would like to capture a large proportion of the conditional positive samples.
  • Training Pipeline: For each algorithm, we trained the model on the training set, used the validation set to choose the est set of hyperparameters, and then predicted on the test set to calculate the test F_2 score. I repeat this process 6 times for each algorithm, giving 6 F_2 scores on 6 different test sets.
  • Results: Below are the results for each algorithm. I chose logistic regression as my final model, but in hindsight I would have chosen random forest because it has significantly lower standard deviation meaning that its performance is more constant across different testing periods.
image
  • Model Interpretation and Feature Importance:
    • Global feature importance: Global feature importance was calculated in three ways: permutation importance, size of the coefficients, and SHAP global importance. In each of these calculations, the lag feature of the target variable was consistently ranked as the most important feature, and its importance is significantly larger than other features, meaning that the model relied most heavily on whether the athlete was injured during the previous event day to predict whether the athlete is injured on a given day. This makes sense as many of the class 1 (injury class) points in the dataset occur on consecutive event days. The one feature that was surprising was ‘perceived recovery’ which has high importance but a positive coefficient, which means that the model is more likely to predict injury if the athlete’s own perception of recovery is high. This suggests that perhaps the coach should not use the athlete’s own judgement of recovery as an indication of whether they will be injured.

      image image
    • Local feature importance: I examined 4 different datapoints (indexed 0, 1, 8 and 14) in the test set for local feature importance using SHAP values, two of which are displayed below.

      • Observation 0 is a true positive and an injury was present on the previous event day, so not surprisingly the lag feature of the target variable that is pushing this datapoint towards being in class 1.image
      • Observation 1 was a false positive, and the prediction for this point was also largely due to the lag of the feature variable. From this observation we can see that the model seems to rely too much on the lag feature, which leads to misclassifying points into class 1 when they are preceded by an event day with injury.image

Python and Package Versions:

Python version 3.10.5
numpy version 1.22.4
matplotlib version 3.5.2
sklearn version 1.1.1
pandas version 1.4.2
xgboost version 1.5.1
shap version 0.40.0

Please refer to the yaml file

License:

Please refer to the license file

data1030-midterm's People

Contributors

selinawaang avatar wangselina avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.