Git Product home page Git Product logo

Comments (6)

painlove1999 avatar painlove1999 commented on July 30, 2024 2

Logits 找到预测和批次中所有标签之间的相似性,它的形状应该是[BS, BS]. 例如,Logit[i][j] 计算第 i 个预测与第 j 个标签的相似度。最终,我们希望预测与其对应的标签最相似,即第 i 个预测应该与批次中的第 i 个标签最相似。这在论文中被称为“批次内分类”。因此,我们使用交叉熵损失进行“分类”,并使用torch.arange(pred.shape[0])作为ground truth,这意味着第i个预测的ground truth index是i。

应该有一种更直接的方式来实现 BMC。我这样写是为了简洁,并显示它与监督对比损失的相似之处,如方程式中所述。3.15。

@painlove1999 您能否提供有关尺寸不匹配的更多信息?当前的实现需要predlabelto be [BS,1]

Thank you for your answer. I misunderstood that the pred in your code is logits, so its dimension should be [batch, cls_num]. I will try to adapt your code in my program.

from balancedmse.

painlove1999 avatar painlove1999 commented on July 30, 2024

Yes, there will be a dimension mismatch problem from the code point of view

from balancedmse.

jiawei-ren avatar jiawei-ren commented on July 30, 2024

Logits finds the similarity between a prediction and all labels in the batch, it's shape should be [BS, BS]. For example, Logit[i][j] computes the i-th prediction's similarity to j-th label. Ultimately, we'd like the prediction to be most similar to its corresponding label, i.e., i-th prediction should be most similar to i-th label in the batch. This is referred to as "classifying within a batch" in the paper. Therefore, we use the cross-entropy loss for the "classification" and usetorch.arange(pred.shape[0]) as the ground truth, which means exactly that the i-th predicition's ground truth index is i.

There should be a more straightforward way to implement BMC. I write it this way for conciseness and to show its resemblance to supervised contrastive loss, as explained in Eqn. 3.15.

@painlove1999 Could you provide more information on the dimension mismatch? The current implementation requires both pred and label to be [BS,1].

from balancedmse.

jiawei-ren avatar jiawei-ren commented on July 30, 2024

I have added more descriptions to variable sizes in readme, hope this helps.

def bmc_loss(pred, target, noise_var):
    """Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
    Args:
      pred: A float tensor of size [batch, 1].
      target: A float tensor of size [batch, 1].
      noise_var: A float number or tensor.
    Returns:
      loss: A float tensor. Balanced MSE Loss.
    """
    logits = - (pred - target.T).pow(2) / (2 * noise_var)   # logit size: [batch, batch]
    loss = F.cross_entropy(logits, torch.arange(pred.shape[0]))     # contrastive-like loss
    loss = loss * (2 * noise_var).detach()  # optional: restore the loss scale, 'detach' when noise is learnable 

    return loss

from balancedmse.

caotong0 avatar caotong0 commented on July 30, 2024

Great love the response.
Another question: Is the range of target [0,1] or [0, inf)?

from balancedmse.

jiawei-ren avatar jiawei-ren commented on July 30, 2024

Great love the response. Another question: Is the range of target [0,1] or [0, inf)?

Thanks! Same as the standard MSE loss, the target range can be anything, e.g., in age estimation, the target range is [0, 120]. It can be unbounded as well, i.e., (-inf, inf).

from balancedmse.

Related Issues (19)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.