Git Product home page Git Product logo

keras-deeplab-v3-plus's Introduction

Keras implementation of Deeplabv3+

DeepLab is a state-of-art deep learning model for semantic image segmentation.

Model is based on the original TF frozen graph. It is possible to load pretrained weights into this model. Weights are directly imported from original TF checkpoint.

Segmentation results of original TF model. Output Stride = 8




Segmentation results of this repo model with loaded weights and OS = 8
Results are identical to the TF model




Segmentation results of this repo model with loaded weights and OS = 16
Results are still good




How to get labels

Model will return tensor of shape (batch_size,height,width,classes). To obtain labels, you need to apply argmax to logits at exit layer. Example of predicting on image1.jpg:

from matplotlib import pyplot as plt
import cv2 # used for resize. if you dont have it, use anything else
import numpy as np
from model import Deeplabv3
deeplab_model = Deeplabv3()
img = plt.imread("imgs/image1.jpg")
w, h, _ = img.shape
ratio = 512. / np.max([w,h])
resized = cv2.resize(img,(int(ratio*h),int(ratio*w)))
resized = resized / 127.5 - 1.
pad_x = int(512 - resized.shape[0])
resized2 = np.pad(resized,((0,pad_x),(0,0),(0,0)),mode='constant')
res = deeplab_model.predict(np.expand_dims(resized2,0))
labels = np.argmax(res.squeeze(),-1)
plt.imshow(labels[:-pad_x])

How to use this model with custom input shape and custom number of classes:

from model import Deeplabv3
deeplab_model = Deeplabv3(input_shape=(384,384,3), classes=4)  

After that you will get a usual Keras model which you can train using .fit and .fit_generator methods

How to train this model:

You can find a lot of usefull parameters in original repo: https://github.com/tensorflow/models/blob/master/research/deeplab/train.py
Important notes:

  1. This model don't have default weight decay, you need to add it yourself
  2. Xception backbone should be trained with OS=16, and only inferenced with OS=8
  3. You can freeze feature extractor for Xception backbone (first 356 layers) and only fine-tune decoder
  4. If you want to train BN layers too, use batch size of at least 12 (16+ is even better)

How to load model

In order to load model after using model.save() use this code:

from model import relu6, BilinearUpsampling
deeplab_model = load_model('example.h5',custom_objects={'relu6':relu6,'BilinearUpsampling':BilinearUpsampling })

Xception vs MobileNetv2

There are 2 available backbones. Xception backbone is more accurate, but has 25 times more parameters than MobileNetv2. For MobileNetv2 there are pretrained weights only for alpha==1., but you can initiate model with different values of alpha.

Requirement (it may work with lower versions too, but not guaranteed)

Keras==2.1.5
tensorflow-gpu==1.6.0
CUDA==9.0

keras-deeplab-v3-plus's People

Contributors

bonlime avatar udayakumar97 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.