Git Product home page Git Product logo

clip-distillation's Introduction

CLIP Knowledge Distillation

This repository contains code and instructions that enable you to create your own customized image classification models with zero-labeled data, by performing knowledge distillation of OpenCLIP models.

Even if you don't need an image classifier directly, you may find this project helpful as inspiration for how you can use knowledge distillation to optimized models for inference, or as an example of how to train models with quantization aware training and structured sparsity for inference on NVIDIA Jetson.

This project includes,

  1. Scripts to search and download relevant data from the LAION database to use for distillation
  2. Scripts to distil any OpenCLIP model to any Pytorch image models (timm) CNN model.
    • Supports Quantization Aware Training (QAT) for downstream INT8 inference
    • Supports training to enforce 2:4 structured sparsity with the ASP library
  3. Scripts to run inference with NVIDIA TensorRT
    • Supports INT8 model
    • Supports acceleration of 2:4 structured sparse models on certain NVIDIA Jetson platforms, like NVIDIA Jetson Orin Nano.

To get started, follow the instructions below.

If you're new to the subject, check out our tutorial jetson-intro-to-distillation for an introduction to knowledge distillation!

Instructions

  1. Step 1 - Search and download relevant unlabeled images to use for distillation
  2. Step 2 - Pre-compute OpenCLIP embeddings
  3. Step 3 - Train the student CNN model to mimic the OpenCLIP model
  4. Step 4 - Run inference using the distilled model
  5. Step 5 (advanced) - Train a student model with structured sparsity
  6. Step 6 (advanced) - Train a student with Quantization aware training and INT8 precision
  7. Next Steps

Step 1 - Search and download images with CLIP filtering

Search for relevant image URLs in the LAION database using CLIP filtering

The first thing we need to do when distilling a model, is obtain data to use for distillation.

For this task, we'll look for relevant images by searching the LAION database. We've provided a script to make this simple.

To search for relevant images, first create a file data/text_prompts.txt with the text prompts to query.

Each prompt should exist on it's own line.

a dog
a cat

Next, call the script to query the images that match the text prompts.

python3 search_clip_images.py \
    "data/text_prompts.txt" \
    "data/image_urls.txt" \
    -n 5000 \
    -m 10000 \
    --max_workers 2 \
    --append

This will output a file data/image_urls.txt that contains the URLs of images matching our text prompt queries.

For the full set of arguments please type

python3 search_clip_images.py --help

Download images from URL file

Now that we've found relevant images to use for distillation, we need to download them.

To do so, we call the following script to download images to an output folder.

python3 download_images.py \
    "data/image_urls.txt" \
    "data/images" \
    --max_workers 32 \
    --timeout 2

This script will download images to the folder data/images. Each image will be given a unique filename base on it's URL.

For the full set of arguments please type

python3 download_images.py --help

Step 2 - Compute OpenCLIP embeddings

The images we downloaded above will be used as inputs to our teacher and student models during distillation. Unfortunately, it can be slow to execute the teacher during training.

To speed up this process, we'll pre-compute the outputs of our teacher model so we don't need to execute the teacher model during training.

To do this, call the compute_openclip_embeddings.py script as follows,

python3 compute_openclip_embeddings.py \
    data/images \
    data/embeddings \
    --batch_size 16 \
    --num_workers 8 \
    --model_name ViT-B-32 \
    --pretrained laion2b_s34b_b79k

This will write the ouptput embeddings to the folder data/embeddings, with filenames that match the image filenames, except for the file extensions.

Note: For available model names and pretrained weight identifiers please reference OpenCLIP Repo.

For the full set of arguments please type

python3 compute_openclip_embeddings.py --help

Step 3 - Train the student CNN model to mimic the OpenCLIP model

Now that we have the data to use for knowledge distillation, we can perform the distillation (student model training) by calling the distil_model_embeddings.py script as follows.

python3 distil_model_embeddings.py \
    resnet18 \
    data/images \
    data/embeddings \
    data/models/resnet18 \
    --output_dim 512 \
    --pretrained

This will output model checkpoints and information to data/models/resnet18.

The distilled model we use in this example is resnet18. This model is highly optimized by TensorRT, and we can readily apply other optimizations like reduced precision and structured sparsity during training. Please see the additional steps below for more information.

For the full set of arguments please type

python3 distil_model_embeddings.py --help

Step 4 - Run inference using the distilled model

Compute text embeddings

During distillation, we trained our student model to match the features of our open-clip model. However, we're interested in creating a classification model.

To create the zero-shot classification model, we need to generate text embeddings from the text prompts that describe our class labels.

To do so, we use the pre-trained OpenCLIP text encoder.

We call the compute_openclip_text_embeddings.py script to create the text embeddings.

python3 compute_openclip_text_embeddings.py \
    data/text_prompts.txt \
    data/text_embeddings.npy \
    --model_name ViT-B-32

In this instance, we used the same text prompts we used for image search as our text prompts for classification.

Predict single image with PyTorch

Now that we have computed th text prompts for our image classes, we can perform image classification with our PyTorch model as follows:

python3 predict_pytorch.py \
    resnet18 \
    data/models/resnet18/checkpoint.pth \
    data/text_embeddings.npy \
    assets/cat.jpg \
    --text_prompts data/text_prompts.txt

Live demo with camera

We can similarily perform inference on a live camera feed as follows:

python3 demo_pytorch.py \
    resnet18 \
    data/models/resnet18/checkpoint.pth \
    data/text_embeddings.npy \
    --text_prompts data/text_prompts.txt \
    --camera_device 0

Step 5 (advanced) - Train a student model with structured sparsity

The training script offers the ability to train for structured sparsity. This can offer additional acceleration when deploying the model on applicable NVIDIA Jetson platforms with TensorRT.

Train the model with structured sparsity

python3 distil_model_embeddings.py \
    resnet18 \
    data/images \
    data/embeddings \
    data/models/resnet18_sparse \
    --output_dim 512 \
    --pretrained \
    --init_checkpoint data/models/resnet18/checkpoint.pth \
    --use_asp \
    --num_epochs 25

Predict with PyTorch

python3 predict_pytorch.py \
    resnet18 \
    data/models/resnet18_sparse/checkpoint.pth \
    data/text_embeddings.npy \
    assets/cat.jpg \
    --text_prompts data/text_prompts.txt \
    --use_asp

Demo with PyTorch

python3 demo_pytorch.py \
    resnet18 \
    data/models/resnet18_sparse/checkpoint.pth \
    data/text_embeddings.npy \
    --text_prompts data/text_prompts.txt \
    --camera_device 0 \
    --use_asp

Export to ONNX

python3 export_onnx.py \
    resnet18 \
    data/models/resnet18_sparse/checkpoint.pth \
    data/onnx/resnet18_sparse.onnx \
    --use_asp

Step 6 (advanced) - Train a student with Quantization aware training and INT8 precision

In addition to structured sparsity, another technique we can use for additional performance is by using reduced INT8 precision. Quantization aware training is a technique to minimize quantization errors introduced when deploying with INT8 precision. It does so by applying quantization during the model forward pass during training. This allows the model to adapt to quantization errors during training. It also allows us to avoid the need for calibration when using post-training quantization.

To distil the model with quantization aware training, follow theses steps

Train the model with quantization aware training (QAT)

python3 distil_model_embeddings.py \
    resnet18 \
    data/images \
    data/embeddings \
    data/models/resnet18_qat \
    --output_dim 512 \
    --pretrained \
    --init_checkpoint data/models/resnet18/checkpoint.pth \
    --use_qat \
    --num_epochs 25

Predict with PyTorch

python3 predict_pytorch.py \
    resnet18 \
    data/models/resnet18_sparse/checkpoint.pth \
    data/text_embeddings.npy \
    assets/cat.jpg \
    --text_prompts data/text_prompts.txt \
    --use_qat

Demo with PyTorch

python3 demo_pytorch.py \
    resnet18 \
    data/models/resnet18_sparse/checkpoint.pth \
    data/text_embeddings.npy \
    --text_prompts data/text_prompts.txt \
    --camera_device 0 \
    --use_qat

Export to ONNX

python3 export_onnx.py \
    resnet18 \
    data/models/resnet18_qat/checkpoint.pth \
    data/onnx/resnet18_qat.onnx \
    --use_qat

Next steps

We hope you found this project helpful and that you were able to train your own image classification model, without using any labeled data.

As a next step, we recommend reading through the source code to see how we used knoweldge distillation in this project. We also recommend reading the source code to see how you can train a model with the convenient libraries in PyTorch for quantization aware training and structured sparsity, for more optimized inference on Jetson.

If you have any questions, or run into any issues, please let us know by opening an issue on GitHub!

clip-distillation's People

Contributors

jaybdub avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

clip-distillation's Issues

Unable to run search_clip_images.py

Hi, I attempt to run search_clip_images.py either in your docker container or out and I get the same error like nothing is going into image_url.txt file:
python3 search_clip_images.py
"data/text_prompts.txt"
"data/image_urls.txt"
-n 5000
-m 10000
--max_workers 2
--append
Found the following 2 text prompts in data/text_prompts.txt
['Drone', 'Helicopter']
Querying images with the following prompts...
Drone
Helicopter
Traceback (most recent call last):
File "/home/user1/ariel/fed_learn/clip-distillation/search_clip_images.py", line 256, in
urls = asyncio.run(
File "/usr/lib/python3.10/asyncio/runners.py", line 44, in run
return loop.run_until_complete(main)
File "/usr/lib/python3.10/asyncio/base_events.py", line 649, in run_until_complete
return future.result()
File "/home/user1/ariel/fed_learn/clip-distillation/search_clip_images.py", line 118, in clip_search_images_by_multi_text
results = await asyncio.gather(*coros)
File "/home/user1/ariel/fed_learn/clip-distillation/search_clip_images.py", line 96, in safe_coro
return await coro
File "/home/user1/ariel/fed_learn/clip-distillation/search_clip_images.py", line 71, in clip_search_images_by_text
item['url'] for item in response.json() if 'url' in item
File "/usr/lib/python3/dist-packages/requests/models.py", line 900, in json
return complexjson.loads(self.text, **kwargs)
File "/usr/lib/python3.10/json/init.py", line 346, in loads
return _default_decoder.decode(s)
File "/usr/lib/python3.10/json/decoder.py", line 337, in decode
obj, end = self.raw_decode(s, idx=_w(s, 0).end())
File "/usr/lib/python3.10/json/decoder.py", line 355, in raw_decode
raise JSONDecodeError("Expecting value", s, err.value) from None
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

FYI, debug raised the following:

The "502 Bad Gateway" error indicates a communication issue between two servers acting as gateways or proxies to fulfill a request. It doesn't specify a problem with your device or internet connection, but rather a fault within the servers involved.

Here's a breakdown of the key points:

Your request reached the first server (nginx) successfully.
Nginx forwarded the request to another upstream server to complete it.
The upstream server failed to respond correctly, causing nginx to return the 502 error.

Possible causes for this error:

Upstream Server Down: The upstream server might be unavailable or experiencing issues.
Overloaded Upstream Server: The upstream server might be overloaded with requests and unable to handle more.
Network Problems: There might be network connectivity issues between nginx and the upstream server.
Misconfiguration: Nginx or the upstream server might be misconfigured.

Distilled model benchmarks?

Thanks for the great work!
It's a very clever way to compute embeddings beforehand and use them directly as target values during backpropagation step.

Questions

  • Have you done any testing to find out, how well the distilled model performs as compared to the original teacher model?
  • If we use Vision Transformer (ViT) models as base, should there be any improvement to embedding quality?
  • Instead of using the distilled model for classification task by computing the probs, How well it performs in case we want to utilize the raw embeddings for ranking the images based on cosine distance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.