kakaobrain / mindall-e Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation of a 1.3B text-to-image generation model trained on 14 million image-text pairs
License: Other
PyTorch implementation of a 1.3B text-to-image generation model trained on 14 million image-text pairs
License: Other
Hi there! Just want to quickly congratulate all the effort done in this project!
Will the models / tokenizers also be stored in Github's releases binary? It could be good as a backup / alternative.
This is an incredible project! For reproducibility, and for some of my own work, would you mind sharing/pointing me to code for fine-tuning VQGAN models (e.g., vqgan_imagenet_f16_16384
) on custom datasets? This would be different than code for training VQGAN from scratch on different datasets.
Additionally, how long does fine-tuning take?
Hi, really want to try this project. The result seems quite general. Is there someone has a finetune script of this with data explanation. Thanks in advance.
I was trying to run the sampling_ex.py
, but no matter how low I set the num_candidates
value (even if it's set to one or two), it always tells me that it has run out of memory. I am using an NVIDIA Quadro M5000 with 8 GB of VRAM.
I am finetuning the minDALL-E model on a self-made dataset but my tokenized text prompts are sometimes longer than 64. What would be the best technique to increase the length of the positional encodings to e.g. 128? I was thinking of keeping the original 64 embeddings and appending 64 more, which have to be trained from scratch. However, I think it might mess with the finetuning, since the embeddings are in the very first layer.
Are there better options/techniques to accomplish this?
Hi I want to know if the code can do the inference when we input the text and half of the image like iGPT and Taming Transformer? If possible, would you mind pointing to the relevance code for this.
I found that sampling code examples/sampling_ex.py fails to save the image if the num_candiates is smaller than 16.
It is due to the value 16 is hardcoded in line 61,
for i in range(16):
The below modification works for lower num_candidates value.
for i in range(min(16, args.num_candidates)):
For info, on Google Colab, the provided notebook examples/sampling_interactive_demo.ipynb
has to be slightly edited.
One has to:
%cd /content
!git clone https://github.com/kakaobrain/minDALL-E.git
%cd /content/minDALL-E
%pip install -q pytorch-lightning omegaconf einops tokenizers
%pip install -q git+https://github.com/openai/CLIP.git
I could have run:
%pip install -q -r requirements.txt
However, it takes a long time for no added value, as some packages are already installed on Colab.
Hi, thanks for sharing the code.
In the forward
function of Transformer1d,
text index is sliced with 0 ~ N-2
and image index is sliced with N-1 ~ N-1 + (T-1)
.
B, T = images.shape
_, N = texts.shape
...
x = torch.cat([texts, images], axis=1).contiguous()
...
texts = x[:, :N-1].contiguous()
images = x[:, N-1:-1].contiguous()
Could you please clarify why you didn't slice like below? Thanks!
texts = x[:, :N]
images = x[:, N:]
Recently Open AI posted GLIDE, a diffusion model made for generating images from text, much like DALL-E.
Would it be possible to compare minDALL-E to GLIDE and put the results on the github?
Thank you in advance!
Also I have to say this is amazing!
First off, great work.
Will information about the training be published anywhere? I'm specifically interested in the number of training epochs used and the LR.
This looks great!
Could you share some information on what setup you used for the training of the transformer model?
It would be helpful to have these information to better understand the cost of training dalle models.
Hi,
It is mentioned in the "Transfer Learning Examples" section that you fine-tuned the pre-trained DALL-E on 8 V100 GPUs. I tried running you transfer_learning_ex.py script on V100 GPUs (16GB GPU memory per CPU). It throws CUDA OOM error. Can you please share the exact specs of the hardware you used for this?
Hi, In minDALL-E, inappropriate dependency versioning constraints can cause risks.
Below are the dependencies and version constraints that the project is using
torch==1.8.0
torchvision>=0.8.2
tokenizers>=0.10.2
pyflakes>=2.2.0
tqdm>=4.46.0
pytorch-lightning>=1.5
einops
omegaconf
git+https://github.com/openai/CLIP.git
matplotlib
The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict.
The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.
After further analysis, in this project,
The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0.
The above modification suggestions can reduce the dependency conflicts as much as possible,
and introduce the latest version as much as possible without calling Error in the projects.
The invocation of the current project includes all the following methods.
tqdm.tqdm.set_description tqdm.tqdm
self.resid_drop torch.cuda.manual_seed_all PIL.Image.fromarray PIL.Image.fromarray.save ExpConfig self.key hashlib.md5 module.weight.data.normal_ self.head pytorch_lightning.loggers.TensorBoardLogger self.lr_schedulers.get_last_lr text_features.image_features.F.cosine_similarity.squeeze W.B.device.H.torch.arange.repeat.transpose numpy.transpose min argparse.ArgumentParser.add_argument self.quantize.get_codebook_entry self.v sorted_idx_remove_cond.scatter self.quant_conv RuntimeError self.apply ImageNetDataModule self.sos.repeat pytorch_lightning.Trainer.fit torchvision.transforms.Compose self.stage2.sos AttnBlock model.stage1.from_ckpt from_file reversed get_positional_encoding datetime.datetime.now tokens.to.unsqueeze torch.nn.functional.cosine_similarity probs.torch.multinomial.clone self.encode pl_module.stage1 self.down.append Normalize self.mid.block_1 download self.conv1 Downsample z_q.permute.contiguous self.conv OptConfig torch.nn.functional.pad Stage1Hparams self.embedding super w_.permute.permute i.images.astype source.info.get from_file.enable_truncation self.norm2 random.seed numpy.random.seed os.path.expanduser x.self.query.view codes.device.T.torch.arange.repeat layers.Block device.args.num_candidates.args.softmax_temperature.args.top_p.args.top_k.args.prompt.model.sampling.cpu self.conv_in device.H.torch.arange.repeat self.mlp.transpose cutoff_topp_probs.masked_fill self.norm1 k.reshape.reshape torch.cuda.amp.autocast x.contiguous.contiguous loop.update argparse.ArgumentParser.parse_args prompt.clip.tokenize.to self.tok_emb_txt device.args.num_candidates.args.softmax_temperature.args.top_p.args.top_k.args.prompt.model.sampling.cpu.numpy Stage2Hparams os.path.dirname torch.tril self.ln1 pytorch_lightning.callbacks.ModelCheckpoint cnt.code_.unsqueeze model_clip.encode_text y.transpose.contiguous.view ImageNetDataModule.setup tuple enumerate torch.nn.Linear self.resid_drop.transpose tokenizer.build_tokenizer i_block.i_level.self.down.attn self.register_buffer self.dropout torchvision.utils.make_grid self.mid.attn_1 x.self.value.view torch.randn output.write self.pos_emb_img self.n_heads.C.self.n_heads.B.T.x.self.key.view.transpose self.ln2 self.nin_shortcut self.stage2.eval self.lr_schedulers.step self.blocks os.path.abspath model.stage2.from_ckpt torch.multinomial self.encoder quant.permute.permute min_encoding_indices.self.embedding.view torch.nn.functional.interpolate labels.self.sos.unsqueeze print torchvision.transforms.Normalize sys.path.append self.decoder torch.einsum self.norm_out torch.optim.AdamW images.self.stage1.get_codes.detach.view MultiHeadSelfAttention einops.rearrange urllib.parse.urlparse stage2.transformer.Transformer1d self.stage1.get_codes DataConfig self.drop omegaconf.OmegaConf.structured dalle.models.Dalle.from_pretrained.sampling preprocess_clip images.torch.stack.to tqdm.tqdm.set_description utils.config.get_base_config tqdm.tqdm x.self.key.view self.n_heads.C.self.n_heads.B.T.x.self.query.view.transpose torch.cat.clone self.decode self.stage2 self.query i_level.self.up.upsample urllib.request.urlopen torch.nn.ModuleList.append self.conv2 source.info self.n_heads.C.self.n_heads.B.T.x.self.value.view.transpose self.lr_schedulers layers.Encoder tarfile.open images.self.stage1.get_codes.detach model_clip.encode_image cutoff_topk_logits utils.sampling.sampling torch.nn.Sequential torch.nn.ModuleList setup_callbacks self.value tokens.to.to self.log math.sqrt isinstance omegaconf.OmegaConf.merge open torch.cat torch.ones torch.topk self.proj_out.reshape torch.argmin self.q self.stage1.parameters os.path.join os.path.exists torch.utils.data.DataLoader self.embedding.weight.data.uniform_ scores.torch.argsort.cpu torch.nn.Module cutoff_topk_logits.to dalle.utils.utils.clip_score int cutoff_topk_logits.clone N.x.contiguous f.extract torch.stack torch.sort self.attn_drop.masked_fill torchvision.datasets.ImageNet torchvision.transforms.CenterCrop optimizer.step download_target.open.read cnt.pos_enc_code_.unsqueeze args.config_downstream.os.path.basename.split self torch.optim.lr_scheduler.CosineAnnealingLR stage1.vqgan.VQGAN ValueError torch.argsort Stage1Config range torch.nn.functional.avg_pool2d omegaconf.OmegaConf.load self.sos x.transpose.contiguous torch.manual_seed os.path.isfile image.astype present.torch.stack.clone pl_module.logger.experiment.add_image os.path.basename ImageLogger self.stage1.eval pytorch_lightning.seed_everything torch.cat.size v.reshape.reshape sos.self.stage2.sos.unsqueeze torchvision.transforms.Resize url.split clip.tokenize datetime.datetime.now.strftime device.W.torch.arange.repeat torch.nn.Conv2d torch.nn.LayerNorm dalle.utils.utils.set_seed cls_idx.torch.LongTensor.to torch.nn.functional.softmax i_block.i_level.self.up.attn ResnetBlock torch.nn.functional.cross_entropy probs.torch.multinomial.clone.detach float images.texts.torch.cat.contiguous f.getmembers z_q.permute.contiguous.view dalle.models.Dalle.from_pretrained source.read VectorQuantizer pytorch_lightning.Trainer torch.sigmoid self.tok_emb_img i_block.i_level.self.down.block torch.clamp self.tokenizer.encode h.self.quantize.view self.conv_out nonlinearity model_clip.to self.ln_f q.permute.reshape torch.arange self.load_state_dict q.permute.permute self.k functools.partial torch.sum self.stage2.sos.repeat self.norm self.mid.block_2 self.head_txt cls utils.realpath_url_or_path torch.load torch.no_grad format past.append torchvision.transforms.ToTensor device.N.torch.arange.repeat presents.append self.stage1.decode_code self.quantize from_file.token_to_id os.makedirs self.pos_emb_txt torch.nn.Embedding utils.sampling.sampling_igpt code.clone.detach dalle.models.ImageGPT.from_pretrained z_q.permute.contiguous.permute torchvision.transforms.RandomCrop self.attn Upsample stage2.transformer.iGPT self.post_quant_conv torch.cumsum super.__init__ download_target.open.read.hashlib.md5.hexdigest self.proj_out i_level.self.down.downsample h.sos.torch.cat.contiguous ImageNetDataModule.train_dataloader self.stage2.view self.head_img self.proj ImageNetDataModule.valid_dataloader self.parameters len z.rearrange.contiguous torch.clip torch.nn.GroupNorm torch.nn.Parameter model.sampling argparse.ArgumentParser torch.nn.Dropout sorted_idx_remove_cond.clone block.sample torch.LongTensor self.log_img from_file.enable_padding torch.bmm self.mlp self.conv_shortcut y.transpose.contiguous recons.cpu.cpu module.bias.data.zero_ GELU self.up.insert dataclasses.field module.weight.data.fill_ clip.load torch.nn.functional.gelu i_block.i_level.self.up.block present.torch.stack.clone.detach from_file.add_special_tokens Stage2Config torch.repeat_interleave dalle.models.Dalle.from_pretrained.to layers.Decoder scores.torch.argsort.cpu.numpy cutoff_topp_probs self.mask.torch.tril.view sos.self.stage2.sos.unsqueeze.repeat torch.cat.transpose images.cpu.cpu self.attn_drop quant.rearrange.contiguous z.rearrange.contiguous.view
@developer
Could please help me check this issue?
May I pull a request to fix it?
Thank you very much.
Hi, what steps do I need to follow to fine tune min-dalle on custom dataset?
transfer_learning_ex.py 코드를 봤습니다 :)
파인튜닝 중에 image 에 대한 text 가 따로 들어가지 않아 보이는데요
image 에 대한 text 로 폴더명이 쓰이는 걸까요?
잘 만들어주셔서 감사합니다.
you can add a way to complete images, give you an image without completion.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.