Git Product home page Git Product logo

Comments (6)

brozi avatar brozi commented on September 13, 2024

It seems like some tokens can make RoBERTa's detokenizer crash. I pushed a fix that will make it output a null token instead of this character when this happens.

note: RoBERTa's tokenizer was created for natural languages. The advantage versus our tokenizer is that you can feed it code in any programming languages easily. The disadvantage is that you won't remove some tokens that don't change the semantics of the code (e.g. newlines in java and C++) and that your BPE and vocab won't be tailored to programming languages. If you are training only on languages we support or which are supported by tree-sitter, you might want to use a tokenizer made for source code.

from codegen.

yazdanbakhsh avatar yazdanbakhsh commented on September 13, 2024

Thanks for the reply and suggestion.

If I understand correctly, I used the tokenizer that you folks developed (fast instead of roberta) as follows. You are suggesting to use a different detokenizer during training?

python3 -m codegen_sources.preprocessing.preprocess /gcs/transcoder_data/train_data_small \
--langs java cpp python \
--mode monolingual_functions \
--bpe_mode fast \
--local true \
--train_splits 8

from codegen.

yazdanbakhsh avatar yazdanbakhsh commented on September 13, 2024

For training we use the following script given in the README, if I understand correctly, I just need to set roberta_mode false to use the program-specific tokenizers.

python3 -m torch.distributed.launch --nproc_per_node=$NGPU codegen_sources/model/train.py \
--exp_name mlm \
--dump_path '/gcs/transcoder_data/train_data_small_dump' \
--data_path '/gcs/transcoder_data/train_data_small/XLM-syml' \
--mlm_steps 'java_monolingual,python_monolingual' \
--add_eof_to_stream true \
--word_mask_keep_rand '0.8,0.1,0.1' \
--word_pred '0.15' \
--encoder_only true \
--n_layers 12  \
--emb_dim 768  \
--n_heads 12  \
--lgs 'java_monolingual-python_monolingual' \
--max_vocab 64000 \
--gelu_activation true \
--roberta_mode true \
--amp 2  \
--fp16 true  \
--batch_size 32 \
--bptt 512 \
--epoch_size 100000 \
--max_epoch 100000 \
--split_data_accross_gpu local \
--optimizer 'adam_inverse_sqrt,warmup_updates=10000,lr=0.0001,weight_decay=0.01' \
--save_periodic 0 \
--validation_metrics _valid_mlm_ppl \
--stopping_criterion '_valid_mlm_ppl,10'

from codegen.

yazdanbakhsh avatar yazdanbakhsh commented on September 13, 2024

Set roberta_mode = false got CUDA memory error. I guess the number of heads in the original Transcoder repo is 8 whereas in CodeGen it is set to 12 and number of layers was 6 compared to 12 in CodeGen. We use V100 GPUs 16GB.

Tried to allocate 96.00 MiB (GPU 6; 15.78 GiB total capacity; 12.64 GiB already allocated; 36.25 MiB free; 14.02 GiB reserved in total by PyTorch

After changing the number of layers and heads, the model seems to be trained for a while and then receives the following error and then CUDA memory issue again.

Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0
...
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
...
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2048.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2048.0
...

from codegen.

brozi avatar brozi commented on September 13, 2024

You made me realize that the example parameters we were showing in transcoder.md are those we used to finetune DOBF with a model that uses the same tokenizer and architecture as RoBERTa to compare it with CodeBERT and GraphCodeBERT. It didn't really make sense and we updated the parameters so that they correspond to TransCoder's architecture. Here are the differences:

--n_layers 6  \ (was 12)
--emb_dim 1024  \  (was 768)
--n_heads 8  \ (was 12)
--gelu_activation false \ (was true)
--roberta_mode false \ (was true)

We also updated it to train on python, java and C++ instead of only python and java like in the original TransCoder paper.

About your memory issues, we train our models on V100 GPUs with 32GB. If you only have 16GB you will need to either decrease the batch size (decrease the batch_size parameter for MLM or the tokens_per_batch parameter for DOBF, DAE and BT) or decrease the size of the model, for instance emb_dim (but in that case you won't be able to reload our pre-trained models). The Gradient overflow thing sometimes happens at the beginning of training and it's generally not a problem.

from codegen.

yazdanbakhsh avatar yazdanbakhsh commented on September 13, 2024

Thanks Baptiste for all the detailed answer. I can confirm the training is now working on our end.

from codegen.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.