Git Product home page Git Product logo

unified-io-inference's Introduction

UnifiedIO

This repo contains code to run models from our paper Unified-IO: A Unified Model for Vision, Language, and Multi-Modal Tasks.

Installation

Install jax, note this might require manually installing Cuda Toolkits and Cudnn toolkits if using GPUs.

Then install the supporting libraries with:

pip install -r requirements.txt

Model weights

Model weights can be found on aws:

To download run:

wget https://ai2-prior-uio.s3.us-west-2.amazonaws.com/public/model-weights-bin/small_1000k.bin -O small.bin

or download with aws-cli:

aws s3 cp s3://ai2-prior-uio/public/model-weights-bin/small_1000k.bin small.bin 

Usage

Download an image to test on:

wget https://farm2.staticflickr.com/1362/1261465554_95741e918b_z.jpg -O dbg_img.png

Then tasks can done using the ModelRunner class:

from uio import runner
from PIL import Image
import numpy as np

model = runner.ModelRunner("small", "small.bin")

with Image.open("dbg_img.png") as img:
  image = np.array(img.convert('RGB'))

# Answer a VQA question, note this might take over a minute the first time it is 
# called while the function is compiled by jax
output = model.vqa(image, "What color is the sofa?")
print(output["text"])  # Should print `green`

This example can be run end-to-end by demo_script.py. ModelRunner supports many more tasks, examples can be seen in the demo notebook.

ModelRunner also provides a lower-level API that can be called with arbitrary text/image output and can generate text/image outputs, as well supporting batch input

out = model.run([image], ["What is the depth map of the image ?"], 
               output_text_len=1, generate_image=True, num_decodes=None)
depth_image = out["image"][0]

Demo notebook

More tasks are shown in demo.ipynb, this requires additionally install jupyter and matplotlib:

pip install matplotlib notebook

Then it can be run with:

jupyter notebook demo.ipynb

Just-in-time compilation

By default ModelRunner compiles the underlying inference calls the first time they are used, this results in faster performance at a one-time cost. This can be disabled by setting the compile parameter to false. You can set the environment variable JAX_LOG_COMPILES=1 to see when a function is being compiled.

Implementation Details

Running UnifiedIO on a task is a 4-step process:

  1. Convert tasks inputs into (image_input, prompt) pairs, the image_input can be None. This step is task-specific and involve things like selecting a prompt for the tasks or converting region locations into region location tokens that are then embedded in the prompt,
  2. Preprocess these components, done by utils.preprocess_image and converting the input prompt into tokens using a T5Tokenizer
  3. Running the model on these pre-processed input, done in model.py. This produces text tokens and/or a 256x256 image as output.
  4. Post-process the results, this step is task-specific and can involve converting the output tokens into text or image locations and/or resizing/cropping the output image.

In ModelRunner, run does steps 2 and 3 and the task-specific methods do steps 1 and 4 for various tasks.

The main neural network code itself can be found in modules.Transformer

Hardware requirements

We have run XL model on GPUs with 24GB of memory, lower memory GPUs should be able to run the smaller models but might not be able to run the XL model.

Citation

If you use this codebase, please cite:

@article{lu2022unified,
  title={Unified-IO: A Unified Model for Vision, Language, and Multi-Modal Tasks},
  author={Lu, Jiasen and Clark, Christopher and Zellers, Rowan and Mottaghi, Roozbeh and Kembhavi, Aniruddha},
  journal={arXiv preprint arXiv:2206.08916},
  year={2022}
}

unified-io-inference's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

unified-io-inference's Issues

Reverting to CPU

Screenshot 2023-02-12 at 9 09 13 AM

Screenshot 2023-02-12 at 9 08 02 AM

After running pip install -r requirements.txt, I found I was not able to run UIO on GPU or TPU. I was able to reproduce the issue on both Paperspace and Google Colab.

Would you be willing to offer some guidance?

how the uio differentiate between instances of different classes

How does uio distinguish between instances of different categories during segmentaion training?My understanding is: when uio is training for segmentation, each sample will only require one category to be segmented, and only need to use randomly generated colors to represent different instances. However, it cannot handle the situation where a picture has multiple category requests

Original error: UNIMPLEMENTED: DNN library is not found.

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling transformers.utils.move_cache().
Moving 5 files to the new cache system
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:04<00:00, 1.04it/s]
0%| | 0/5015 [00:00<?, ?it/s]2022-10-13 06:34:12.209025: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:421] Loaded runtime CuDNN library: 8.0.5 but source was compiled with: 8.2.4. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
2022-10-13 06:34:12.962264: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:421] Loaded runtime CuDNN library: 8.0.5 but source was compiled with: 8.2.4. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
2022-10-13 06:34:12.974684: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:421] Loaded runtime CuDNN library: 8.0.5 but source was compiled with: 8.2.4. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
0%| | 0/5015 [00:02<?, ?it/s]
Traceback (most recent call last):
File "run.py", line 24, in
output = model.vqa(image, question)
File "/home/phucpx/EVJVQA/uio/runner.py", line 297, in vqa
generate_image=False, num_decodes=num_decodes)
File "/home/phucpx/EVJVQA/uio/runner.py", line 190, in run
return_all_decodes=True
File "/home/phucpx/EVJVQA/uio/model.py", line 388, in predict_batch_with_aux
mutable=['cache'])
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 1247, in apply
)(variables, *args, **kwargs, rngs=rngs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/core/scope.py", line 865, in wrapper
y = fn(root, *args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 1689, in scope_fn
return fn(module.clone(parent=scope), *args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/phucpx/EVJVQA/uio/network.py", line 1054, in call
image_decoder_tokens = self.discrete_vae.get_codebook_indices(image_decoder_targets, vae_decode) # 0 is the start token.
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/phucpx/EVJVQA/uio/network.py", line 348, in get_codebook_indices
h = self.encoder(x, training)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/phucpx/EVJVQA/uio/network.py", line 207, in call
name='conv_in')(x)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/phucpx/EVJVQA/uio/t5x_layers.py", line 540, in call
precision=self.precision)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/lax/convolution.py", line 165, in conv_general_dilated
preferred_element_type=preferred_element_type)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/core.py", line 328, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/core.py", line 331, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/core.py", line 698, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 113, in apply_primitive
**params)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/util.py", line 222, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/util.py", line 215, in cached
return f(*args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 197, in xla_primitive_callable
prim.name, donated_invars, False, *arg_specs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 343, in _xla_callable_uncached
keep_unused, *arg_specs).compile().unsafe_call
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 980, in compile
**self.compile_args)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 1137, in from_xla_computation
host_callbacks)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 1055, in compile_or_get_cached
host_callbacks)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/profiler.py", line 313, in wrapper
return func(*args, **kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/jax/_src/dispatch.py", line 994, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,256,256,128]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,256,256,3]{2,1,3,0} %copy, f32[3,3,3,128]{1,0,2,3} %copy.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 256, 256, 3) rhs_shape=(3, 3, 3, 128) precision=None preferred_element_type=None]" source_file="/home/phucpx/EVJVQA/uio/t5x_layers.py" source_line=540}, backend_config="{"conv_result_scale":1,"activation_mode":"0","side_input_scale":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "run.py", line 24, in
output = model.vqa(image, question)
File "/home/phucpx/EVJVQA/uio/runner.py", line 297, in vqa
generate_image=False, num_decodes=num_decodes)
File "/home/phucpx/EVJVQA/uio/runner.py", line 190, in run
return_all_decodes=True
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 402, in wrapped_module_method
return self._call_wrapped_method(fun, args, kwargs)
File "/home/phucpx/miniconda3/envs/phucpx/lib/python3.7/site-packages/flax/linen/module.py", line 705, in _call_wrapped_method
y = fun(self, *args, **kwargs)
File "/home/phucpx/EVJVQA/uio/model.py", line 388, in predict_batch_with_aux
mutable=['cache'])
File "/home/phucpx/EVJVQA/uio/network.py", line 1054, in call
image_decoder_tokens = self.discrete_vae.get_codebook_indices(image_decoder_targets, vae_decode) # 0 is the start token.
File "/home/phucpx/EVJVQA/uio/network.py", line 348, in get_codebook_indices
h = self.encoder(x, training)
File "/home/phucpx/EVJVQA/uio/network.py", line 207, in call
name='conv_in')(x)
File "/home/phucpx/EVJVQA/uio/t5x_layers.py", line 540, in call
precision=self.precision)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm for:
%cudnn-conv = (f32[1,256,256,128]{2,1,3,0}, u8[0]{0}) custom-call(f32[1,256,256,3]{2,1,3,0} %copy, f32[3,3,3,128]{1,0,2,3} %copy.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01io->b01f, custom_call_target="__cudnn$convForward", metadata={op_name="jit(conv_general_dilated)/jit(main)/conv_general_dilated[window_strides=(1, 1) padding=((1, 1), (1, 1)) lhs_dilation=(1, 1) rhs_dilation=(1, 1) dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 3, 1, 2), rhs_spec=(3, 2, 0, 1), out_spec=(0, 3, 1, 2)) feature_group_count=1 batch_group_count=1 lhs_shape=(1, 256, 256, 3) rhs_shape=(3, 3, 3, 128) precision=None preferred_element_type=None]" source_file="/home/phucpx/EVJVQA/uio/t5x_layers.py" source_line=540}, backend_config="{"conv_result_scale":1,"activation_mode":"0","side_input_scale":0}"

Original error: UNIMPLEMENTED: DNN library is not found.

o ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false. Please also file a bug for the root cause of failing autotuning.

Hi there, i got the above error when run with the pretrained model unifedio: xl_100k.bin.
Please, how to fix them?

Thansk all

table 3

Dear author:
How to get the results in Table 3?

Flax backward compatibility

I got this error when running the demo

ImportError: cannot import name 'optim' from 'flax'

The error is gone when I switch back to an older version of flax

pip install flax==0.5.0

object detection in inference

For object detection I followed what paper says and provide a single prompt "What objects are in the image ?" to the model, as follows:
image
But the model output text is like this, not containing object locolization:
image
which can not processed by ModelRunner._extract_boxes
Could you please kindly tell me how to do object detection in inference? Thanks!

Multiple object detection

I am using runner to create detection outputs. Even though there are multiple objects of the same class, the model output is always single. Even if I use num_decodes > 1, it gives the same prediction multiple times. Is this a bug from the code side, or is the model not trained to predict multiple objects in the same image?

Request for releasing testing code on COCO pose dataset.

Hello and thank you for your interesting and impressive work! I was wondering if it would be possible to release the testing code for the COCO pose dataset.

I have tried running the provided demo on each COCO original image and its corresponding ground truth bounding box. However, the resulting AP score I obtained was only 25. Could you please provide some insights or suggestions on any additional tricks or configurations that could potentially improve the results?

Thank you in advance for your assistance and looking forward to your response.

Loss function used in the multi-task training

Hi! Thank you for releasing the code.
I'd like to ask what loss function you used in the multi-task training. Is it a uniform loss function or different tasks have different loss functions?

JAX version

Hi! Thank you for releasing the code.
I'd like to ask what version of JAX, including CUDA and CuDNN versions, you used to run the code.
I followed the instructions for JAX installation via pip and I tried to run the code with jax 0.3.17, jaxlib 0.3.15+cuda11.cudnn805. There's CUDA 11.6 and CuDNN 8.0.5 on the server I use. However, when I try to execute the notebook demo, I get an error AttributeError: module 'jax' has no attribute '_src' that seems to be thrown by jax interpreters itself.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.