Comments (28)
Hi @chiyuzhang94 ,
Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:
clm_outputs = extra['clm_outputs']
clm_losses = modules.cross_entropy(
clm_outputs.view(-1, clm_outputs.size(-1)),
gen_targets,
reduction='none',
ignore_index=self.padding_idx,
)
with torch.no_grad():
valid_tokens = targets.ne(self.padding_idx)
masked_on_valid = masked_tokens[valid_tokens]
copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
sum_clm_weights = copy_weights.sum()
clm_loss = torch.sum(clm_losses * copy_weights)
I hope this helps! Let me know if you have any other questions.
Best,
Yu
from coco-lm.
I was able to successfully add the additional states necessary for continual pre-training.
I have pushed the base model weights here. https://huggingface.co/kamalkraj/COCO-LM/tree/main
In the model config file, I am resetting these additional states. so it shouldn't be an issue.
https://github.com/kamalkraj/fairseq/blob/0bbef8be75572d1a986e892611dfbef160b46497/examples/coco_lm/config/base.yaml#L9-L11
The rest of the model config settings are for the base
setting mentioned in the paper. and for training from scratch.
from coco-lm.
Hi @kamalkraj ,
I modify your code a bit to make it run on multiple nodes with multiple GPUs. Here are the shell scripts: https://github.com/chiyuzhang94/fairseq/tree/coco-lm/examples/coco_lm
Best,
Chiyu
from coco-lm.
Hi @chiyuzhang94 ,
Unfortunately, we probably will not be able to release the pretraining script in the near future, as that part involves other technologies within Microsoft (e.g., for speeding up large-scale parallel model training) that may require special approval to be open-sourced. However, I believe that the task and criterion scripts (in fairseq) for pretraining will not be too difficult to implement—you could refer to the ELECTRA implementation by the MC-BERT open source project, specifically this file for the task script and this file for the criterion script. I believe those are the two major files required for pretraining, and it should be easier to build COCO-LM upon those.
I hope this helps!
Best,
Yu
from coco-lm.
Thanks, @yumeng5!
I will take a look.
Best,
Chiyu
from coco-lm.
Hi @yumeng5 ,
I found this repo implemented COCO-LM. But I think he may miss something in your objective.
Their implementations are:
weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss
My understanding is that cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of the binary discrimination. But I do not see the LM loss of the CLM loss. Hence, I want to ask you this detail. For LLM in Formula 2, do you use the cross-entropy loss function to compute the loss between the logit from the main model and original tokens?
Thanks!
Chiyu
from coco-lm.
@yumeng5
Is it possible to share the complete code for forward pass and all loss calculations, as you shared above?
from coco-lm.
Hi @chiyuzhang94 ,
Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:
clm_outputs = extra['clm_outputs'] clm_losses = modules.cross_entropy( clm_outputs.view(-1, clm_outputs.size(-1)), gen_targets, reduction='none', ignore_index=self.padding_idx, ) with torch.no_grad(): valid_tokens = targets.ne(self.padding_idx) masked_on_valid = masked_tokens[valid_tokens] copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach()) sum_clm_weights = copy_weights.sum() clm_loss = torch.sum(clm_losses * copy_weights)
I hope this helps! Let me know if you have any other questions.
Best, Yu
sum_clm_weights
usage ?
from coco-lm.
Hi @kamalkraj ,
I have attached the core functions for the forward pass computation as below:
def get_seq_label(self, sim_matrix):
bsz = sim_matrix.size(0)
if self.seq_label is None or bsz > self.seq_label.size(0):
self.seq_label = torch.arange(0, bsz, device=sim_matrix.device).view(-1, 2)
self.seq_label[:, 0] += 1
self.seq_label[:, 1] += -1
# label is [1, 0, 3, 2, 5, 4, ...]
self.seq_label = self.seq_label.view(-1)
return self.seq_label
else:
return self.seq_label[:bsz]
def seqcontrast(self, out_1, out_2, temperature):
batch_size = out_1.size(0)
# [2*B, D], orig and span interleavely
global_out = torch.cat([out_1, out_2], dim=-1).view(2 * batch_size, -1)
# [2*B, 2*B]
sim_matrix = torch.mm(global_out, global_out.t()) / temperature
global_batch_size = sim_matrix.size(0)
sim_matrix.masked_fill_(torch.eye(global_batch_size, device=sim_matrix.device, dtype=torch.bool), float('-inf'))
truth = self.get_seq_label(sim_matrix)
contrast_loss = 0.5 * F.nll_loss(
F.log_softmax(sim_matrix, dim=-1, dtype=torch.float32),
truth,
reduction='sum',
)
return contrast_loss
def forward(self, model, sample, reduce=True):
masked_tokens = sample['net_input']['src_tokens'].eq(self.mask_idx)
sample_size = masked_tokens.int().sum()
gen_logits, binary_output, binary_target, replace_tokens, extra = model(
**sample['net_input'],
masked_tokens=masked_tokens,
targets=sample['target']
)
targets = model.get_targets(sample, [gen_logits])
gen_targets = targets[masked_tokens].view(-1)
# auxiliary model MLM loss
gen_loss = modules.cross_entropy(
gen_logits.view(-1, gen_logits.size(-1)),
gen_targets,
reduction='sum',
ignore_index=self.padding_idx,
)
binary_target = binary_target.view(-1)
binary_output = binary_output.view(-1)
# binary classification (copy mechanism) loss
binary_loss = F.binary_cross_entropy_with_logits(
binary_output.float(),
binary_target.float(),
reduction='mean')
clm_outputs = extra['clm_outputs']
clm_losses = modules.cross_entropy(
clm_outputs.view(-1, clm_outputs.size(-1)),
gen_targets,
reduction='none',
ignore_index=self.padding_idx,
)
with torch.no_grad():
valid_tokens = targets.ne(self.padding_idx)
masked_on_valid = masked_tokens[valid_tokens]
copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
# CLM loss
clm_loss = torch.sum(clm_losses * copy_weights)
seq_emb_1, seq_emb_2 = extra['seq_emb'], extra['span_seq_emb']
seq_emb_1 = F.normalize(seq_emb_1.float(), dim=-1).type_as(seq_emb_1)
seq_emb_2 = F.normalize(seq_emb_2.float(), dim=-1).type_as(seq_emb_2)
query_emb = seq_emb_2
key_emb = seq_emb_1
# SCL loss
scl_loss = self.seqcontrast(query_emb, key_emb, self.args.temperature)
bsz = targets.size(0)
scl_loss = scl_loss / bsz * sample_size
loss = gen_loss + self.args.binary_loss_weight * binary_loss * sample_size + \
clm_loss + self.args.scl_loss_weight * scl_loss
# log variables you want to monitor
logging_output = {}
return loss, sample_size, logging_output
And sum_clm_weights
does not have any usage in training; it was purely for logging/tracking. I have removed it from the code.
I hope this helps! Let me know if you have further questions.
Best,
Yu
from coco-lm.
Hi @yumeng5,
Thank you so much for sharing the code. This is really helpful
Just confirming:
The model
object you pass in the above forward
function is an object of this COCOLM_Model right?
what is the init value of self.seq_label
? None
from coco-lm.
span_tokens
is the cropped sequence, right ? is the same cropping used in a whole batch or different cropping per sample in a batch?
from coco-lm.
@yumeng5
what sample-break-mode
is used for model training ? complete
, complete_doc
, eos
?
or
is it like bert/electra - [CLS][SENT-A][SEP][SENT-B][SEP]
Did you apply whole-word-masking
like BERT ? or dynamic masking used in ELECTRA?
from coco-lm.
Hi @kamalkraj ,
The model object you pass in the above forward function is an object of this COCOLM_Model right?
Yes, the forward function I shared directly used the COCOLM_Model as the model
object.
what is the init value of self.seq_label ? None
Correct.
span_tokens is the cropped sequence, right ? is the same cropping used in a whole batch or different cropping per sample in a batch?
Right. Each sample is independently cropped (i.e., the length of the resulting cropped sequence depends only on the length of the original sequence and the cropping ratio); this is done during the data preparation step, similar to how the masked sequences are created for MLM training.
what sample-break-mode is used for model training ? complete, complete_doc, eos ?
We used complete_doc
for sample-break-mode
.
Did you apply whole-word-masking like BERT ? or dynamic masking used in ELECTRA?
We didn't use whole-word-masking; the random masks for each sequence will be determined on the fly during training (this should be the same with RoBERTa/ELECTRA dynamic masking).
from coco-lm.
Hi @yumeng5 ,
Thanks for your recent sharing.
I am not familiar with the fairseq package. Thus, I want to ask a few more questions about the pre-training implementation.
For the model architecture, I think I can directly use the model you defined here. The model registered name is cocolm
. Right?
For the loss function, I think I should create a script under the criterions
directory and put the forward method into a class object. The registered name of this loss is cocolm
as well.
For the task script, did you directly use this script from MC-BERT?
Thanks for your time.
Best,
Chiyu
from coco-lm.
Hi @yumeng5,
Do you know what is difference between mask_token_dataset.py
vs mask_token_dataset2.py
?
https://www.diffchecker.com/0r1HZNR6
Did you make this change for COCO-LM ?
from coco-lm.
Hi @chiyuzhang94 ,
For the model architecture, I think I can directly use the model you defined here. The model registered name is cocolm. Right?
Yes.
For the loss function, I think I should create a script under the criterions directory and put the forward method into a class object. The registered name of this loss is cocolm as well.
Right.
For the task script, did you directly use this script from MC-BERT?
The task script should be almost identical except that the SCL objective in COCO-LM requires creating cropped sequences, which should be handled by a new dataset script like the following:
from functools import lru_cache
import numpy as np
import torch
from fairseq.data import Dictionary, data_utils
from . import BaseWrapperDataset, LRUCacheDataset
class SpanDataset(BaseWrapperDataset):
"""
A wrapper Dataset for sampling contiguous span
"""
def __init__(
self,
dataset: torch.utils.data.Dataset,
seed: int = 1,
span: float = 0,
):
self.dataset = LRUCacheDataset(dataset)
self.span = span
self.epoch = 0
self.seed = seed
def set_epoch(self, epoch, **unused):
super().set_epoch(epoch)
self.epoch = epoch
def __getitem__(self, index: int):
return self.__getitem_cached__(self.seed, self.epoch, index)
@lru_cache(maxsize=16)
def __getitem_cached__(self, seed: int, epoch: int, index: int):
with data_utils.numpy_seed(self.seed, self.epoch, index):
item = self.dataset[index]
sz = len(item)
if self.span > 1:
span_length = min(int(self.span), sz)
else:
span_length = int(self.span * sz)
start_idx = np.random.randint(0, sz - span_length + 1)
new_item = item.clone()
return new_item[start_idx: start_idx + span_length]
The SpanDataset
should be used in the task script for creating cropped sequences which are passed into the final dataset dictionary as a new entry:
span_tokens = SpanDataset(dataset, span=self.args.span, seed=self.args.seed + 1)
# [CLS] could be removed by cropping; add back
span_tokens = PrependTokenDataset(span_tokens, self.source_dictionary.bos())
span_tokens = RightPadDataset(span_tokens, pad_idx=self.source_dictionary.pad(),)
self.datasets[split] = SortDataset(
NestedDictionaryDataset(
{
'id': IdDataset(),
'net_input': {
'src_tokens': RightPadDataset(
src_dataset,
pad_idx=self.source_dictionary.pad(),
),
'span_tokens': span_tokens,
'src_lengths': NumelDataset(src_dataset, reduce=False),
},
'target': RightPadDataset(
tgt_dataset,
pad_idx=self.source_dictionary.pad(),
),
'nsentences': NumSamplesDataset(),
'ntokens': NumelDataset(src_dataset, reduce=True),
},
sizes=[src_dataset.sizes],
),
sort_order=[
shuffle,
src_dataset.sizes,
],
)
I hope the above helps! Let me know if anything is unclear.
Best,
Yu
from coco-lm.
Hi @kamalkraj ,
I believe in COCO-LM we did not use mask_whole_words
, so the difference between mask_token_dataset.py
and mask_token_dataset2.py
should not matter. You could use either one together with the SpanDataset
I shared above for creating the extra cropped sequences required by COCO-LM.
Best,
Yu
from coco-lm.
Hi @yumeng5 ,
Thanks!
How did you preprocess the pre-trained data? Do you have scripts for it?
Best,
Chiyu
from coco-lm.
Hi @chiyuzhang94 ,
I haven't done the pretraining data preprocessing myself, so I may not be able to provide detailed instructions or scripts on that. Probably you could look at the MC-BERT repo for some references.
Best,
Yu
from coco-lm.
Hi @chiyuzhang94,
I have put everything together with the latest fairseq codebase here - https://github.com/kamalkraj/fairseq/tree/coco-lm/examples/coco_lm
You can use the instruction there to pre-train the model. Please let me know if any issues or find something different from the original paper.
@yumeng5
It will be great if you can look into the code and let me know if I messed up anything or you did different in your implementation
from coco-lm.
Thanks very much, @kamalkraj !!
I have a question. I wonder if I can load a google pre-trained electra checkpoint and continue pre-training. Do you know how to use the electra vocabulary? The vocabulary files of electra is not sp.model
.
Could you give me any idea?
Best,
Chiyu
from coco-lm.
@chiyuzhang94
You can't load google pre-trained electra checkpoint in this code base.
But you should be able to continue pre-training from the coco-lm checkpoints released in this repo.
Electra uses wordpiece tokenizer. You should be able to use any tokenizer if you are pre-training from scratch.
from coco-lm.
Thanks for your suggestion, @kamalkraj !
I followed your instruction and tried the code.
I found that we cannot continue training the COCO-LM by the released checkpoint. Because fairseq expects to load best_loss
and optimizer
which are not included in the checkpoint.
But I guess I could train from scratch.
Best,
Chiyu
from coco-lm.
Hi @kamalkraj ,
I have a question about distributed training. I wonder how to add distributed training into to the fairseq-hydra-train
command. Do you have any experience?
Thanks.
from coco-lm.
By default, the code runs on all visible GPUS.
You can control GPUs visibility by CUDA_VISIBLE_DEVICES
from coco-lm.
By default, the code runs on all visible GPUS.
You can control GPUs visibility by
CUDA_VISIBLE_DEVICES
I see. Yes, it works on one node with multiple GPUs, but not multiple nodes with multiple GPUs.
from coco-lm.
You can look into fairseq documentation for multi-node distributed training. I haven't done it yet.
https://github.com/pytorch/fairseq/blob/main/docs/hydra_integration.md
from coco-lm.
Hi @chiyuzhang94 ,
Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:
clm_outputs = extra['clm_outputs'] clm_losses = modules.cross_entropy( clm_outputs.view(-1, clm_outputs.size(-1)), gen_targets, reduction='none', ignore_index=self.padding_idx, ) with torch.no_grad(): valid_tokens = targets.ne(self.padding_idx) masked_on_valid = masked_tokens[valid_tokens] copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach()) sum_clm_weights = copy_weights.sum() clm_loss = torch.sum(clm_losses * copy_weights)
I hope this helps! Let me know if you have any other questions.
Best, Yu
Hi Yu,
It seems that the copy_weights
here should multiply the probability before the log
operation in cross_entropy
according to your paper. Right? If true, what is the difference between corrective language modeling
and replaced token detection + mlm
in terms of code implementations?
Best,
Jun.
from coco-lm.
Related Issues (7)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from coco-lm.