Git Product home page Git Product logo

xaviolo99 / dual-transfer-caption-generation Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 2.0 3.36 MB

This repository contains the notebooks used for training and evaluating a model which allows users to input images and get captions for the images back. It combines EfficientNet and GPT-2 pretrained models and a model to merge the outputs of the previous models.

Jupyter Notebook 99.94% Python 0.06%

dual-transfer-caption-generation's People

Contributors

xaviolo99 avatar

Stargazers

 avatar  avatar  avatar viriya avatar

Watchers

James Cloos avatar  avatar

Forkers

kukuhaza josutk

dual-transfer-caption-generation's Issues

Train process

Hi~
Thanks for your share!
When I try to train this model , I found the problem as follows:
Can you tell me what is the problem?...Thank you so much~
I used python3.6 pytorch 1.1 cuda 10.0.130


Epoch 1

RuntimeError Traceback (most recent call last)
in
11 for idx, (images, texts, masks) in enumerate(train_generator):
12
---> 13 loss, outputs, _ = model(texts, images, labels=texts, attention_mask=masks)
14
15 sum_loss += loss.item()

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, labels, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
776 output_attentions=output_attentions,
777 output_hidden_states=output_hidden_states,
--> 778 return_dict=return_dict,
779 )
780 hidden_states = transformer_outputs[0]

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/modeling_gpt2.py in forward(self, input_ids, past_key_values, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions, output_hidden_states, return_dict, **kwargs)
651 encoder_attention_mask=encoder_attention_mask,
652 use_cache=use_cache,
--> 653 output_attentions=output_attentions,
654 )
655

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/modeling_gpt2.py in forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
289 head_mask=head_mask,
290 use_cache=use_cache,
--> 291 output_attentions=output_attentions,
292 )
293 attn_output = attn_outputs[0] # output_attn: a, present, (attentions)

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/modeling_gpt2.py in forward(self, hidden_states, layer_past, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, use_cache, output_attentions)
225 if layer_past is not None:
226 past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
--> 227 key = torch.cat((past_key, key), dim=-1)
228 value = torch.cat((past_value, value), dim=-2)
229

RuntimeError: invalid argument 0: Tensors must have same number of dimensions: got 4 and 2 at /tmp/pip-req-build-jh50bw28/aten/src/THC/generic/THCTensorMath.cu:62

test error

Hi, thanks for your work.
I have one error when I test the model with your code here:
'''
MAX_LENGTH = 0
for image, tensor in test_generator:
display(Image.fromarray(image[0].numpy()))
#generate_some_text("<|endoftext|>", image=tensor)
#print(tensor.shape, torch.Tensor(tokenizer.encode("<|endoftext|>")).unsqueeze(0).long().cuda().shape)

beam_output = model.generate(
    torch.Tensor(tokenizer.encode("<|endoftext|>")).unsqueeze(0).long().cuda(), 
    image=tensor, 
    pad_token_id=tokenizer.eos_token_id, 
    num_beams=256, 
    no_repeat_ngram_size=2, 
    early_stopping=True,
    num_return_sequences=1
)

if len(beam_output.shape) == 1:
    beam_output.unsqueeze(0)
    
output = beam_output.to('cpu').numpy()[:, 1:].tolist()
captions = [tokenizer.decode([t for t in o if t != tokenizer.eos_token_id])
            for o in output]

for caption in captions:
    print(caption)

break

'''
And the error is:
**---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
in
12 no_repeat_ngram_size=2,
13 early_stopping=True,
---> 14 num_return_sequences=1
15 )
16

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
13 def decorate_context(*args, **kwargs):
14 with self:
---> 15 return func(*args, **kwargs)
16 return decorate_context
17

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/generation_utils.py in generate(self, input_ids, decoder_input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, attention_mask, decoder_start_token_id, use_cache, **model_kwargs)
487 attention_mask=attention_mask,
488 use_cache=use_cache,
--> 489 model_kwargs=model_kwargs,
490 )
491 else:

~/miniconda3/envs/dual/lib/python3.6/site-packages/transformers/generation_utils.py in _generate_beam_search(self, input_ids, cur_len, max_length, min_length, do_sample, early_stopping, temperature, top_k, top_p, repetition_penalty, no_repeat_ngram_size, bad_words_ids, pad_token_id, eos_token_id, batch_size, num_return_sequences, length_penalty, num_beams, vocab_size, attention_mask, use_cache, model_kwargs)
663 input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
664 )
--> 665 outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
666 next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
667

~/miniconda3/envs/dual/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)

TypeError: forward() got an unexpected keyword argument 'return_dict'
**
What's wrong with this piece of code??

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.