This repository contains the code for Building a simple Keras + deep learning REST API, published on the Keras.io blog.
The method covered here is intended to be instructional. It is not meant to be production-level and capable of scaling under heavy load. If you're interested in a more advanced Keras REST API that leverages message queues and batching, please refer to this tutorial.
For an even more advanced version that includes deploying a model to production, refer to this blog post.
I assume you already have Keras (and a supported backend) installed on your system. From there you need to install Flask and requests:
$ pip install flask gevent requests
Next, clone the repo:
$ git clone https://github.com/jrosebr1/simple-keras-rest-api.git
Below you can see the image we wish to classify, a dog, but more specifically a beagle:
The Flask + Keras server can be started by running:
$ python run_keras_server.py
Using TensorFlow backend.
* Loading Keras model and Flask starting server...please wait until server has fully started
...
* Running on http://127.0.0.1:5000
You can now access the REST API via http://127.0.0.1:5000
.
Requests can be submitted via cURL:
$ curl -X POST -F [email protected] 'http://localhost:5000/predict'
{
"predictions": [
{
"label": "beagle",
"probability": 0.9901360869407654
},
{
"label": "Walker_hound",
"probability": 0.002396771451458335
},
{
"label": "pot",
"probability": 0.0013951235450804234
},
{
"label": "Brittany_spaniel",
"probability": 0.001283277408219874
},
{
"label": "bluetick",
"probability": 0.0010894243605434895
}
],
"success": true
}
Or programmatically:
$ python simple_request.py
1. beagle: 0.9901
2. Walker_hound: 0.0024
3. pot: 0.0014
4. Brittany_spaniel: 0.0013
5. bluetick: 0.0011
For solving the issue of
model.predict ValueError: Tensor Tensor("fc1000/Softmax:0", shape=(?, 1000), dtype=float32) is not an element of this graph.
import tensorflow as tf
...
with graph.as_default():
preds = model.predict(image)
...
def load_model():
# load the pre-trained Keras model (here we are using a model
# pre-trained on ImageNet and provided by Keras, but you can
# substitute in your own networks just as easily)
global model
model = ResNet50(weights="imagenet")
global graph
graph = tf.get_default_graph()