Git Product home page Git Product logo

Comments (1)

jordanFisherYzw avatar jordanFisherYzw commented on May 24, 2024

Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras

from future import absolute_import, division, print_function, unicode_literals

def get_local_env(cmd_str):
import subprocess
out = subprocess.Popen(cmd_str, stdout=subprocess.PIPE, shell=True)
return out.stdout.read().decode('utf-8').strip()

def get_runtime_env():
print("----------------------\n")
print("libjvm : " + get_local_env("find /apache/ -name libjvm.so"))
print("----------------------\n")
print("libhdfs.so : " + get_local_env("find /apache/ -name libhdfs.so"))
print("----------------------\n")
print("java related : " + get_local_env("echo $JAVA_HOME; echo $JRE_HOME; echo $JDK_HOME"))
print("---------------------\n")
print("classpath : " + get_local_env("echo $CLASSPATH"))
print("---------------------\n")
print(get_local_env("ls -alh /apache | grep hadoop"))
print("---------------------\n")
print(get_local_env("/apache/hadoop/bin/hadoop classpath --glob"))
print("---------------------\n")
print(get_local_env("ls -alh /apache/ | grep hadoop"))

def main_fun(args, ctx):
"""Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets."""
import tensorflow as tf
import tensorflow_io as tfio
from tensorflowonspark import compat
get_runtime_env()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
BUFFER_SIZE = args.buffer_size
BATCH_SIZE = args.batch_size
NUM_WORKERS = args.cluster_size

def parse_tfos(example_proto):
    feature_def = {"label": tf.io.FixedLenFeature(1, tf.int64),
                   "image": tf.io.FixedLenFeature(28 * 28 * 1, tf.int64)}
    features = tf.io.parse_single_example(example_proto, feature_def)
    image = tf.cast(features['image'], tf.float32) / 255
    image = tf.reshape(image, (28, 28, 1))
    label = tf.cast(features['label'], tf.int32)
    return (image, label)

image_pattern = ctx.absolute_path(args.images_labels)
print("image_pattern is {0}".format(image_pattern))
ds = tf.data.Dataset.list_files(image_pattern)
ds = ds.repeat(args.epochs).shuffle(BUFFER_SIZE)
ds = ds.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP'))
train_datasets_unbatched = ds.map(parse_tfos)

def build_and_compile_cnn_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    model.compile(
        loss=tf.keras.losses.sparse_categorical_crossentropy,
        optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
        metrics=['accuracy'])
    return model

# single node
# single_worker_model = build_and_compile_cnn_model()
# single_worker_model.fit(x=train_datasets, epochs=3)

# Here the batch size scales up by number of workers since
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64,
# and now this becomes 128.
GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_WORKERS
train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE)

# this fails
# callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)]
tf.io.gfile.makedirs(args.model_dir)
filepath = args.model_dir + "/weights-{epoch:04d}"
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)]

# Note: if you part files have an uneven number of records, you may see an "Out of Range" exception
# at less than the expected number of steps_per_epoch, because the executor with least amount of records will finish first.
steps_per_epoch = 60000 / GLOBAL_BATCH_SIZE

with strategy.scope():
    multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)

compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief')

if name == 'main':
import argparse
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
import socket

notebook_ip = socket.gethostbyname(socket.gethostname())
notebook_port = "30202"
conf = SparkConf()
conf.setAppName("mnist_keras_tfrecord")
conf.set("spark.driver.host", notebook_ip)
conf.set("spark.driver.port", notebook_port)
sc = SparkContext(conf=conf)
executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1

parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64)
parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000)
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--epochs", help="number of epochs", type=int, default=3)
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--model_dir", help="path to save model/checkpoint", default="mnist_model")
parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export")
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")

args = parser.parse_args()
print("args:", args)

from tensorflowonspark import TFCluster

cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard,
                        input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
cluster.shutdown()

from tensorflowonspark.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.