Thank you for your outstanding work! Could you kindly help me with the following three specific questions:
1)Why is an additional dimension concatenation required in this case?
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L351
if args.use_attributes:
self.att_embeds = torch.cat([self.att_embeds, torch.matmul(self.att_embeds.squeeze().T, self.att_W).mean(1, keepdim=True).T.unsqueeze(0)], dim=1)
2)FOMO needs attributes selected for each category, but the current implementation doesn't guarantee an equal number of attributes selected for each category.
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L441
_, top_indices = torch.topk(self.att_W.view(-1), num_classes * self.num_attributes_per_class)
3)Is the training and evaluation process consistent in computing attribute scores?
Training Stage:without learnable parameters logit_shift and logit_scale
code for attribute_refinement: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L394
code for attribute_selection: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L431
cos_sim = cosine_similarity(image_embeddings, self.att_embeds, dim=-1)
Eval Stage:with learnable parameters logit_shift and logit_scale
pred_logits = (pred_logits + logit_shift) * logit_scale
code: https://github.com/orrzohar/FOMO/blob/main/models/FOMO.py#L643
(pred_logits, class_embeds) = self.model.class_predictor(image_feats, self.att_embeds.repeat(batch_size, 1, 1),
self.att_query_mask)
def class_predictor(
self,
image_feats: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor] = None,
query_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor]:
"""
Args:
image_feats:
Features extracted from the `image_text_embedder`.
query_embeds:
Text query embeddings.
query_mask:
Must be provided with query_embeddings. A mask indicating which query embeddings are valid.
"""
(pred_logits, image_class_embeds) = self.class_head(image_feats, query_embeds, query_mask)
return (pred_logits, image_class_embeds)
class OwlViTClassPredictionHead(nn.Module):
def __init__(self, config: OwlViTConfig):
super().__init__()
out_dim = config.text_config.hidden_size
self.query_dim = config.vision_config.hidden_size
self.dense0 = nn.Linear(self.query_dim, out_dim)
self.logit_shift = nn.Linear(self.query_dim, 1)
self.logit_scale = nn.Linear(self.query_dim, 1)
self.elu = nn.ELU()
def forward(
self,
image_embeds: torch.FloatTensor,
query_embeds: Optional[torch.FloatTensor],
query_mask: Optional[torch.Tensor],
) -> Tuple[torch.FloatTensor]:
image_class_embeds = self.dense0(image_embeds)
if query_embeds is None:
device = image_class_embeds.device
batch_size, num_patches = image_class_embeds.shape[:2]
pred_logits = torch.zeros((batch_size, num_patches, self.query_dim)).to(device)
return (pred_logits, image_class_embeds)
# Normalize image and text features
image_class_embeds /= torch.linalg.norm(image_class_embeds, dim=-1, keepdim=True) + 1e-6
query_embeds /= torch.linalg.norm(query_embeds, dim=-1, keepdim=True) + 1e-6
# Get class predictions
pred_logits = torch.einsum("...pd,...qd->...pq", image_class_embeds, query_embeds)
# Apply a learnable shift and scale to logits
logit_shift = self.logit_shift(image_embeds)
logit_scale = self.logit_scale(image_embeds)
logit_scale = self.elu(logit_scale) + 1
pred_logits = (pred_logits + logit_shift) * logit_scale
if query_mask is not None:
if query_mask.ndim > 1:
query_mask = torch.unsqueeze(query_mask, dim=-2)
pred_logits = pred_logits.to(torch.float64)
pred_logits = torch.where(query_mask == 0, -1e6, pred_logits)
pred_logits = pred_logits.to(torch.float32)
return (pred_logits, image_class_embeds)
Your excellent will be a great help to my research!