Git Product home page Git Product logo

keras_multi_gpu's Introduction

Keras_Multi_GPU

Multi-GPU training for Keras

Usage

from multi_gpu import to_multi_gpu

# Add the one line code before model.compile()
model = to_multi_gpu(model,n_gpus=4) 

keras_multi_gpu

Example

Environment:

  • Keras-2.0.4
  • Tensorflow-1.1.0
  • Ubuntu 14.04
  • 1080Ti, K80

Test on 2 1080Ti GPUs with:

model = to_multi_gpu(model,n_gpus=2) 
$ CUDA_VISIBLE_DEVICES=0,1 python mnist_cnn.py

60000/60000 [==============================] - 9s - loss: 0.3368 - acc: 0.8978 - val_loss: 0.0734 - val_acc: 0.9768
Epoch 2/12
60000/60000 [==============================] - 5s - loss: 0.1134 - acc: 0.9660 - val_loss: 0.0518 - val_acc: 0.9837
Epoch 3/12
60000/60000 [==============================] - 5s - loss: 0.0850 - acc: 0.9746 - val_loss: 0.0419 - val_acc: 0.9856
Epoch 4/12
60000/60000 [==============================] - 5s - loss: 0.0709 - acc: 0.9794 - val_loss: 0.0364 - val_acc: 0.9870
Epoch 5/12
60000/60000 [==============================] - 5s - loss: 0.0626 - acc: 0.9816 - val_loss: 0.0328 - val_acc: 0.9892
Epoch 6/12
60000/60000 [==============================] - 5s - loss: 0.0576 - acc: 0.9828 - val_loss: 0.0316 - val_acc: 0.9888
Epoch 7/12
60000/60000 [==============================] - 5s - loss: 0.0503 - acc: 0.9845 - val_loss: 0.0300 - val_acc: 0.9896
Epoch 8/12
60000/60000 [==============================] - 5s - loss: 0.0461 - acc: 0.9866 - val_loss: 0.0305 - val_acc: 0.9897
Epoch 9/12
60000/60000 [==============================] - 5s - loss: 0.0446 - acc: 0.9864 - val_loss: 0.0269 - val_acc: 0.9905
Epoch 10/12
60000/60000 [==============================] - 5s - loss: 0.0422 - acc: 0.9865 - val_loss: 0.0266 - val_acc: 0.9911
Epoch 11/12
60000/60000 [==============================] - 5s - loss: 0.0393 - acc: 0.9886 - val_loss: 0.0288 - val_acc: 0.9906
Epoch 12/12
60000/60000 [==============================] - 5s - loss: 0.0373 - acc: 0.9889 - val_loss: 0.0292 - val_acc: 0.9906
Test loss: 0.0291657444344
Test accuracy: 0.9906

real    1m16.904s
user    2m30.140s
sys     0m35.792s

Test on 4 K80 GPUs with:

model = to_multi_gpu(model,n_gpus=4) 
$ CUDA_VISIBLE_DEVICES=0,1,2,3 python mnist_cnn.py

60000/60000 [==============================] - 11s - loss: 0.3432 - acc: 0.8956 - val_loss: 0.0774 - val_acc: 0.9766
Epoch 2/12
60000/60000 [==============================] - 7s - loss: 0.1151 - acc: 0.9656 - val_loss: 0.0540 - val_acc: 0.9825
Epoch 3/12
60000/60000 [==============================] - 7s - loss: 0.0873 - acc: 0.9746 - val_loss: 0.0448 - val_acc: 0.9849
Epoch 4/12
60000/60000 [==============================] - 7s - loss: 0.0696 - acc: 0.9793 - val_loss: 0.0390 - val_acc: 0.9866
Epoch 5/12
60000/60000 [==============================] - 7s - loss: 0.0626 - acc: 0.9820 - val_loss: 0.0351 - val_acc: 0.9886
Epoch 6/12
60000/60000 [==============================] - 7s - loss: 0.0545 - acc: 0.9841 - val_loss: 0.0328 - val_acc: 0.9888
Epoch 7/12
60000/60000 [==============================] - 7s - loss: 0.0503 - acc: 0.9847 - val_loss: 0.0334 - val_acc: 0.9897
Epoch 8/12
60000/60000 [==============================] - 7s - loss: 0.0458 - acc: 0.9868 - val_loss: 0.0303 - val_acc: 0.9891
Epoch 9/12
60000/60000 [==============================] - 7s - loss: 0.0428 - acc: 0.9870 - val_loss: 0.0290 - val_acc: 0.9903
Epoch 10/12
60000/60000 [==============================] - 7s - loss: 0.0422 - acc: 0.9877 - val_loss: 0.0301 - val_acc: 0.9905
Epoch 11/12
60000/60000 [==============================] - 7s - loss: 0.0388 - acc: 0.9880 - val_loss: 0.0295 - val_acc: 0.9902
Epoch 12/12
60000/60000 [==============================] - 7s - loss: 0.0377 - acc: 0.9891 - val_loss: 0.0281 - val_acc: 0.9910
Test loss: 0.0280628399889
Test accuracy: 0.991

real    1m49.362s
user    4m58.328s
sys     0m57.112s

Test on one K80 GPU with:

model = to_multi_gpu(model,n_gpus=1) 
$ CUDA_VISIBLE_DEVICES=0 python mnist_cnn.py

60000/60000 [==============================] - 15s - loss: 0.3335 - acc: 0.8989 - val_loss: 0.0744 - val_acc: 0.9765
Epoch 2/12
60000/60000 [==============================] - 10s - loss: 0.1163 - acc: 0.9660 - val_loss: 0.0530 - val_acc: 0.9829
Epoch 3/12
60000/60000 [==============================] - 10s - loss: 0.0856 - acc: 0.9747 - val_loss: 0.0427 - val_acc: 0.9851
Epoch 4/12
60000/60000 [==============================] - 10s - loss: 0.0733 - acc: 0.9777 - val_loss: 0.0384 - val_acc: 0.9873
Epoch 5/12
60000/60000 [==============================] - 10s - loss: 0.0646 - acc: 0.9810 - val_loss: 0.0380 - val_acc: 0.9872
Epoch 6/12
60000/60000 [==============================] - 10s - loss: 0.0577 - acc: 0.9833 - val_loss: 0.0326 - val_acc: 0.9886
Epoch 7/12
60000/60000 [==============================] - 10s - loss: 0.0522 - acc: 0.9845 - val_loss: 0.0321 - val_acc: 0.9885
Epoch 8/12
60000/60000 [==============================] - 10s - loss: 0.0472 - acc: 0.9860 - val_loss: 0.0314 - val_acc: 0.9892
Epoch 9/12
60000/60000 [==============================] - 10s - loss: 0.0448 - acc: 0.9868 - val_loss: 0.0300 - val_acc: 0.9898
Epoch 10/12
60000/60000 [==============================] - 10s - loss: 0.0415 - acc: 0.9876 - val_loss: 0.0307 - val_acc: 0.9898
Epoch 11/12
60000/60000 [==============================] - 10s - loss: 0.0399 - acc: 0.9882 - val_loss: 0.0309 - val_acc: 0.9896
Epoch 12/12
60000/60000 [==============================] - 10s - loss: 0.0371 - acc: 0.9889 - val_loss: 0.0301 - val_acc: 0.9901
Test loss: 0.0301153413276
Test accuracy: 0.9901

real    2m20.542s
user    2m13.728s
sys     0m22.752s

keras_multi_gpu's People

Contributors

kuixu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

keras_multi_gpu's Issues

Incompatible shapes

Hi,
seems like it's not able to split the tensors every time. I get "incompatible shapes".
Perhaps it cannot handle odd lengths?

nvalidArgumentError: Incompatible shapes: [17,258] vs. [35,258]
	 [[Node: sequential_4/gru_3/while/add = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](sequential_4/gru_3/while/strided_slice, sequential_4/gru_3/while/MatMul)]]

Caused by op 'sequential_4/gru_3/while/add', defined at:
  File "/usr/lib/python3.5/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/pawel/venv/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/pawel/venv/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()

(...)

  File "/home/pawel/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/pawel/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

InvalidArgumentError (see above for traceback): Incompatible shapes: [17,258] vs. [35,258]
	 [[Node: sequential_4/gru_3/while/add = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](sequential_4/gru_3/while/strided_slice, sequential_4/gru_3/while/MatMul)]]

multi-gpu not works

Hi, I try to use this method to parallel my model. but it may do not work.
For example, I use 3 gpus. but actually it seem use one gpu.
The code is

def multi_gpu_wrapper(single_model, num_gpu):
    inputs = single_model.inputs
    towers = []
    concate_layer = tf.keras.layers.Concatenate(axis=0)
    for gpu_id in range(num_gpu):
        print 'cur gpu is ', gpu_id
        with tf.device('/gpu:' + str(gpu_id)):
            splited_layer = tf.keras.layers.Lambda(lambda x: slice_batch(x, num_gpu, gpu_id))
            cur_inputs = []
            for input in inputs:
                cur_inputs.append(
                    splited_layer(input)
                )
            towers.append(single_model(cur_inputs))
            print towers[-1]
    outputs = []
    num_output = len(towers[-1])
    with tf.device('/cpu:0'):
        for i in range(num_output):
            tmp_outputs = []
            for j in range(num_gpu):
                tmp_outputs.append(towers[j][i])
            outputs.append(concate_layer(tmp_outputs))
    multi_gpu_model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return multi_gpu_model

The output of nvidia-smi is:
image

Do you know how to fix it?
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.