Git Product home page Git Product logo

gtsrb_resnet's Introduction

German Traffic Sign Recognition Challenge using ResNets

This repository contains a simple, light and high accuracy model for the German Traffic Sign Recognition Benchmark (GTSRB) dataset. This model was designed and trained for the NYU's Fall 2018 Computer Vision course competition in Kaggle. All training was done using GPUs in NYU's Prince cluster.

The baseline code for training and producing predictions was obtained here and modified in this repository.

The Residual Network implemented in model48.py achieved 99.02% accuracy as a single model in the test set of GTSRB. Although not being the record score on this dataset, this model can easily be trained in under 10 minutes in a single GPU. In its first epoch, it achieves already ~93% of accuracy in the validation set generated by data48.py.

(Please note: to proceed with training, please make sure you have PyTorch properly installed)

Training the model:

  1. First download the dataset from GTSRB
  2. Now run the following command:
python main48_cuda.py --data='<folder-with-data-zips>' --epochs=20 --batch-size=64 --lr=0.01 --wd=0.8 --momentum=0.9

(The code will unzip the dataset for you and create train/validation/test folders)

  1. The code saves a checkpoint of the model after each training epoch as model_##.pth and runs it through the validation set, where you can see the validation accuracy. Also, it saves the training and validation losses per epoch in the file losses.p, to visualize it, run the following code:
import pickle
import matplotlib.pyplot as plt

losses = pickle.load(open('losses.p', 'rb'))

epochs = [e[0] for e in losses]
training_loss = [e[1] for e in losses]
val_loss = [e[2] for e in losses]

plt.plot(epochs, training_loss,  label="Training loss")
plt.plot(epochs, val_loss, label="Validation loss")
plt.xlabel("Epochs")
plt.ylabel("Error")
plt.legend()
plt.show()

Which will produce a graph like this:

resnet training validation loss

Generating Predictions:

  1. After selecting the saved checkpoint (e.g. model_16.pth) that you want to evaluate in the test set, run the following command:
python evaluate48_cuda.py --data='<folder-with-data-zips>' --model='<chosen-model.pth>'
  1. This will produce the file gtsrb_kaggle.csv, which contains the predictions of your model.

gtsrb_resnet's People

Contributors

mmoraes-rafael avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Forkers

pharath

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.