Comments (6)
Logits 找到预测和批次中所有标签之间的相似性,它的形状应该是
[BS, BS]
. 例如,Logit[i][j] 计算第 i 个预测与第 j 个标签的相似度。最终,我们希望预测与其对应的标签最相似,即第 i 个预测应该与批次中的第 i 个标签最相似。这在论文中被称为“批次内分类”。因此,我们使用交叉熵损失进行“分类”,并使用torch.arange(pred.shape[0])
作为ground truth,这意味着第i个预测的ground truth index是i。应该有一种更直接的方式来实现 BMC。我这样写是为了简洁,并显示它与监督对比损失的相似之处,如方程式中所述。3.15。
@painlove1999 您能否提供有关尺寸不匹配的更多信息?当前的实现需要
pred
和label
to be[BS,1]
。
Thank you for your answer. I misunderstood that the pred in your code is logits, so its dimension should be [batch, cls_num]. I will try to adapt your code in my program.
from balancedmse.
Yes, there will be a dimension mismatch problem from the code point of view
from balancedmse.
Logits finds the similarity between a prediction and all labels in the batch, it's shape should be [BS, BS]
. For example, Logit[i][j] computes the i-th prediction's similarity to j-th label. Ultimately, we'd like the prediction to be most similar to its corresponding label, i.e., i-th prediction should be most similar to i-th label in the batch. This is referred to as "classifying within a batch" in the paper. Therefore, we use the cross-entropy loss for the "classification" and usetorch.arange(pred.shape[0])
as the ground truth, which means exactly that the i-th predicition's ground truth index is i.
There should be a more straightforward way to implement BMC. I write it this way for conciseness and to show its resemblance to supervised contrastive loss, as explained in Eqn. 3.15.
@painlove1999 Could you provide more information on the dimension mismatch? The current implementation requires both pred
and label
to be [BS,1]
.
from balancedmse.
I have added more descriptions to variable sizes in readme, hope this helps.
def bmc_loss(pred, target, noise_var):
"""Compute the Balanced MSE Loss (BMC) between `pred` and the ground truth `targets`.
Args:
pred: A float tensor of size [batch, 1].
target: A float tensor of size [batch, 1].
noise_var: A float number or tensor.
Returns:
loss: A float tensor. Balanced MSE Loss.
"""
logits = - (pred - target.T).pow(2) / (2 * noise_var) # logit size: [batch, batch]
loss = F.cross_entropy(logits, torch.arange(pred.shape[0])) # contrastive-like loss
loss = loss * (2 * noise_var).detach() # optional: restore the loss scale, 'detach' when noise is learnable
return loss
from balancedmse.
Great love the response.
Another question: Is the range of target [0,1] or [0, inf)?
from balancedmse.
Great love the response. Another question: Is the range of target [0,1] or [0, inf)?
Thanks! Same as the standard MSE loss, the target range can be anything, e.g., in age estimation, the target range is [0, 120]. It can be unbounded as well, i.e., (-inf, inf).
from balancedmse.
Related Issues (19)
- When release the code of IHMR? HOT 3
- BMCLossMD loss Multi-dimensional HOT 8
- can we implement is code in tensorflow? HOT 1
- Bad performance on train set HOT 1
- The formula derivation is a little unclear. HOT 2
- Dose BMCloss only work for minibatch > 1? HOT 2
- In my task, Balance MSE loss decreases but MES loss increases. HOT 5
- Restore the loss scale? HOT 1
- About the "bmc_loss_md" HOT 3
- about GAILoss HOT 3
- Is BMC loss OK when batchsize = 1? HOT 1
- How to apply BalancedMSE for the d-dimensional regression task? HOT 2
- Balanced MAE? HOT 4
- Does batch size have a huge impact for result? HOT 3
- reconstruction loss HOT 2
- Are there some tricks to use BalancedMse in MultiDimension data? HOT 5
- err HOT 3
- The shape of logits HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from balancedmse.