Git Product home page Git Product logo

coca-pytorch's Introduction

CoCa - Pytorch

Implementation of CoCa, Contrastive Captioners are Image-Text Foundation Models, in Pytorch. They were able to elegantly fit in contrastive learning to a conventional encoder / decoder (image to text) transformer, achieving SOTA 91.0% top-1 accuracy on ImageNet with a finetuned encoder.

This repository also chooses to adopt the specific transformer architecture from PaLM, for both the unimodal and multimodal transformers as well as the cross attention blocks (parallel SwiGLU feedforwards)

Update: CoCa has been trained by the good folks over at OpenClip

Install

$ pip install coca-pytorch

Usage

First install the vit-pytorch for the image encoder, which needs to be pretrained

$ pip install vit-pytorch>=0.40.2

Then

import torch

# import vision transformer

from vit_pytorch.simple_vit_with_patch_dropout import SimpleViT
from vit_pytorch.extractor import Extractor

vit = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    patch_dropout = 0.5  # https://arxiv.org/abs/2212.00794
)

vit = Extractor(vit, return_embeddings_only = True, detach = False)

# extractor will enable it so the vision transformer returns its embeddings

# import CoCa and instantiate it

from coca_pytorch.coca_pytorch import CoCa

coca = CoCa(
    dim = 512,                     # model dimension
    img_encoder = vit,             # vision transformer - image encoder, returning image embeddings as (batch, seq, dim)
    image_dim = 1024,              # image embedding dimension, if not the same as model dimensions
    num_tokens = 20000,            # number of text tokens
    unimodal_depth = 6,            # depth of the unimodal transformer
    multimodal_depth = 6,          # depth of the multimodal transformer
    dim_head = 64,                 # dimension per attention head
    heads = 8,                     # number of attention heads
    caption_loss_weight = 1.,      # weight on the autoregressive caption loss
    contrastive_loss_weight = 1.,  # weight on the contrastive loss between image and text CLS embeddings
).cuda()

# mock text and images

text = torch.randint(0, 20000, (4, 512)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()

# train by giving CoCa your text and images with `return_loss = True`

loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

loss.backward()

# do the above for as much text and images...
# then you can get the caption logits as so

logits = coca(
    text = text,
    images = images
) # (4, 512, 20000)

# and the CLIP-like text and image embeddings as

text_embeds, image_embeds = coca(
    text = text,
    images = images,
    return_embeddings = True
) # (4, 512), (4, 512)

Citations

@inproceedings{Yu2022CoCaCC,
  title   = {CoCa: Contrastive Captioners are Image-Text Foundation Models},
  author  = {Jiahui Yu and Zirui Wang and Vijay Vasudevan and Legg Yeung and Mojtaba Seyedhosseini and Yonghui Wu},
  year    = {2022}
}
@inproceedings{Chowdhery2022PaLMSL,
    title   = {PaLM: Scaling Language Modeling with Pathways},
    author  = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
    year    = {2022}
}

coca-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

coca-pytorch's Issues

Image tokens dimension mismatch with the layernorm bias dimension

File "/home/usr/anaconda3/envs/varpt13/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/usr/anaconda3/envs/varpt13/lib/python3.8/site-packages/coca_pytorch/coca_pytorch.py", line 246, in forward
context = self.context_norm(context)
File "/home/usr/anaconda3/envs/varpt13/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/usr/anaconda3/envs/varpt13/lib/python3.8/site-packages/coca_pytorch/coca_pytorch.py", line 25, in forward
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
File "/home/usr/anaconda3/envs/varpt13/lib/python3.8/site-packages/torch/nn/functional.py", line 2515, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: Expected weight to be of same shape as normalized_shape, but got weight of shape [1024] and normalized_shape = [1000]

convolution encoder better result then vit

Thank you for this work.
The generalizing ability of neural networks based on convolution layers is much greater.

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        from efficientnet_pytorch import EfficientNet
        self.model = EfficientNet.from_pretrained('efficientnet-b4')
        self.model = EfficientNet.from_pretrained('efficientnet-b4')
        self.model._fc = torch.nn.Linear(1792, 1024)
        self.conv1D = torch.nn.Conv1d(1, 128, 3, padding='same')

    def forward(self, x):
        
        x = self.model(x)
        x = torch.unsqueeze(x, 1)
        x = self.conv1D(x)
        # return (batch, seq, dim)
        return x

How to train the model using my own dataset?

Can someone tell me how to train the model using my own dataset? is it like below?But I have many images and texts...

# train by giving CoCa your text and images with `return_loss = True`
loss = coca(
    text = text,
    images = images,
    return_loss = True  # set this to True to get the full caption + contrastive loss
)

attn_mask

cls_mask = rearrange(text!=self.pad_id, 'b j -> b 1 j')  
attn_mask = F.pad(cls_mask, (0, 1, seq, 0), value=True)

attn_mask = rearrange(attn_mask, 'b i j -> b 1 i j')  
sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)

Hello, I am confused of the implement of "attn_mask". I think this padding function only can mask the last row of "sim". Could you please explain it? Perhaps it's a very fool question. Thank you so much.

Reproduce Problem

Thanks for this repo. I'm using my own dataset for pre-training via CoCa, and I found that the contrastive loss output from each batch is basically unchanged, is it the contrastive_label that needs to be changed or is there some other place where I need to make corresponding changes?
Thank you!!!!

Generating the caption of a given image

Hello,

Thank you for having implemented this model. Have you already implemented some code to generate the caption of a given image? If not, do you have an idea about how you would do it in this particular architecture?

Thank you in advance.

why train VIT visual encoder first?

Hi, thanks for sharing this repo. In the CoCA paper, both the visual encoder and text encoder are end-to end trained. But in this repo, the vit is first pretrained then fixed to train CoCa.

register_buffer for masks and position encodings breaks DDP

Hi!

Unfortunately, using buffers to cache masks and pos encodings fails when running with DDP.

def get_mask(self, n, device):
if self.mask is not None and self.mask.shape[-1] >= n:
return self.mask[:n, :n]
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
self.register_buffer("mask", mask, persistent=False)
return mask
def get_rotary_embedding(self, n, device):
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
return self.pos_emb[:n]
pos_emb = self.rotary_emb(n, device=device)
self.register_buffer("pos_emb", pos_emb, persistent=False)
return pos_emb

Each rank has a different sequence length because text comes in different sizes. Pytorch buffers are synched by DDP but fail to be reduced since the tensors have different dims on each rank.

I found that using buffers is redundant here anyway since we don't store them in state_dict (persistent=False). Unless you know of a good reason why buffers are preferable that I am missing, @lucidrains?

This code worked in DDP setting:

class ParallelTransformerBlock(nn.Module):
    ...
        self.mask = None
        self.pos_emb = None

    def get_mask(self, n: int, device: torch.device) -> Tensor:
        if self.mask is not None and self.mask.shape[-1] >= n:
            return self.mask[:n, :n]

        mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
        self.mask = mask
        return mask

    def get_rotary_embedding(self, n: int, device: torch.device) -> Tensor:
        if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
            return self.pos_emb[:n]

        pos_emb = self.rotary_emb(n, device=device)
        self.pos_emb = pos_emb
        return pos_emb

Thanks,
George

LayerNorm after attentional_pooler

Hi, if I understand correctly, there is a single LayerNorm that is applied to all the queries output by the attentional pooler, however in the paper it seems like they use a different one for the one query used by the contrastive loss and those that are used as context for the multimodal part. Does it make a difference or is it the same or am I just wrong?

Contrastive loss should be applied to L2-normed embeddings instead of layer normed?

Hi @lucidrains, thank you for the implementation. Just wanted to confirm this with you, based on your code we're normalizing the img embedding and text embedding respectively using a learnable Layer Norm transformation before applying the contrastive loss. But based on my understanding, for contrastive loss we typically maximize the relative cosine similarity so the embeddings should be L2-normed instead of layernormed? Thank you.

Reproducing the results in the paper

Thanks for this repo. Curious, is this an independent implementation of the CoCa paper? If yes, did you reproduce any result in the paper to ensure correctness of implementation?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.