Git Product home page Git Product logo

magvit2-pytorch's Issues

Flash attention not working on A100 GPU

I'm trying to train the model on Imagenet, but I'm running into issues getting the model and data to fit in the GPU memory. I'm trying to use A100 gpus, but when the trainer runs I get this error:

File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
  return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 385, in forward
  x = super().forward(x, *args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 375, in forward
  out = self.attend(q, k, v)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1510, in _wrapped_call_impl
  return self._call_impl(*args, **kwargs)
File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _call_impl
  return forward_call(*args, **kwargs)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 235, in forward
  return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)
File "/projects/users/pfeiljx/magvit/magvit2_pytorch/attend.py", line 191, in flash_attn
  out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution

I think this is related to this issue: lucidrains/x-transformers#143

Is there a workaround for this issue?

Thank you!

About the correctness

Hello, thanks for this great implementation. Here is a question, how to make sure this code can reproduce the result present in the original paper? i.e, the FID/FVD or IS in the benchmark. Have anyone measured such scores, and any pretrained models can achieve this score?

Reconstruction image is always a solid color

Hello,

I've been working on training this on the imagenet data, but I'm concerned I'm doing something wrong because the reconstructions are always a solid color. I haven't trained it very long ~1500 steps (batch size 10), but I just wanted to check if this is expected.

1300 steps:
image

1200 steps:
image

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, 
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/imagenet/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 8,
    num_train_steps = 1_000_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=1e-4, # From the paper
    accelerate_kwargs={"split_batches": True, "mixed_precision": 'fp16'},
    random_split_seed=171,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={}
)

trainer.train()

Unsuccessful image reconstruction

I am running into almost the same issue as the closed thread #12. Not sure how it ended up being resolved. Turning off GAN does not help. Any suggestions are greatly appreciated! Thanks.

The reconstructed images look like these:
3
4

train on video dataset

Thanks a lot for your implementation! Can this tokenizer be trained on video dataset in the current version? I found that its recon_loss is very large and cannot converge, and discr_loss cannot converge either.

Here shows the losses on the video dataset:

image

expired discord invitation

Hey @lucidrains and guys, I'm working on reproduce this tokenizer and would like to join the discussion. Could you update the discord invitation link? Thanks in advance!

Pixelated image reconstruction

Trained the model on a set of 10K images (randomly cropped 128 x 128 patches) for 6K iteration with a batch size of 8. The output looks pixelated. Is this what is expected, or the output should somehow look similar to the input? Any feedbacks are appreciated!

2024-01-07_19-45-46

pretrained weights

Thanks for your great work! Is there any pretrained model that you can share?

‘video_contains_first_frame’ in encoder

Great works! But I find the code had some mistake. In Line 1563 in magvit2_pytorch.py, I notice authors use left pad, so the first frame should be video[:, :, self.time_padding], and video should be video[:, :, (self.time_padding + 1):]. Please check the code, if I have misunderstood, please also point it out.
I have another question. When using this set of code to train on an image dataset, why are the reconstructed images the same when inputting different images, whether it is a model with randomly initialized parameters or a trained model. Additionally, a decrease in loss is normal, meaning that the reconstructed images are all the same, unless they are completely black or have other colors. How to solve this problem. Thanks!

Discriminator loss converges to zero early in training

I compared v0.1.26 without the GAN and v0.1.36 with the GAN using the fashion mnist data and was able to get better reconstructions without the GAN:
https://api.wandb.ai/links/pfeiljx/f7wdueh0

Do you have any suggestions for improving training?

I'm using a cosine scheduler for the model and discriminator. Should I use a different learning rate schedule for the discriminator?

I saw similar discriminator collapse with the VQ-GAN, and I read that delaying the discriminator until the generator model is optimized may help. Maybe delaying the discriminator until a certain reconstruction loss is achieved?

After googling some strategies, I saw the unrolled GAN where the generator stays a few steps ahead of the discriminator. I'm not sure how difficult it would be to implement a similar strategy here.

I'm just brainstorming, so feel free to address or ignore any of these comments.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "fp16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()

Running multi-gpu hangs after first step

I'm using accelerate multi-gpu support to run on a cluster of A100 gpus.

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------In which compute environment are you running?
This machine
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Which type of machine are you using?
multi-GPU
How many different machines will you use (use more than 1 for multi-node training)? [1]:
Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]:
Do you wish to optimize your script with torch dynamo?[yes/NO]:
Do you want to use DeepSpeed? [yes/NO]:
Do you want to use FullyShardedDataParallel? [yes/NO]:
Do you want to use Megatron-LM ? [yes/NO]:
How many GPU(s) should be used for distributed training? [1]:4
What GPU(s) (by id) should be used for training on this machine as a comma-seperated list? [all]:0,1,2,3
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Do you wish to use FP16 or BF16 (mixed precision)?
fp16

I can train on a single GPU, but multi-gpu hangs for me. Is there a recommended configuration for running multi-GPU training?

Running with GAN raises RuntimeError

v0.1.32 works without the GAN, but I get an error when using the GAN again.

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d_%H%M%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    channels=1,
    use_gan=True,
    use_fsq=False,
    codebook_size=2**13,
    init_dim=64,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 10,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True, "mixed_precision": "bf16"},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 W/ GAN 2**13 {RUNTIME}'):
    trainer.train()
Traceback (most recent call last):
  File "/projects/users/pfeiljx/magvit/slurm/mnist/run-mnist-test-run.py", line 46, in <module>
    trainer.train()
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 520, in train
    self.train_step(dl_iter)
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/trainer.py", line 341, in train_step
    loss, loss_breakdown = self.model(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 659, in forward
    return model_forward(*args, **kwargs)
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/accelerate/utils/operations.py", line 647, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x7fff42669b40>", line 53, in forward
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 1832, in forward
    norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.grad_layer_wrt_loss) at 0x7fff42659900>", line 50, in grad_layer_wrt_loss
  File "/projects/users/pfeiljx/magvit/magvit2_pytorch/magvit2_pytorch.py", line 129, in grad_layer_wrt_loss
    return torch_grad(
  File "/homes/pfeiljx/miniconda3/envs/magvit/lib/python3.10/site-packages/torch/autograd/__init__.py", line 394, in grad
    result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Example for Images

I'm excited to apply this model to image data. I was wondering if you could give a trivial example for how to apply this library to individual images instead of video. Thank you!

About GroupNorm described in the MAGVIT V2 paper

Hello, thanks for your nice work. I notice that there are some differences between your implementations and original paper. One notable difference is the use of group normalization in the original paper. From my understanding, directly applying group normalization to a 5D video tensor (B, C, T, H, W) can result in non-causal behaviors. In your implementation, you did not include group normalization. Could you please explain your reasoning behind this choice? Is it related to the issue I mentioned?

About training steps and correctness.

According to the same settings in the readme, I trained 40,000 coco images.

Currently, 11,700/1_000_000 steps have been trained, but reconstruction has not been possible, as shown in the figure below.
image

step 20,000
image

step 39,000
image

step 54,000
image

The reconstruction results of the first few steps are shown in the figure below.
step 100
image
step 200
image

The training indicator curve is shown in the figure below.
image

So, is the current training normal? If it's not normal, can you help locate the problem? If it is normal, how many steps does it take to train to reconstruct the image? @lucidrains

About training speed.

Hi, I used 2x 8xA100 machine to train this code on video datasets. I use accelerate as ddp launcher.

After 8 ~ 9 hours of running, I only ran about 3800 steps.

Is this normal?

Is there any requirement on the training images?

Hi, thanks for the great job!

When I try it out, I used a subfolder of imagenet (/ILSVRC/Data/CLS-LOC/train/n02096437) which contains a lot of images as dataset_folder, but I got the error: 0 training samples found at /ILSVRC/Data/CLS-LOC/train/n02096437

I double checked the folder, it has a lot of images.

I wonder if there's any requirement?

Training difficulties

Hi, I am experiencing some difficulties during the training of magvit2. I don't know if I made some mistakes somewhere or where the problem might be coming from.

It seems that my understanding of the paper might me be erroneous, I tried with 2 codebooks of size 512 and I can't seem to fit the training data. The training is really unstable. I tried to replace the LFQ with a classical VQ and it was more stable and was able to converge.
What is the config that you tried for training the model ?

object has no attribute 'has_multiscale_discrs'

Thank you very much for your work. I've encountered an issue while using multiple GPUs with accelerate to launch:

Error message as follows:

magvit2-pytorch/magvit2_pytorch/trainer.py", line 203, in init
raise AttributeError("'{}' object has no attribute '{}'".format(
self.has_multiscale_discrs = self.model.has_multiscale_discrs
File "magvit2-pytorch/magvit2_pytorch/.magvit2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1614, in getattr
AttributeError : self.has_multiscale_discrs = self.model.has_multiscale_discrs'DistributedDataParallel' object has no attribute 'has_multiscale_discrs'

How should I resolve this? Thank you.

Question about casual 3d cnn

Hi. Thank you for your excellent job. I have a question of the casual 3d cnn.
From my point of view, if we use casual 3d cnn, then we dont need to use transformer. Transformer is only used in c-vivit. But in the code I saw linear_attend_space and attend_space .
Is my understanding wrong?

The results for CausalConv3d

Hi @lucidrains , thanks for your awesome work! I used your causal conv implementation and trained on a video vqgan network. The results are as follows:
Original clip sequence:
36500_image
The reconstructed clip sequence:
36500_image_prime
I've noticed that the reconstruction seems to heavily rely on the initial frame. As the sequence progresses, the clarity of the images appears to diminish, leading to a more blurring effect with each subsequent frame. Could you provide any insights into this phenomenon? Thank you for your time and assistance!

Question about Imagenet Parameters

Hi @lucidrains ,

Thanks again for this great resource. I'm trying to get the training up and running on ImageNet, but I get a strange error midway through training. I was hoping you could take a quick look to see if I'm doing something that doesn't make sense. Thank you!

Traceback (most recent call last):
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/run/test-fashion-mnist.py", line 39, in <module>
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/trainer.py", line 431, in train
    self.train_step(dl_iter)
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/trainer.py", line 290, in train_step
    loss, loss_breakdown = self.model(
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.forward) at 0x2ae9f33abb80>", line 53, in forward
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 1561, in forward
    x = self.encode(padded_video, cond = cond)
  File "<@beartype(magvit2_pytorch.magvit2_pytorch.VideoTokenizer.encode) at 0x2ae9f33ab5e0>", line 53, in encode
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 1442, in encode
    x = self.conv_in(video)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/projects/grc/users/pfeiljx/magvit2-pytorch/magvit2_pytorch/magvit2_pytorch.py", line 867, in forward
    return self.conv(x)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 610, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/ui/abv/pfeiljx/miniconda/envs/magvit/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 605, in _conv_forward
    return F.conv3d(
RuntimeError: Given groups=1, weight of size [64, 3, 7, 7, 7], expected input[1, 1, 10, 230, 230] to have 3 channels, but got 1 channels instead
srun: error: mg092: task 0: Exited with exit code 1

Here is the code I'm running:

from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

tokenizer = VideoTokenizer(
    image_size = 256,
    init_dim = 64,
    max_dim = 512,
    channels=3,
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'compress_space',
        ('consecutive_residual', 2),
        'linear_attend_space',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
        'compress_time',
        ('consecutive_residual', 2),
        'compress_time',
        ('consecutive_residual', 2),
        'attend_time',
    )
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='imagenet/ILSVRC/Data/CLS-LOC/train/n01440764',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 1,
    grad_accum_every = 4,
    num_train_steps = 1_000
)

trainer.train()

Large scale training

Hey, wanted to start this comm channel as I'm looking to do a large scale training run using some of this code. I'm happy to share graphs/samples as I go along and wanted to ask a few things to start off:

  • Is the implementation of the original paper functionality complete?
  • What is the best configuration you have found? (I'm seeing talks about LFQ vs FSQ and I see code for diff transformers etc.)

As always, thanks for this! Very helpful

weights

Dear Authors,

I would like to express my sincere gratitude for your outstanding work. I was wondering if you could kindly inform me about the anticipated release date of your model weights.

Thank you very much for your time and consideration.

Is there anyone success to train this model?

I tried to train this model few days. However, the reconstruction results always abnormal. If there is anyone success to train this model, can you tell me some tips for training?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.