Git Product home page Git Product logo

ntm-one-shot-tf's Introduction

One Shot Learning using Memory-Augmented Neural Networks in Tensorflow.

Update: added support for Tensorflow v1*.

Tensorflow implementation of the paper One-shot Learning with Memory-Augmented Neural Networks.

Current Progress of Implementation:

  • Utility Functions:
    • Image Handler
    • Metrics (Accuracy)
    • Similarities (Cosine Similarity)
  • LSTM Controller and Memory Unit
  • Batch Generators
  • Omniglot Tester Code
  • Unsupervised Feature Learning through Autoencoders
  • Cattle/New Born Recognition

The benchmark dataset is Omniglot dataset. All the datasets should be placed in the data/ folder.

Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, Timothy Lillicrap, One-shot Learning with Memory-Augmented Neural Networks, [arXiv]

ntm-one-shot-tf's People

Contributors

hmishra2250 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ntm-one-shot-tf's Issues

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Hi, I'am exploring your code recently, I've encountered a problem when I try to run python Omniglot.py.
System: Ubuntu 16.04
tensroflow version: 1.1.0-rc2
cuda 8.0
python: 2.7

The error information is as following, could you please help me to find out the solution?

python Omniglot.py
Traceback (most recent call last):
  File "Omniglot.py", line 112, in <module>
    omniglot()
  File "Omniglot.py", line 31, in omniglot
    output_var, output_var_flatten, params = memory_augmented_neural_network(input_ph, target_ph, batch_size=batch_size, nb_class=nb_class, memory_shape=memory_shape, controller_size=controller_size, input_size=input_size, nb_reads=nb_reads)
  File "/data/yangyu/workspace/MANN/MANN/Model.py", line 33, in memory_augmented_neural_network
    W_key = tf.get_variable('W_key', shape=shape,initializer=tf.random_uniform_initializer(-1*high, high))
  File "/data/yangyu/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 1065, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/data/yangyu/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 962, in get_variable
    use_resource=use_resource, custom_getter=custom_getter)
  File "/data/yangyu/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 367, in get_variable
    validate_shape=validate_shape, use_resource=use_resource)
  File "/data/yangyu/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/variable_scope.py", line 303, in _true_getter
    is_scalar = shape is not None and not shape

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

ValueError: sample larger than population

mldl@mldlUB1604:/ub16_prj/NTM-One-Shot-TF$ python Omniglot.py
Shapes Recieved in Update: V, dim, val are ==> [64, 128] [64] [64]
Shapes Recieved in Body of Update: v, d2, chg are ==> [128] [] []
wlu_tm1 : [16, 4]
Shapes Recieved in Update: V, dim, val are ==> [16, 128, 40] [16] [16, 40]
Shapes Recieved in Body of Update: v, d2, chg are ==> [128, 40] [] [40]
Compiling the Model
Output, Target shapes: [16, 50, 5] [16, 50, 5]
Done
Training the model
Traceback (most recent call last):
File "Omniglot.py", line 112, in
omniglot()
File "Omniglot.py", line 85, in omniglot
for i, (batch_input, batch_output) in generator:
File "/home/mldl/ub16_prj/NTM-One-Shot-TF/MANN/Utils/Generator.py", line 36, in next
return (self.num_iter - 1), self.sample(self.nb_samples)
File "/home/mldl/ub16_prj/NTM-One-Shot-TF/MANN/Utils/Generator.py", line 41, in sample
sampled_character_folders = random.sample(self.character_folders, nb_samples)
File "/usr/lib/python2.7/random.py", line 323, in sample
raise ValueError("sample larger than population")
ValueError: sample larger than population
mldl@mldlUB1604:
/ub16_prj/NTM-One-Shot-TF$ ll data/
total 18412
drwxrwxr-x 5 mldl mldl 4096 8月 4 18:24 ./
drwxrwxr-x 7 mldl mldl 4096 8月 4 18:16 ../
drwxr-xr-x 32 mldl mldl 4096 10月 21 2015 images_background/
-rw-rw-r-- 1 mldl mldl 1319022 8月 4 18:23 images_background_small1.zip
-rw-rw-r-- 1 mldl mldl 1580917 8月 4 18:23 images_background_small2.zip
-rw-rw-r-- 1 mldl mldl 9464212 8月 4 18:23 images_background.zip
-rw-rw-r-- 1 mldl mldl 6462886 8月 4 18:23 images_evaluation.zip
drwxrwxr-x 2 mldl mldl 4096 8月 3 17:26 omniglot/
drwxrwxr-x 2 mldl mldl 4096 8月 4 18:23 one-shot-classification/
mldl@mldlUB1604:~/ub16_prj/NTM-One-Shot-TF$ ll data/images_background

Deviations from the paper

Hi
Nice work with the code
Just wanted to document how your code differs from the paper.
For example, you have an extra a_t formed using W_add, b_add and h_t
The paper simply adds the key k_t in place of a_t
Also, it's not clear how they actually got k_t.

Also I assumed they had used CNNs to get features before going to the LSTM controller, since it was an image input

IOError: [Errno 21] Is a directory:

When I run this code.The terminal appears following:
/usr/bin/python2.7 /home/cr/PycharmProjects/NTM-One-Shot-TF-master/Omniglot.py
Shapes Recieved in Update: V, dim, val are ==> [64, 128] [64] [64]
Shapes Recieved in Body of Update: v, d2, chg are ==> [128] [] []
wlu_tm1 : [16, 4]
Shapes Recieved in Update: V, dim, val are ==> [16, 128, 40] [16] [16, 40]
Shapes Recieved in Body of Update: v, d2, chg are ==> [128, 40] [] [40]
Compiling the Model
Output, Target shapes: [16, 50, 5] [16, 50, 5]
Done
Traceback (most recent call last):
File "/home/cr/PycharmProjects/NTM-One-Shot-TF-master/Omniglot.py", line 113, in
omniglot()
File "/home/cr/PycharmProjects/NTM-One-Shot-TF-master/Omniglot.py", line 87, in omniglot
for i, (batch_input, batch_output) in generator:
File "/home/cr/PycharmProjects/NTM-One-Shot-TF-master/MANN/Utils/Generator.py", line 36, in next
return (self.num_iter - 1), self.sample(self.nb_samples)
File "/home/cr/PycharmProjects/NTM-One-Shot-TF-master/MANN/Utils/Generator.py", line 56, in sample
for (filename, angle, shift) in zip(image_files, angles, shifts)], dtype=np.float32)
File "/home/cr/PycharmProjects/NTM-One-Shot-TF-master/MANN/Utils/Images.py", line 29, in load_transform
original = imread(image_path, flatten=True)
File "/usr/local/lib/python2.7/dist-packages/scipy-0.19.1-py2.7-linux-x86_64.egg/scipy/misc/pilutil.py", line 156, in imread
im = Image.open(name)
File "/usr/lib/python2.7/dist-packages/PIL/Image.py", line 1996, in open
fp = builtins.open(fp, "rb")
IOError: [Errno 21] Is a directory: './data/omniglot/images_background/Blackfoot_(Canadian_Aboriginal_Syllabics)/character11'
Training the model

Process finished with exit code 1

How to handle this problem?

Questions about batchsize and memory

Hi, I get a little confused about the memory module.
The shape of the memory in your code is Batchsize x Memory_size x Slotsize(or vector_dim). So actually when the training process is finished, there are Batchsize memories generated, not just one memory.
That means, once the value of batchsize is changed, the memory changed, which seems make no sense.
During test process, the value of Batchsize should be the same with that in the training process. So what if I got a batch that contains N samples (N<Batchsize), which memory should each sample use?(when N = Batchsize, each sample use one memory).

The key problem, I think, is why there should be more than 1 memory?

Thank you.

Memory tends to be the same? And the confusion about equation

Hi, thanks for your code, I use these to do sequence prediction task, I find there is a problem that the vector in different memory slots tend to be same. So do you know how to fix this problem.
And for MANN, I find w_write(t) = w_read(t-1) + w_lt(t), why use t-1 timestep read weight rather than t timestep read weight, I think w_read(t) is more related to w_write(t), is there some consideration ?
thanks

Results Explainability

Hi,

Thanks for creating this, it works well. However, I am having trouble understanding the output from the testing script. Could you explain what I am looking at?
Screenshot from 2019-09-13 10-34-48

Warm Regards

when i run the code i get this type error:

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

Screenshots
If applicable, add screenshots to help explain your problem.

Desktop (please complete the following information):

  • OS: [e.g. iOS]
  • Browser [e.g. chrome, safari]
  • Version [e.g. 22]

Smartphone (please complete the following information):

  • Device: [e.g. iPhone6]
  • OS: [e.g. iOS8.1]
  • Browser [e.g. stock browser, safari]
  • Version [e.g. 22]

Additional context
Add any other context about the problem here.

Error while running Omniglot.py

I receive such error while running

Traceback (most recent call last):
File "/root/tensorflow/anomaly-detector/NTM-One-Shot-TF/Omniglot.py", line 112, in
omniglot()
File "/root/tensorflow/anomaly-detector/NTM-One-Shot-TF/Omniglot.py", line 85, in omniglot
for i, (batch_input, batch_output) in generator:
File "/root/tensorflow/anomaly-detector/NTM-One-Shot-TF/MANN/Utils/Generator.py", line 36, in next
return (self.num_iter - 1), self.sample(self.nb_samples)
File "/root/tensorflow/anomaly-detector/NTM-One-Shot-TF/MANN/Utils/Generator.py", line 56, in sample
for (filename, angle, shift) in zip(image_files, angles, shifts)], dtype=np.float32)
File "/root/tensorflow/anomaly-detector/NTM-One-Shot-TF/MANN/Utils/Images.py", line 29, in load_transform
original = imread(image_path, flatten=True)
File "/usr/local/lib/python2.7/dist-packages/scipy/misc/pilutil.py", line 156, in imread
im = Image.open(name)
File "/usr/lib/python2.7/dist-packages/PIL/Image.py", line 2251, in open
fp = builtins.open(fp, "rb")
IOError: [Errno 21] Is a directory: './data/omniglot/images_background/Japanese_(katakana)/character45'

Last two dirs in path may vary

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.