Git Product home page Git Product logo

Comments (20)

jpfeil avatar jpfeil commented on July 30, 2024 3

@lucidrains I ran the fashion mnist data last night and the model was able to converge:

https://api.wandb.ai/links/pfeiljx/udspvdgu

import torch
from datetime import datetime
from magvit2_pytorch import (
    VideoTokenizer,
    VideoTokenizerTrainer
)

RUNTIME = datetime.now().strftime("%y%m%d %H:%M:%S")

tokenizer = VideoTokenizer(
    image_size = 32,
    codebook_size=1_024,
    use_gan=True,
    use_fsq=True,
    init_dim=128, # From the paper,
    adversarial_loss_weight=0.1, # From the paper
    perceptual_loss_weight=0.1, # From the paper
    grad_penalty_loss_weight=10.0,
    lfq_entropy_loss_weight=0.3, # From the paper
    layers = (
        'residual',
        'compress_space',
        ('consecutive_residual', 2),
        'attend_space',
    ),
)

trainer = VideoTokenizerTrainer(
    tokenizer,
    dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
    dataset_type = 'images',                        # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
    batch_size = 5,
    grad_accum_every = 5,
    num_train_steps = 5_000,
    num_frames=1,
    max_grad_norm=1.0,
    learning_rate=2e-5,
    accelerate_kwargs={"split_batches": True},
    random_split_seed=85,
    optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
    ema_kwargs={},
    use_wandb_tracking=True,
    checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
    results_folder=f'./runs/{RUNTIME}/results',
)


with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 {RUNTIME}'):
    trainer.train()

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024 1

Thanks @lucidrains I'll let you know when the wandb report is ready.

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024 1

@jpfeil @jacobpfeil i think this repository should support pretraining with 2d conv layers, and then a way to convert it to 3d for video. but let me meditate on the simplest way to achieve this

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024 1

Thanks @lucidrains. Let me know if I can help run some tests. I have access to a few A100 GPUs.

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024 1

@jpfeil sounds good

let me think about this for a few days or the code will come out wrong

measure twice cut once kinda thing

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024 1

@lucidrains After looking at the FashionMNIST results, it looks like the discriminator collapsed to zero loss. So, I think the learning stopped prematurely. I'm also not getting good reconstructions.

sampled 17

For VQ-GAN, I've read that the autoencoder needs a couple epochs to generate good images before the discriminator starts. Is there a way to do that here?

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024 1

@jpfeil try 0.1.31 with use_gan = False on the VideoTokenizer

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024 1

Woops. My Tokenizer change wasn't saved. Running now...

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil could you retry with fp32? and train until 5000 steps? also, grad accum of 4-6 is sufficient (32-64 effective batch size)

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil also share your training curve, try out wandb's report feature for easy sharing

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024

@lucidrains This was run on 0.1.24, so I'm going to pull the latest version and retry. The loss was slowly improving, but around step 1000, the loss became nan. The only change I've made is I added a cosine schedule with warmup. I'm also still using bf16, so I'll change that in the next run.

https://api.wandb.ai/links/pfeiljx/p2x7x2x2

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024

Hi @lucidrains

I ran it using fp32 and trained for 5000 steps, but I did not see any improvement.

https://api.wandb.ai/links/pfeiljx/8kqeyypi

Let me know if you have any suggestions.

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil yea i could add that, but only if need be

what happens if you set adversarial_loss_weight to 0.

it really should converge for fashion mnist quite quickly, even without the GAN system

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024

I get an assertion error because self.has_gan attribute gets set to False. Is it okay to override that assertion?

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil could you point to the line number?

could you also give 0.1.29 a quick try? may be a bug but not entirely sure

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil oh nvm, yes i see it. we should be able to turn off adversarial loss, let me fix

from magvit2-pytorch.

lucidrains avatar lucidrains commented on July 30, 2024

@jpfeil give the imagenet run another try

there may have been a bug with how I zeroed the gradients a few patches ago

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

from magvit2-pytorch.

coolbunnyx avatar coolbunnyx commented on July 30, 2024

This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.

Hi @jpfeil Do you mind sharing how did you end up solving it? I run into the same issue #25

from magvit2-pytorch.

jpfeil avatar jpfeil commented on July 30, 2024

Hi @coolbunnyx,

Sorry for the delay. I think you already solved it, but I was able to get good reconstruction after training for longer.

from magvit2-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.