Git Product home page Git Product logo

simple_cnn's Introduction

Simple CNN

Simple CNN is a pipeline which can be use to train and infer CNN models by use of PyTorch and ONNX. It's simple and easy to USE !!! πŸ”₯πŸ”₯


Install

  • Clone the repo and install requirements.txt in a Python environment
    git clone https://github.com/LahiRumesh/simple_cnn.git
    cd simple_cnn
    pip install -r requirements.txt

Data Preparation

  • Split images into train and val folders with each class the Image Folder πŸ“‚.. i.e for cat vs dogs classification, there should be a cat folder and dog folder in both train and val. The following folder structure illustrates 3 classes
β”œβ”€β”€ Image_Folder
     β”œβ”€β”€ train
     β”‚   │───── class1
     β”‚   β”‚     β”œβ”€β”€ class1.0.jpg
     β”‚   β”‚     β”œβ”€β”€ class1.1.jpg
     β”‚   β”‚     β”œβ”€β”€ class1.2.jpg
     β”‚   β”‚     β”œβ”€β”€ .........
     β”‚   β”‚     └── class1.500.jpg
     β”‚   β”‚
     β”‚   │───── class2
     β”‚   β”‚     β”œβ”€β”€ class2.0.jpg
     β”‚   β”‚     β”œβ”€β”€ class2.1.jpg
     β”‚   β”‚     β”œβ”€β”€ class2.2.jpg
     β”‚   β”‚     β”œβ”€β”€ .........
     β”‚   β”‚     └── class2.500.jpg
     β”‚   β”‚
     β”‚   └───── class3
     β”‚          β”œβ”€β”€ class3.0.jpg
     β”‚          β”œβ”€β”€ class3.1.jpg
     β”‚          β”œβ”€β”€ class3.2.jpg
     β”‚          β”œβ”€β”€ .........
     β”‚          └── class3.500.jpg   
     β”‚
     └── val
         │───── class1
         β”‚     β”œβ”€β”€ class1.501.jpg
         β”‚     β”œβ”€β”€ class1.502.jpg
         β”‚     β”œβ”€β”€ class1.503.jpg
         β”‚     β”œβ”€β”€ .........
         β”‚     └── class1.600.jpg
         β”‚
         │───── class2
         β”‚     β”œβ”€β”€ class2.501.jpg
         β”‚     β”œβ”€β”€ class2.502.jpg
         β”‚     β”œβ”€β”€ class2.503.jpg
         β”‚     β”œβ”€β”€ .........
         β”‚     └── class2.600.jpg
         β”‚
         └───── class3
               β”œβ”€β”€ class3.501.jpg
               β”œβ”€β”€ class3.502.jpg
               β”œβ”€β”€ class3.503.jpg
               β”œβ”€β”€ .........
               └── class3.600.jpg

Training

After the data preparation, it's time for the training !

  • Use the config.py to set the parameters, here are few parameters.
   cfg.data_dir = 'Data/Images/Image_Folder' # Image Folder path which contain train and val folders 
   cfg.device = '0' # cuda device, i.e. 0 or 0,1,2,3    
   cfg.image_size = 224 #input image size
   cfg.batch_size = 8 # batch size
   cfg.epochs = 50 #number of epochs

   cfg.model = 'resnet18' # torch vision classification model architectures for image classification 
                          # i.e. resnet18 or vgg16, alexnet, densenet121, squeezenet1_0

   cfg.pretrained = True  # use pretrained weights for training

   #Early Stopping
   cfg.use_early_stopping = True # use Early stopping
   cfg.patience = 8 # how many epochs to wait before stopping when accuracy is not improving
   cfg.min_delta = 0.0001 # minimum difference between new accuracy and old accuracy for new accuracy to be considered as an improvement                   
  • Here are the Available pre-trained models in Simple CNN

    Architectures Available Models
    Resnet resnet18, resnet34, resnet50, resnet101, resnet152
    VGG vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn
    Densenet densenet121, densenet169, densenet161 , densenet201
    Squeezenet squeezenet1_0, squeezenet1_1
    Alexnet alexnet

Run cnn_train.py to start the training, all the logs will save in wandb, and ONNX weight files will save in the "models/Image_Folder" folder for each training experiment with the model name.


Inference

  • After the training process, use the exported ONNX model for inference using cnn_inference.py
python cnn_inference.py --model_path=models/ImageFolder/ImageFolder_resnet18_exp_1.onnx --class_path=models/ImageFolder/classes.txt --img_path=test1.jpg --image_size=224 --use_transform=True
'''
 Args:
'''
   --model_path :  ONNX model path
   --class_path : Class file (classes.txt) path contain class names
   --img_path  : Input image path
   --image_size : input image size
   --show_image : Display the image
   --use_transform : Use image transforms in pre-processing step (During the training, process images are Normalize with a mean and standard deviation)                 

Calculate Test Accuracy

  • Use the test_accuracy.py to calculate the ONNX model accuracy on the test data.
python test_accuracy.py --model_path=models/ImageFolder/ImageFolder_resnet18_exp_1.onnx --class_path=models/ImageFolder/classes.txt --img_dir=Image_Folder/test --image_size=224 --use_transform=True

The following illustrates 3 classes of the test image folder

β”œβ”€β”€ Image_Folder
     β”œβ”€β”€ test
        │───── class1
        β”‚     β”œβ”€β”€ class1.0.jpg
        β”‚     β”œβ”€β”€ class1.1.jpg
        β”‚     β”œβ”€β”€ class1.2.jpg
        β”‚     β”œβ”€β”€ .........
        β”‚     └── class1.500.jpg
        β”‚
        │───── class2
        β”‚     β”œβ”€β”€ class2.0.jpg
        β”‚     β”œβ”€β”€ class2.1.jpg
        β”‚     β”œβ”€β”€ class2.2.jpg
        β”‚     β”œβ”€β”€ .........
        β”‚     └── class2.500.jpg
        β”‚
        └───── class3
               β”œβ”€β”€ class3.0.jpg
               β”œβ”€β”€ class3.1.jpg
               β”œβ”€β”€ class3.2.jpg
               β”œβ”€β”€ .........
               └── class3.500.jpg   
     

All the test resultsΒ will save in the folder "test_results" folder for each test experiment.

'''
 Args:
'''
   --model_path :  ONNX model path
   --class_path : Class file (classes.txt) path contain class names
   --img_dir  : Test images folder path
   --image_size : input image size
   --use_transform : Use image transforms in pre-processing step (During the training, process images are Normalize with a mean and standard deviation)                 

Reference:

simple_cnn's People

Contributors

lahirumesh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.