Comments (20)
@lucidrains I ran the fashion mnist data last night and the model was able to converge:
https://api.wandb.ai/links/pfeiljx/udspvdgu
import torch
from datetime import datetime
from magvit2_pytorch import (
VideoTokenizer,
VideoTokenizerTrainer
)
RUNTIME = datetime.now().strftime("%y%m%d %H:%M:%S")
tokenizer = VideoTokenizer(
image_size = 32,
codebook_size=1_024,
use_gan=True,
use_fsq=True,
init_dim=128, # From the paper,
adversarial_loss_weight=0.1, # From the paper
perceptual_loss_weight=0.1, # From the paper
grad_penalty_loss_weight=10.0,
lfq_entropy_loss_weight=0.3, # From the paper
layers = (
'residual',
'compress_space',
('consecutive_residual', 2),
'attend_space',
),
)
trainer = VideoTokenizerTrainer(
tokenizer,
dataset_folder='/projects/users/pfeiljx/mnist/TRAIN',
dataset_type = 'images', # 'videos' or 'images', prior papers have shown pretraining on images to be effective for video synthesis
batch_size = 5,
grad_accum_every = 5,
num_train_steps = 5_000,
num_frames=1,
max_grad_norm=1.0,
learning_rate=2e-5,
accelerate_kwargs={"split_batches": True},
random_split_seed=85,
optimizer_kwargs={"betas": (0.9, 0.99)}, # From the paper
ema_kwargs={},
use_wandb_tracking=True,
checkpoints_folder=f'./runs/{RUNTIME}/checkpoints',
results_folder=f'./runs/{RUNTIME}/results',
)
with trainer.trackers(project_name = 'magvit', run_name = f'MNIST v0.1.26 {RUNTIME}'):
trainer.train()
from magvit2-pytorch.
Thanks @lucidrains I'll let you know when the wandb report is ready.
from magvit2-pytorch.
@jpfeil @jacobpfeil i think this repository should support pretraining with 2d conv layers, and then a way to convert it to 3d for video. but let me meditate on the simplest way to achieve this
from magvit2-pytorch.
Thanks @lucidrains. Let me know if I can help run some tests. I have access to a few A100 GPUs.
from magvit2-pytorch.
@jpfeil sounds good
let me think about this for a few days or the code will come out wrong
measure twice cut once kinda thing
from magvit2-pytorch.
@lucidrains After looking at the FashionMNIST results, it looks like the discriminator collapsed to zero loss. So, I think the learning stopped prematurely. I'm also not getting good reconstructions.
For VQ-GAN, I've read that the autoencoder needs a couple epochs to generate good images before the discriminator starts. Is there a way to do that here?
from magvit2-pytorch.
@jpfeil try 0.1.31 with use_gan = False
on the VideoTokenizer
from magvit2-pytorch.
Woops. My Tokenizer change wasn't saved. Running now...
from magvit2-pytorch.
@jpfeil could you retry with fp32? and train until 5000 steps? also, grad accum of 4-6 is sufficient (32-64 effective batch size)
from magvit2-pytorch.
@jpfeil also share your training curve, try out wandb's report feature for easy sharing
from magvit2-pytorch.
@lucidrains This was run on 0.1.24, so I'm going to pull the latest version and retry. The loss was slowly improving, but around step 1000, the loss became nan. The only change I've made is I added a cosine schedule with warmup. I'm also still using bf16, so I'll change that in the next run.
https://api.wandb.ai/links/pfeiljx/p2x7x2x2
from magvit2-pytorch.
Hi @lucidrains
I ran it using fp32 and trained for 5000 steps, but I did not see any improvement.
https://api.wandb.ai/links/pfeiljx/8kqeyypi
Let me know if you have any suggestions.
from magvit2-pytorch.
@jpfeil yea i could add that, but only if need be
what happens if you set adversarial_loss_weight
to 0.
it really should converge for fashion mnist quite quickly, even without the GAN system
from magvit2-pytorch.
I get an assertion error because self.has_gan attribute gets set to False. Is it okay to override that assertion?
from magvit2-pytorch.
@jpfeil could you point to the line number?
could you also give 0.1.29 a quick try? may be a bug but not entirely sure
from magvit2-pytorch.
@jpfeil oh nvm, yes i see it. we should be able to turn off adversarial loss, let me fix
from magvit2-pytorch.
@jpfeil give the imagenet run another try
there may have been a bug with how I zeroed the gradients a few patches ago
from magvit2-pytorch.
This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.
from magvit2-pytorch.
This is resolved for fashion mnist, but I haven't been able to run through enough imagenet data to see if it works for imagenet. I'm going to close this now and if it comes up again for imagenet, I'll open a new issue.
Hi @jpfeil Do you mind sharing how did you end up solving it? I run into the same issue #25
from magvit2-pytorch.
Hi @coolbunnyx,
Sorry for the delay. I think you already solved it, but I was able to get good reconstruction after training for longer.
from magvit2-pytorch.
Related Issues (20)
- Large scale training HOT 17
- Running multi-gpu hangs after first step HOT 9
- Is there any requirement on the training images? HOT 3
- object has no attribute 'has_multiscale_discrs' HOT 2
- weights HOT 1
- Unsuccessful image reconstruction HOT 3
- expired discord invitation HOT 3
- pretrained weights
- Pixelated image reconstruction HOT 7
- ‘video_contains_first_frame’ in encoder HOT 1
- recon images is black HOT 9
- Question about casual 3d cnn HOT 1
- The configuration of training
- Is there anyone success to train this model? HOT 15
- Running multi-gpu training HOT 5
- About training steps and correctness. HOT 3
- Error while loading the states of optimizer in Trainer - def load(self, path)
- Is there any pretrained weights for debug? HOT 1
- About training speed.
- Why is magvitv2 different from the description in the paper? Am I understanding it wrong? HOT 7
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from magvit2-pytorch.