Git Product home page Git Product logo

tf_estimator_example's Introduction

Tensorflow Estimator example

Example scripts on how to use Tensorflow's Estimator class.

This repository as an accompanying blogpost at https://medium.com/@peter.roelants/tensorflow-estimator-dataset-apis-caeb71e6e196

The main file of interest will be srs/mnist_estimator.py, which defines an example Estimator to train an network on mnist.

Setup environment

With Anaconda Python:

conda env create -f env.yml
source activate tensorflow

Training

Training locally

After setting up the environment you can run the training locally with:

./src/mnist_estimator.py

Training can be monitored with TensorBoard:

tensorboard --logdir=./mnist_training

After training you can check the inference with:

./src/mnist_inference.py

Training on Google Cloud

  1. Create a [new project in the cloud resource manager as described here. (I named my project mnist-estimator)
  2. Install the Google Cloud SDK
  3. Enable the ML Engine APIs.
  4. Set up a Google Cloud Storage (GCS) bucket as described here. This will be needed to save our model checkpoints. I named my bucket estimator-data.

Run the training job on Google Cloud with:

gcloud ml-engine jobs submit training mnist_estimator_`date +%s` \
    --project mnist-estimator \
    --runtime-version 1.8 \
    --python-version 3.5 \
    --job-dir gs://estimator-data/train \
    --scale-tier BASIC \
    --region europe-west1 \
    --module-name src.mnist_estimator \
    --package-path src/ \
    -- \
    --train-steps 6000 \
    --batch-size 128

Note:

  • Replace gs://estimator-data/ with the link to the bucket you created.
  • Latest Python supported on gcloud is 3.5 (although I'm using 3.6 locally)
  • The --project flag will refer to the gcloud project (mnist-estimator in my case). To avoid using this flag you can set the default project in this case with gcloud config set core/project mnist-estimator.
  • You can feed in arguments to the script by adding an empty -- after the gcloud parameters and adding your custom arguments after, like train-steps and batch-size in this case.
  • Note that the job-dir argument will be fed into the arguments of mnist_estimator. This script should thus always accept this parameter.

You can follow the training with tensorboard by:

tensorboard --logdir=gs://estimator-data/train

However, tensorboard seems to update very slowly when connected to a gcloud bucket. Sometimes it didn't even want to display all data.

After training you can download the checkpoint files from the gcloud bucket.

More info

There is a Google Cloud blogpost going into more detail on training an estimator in the cloud if you're interested.

tf_estimator_example's People

Contributors

peterroelants avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.