dhlee347 / pytorchic-bert Goto Github PK
View Code? Open in Web Editor NEWPytorch Implementation of Google BERT
License: Apache License 2.0
Pytorch Implementation of Google BERT
License: Apache License 2.0
def load(self, model_file, pretrain_file):
""" load saved model or pretrained transformer (a part of model) """
if model_file:
print('Loading the model from', model_file)
self.model.load_state_dict(torch.load(model_file))
elif pretrain_file: # use pretrained transformer
print('Loading the pretrained model from', pretrain_file)
if pretrain_file.endswith('.ckpt'): # checkpoint file in tensorflow
checkpoint.load_model(self.model.transformer, pretrain_file)
elif pretrain_file.endswith('.pt'): # pretrain model file in pytorch
self.model.transformer.load_state_dict(
{key[12:]: value
for key, value in torch.load(pretrain_file).items()
if key.startswith('transformer')}
) # load only transformer parts
Could I kindly ask that what is the meaning of key[12:]: value when you load a pretrained_model? Just want to keep the last layer? Thanks, hope for your reply.
Hello,
Thanks for the excellent work in compressing the HF code in a single repo for BERT.
Just a couple fo questions:
a) Is it possible to load pertained BERT weights and then fine tune on top of it on my own dataset?
b) Does this support multi GPU training?
Thanks
Abhishek
Hi,
Have you checked the total number of parameters in the model? I checked that it is 220 million, which is more than the 110 million parameters presented in the original BERT-base model. Hope for your reply! Thanks!
I use a chinese corpus to pre-train a bert on this project.
And i find that my loss almost stop decreasing when it reaches about 4.0. I never trained an english version bert. Is there some more training log for english bert? I just want to know the final token-level MLM loss on English bert pre-training. Thanks first.
Thank you for your great code. I'm a student and a beginner of data analysis.
I want to executive your code but I have some questions. It may be a silly question, but can you give me some details about files?
We need a $DATA_FILE as a train set, but what is vocab.txt? I can get the vocab.txt file from google's github. Just use it? or Can I customize it?(Because I want to make a bert which has lower parameters than BERT-BASE.)
Also, the ouput file model_steps_xxxx.pt is compatible with BERT in google's github?
Sorry I am not an expert, so maybe my questions are so silly. Thank you.
How can we use in on test dataset for cola task?
Can i fune tune this model to make it run squad?
Hi, i want to pretrain the code for chinese data as datafile . The formate is like this:
今天 天气 好
and can i use the my own vocab.txt ?
thanks a lot.
Hey, if i want to restore a checkpointed model and perform pretraining, how would I do that?
Hi,
First of all, thank you so much for the great work you guys have done in your scripts.
I would like to visualize the attention weights obtained after training. I tried to use this visualization tool: https://github.com/jessevig/bertviz#attention-head-view
However, it seems to only work for the pre-trained BERT model, not the fine-tuned one we make ourselves via your script.
What would you recommend for visualization?
Thank you!
I like it! You may want to check the work NVIDIA did to incorporate FP16 training in our repo. It really speeds the model on recent GPUs (4x speed up on a V100!).
You basically just have to change the Layer Norm module in the model and tweak a bit the training to use NVIDIA's apex.
Is GEGLU innovative, or is it derived from a certain paper?
Hey,
I have difficulties in running the pretrain, any help would be appreciated.
So I've prepared corpus.txt (quite small, about 1000 lines) that looks like this:
document 1 line 1...
document 1 line 2...
document 1 line 3...
document 2 line 1...
document 2 line 2...
document 2 line 3...
And I run the pretrain.py but I got an error on train.py file, on this line:
print('Epoch %d/%d : Average Loss %5.3f'%(e+1, self.cfg.n_epochs, loss_sum/(i+1)))
So for the time being I commented that line.
And after I run again, here what I got:
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
Iter (loss=X.XXX): 0it [00:00, ?it/s]
....
Could you please point me where I could possibly make the mistake?
Thanks!
p.s. I have commented some part of the code in train.py (the part where it loads the checkpoint, because I dont install the tensorflow for a reason). What I want to do for now is training a pretrained bert model using my own data. I am not sure if it is causing the error above?
if I pretrain Bert with masking strategy,In principle, I can predict the word of mask given by a source sentence and a target sentence with mask.
anyone can tell me how to do that?
thank you man.
Can you tell me the data set that can replace 'books_large_all.txt'?
thank you
https://github.com/dhlee347/pytorchic-bert/blob/master/classify.py#L137
When fine-tuning, is there no needed '[SEP]', '[CLS]'?
In google-research code, They add 'CLS' and 'SEP' in fine-tuning.
https://github.com/santhoshkolloju/Abstractive-Summarization-With-Transfer-Learning/blob/master/preprocess.py#L197
In pretrain get_loss function, loss_lm is calculated by mean.
Because of this, all zero values in loss_lm handles as a correct answer.
So, I think we need to change mean to numerator / denominator like tensorflow.
loss_lm = (loss_lm * masked_weights.float()).mean()
to
loss_lm_numerator = (loss_lm*masked_weights.float()).sum()
loss_lm_denominator = masked_weights.sum() + 1e-5
loss_lm = loss_lm_numerator / loss_lm_denominator
Is it correct?
On this code line,
the pad index 0 is same with first segment index.
So, it may not offer segment information exactly.
Hi, thank you very much for the implementation!
I'm trying to compare your implementation with the official TF BERT head-to-head with the Gutenberg dataset (since the BookCorpus dataset is no longer available now).
I assume that the text input file format is the same as huggingface's implementation. Is that correct? A direct clarification of the text dataset format would be great for new users.
There might be a corner case of seek_random_offset()
if using utf-8 text dataset (like the above) for pre-training. When doing f.seek(randint(0, max_offset), 0)
, If the function happens to truncate the utf-8 '
character (i.e. from \xe2\x80\x99
into something like \x99
), pretrain.py
will raise the error like the following:
File "/home/tkdrlf9202/PycharmProjects/pytorchic-bert/pretrain.py", line 88, in __iter__
seek_random_offset(self.f_neg)
File "/home/tkdrlf9202/PycharmProjects/pytorchic-bert/pretrain.py", line 41, in seek_random_offset
f.readline() # throw away an incomplete sentence
File "/home/tkdrlf9202/anaconda3/envs/p36/lib/python3.6/codecs.py", line 321, in decode
(result, consumed) = self._buffer_decode(data, self.errors, final)
UnicodeDecodeError: 'utf-8' codec can't decode byte 0x99 in position 0: invalid start byte
The error could be mitigated if we use
self.f_pos = open(file, "r", encoding='utf-8', errors='ignore')
self.f_neg = open(file, "r", encoding='utf-8', errors='ignore')
instead of self.f_pos = open(file, 'r')
in SentPairDataLoader
, but half-silently removing some characters might lead to reproducibility issues (I guess chances are minimal since the f.readline()
next to f.seek(randint(0, max_offset), 0)
is for ditching the incomplete sequence).
I'd like to hear your opinions and thanks again for the contribution!
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.