Comments (1)
Adapted from: https://www.tensorflow.org/beta/tutorials/distribute/multi_worker_with_keras
from future import absolute_import, division, print_function, unicode_literals
def get_local_env(cmd_str):
import subprocess
out = subprocess.Popen(cmd_str, stdout=subprocess.PIPE, shell=True)
return out.stdout.read().decode('utf-8').strip()
def get_runtime_env():
print("----------------------\n")
print("libjvm : " + get_local_env("find /apache/ -name libjvm.so"))
print("----------------------\n")
print("libhdfs.so : " + get_local_env("find /apache/ -name libhdfs.so"))
print("----------------------\n")
print("java related : " + get_local_env("echo $JAVA_HOME; echo $JRE_HOME; echo $JDK_HOME"))
print("---------------------\n")
print("classpath : " + get_local_env("echo $CLASSPATH"))
print("---------------------\n")
print(get_local_env("ls -alh /apache | grep hadoop"))
print("---------------------\n")
print(get_local_env("/apache/hadoop/bin/hadoop classpath --glob"))
print("---------------------\n")
print(get_local_env("ls -alh /apache/ | grep hadoop"))
def main_fun(args, ctx):
"""Example demonstrating loading TFRecords directly from disk (e.g. HDFS) without tensorflow_datasets."""
import tensorflow as tf
import tensorflow_io as tfio
from tensorflowonspark import compat
get_runtime_env()
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
BUFFER_SIZE = args.buffer_size
BATCH_SIZE = args.batch_size
NUM_WORKERS = args.cluster_size
def parse_tfos(example_proto):
feature_def = {"label": tf.io.FixedLenFeature(1, tf.int64),
"image": tf.io.FixedLenFeature(28 * 28 * 1, tf.int64)}
features = tf.io.parse_single_example(example_proto, feature_def)
image = tf.cast(features['image'], tf.float32) / 255
image = tf.reshape(image, (28, 28, 1))
label = tf.cast(features['label'], tf.int32)
return (image, label)
image_pattern = ctx.absolute_path(args.images_labels)
print("image_pattern is {0}".format(image_pattern))
ds = tf.data.Dataset.list_files(image_pattern)
ds = ds.repeat(args.epochs).shuffle(BUFFER_SIZE)
ds = ds.interleave(lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP'))
train_datasets_unbatched = ds.map(parse_tfos)
def build_and_compile_cnn_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(
loss=tf.keras.losses.sparse_categorical_crossentropy,
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
metrics=['accuracy'])
return model
# single node
# single_worker_model = build_and_compile_cnn_model()
# single_worker_model.fit(x=train_datasets, epochs=3)
# Here the batch size scales up by number of workers since
# `tf.data.Dataset.batch` expects the global batch size. Previously we used 64,
# and now this becomes 128.
GLOBAL_BATCH_SIZE = BATCH_SIZE * NUM_WORKERS
train_datasets = train_datasets_unbatched.batch(GLOBAL_BATCH_SIZE)
# this fails
# callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=args.model_dir)]
tf.io.gfile.makedirs(args.model_dir)
filepath = args.model_dir + "/weights-{epoch:04d}"
callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=filepath, verbose=1, save_weights_only=True)]
# Note: if you part files have an uneven number of records, you may see an "Out of Range" exception
# at less than the expected number of steps_per_epoch, because the executor with least amount of records will finish first.
steps_per_epoch = 60000 / GLOBAL_BATCH_SIZE
with strategy.scope():
multi_worker_model = build_and_compile_cnn_model()
multi_worker_model.fit(x=train_datasets, epochs=args.epochs, steps_per_epoch=steps_per_epoch, callbacks=callbacks)
compat.export_saved_model(multi_worker_model, args.export_dir, ctx.job_name == 'chief')
if name == 'main':
import argparse
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
import socket
notebook_ip = socket.gethostbyname(socket.gethostname())
notebook_port = "30202"
conf = SparkConf()
conf.setAppName("mnist_keras_tfrecord")
conf.set("spark.driver.host", notebook_ip)
conf.set("spark.driver.port", notebook_port)
sc = SparkContext(conf=conf)
executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=64)
parser.add_argument("--buffer_size", help="size of shuffle buffer", type=int, default=10000)
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--epochs", help="number of epochs", type=int, default=3)
parser.add_argument("--images_labels", help="HDFS path to MNIST image_label files in parallelized format")
parser.add_argument("--model_dir", help="path to save model/checkpoint", default="mnist_model")
parser.add_argument("--export_dir", help="path to export saved_model", default="mnist_export")
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
args = parser.parse_args()
print("args:", args)
from tensorflowonspark import TFCluster
cluster = TFCluster.run(sc, main_fun, args, args.cluster_size, num_ps=0, tensorboard=args.tensorboard,
input_mode=TFCluster.InputMode.TENSORFLOW, master_node='chief')
cluster.shutdown()
from tensorflowonspark.
Related Issues (20)
- Writing checkpoints to HDFS takes long HOT 2
- when using mnist_spark.py , serializer.dump_stream Timeout while feeding partition HOT 2
- pkg_resources.DistributionNotFound: The 'tensorflow' distribution was not found and is required by the application HOT 3
- MNIST example - Exception in TF background thread HOT 2
- the doubt about the data policy HOT 1
- Performance issues in the program HOT 2
- Performance issues in examples/mnist/estimator (by P3) HOT 3
- Retaining original columns after inference HOT 2
- tensorflow.python.framework.errors_impl.UnimplementedError: File system scheme 'cosn' not implemented HOT 2
- Model Saved with TF-2.5.0 HOT 3
- How to integrate a model into Spark cluster HOT 12
- Get stuck at "Added broadcast_0_piece0 in memory on" while runing Spark standalone cluster HOT 1
- ExitCode: 13 executing mnist_data_setup.py on a yarn cluster HOT 3
- can it run on tensorflow-cpu? HOT 1
- can it run use ParameterServerStrategy HOT 3
- do we support scala & java code write tensorflow model with tenorflow-core-api ? HOT 3
- Evalator hangs while training HOT 1
- yarn mode error HOT 1
- I have been trying to use TensorFlowOnSpark in Azure Synapse Analytics and I would like to ask if you have any information about its compatibility in this environment
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tensorflowonspark.