Git Product home page Git Product logo

pytorchlite's Introduction

pytorch_lite

  • flutter package to help run pytorch lite models classification and YoloV5 and YoloV8.

example for Classification

image

example for Object detection

image

Usage

preparing the model

  • classification
import torch
from torch.utils.mobile_optimizer import optimize_for_mobile


model = torch.load('model_scripted.pt',map_location="cpu")
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("model.pt")
  • object detection (yolov5)
!python export.py --weights "the weights of your model" --include torchscript --img 640 --optimize

example

!python export.py --weights yolov5s.pt --include torchscript --img 640 --optimize
  • object detection (yolov8)
!yolo mode=export model="your model" format=torchscript optimize

example

!yolo mode=export model=yolov8s.pt format=torchscript optimize

Installation

To use this plugin, add pytorch_lite as a dependency in your pubspec.yaml file.

Create a assets folder with your pytorch model and labels if needed. Modify pubspec.yaml accordingly.

assets:
  - assets/models/model_classification.pt
  - assets/labels_classification.txt
  - assets/models/model_objectDetection.torchscript
  - assets/labels_objectDetection.txt

Run flutter pub get

For release

  • Go to android/app/build.gradle
  • Add those next lines in the release config
shrinkResources false
minifyEnabled false

example

    buildTypes {
        release {
            shrinkResources false
            minifyEnabled false
            // TODO: Add your own signing config for the release build.
            // Signing with the debug keys for now, so `flutter run --release` works.
            signingConfig signingConfigs.debug
        }
    }

Import the library

import 'package:pytorch_lite/pytorch_lite.dart';

Load model

Either classification model:

ClassificationModel classificationModel= await PytorchLite.loadClassificationModel(
          "assets/models/model_classification.pt", 224, 224,
          labelPath: "assets/labels/label_classification_imageNet.txt");

Or objectDetection model:

ModelObjectDetection objectModel = await PytorchLite.loadObjectDetectionModel(
          "assets/models/yolov5s.torchscript", 80, 640, 640,
          labelPath: "assets/labels/labels_objectDetection_Coco.txt",
          objectDetectionModelType: ObjectDetectionModelType.yolov5);

Get classification prediction as label

String imagePrediction = await classificationModel.getImagePrediction(await File(image.path).readAsBytes());

Get classification prediction as label from camera image

String imagePrediction = await _objectModel.getCameraImagePrediction(
        cameraImage,
        rotation, // check example for rotation values
        );

Get classification prediction as raw output layer

List<double>? predictionList = await _imageModel!.getImagePredictionList(
      await File(image.path).readAsBytes(),
    );

Get classification prediction as raw output layer from camera image

List<double>? predictionList = await _imageModel!.getCameraImagePredictionList(
        cameraImage,
        rotation, // check example for rotation values
    );

Get classification prediction as Probabilities (incase model is not using softmax)

List<double>? predictionListProbabilities = await _imageModel!.getImagePredictionListProbabilities(
      await File(image.path).readAsBytes(),
    );

Get classification prediction as Probabilities (incase model is not using softmax)

List<double>? predictionListProbabilities = await _imageModel!.getCameraPredictionListProbabilities(
        cameraImage,
        rotation, // check example for rotation values
    );

Get object detection prediction for an image

 List<ResultObjectDetection> objDetect = await _objectModel.getImagePrediction(await File(image.path).readAsBytes(),
        minimumScore: 0.1, IOUThershold: 0.3);

Get object detection prediction from camera image

 List<ResultObjectDetection> objDetect = await _objectModel.getCameraImagePrediction(
        cameraImage,
        rotation, // check example for rotation values
        minimumScore: 0.1, IOUThershold: 0.3);

Get render boxes with image

objectModel.renderBoxesOnImage(_image!, objDetect)

Image prediction for an image with custom mean and std

final mean = [0.5, 0.5, 0.5];
final std = [0.5, 0.5, 0.5];
String prediction = await classificationModel
        .getImagePrediction(image, mean: mean, std: std);

pytorchlite's People

Contributors

quangthaizota avatar izotadqhai avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.