Git Product home page Git Product logo

vmichals / figureqa-baseline Goto Github PK

View Code? Open in Web Editor NEW
36.0 4.0 8.0 27 KB

TensorFlow implementation of the CNN-LSTM, Relation Network and text-only baselines for the paper "FigureQA: An Annotated Figure Dataset for Visual Reasoning"

License: MIT License

Python 100.00%
tensorflow relation-network deep-learning neural-networks python3 figure-analysis visual-question-answering vqa relational-reasoning microsoft

figureqa-baseline's Introduction

FigureQA-Baseline

This repository contains the TensorFlow implementations of the CNN-LSTM, Relation Network and text-only baselines for our paper FigureQA: An Annotated Figure Dataset for Visual Reasoning [Project Page].

If you use code from this repository for your scientific work, please cite

Kahou, S. E., Michalski, V., Atkinson, A., Kadar, A., Trischler, A., & Bengio, Y. (2017). Figureqa: An annotated figure dataset for visual reasoning. arXiv preprint arXiv:1710.07300.

If you use the Relation Network implementation, please also cite

Santoro, A., Raposo, D., Barrett, D. G., Malinowski, M., Pascanu, R., Battaglia, P., & Lillicrap, T. (2017). A simple neural network module for relational reasoning. In Advances in neural information processing systems (pp. 4974-4983).

Getting started

The setup was tested with python 3, tensorflow 1.4 and 1.6.0-rc1. We recommend using the Anaconda Python Distribution.

  1. Create a virtual machine, e.g. via
conda create -p ~/venvs/figureqa python=3
  1. Activate the environment:
source activate ~/venvs/figureqa
  1. Install dependencies:
conda install numpy tqdm six matplotlib pandas
pip install tensorflow-gpu 
  1. Download the FigureQA data set tar.gz archives (unextracted) into a directory named FigureQA.

  2. Clone the baseline repository somewhere locally (here we're using $HOME/workspace)

mkdir -p ~/workspace
cd ~/workspace
git clone [email protected]:vmichals/FigureQA-Baseline.git

Training and Evaluation

Training a model

Run the training script for the model. It takes the following required arguments:

  • --data-path: the directory, in which you placed the tar.gz archives of FigureQA, referred to as DATAPATH in the following.
  • --tmp-path: a temporary directory, in which the script will extract the data (preferably on fast storage, such as an SSD or a RAM disk), from now on referred to as TMPPATH
  • --model: the model you want to train (rn, cnn or text), from now on referred to as MODEL
  • --num-gpus: the number of GPUs to use (in the same machine), from now on referred to as NUMGPU
  • --val-set: the validation set to use for early-stopping (validation1 or validation2), from now on referred to as VALSET
  • (additional configuration options can be found in the *.json files in the cfg subfolder)
cd ~/workspace/FigureQA-baseline
python -u train_model_on_figureqa.py --data-path DATAPATH --tmp-path TMPPATH \
    --model MODEL --num-gpus NUMGPU --val-set VALSET

Resuming interrupted training

To resume interrupted training, run the resume script, which takes the following required arguments:

  • --data-path: same as for the training script
  • --tmp-path: same as for the training script
  • --train-dir: the training directory created by the training script (a subfolder of the train_dir), from now on referred to as TRAINDIR
  • --num-gpus: same as for the training script
  • --resume-step: the time-step from which to resume (check training directory for the model-TIMESTEP.meta file with the largest TIMESTEP), from now on referred to as RESUMESTEP
cd ~/workspace/FigureQA-baseline
python -u resume_train_model_on_figureqa.py --train-dir TRAINDIR --resume-step RESUMESTEP \
    --num-gpus NUMGPU --data-path DATAPATH --tmp-path TMPPATH

Testing

To evaluate a trained model, run the eval script, which takes the following required arguments:

  • --train-dir: same as for the resume script
  • --meta-file: the meta file of the trained model, e.g. "model_val_best.meta" from the
  • --data-path: same as for the resume and training script
  • --tmp-path: same as for the resume and training script

Example:

cd ~/workspace/FigureQA-baseline
python -u ./eval_model_on_figureqa.py  --train-dir TRAINDIR --meta-file METAFILE \
    --partition test2 --data-path DATAPATH --tmp-path TMPPATH

For the test1 and test2 partitions, the script will dump your predictions to a csv file. To get the test accuracy, please submit the file here and we will get back to you with the results as soon as possible.

figureqa-baseline's People

Contributors

vmichals avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

figureqa-baseline's Issues

OP_REQUIRES failed at iterator_ops.cc:902 : Invalid argument: buffer_size must be greater than zero.

mldl@ub1604:~/ub16_prj/FigureQA-baseline$ python3 -u train_model_on_figureqa.py --data-path ../../data/ --tmp-path tmp/ --model cnn --num-gpus 0 --val-set validation1
extracting tmp/FigureQA/figureqa-validation1-v1.tar.gz...
extracting tmp/FigureQA/figureqa-sample-train1-v1.tar.gz...
extracting tmp/FigureQA/figureqa-test2-v1.tar.gz...
extracting tmp/FigureQA/figureqa-test1-v1.tar.gz...
extracting tmp/FigureQA/figureqa-validation2-v1.tar.gz...
extracting tmp/FigureQA/figureqa-train1-v1.tar.gz...
loading training data...
trying to load dictionary from figureqa_dict.json...
building inverse dictionary...
tokenizing questions...
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1327368/1327368 [00:09<00:00, 139195.52it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1327368/1327368 [00:03<00:00, 372262.48it/s]
creating sparse tensor for questions...
adding image filenames and questions to sub dataset list...
WARNING:tensorflow:From /home/mldl/ub16_prj/FigureQA-baseline/data/figureqa.py:178: Dataset.from_sparse_tensor_slices (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.data.Dataset.from_tensor_slices().
adding answers to sub dataset list...
zipping up sub datasets...
shuffling dataset...
adding input parser to dataset pipeline
batching...
loading validation data...
trying to load dictionary from figureqa_dict.json...
building inverse dictionary...
tokenizing questions...
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 265106/265106 [00:01<00:00, 156351.44it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 265106/265106 [00:00<00:00, 547474.07it/s]
creating sparse tensor for questions...
adding image filenames and questions to sub dataset list...
adding answers to sub dataset list...
zipping up sub datasets...
shuffling dataset...
adding input parser to dataset pipeline
batching...
building model graph...
2018-07-24 02:12:50.517562: W tensorflow/core/framework/allocator.cc:108] Allocation of 170922432 exceeds 10% of system memory.
2018-07-24 02:12:52.280888: W tensorflow/core/framework/allocator.cc:108] Allocation of 170922432 exceeds 10% of system memory.
2018-07-24 02:12:52.368087: W tensorflow/core/framework/allocator.cc:108] Allocation of 170922432 exceeds 10% of system memory.
2018-07-24 02:12:54.096779: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2018-07-24 02:13:03.779639: W tensorflow/core/framework/allocator.cc:108] Allocation of 170922432 exceeds 10% of system memory.
2018-07-24 02:13:05.668674: W tensorflow/core/framework/allocator.cc:108] Allocation of 170922432 exceeds 10% of system memory.
2018-07-24 02:13:22.689298: W tensorflow/core/framework/op_kernel.cc:1318] OP_REQUIRES failed at iterator_ops.cc:902 : Invalid argument: buffer_size must be greater than zero.
[[Node: ShuffleDataset = ShuffleDataset[output_shapes=[[], [?,1], [?], [1], [], []], output_types=[DT_STRING, DT_INT64, DT_INT32, DT_INT64, DT_INT64, DT_INT64], reshuffle_each_iteration=true](ZipDataset, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size)]]
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1322, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1307, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1409, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: buffer_size must be greater than zero.
[[Node: ShuffleDataset = ShuffleDataset[output_shapes=[[], [?,1], [?], [1], [], []], output_types=[DT_STRING, DT_INT64, DT_INT32, DT_INT64, DT_INT64, DT_INT64], reshuffle_each_iteration=true](ZipDataset, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size)]]
[[Node: OneShotIterator = OneShotIteratorcontainer="", dataset_factory=_make_dataset_rQdfMWzAEbQ[], output_shapes=[[?,256,256,3], [?,?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT64, DT_INT64], shared_name="", _device="/job:localhost/replica:0/task:0/device:CPU:0"]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "train_model_on_figureqa.py", line 194, in
train_handle = sess.run(train_iterator.string_handle())
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1135, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1316, in _do_run
run_metadata)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1335, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: buffer_size must be greater than zero.
[[Node: ShuffleDataset = ShuffleDataset[output_shapes=[[], [?,1], [?], [1], [], []], output_types=[DT_STRING, DT_INT64, DT_INT32, DT_INT64, DT_INT64, DT_INT64], reshuffle_each_iteration=true](ZipDataset, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size, ShuffleDataset/buffer_size)]]
[[Node: OneShotIterator = OneShotIteratorcontainer="", dataset_factory=_make_dataset_rQdfMWzAEbQ[], output_shapes=[[?,256,256,3], [?,?], [?], [?]], output_types=[DT_FLOAT, DT_INT32, DT_INT64, DT_INT64], shared_name="", _device="/job:localhost/replica:0/task:0/device:CPU:0"]]
mldl@ub1604:~/ub16_prj/FigureQA-baseline$

Trying to use the RN code for CLEVR dataset which results in overfitting.

Hi,

Not sure if this is the right place to ask this question, but I adapted the RN code to train on the CLEVR dataset. I set up the hyperparameters by following the paper but I am facing the issue of overfitting. The validation accuracy gets stuck at 50%.

Any suggestions on this will be helpful.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.