Git Product home page Git Product logo

Comments (28)

yumeng5 avatar yumeng5 commented on July 24, 2024 1

Hi @chiyuzhang94 ,

Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:

clm_outputs = extra['clm_outputs']
clm_losses = modules.cross_entropy(
    clm_outputs.view(-1, clm_outputs.size(-1)),
    gen_targets,
    reduction='none',
    ignore_index=self.padding_idx,
)
with torch.no_grad():
    valid_tokens = targets.ne(self.padding_idx)
    masked_on_valid = masked_tokens[valid_tokens]
    copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
    sum_clm_weights = copy_weights.sum()
clm_loss = torch.sum(clm_losses * copy_weights)

I hope this helps! Let me know if you have any other questions.

Best,
Yu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024 1

@chiyuzhang94

I was able to successfully add the additional states necessary for continual pre-training.
I have pushed the base model weights here. https://huggingface.co/kamalkraj/COCO-LM/tree/main

In the model config file, I am resetting these additional states. so it shouldn't be an issue.
https://github.com/kamalkraj/fairseq/blob/0bbef8be75572d1a986e892611dfbef160b46497/examples/coco_lm/config/base.yaml#L9-L11

The rest of the model config settings are for the base setting mentioned in the paper. and for training from scratch.

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024 1

Hi @kamalkraj ,
I modify your code a bit to make it run on multiple nodes with multiple GPUs. Here are the shell scripts: https://github.com/chiyuzhang94/fairseq/tree/coco-lm/examples/coco_lm

Best,
Chiyu

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @chiyuzhang94 ,

Unfortunately, we probably will not be able to release the pretraining script in the near future, as that part involves other technologies within Microsoft (e.g., for speeding up large-scale parallel model training) that may require special approval to be open-sourced. However, I believe that the task and criterion scripts (in fairseq) for pretraining will not be too difficult to implement—you could refer to the ELECTRA implementation by the MC-BERT open source project, specifically this file for the task script and this file for the criterion script. I believe those are the two major files required for pretraining, and it should be easier to build COCO-LM upon those.

I hope this helps!

Best,
Yu

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Thanks, @yumeng5!

I will take a look.

Best,
Chiyu

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Hi @yumeng5 ,

I found this repo implemented COCO-LM. But I think he may miss something in your objective.

Their implementations are:
weighted_loss = self.cl_weight * cl_loss + self.gen_weight * mlm_loss + self.disc_weight * disc_loss

My understanding is that cl_loss is the contrastive loss, mlm_loss is the loss of the auxiliary generator, and disc_loss is the loss of the binary discrimination. But I do not see the LM loss of the CLM loss. Hence, I want to ask you this detail. For LLM in Formula 2, do you use the cross-entropy loss function to compute the loss between the logit from the main model and original tokens?

Thanks!
Chiyu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

@yumeng5
Is it possible to share the complete code for forward pass and all loss calculations, as you shared above?

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

Hi @chiyuzhang94 ,

Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:

clm_outputs = extra['clm_outputs']
clm_losses = modules.cross_entropy(
    clm_outputs.view(-1, clm_outputs.size(-1)),
    gen_targets,
    reduction='none',
    ignore_index=self.padding_idx,
)
with torch.no_grad():
    valid_tokens = targets.ne(self.padding_idx)
    masked_on_valid = masked_tokens[valid_tokens]
    copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
    sum_clm_weights = copy_weights.sum()
clm_loss = torch.sum(clm_losses * copy_weights)

I hope this helps! Let me know if you have any other questions.

Best, Yu

sum_clm_weights usage ?

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @kamalkraj ,

I have attached the core functions for the forward pass computation as below:

def get_seq_label(self, sim_matrix):
    bsz = sim_matrix.size(0)
    if self.seq_label is None or bsz > self.seq_label.size(0):
        self.seq_label = torch.arange(0, bsz, device=sim_matrix.device).view(-1, 2)
        self.seq_label[:, 0] += 1
        self.seq_label[:, 1] += -1
        # label is [1, 0, 3, 2, 5, 4, ...]
        self.seq_label = self.seq_label.view(-1)
        return self.seq_label
    else:
        return self.seq_label[:bsz]

def seqcontrast(self, out_1, out_2, temperature):
    batch_size = out_1.size(0)
    # [2*B, D], orig and span interleavely
    global_out = torch.cat([out_1, out_2], dim=-1).view(2 * batch_size, -1)
    # [2*B, 2*B]
    sim_matrix = torch.mm(global_out, global_out.t()) / temperature
    global_batch_size = sim_matrix.size(0)
    sim_matrix.masked_fill_(torch.eye(global_batch_size, device=sim_matrix.device, dtype=torch.bool), float('-inf'))
    truth = self.get_seq_label(sim_matrix)
    contrast_loss = 0.5 * F.nll_loss(
        F.log_softmax(sim_matrix, dim=-1, dtype=torch.float32),
        truth,
        reduction='sum',
    )
    return contrast_loss

def forward(self, model, sample, reduce=True):
    masked_tokens = sample['net_input']['src_tokens'].eq(self.mask_idx)
    sample_size = masked_tokens.int().sum()
    gen_logits, binary_output, binary_target, replace_tokens, extra = model(
        **sample['net_input'],
        masked_tokens=masked_tokens,
        targets=sample['target']
    )

    targets = model.get_targets(sample, [gen_logits])
    gen_targets = targets[masked_tokens].view(-1)
    # auxiliary model MLM loss
    gen_loss = modules.cross_entropy(
        gen_logits.view(-1, gen_logits.size(-1)),
        gen_targets,
        reduction='sum',
        ignore_index=self.padding_idx,
    )

    binary_target = binary_target.view(-1)
    binary_output = binary_output.view(-1)
    # binary classification (copy mechanism) loss
    binary_loss = F.binary_cross_entropy_with_logits(
                    binary_output.float(),
                    binary_target.float(),
                    reduction='mean')

    clm_outputs = extra['clm_outputs']
    clm_losses = modules.cross_entropy(
        clm_outputs.view(-1, clm_outputs.size(-1)),
        gen_targets,
        reduction='none',
        ignore_index=self.padding_idx,
    )
    with torch.no_grad():
        valid_tokens = targets.ne(self.padding_idx)
        masked_on_valid = masked_tokens[valid_tokens]
        copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
    # CLM loss
    clm_loss = torch.sum(clm_losses * copy_weights)

    seq_emb_1, seq_emb_2 = extra['seq_emb'], extra['span_seq_emb']
    seq_emb_1 = F.normalize(seq_emb_1.float(), dim=-1).type_as(seq_emb_1)
    seq_emb_2 = F.normalize(seq_emb_2.float(), dim=-1).type_as(seq_emb_2)
    query_emb = seq_emb_2
    key_emb = seq_emb_1
    # SCL loss
    scl_loss = self.seqcontrast(query_emb, key_emb, self.args.temperature)
    bsz = targets.size(0)
    scl_loss = scl_loss / bsz * sample_size
    loss = gen_loss + self.args.binary_loss_weight * binary_loss * sample_size + \
           clm_loss + self.args.scl_loss_weight * scl_loss

    # log variables you want to monitor
    logging_output = {}
    return loss, sample_size, logging_output

And sum_clm_weights does not have any usage in training; it was purely for logging/tracking. I have removed it from the code.

I hope this helps! Let me know if you have further questions.

Best,
Yu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

Hi @yumeng5,
Thank you so much for sharing the code. This is really helpful

Just confirming:

The model object you pass in the above forward function is an object of this COCOLM_Model right?

what is the init value of self.seq_label ? None

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

span_tokens is the cropped sequence, right ? is the same cropping used in a whole batch or different cropping per sample in a batch?

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

@yumeng5
what sample-break-mode is used for model training ? complete, complete_doc, eos ?
or
is it like bert/electra - [CLS][SENT-A][SEP][SENT-B][SEP]

Did you apply whole-word-masking like BERT ? or dynamic masking used in ELECTRA?

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @kamalkraj ,

The model object you pass in the above forward function is an object of this COCOLM_Model right?

Yes, the forward function I shared directly used the COCOLM_Model as the model object.

what is the init value of self.seq_label ? None

Correct.

span_tokens is the cropped sequence, right ? is the same cropping used in a whole batch or different cropping per sample in a batch?

Right. Each sample is independently cropped (i.e., the length of the resulting cropped sequence depends only on the length of the original sequence and the cropping ratio); this is done during the data preparation step, similar to how the masked sequences are created for MLM training.

what sample-break-mode is used for model training ? complete, complete_doc, eos ?

We used complete_doc for sample-break-mode.

Did you apply whole-word-masking like BERT ? or dynamic masking used in ELECTRA?

We didn't use whole-word-masking; the random masks for each sequence will be determined on the fly during training (this should be the same with RoBERTa/ELECTRA dynamic masking).

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Hi @yumeng5 ,

Thanks for your recent sharing.

I am not familiar with the fairseq package. Thus, I want to ask a few more questions about the pre-training implementation.

For the model architecture, I think I can directly use the model you defined here. The model registered name is cocolm. Right?
For the loss function, I think I should create a script under the criterions directory and put the forward method into a class object. The registered name of this loss is cocolm as well.
For the task script, did you directly use this script from MC-BERT?

Thanks for your time.

Best,
Chiyu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

Hi @yumeng5,

Do you know what is difference between mask_token_dataset.py vs mask_token_dataset2.py ?

https://www.diffchecker.com/0r1HZNR6

Did you make this change for COCO-LM ?

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @chiyuzhang94 ,

For the model architecture, I think I can directly use the model you defined here. The model registered name is cocolm. Right?

Yes.

For the loss function, I think I should create a script under the criterions directory and put the forward method into a class object. The registered name of this loss is cocolm as well.

Right.

For the task script, did you directly use this script from MC-BERT?

The task script should be almost identical except that the SCL objective in COCO-LM requires creating cropped sequences, which should be handled by a new dataset script like the following:

from functools import lru_cache

import numpy as np
import torch
from fairseq.data import Dictionary, data_utils

from . import BaseWrapperDataset, LRUCacheDataset

class SpanDataset(BaseWrapperDataset):
    """
    A wrapper Dataset for sampling contiguous span
    """

    def __init__(
        self,
        dataset: torch.utils.data.Dataset,
        seed: int = 1,
        span: float = 0,
    ):
        self.dataset = LRUCacheDataset(dataset)
        self.span = span
        self.epoch = 0
        self.seed = seed

    def set_epoch(self, epoch, **unused):
        super().set_epoch(epoch)
        self.epoch = epoch

    def __getitem__(self, index: int):
        return self.__getitem_cached__(self.seed, self.epoch, index)

    @lru_cache(maxsize=16)
    def __getitem_cached__(self, seed: int, epoch: int, index: int):
        with data_utils.numpy_seed(self.seed, self.epoch, index):
            item = self.dataset[index]
            sz = len(item)
            if self.span > 1:
                span_length = min(int(self.span), sz)
            else:
                span_length = int(self.span * sz)
            start_idx = np.random.randint(0, sz - span_length + 1)
            new_item = item.clone()
            return new_item[start_idx: start_idx + span_length]

The SpanDataset should be used in the task script for creating cropped sequences which are passed into the final dataset dictionary as a new entry:

span_tokens = SpanDataset(dataset, span=self.args.span, seed=self.args.seed + 1)
# [CLS] could be removed by cropping; add back
span_tokens = PrependTokenDataset(span_tokens, self.source_dictionary.bos())
span_tokens = RightPadDataset(span_tokens, pad_idx=self.source_dictionary.pad(),)
self.datasets[split] = SortDataset(
            NestedDictionaryDataset(
                {
                    'id': IdDataset(),
                    'net_input': {
                        'src_tokens': RightPadDataset(
                            src_dataset,
                            pad_idx=self.source_dictionary.pad(),
                        ),
                        'span_tokens': span_tokens,
                        'src_lengths': NumelDataset(src_dataset, reduce=False),
                    },
                    'target': RightPadDataset(
                        tgt_dataset,
                        pad_idx=self.source_dictionary.pad(),
                    ),
                    'nsentences': NumSamplesDataset(),
                    'ntokens': NumelDataset(src_dataset, reduce=True),
                },
                sizes=[src_dataset.sizes],
            ),
            sort_order=[
                shuffle,
                src_dataset.sizes,
            ],
        )

I hope the above helps! Let me know if anything is unclear.

Best,
Yu

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @kamalkraj ,

I believe in COCO-LM we did not use mask_whole_words, so the difference between mask_token_dataset.py and mask_token_dataset2.py should not matter. You could use either one together with the SpanDataset I shared above for creating the extra cropped sequences required by COCO-LM.

Best,
Yu

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Hi @yumeng5 ,

Thanks!

How did you preprocess the pre-trained data? Do you have scripts for it?

Best,
Chiyu

from coco-lm.

yumeng5 avatar yumeng5 commented on July 24, 2024

Hi @chiyuzhang94 ,

I haven't done the pretraining data preprocessing myself, so I may not be able to provide detailed instructions or scripts on that. Probably you could look at the MC-BERT repo for some references.

Best,
Yu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

Hi @chiyuzhang94,

I have put everything together with the latest fairseq codebase here - https://github.com/kamalkraj/fairseq/tree/coco-lm/examples/coco_lm
You can use the instruction there to pre-train the model. Please let me know if any issues or find something different from the original paper.

@yumeng5
It will be great if you can look into the code and let me know if I messed up anything or you did different in your implementation

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Thanks very much, @kamalkraj !!

I have a question. I wonder if I can load a google pre-trained electra checkpoint and continue pre-training. Do you know how to use the electra vocabulary? The vocabulary files of electra is not sp.model.

Could you give me any idea?

Best,
Chiyu

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

@chiyuzhang94
You can't load google pre-trained electra checkpoint in this code base.
But you should be able to continue pre-training from the coco-lm checkpoints released in this repo.

Electra uses wordpiece tokenizer. You should be able to use any tokenizer if you are pre-training from scratch.

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Thanks for your suggestion, @kamalkraj !

I followed your instruction and tried the code.
I found that we cannot continue training the COCO-LM by the released checkpoint. Because fairseq expects to load best_loss and optimizer which are not included in the checkpoint.
But I guess I could train from scratch.

Best,
Chiyu

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

Hi @kamalkraj ,

I have a question about distributed training. I wonder how to add distributed training into to the fairseq-hydra-train command. Do you have any experience?

Thanks.

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

@chiyuzhang94

By default, the code runs on all visible GPUS.

You can control GPUs visibility by CUDA_VISIBLE_DEVICES

from coco-lm.

chiyuzhang94 avatar chiyuzhang94 commented on July 24, 2024

@chiyuzhang94

By default, the code runs on all visible GPUS.

You can control GPUs visibility by CUDA_VISIBLE_DEVICES

I see. Yes, it works on one node with multiple GPUs, but not multiple nodes with multiple GPUs.

from coco-lm.

kamalkraj avatar kamalkraj commented on July 24, 2024

You can look into fairseq documentation for multi-node distributed training. I haven't done it yet.

https://github.com/pytorch/fairseq/blob/main/docs/hydra_integration.md

from coco-lm.

junxiazju avatar junxiazju commented on July 24, 2024

Hi @chiyuzhang94 ,

Yes, there should be a cross entropy loss for the CLM task, which corresponds to training the CLM outputs of the main model to predict the original tokens. Note that there is a copy mechanism and the CLM loss is computed from the masked positions only. The following code snippet should be helpful for the CLM loss implementation:

clm_outputs = extra['clm_outputs']
clm_losses = modules.cross_entropy(
    clm_outputs.view(-1, clm_outputs.size(-1)),
    gen_targets,
    reduction='none',
    ignore_index=self.padding_idx,
)
with torch.no_grad():
    valid_tokens = targets.ne(self.padding_idx)
    masked_on_valid = masked_tokens[valid_tokens]
    copy_weights = 1.0 - torch.sigmoid(binary_output[masked_on_valid].detach())
    sum_clm_weights = copy_weights.sum()
clm_loss = torch.sum(clm_losses * copy_weights)

I hope this helps! Let me know if you have any other questions.

Best, Yu

Hi Yu,

It seems that the copy_weights here should multiply the probability before the log operation in cross_entropy according to your paper. Right? If true, what is the difference between corrective language modeling and replaced token detection + mlm in terms of code implementations?

Best,
Jun.

from coco-lm.

Related Issues (7)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.