Git Product home page Git Product logo

segment_anything_keras's Introduction

Segment Anything Model in Multi-Backend Keras

This is an implementation of the Segment Anything predictor and automatic mask generator in Keras 3.

The demos uses KerasCV's Segment Anything model. Note that we depend on the KerasCV's source directly until v0.7.0 has branched.

Install the package

pip install git+https://github.com/tirthasheshpatel/segment_anything_keras.git

Install the required dependencies:

pip install Pillow numpy keras-nightly git+https://github.com/keras-team/keras-cv.git

Install TensorFlow, JAX, or PyTorch, whichever backend you'd like to use.

To get all the dependencies and all the backends to run the demos, do:

pip install -r requirements.txt

Getting the pretrained Segment Anything Model

# Use TensorFlow backend, choose any you want
import os
os.environ['KERAS_BACKEND'] = "tensorflow"

from keras_cv.models import SegmentAnythingModel
from sam_keras import SAMPredictor

# Get the huge model trained on the SA-1B dataset.
# Other available options are:
#   - "sam_base_sa1b"
#   - "sam_large_sa1b"
model = SegmentAnythingModel.from_preset("sam_huge_sa1b")

# Create the predictor
predictor = SAMPredictor(model)

# Now you can use the predictor just like the one on the original repo.
# The only difference is list of input dicts isn't supported; instead
# pass each input dict separately to the `predict` method.

Notes

Right now JAX and TensorFlow have large compile-time overhead. Prompt encoder recompiles each time a different combination of prompts (points only, points + boxes, boxes only, etc) is passed. To avoid this, compile the model with run_eagerly=True.

Benchmarks

All the benchmarks were run in Colab with following configurations:

  • For A100: 40 GB GPU RAM, 51 GB CPU RAM
  • For V100: 16 GB GPU RAM, 51 GB CPU RAM
Model Device End-To-End Huge End-to-End Large End-to-End Base Fixed Image
PyTorch Native A100 445 ms ± 4.76 ms 272 ms ± 3.73 ms 126 ms ± 624 µs 8.54 ms ± 53.2 µs
PyTorch (Keras 3) A100 482 ms ± 1.86 ms 293 ms ± 1.82 ms 146 ms ± 907 µs 36.4 ms ± 424 µs
TensorFlow (Keras 3) A100 N/A N/A N/A N/A
JAX (Keras 3) A100 125 ms ± 476 µs 84.8 ms ± 193 µs 44.2 ms ± 210 µs 6.78 ms ± 135 µs
PyTorch Native V100 585 ms ± 3.67 ms 339 ms ± 1.2 ms 153 ms ± 575 µs 8.54 ms ± 266 µs
PyTorch (Keras 3) V100 616 ms ± 1.22 ms 365 ms ± 2.52 ms 153 ms ± 575 µs 37.6 ms ± 1.09 ms
TensorFlow (Keras 3) V100 N/A N/A N/A N/A
Jax (Keras 3) V100 545 ms ± 3.02 ms 313 ms ± 1.07 ms 125 ms ± 441 µs 7.17 ms ± 101 µs

segment_anything_keras's People

Contributors

tirthasheshpatel avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.