Git Product home page Git Product logo

Comments (8)

FengLi-ust avatar FengLi-ust commented on June 8, 2024

Hey, thanks for your interest in our work. We are glad to see that you implement DINO based on DN-DETR. Is there any difference between your re-implemented DINO and our DINO? Can it achieve the same performance as our DINO? Maybe you could provide more information and we can discuss with more details.

from dn-detr.

Vallum avatar Vallum commented on June 8, 2024

Fist of all, with Resnet-50 + Deformable DETR + DN+DINO, 36 epochs

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.509
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.691
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.556
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.337
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.542
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.653
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.380
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.658
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.730
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.570
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.772
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.880

I think the performance is almost same with your official DINO.

For MQS, I wanted to maintain the choice between original Deformable DETR's 2 stage variants and your other variants,

        if self.two_stage:
            output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)

            # hack implementation for two-stage Deformable DETR
            enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
            enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals

            topk = self.two_stage_num_proposals
            topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
            topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
            topk_coords_unact = topk_coords_unact.detach()
            reference_points = topk_coords_unact.sigmoid()

            # MQS is dab + mqs
            if self.use_mqs:
                assert self.use_dab
                reference_points_mqs = reference_points

                # sometimes the target is empty, add a zero part of query_embed to avoid unused parameters
                reference_points_mqs += self.tgt_embed.weight[0][0]*torch.tensor(0).cuda()
                tgt_mqs = self.tgt_embed.weight[:, None, :].repeat(1, bs, 1).transpose(0, 1) # nq, bs, d_model

                # query_embed is not None when training.
                if query_embed is not None:
                    reference_points_dab = query_embed[..., self.d_model:].sigmoid()
                    tgt_dab = query_embed[..., :self.d_model]

                    reference_points = torch.cat([reference_points_dab, reference_points_mqs], dim=1)
                    tgt = torch.cat([tgt_dab, tgt_mqs], dim=1)
                else:
                    reference_points = reference_points_mqs
                    tgt = tgt_mqs
            else:
                pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact, self.d_model)))
                query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
        else:
            if self.use_dab:
                reference_points = query_embed[..., self.d_model:].sigmoid()
                tgt = query_embed[..., :self.d_model]
                # tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
            else:
                query_embed, tgt = torch.split(query_embed, c, dim=1)
                query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
                tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
                reference_points = self.reference_points(query_embed).sigmoid()
                # bs, num_quires, 2

        init_reference_out = reference_points

For negative sampling,
the Official DINO pushed the negative sampling out of dn_components.py,
in which case, it was hard for me to use both DN-DETR and DINO simutaneously.
So I implemented double sampling inside dn_components.py

# in def prepare_for_dn
    if training:
        if contrastive:
            new_targets = []
            for t in targets:
                new_t = {}
                new_t['labels'] = torch.cat([t['labels'], torch.tensor(len(t['labels']) * [num_classes], dtype=torch.int64).cuda()], dim=0)
                new_t['boxes'] = torch.cat([t['boxes'], t['boxes']], dim=0)
                new_targets.append(new_t)
            targets = new_targets
        known = [(torch.ones_like(t['labels'])).cuda() for t in targets] # [ [ 1, 1], [1, 1, 1], ... ]
        know_idx = [torch.nonzero(t) for t in known] # [ [0, 1], [0, 1, 2], ... ]
        known_num = [sum(k) for k in known] # [ 2, 3, ... ]

With this implementation, I parsed the performance in every module step from DN-DETR to DN+DINO

source resnet-50 epochs AP AP50 AP75 APS APM APL
paper dino-MQS-LFT-4scale 12 47.9 65.3 52.1 31.2 50.9 61.9
paper dino-MQS-LFT-5scale 12 48.3 65.8 52.4 32.2 51.3 62.2
paper DN-DDETR-4scale 12 43.4 61.9 47.2 24.8 46.8 59.4
self DN-DDETR-MQS-4scale 12 48.2 66.0 52.6 29.9 51.4 63.0
self DN-DDETR-MQS-LFT-4scale 12 48.1 65.3 52.4 30.4 51.3 62.7
self DN-DDETR-CDN-MQS-LFT-4scale 12 48.2 65.4 52.5 31.1 51.1 63.3
source resnet-50 epochs AP AP50 AP75 APS APM APL
paper dino-MQS-LFT-4scale 36 50.5 68.3 55.1 32.7 53.9 64.9
paper dino-MQS-LFT-5scale 36 51.0 69.0 55.6 34.1 53.6 65.6
paper DN-DDETR-4scale 50 48.6 67.4 52.7 31.0 52.0 63.7
self DN-DDETR-MQS-4scale 36 49.9 68.2 54.0 34.6 53.2 64.6
self DN-DDETR-MQS-LFT-4scale 36 50.3 68.2 55.0 32.7 53.3 65.1
self DN-DDETR-CDN-MQS-LFT-4scale 36 50.7 68.7 55.4 33.2 54.1 65.2

from dn-detr.

FengLi-ust avatar FengLi-ust commented on June 8, 2024

That's great. Looks like you are comparing with the old DINO in your table. You can initialize the parameters as the new DINO, so you can achieve around 49.0 in 12epoch.

from dn-detr.

Vallum avatar Vallum commented on June 8, 2024

@FengLi-ust Could you let me know which parameters in newer DINO are different from older DINO?
I followed older paper settings, but I cannot find which one is different. For me, their parameters look just same?!

For my latest 12 epochs setting, I found

 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.487
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.664
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.530
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.309
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.520
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.631

It looked similar, but the small area AP is comparatively lower. (0.309 vs. 0.32) So I am not sure that they have same or similar settings.

Namespace(amp=False, aux_loss=True, backbone='resnet50', backbone_freeze_keywords=None,
batch_norm_type='FrozenBatchNorm2d', batch_size=2, bbox_loss_coef=5, box_noise_scale=0.4, 
clip_max_norm=0.1, cls_loss_coef=1, coco_panoptic_path=None, coco_path='data/coco', contrastive=True, 
dataset_file='coco', debug=False, dec_layers=6, dec_n_points=4, device='cuda', 
dice_loss_coef=1, dilation=False, dim_feedforward=2048, 
dist_backend='nccl', dist_url='env://', distributed=True, drop_lr_now=False, 
dropout=0.0, enc_layers=6, enc_n_points=4, eos_coef=0.1, epochs=12, eval=False, 
find_unused_params=False, finetune_ignore=None, fix_size=False, focal_alpha=0.25, 
frozen_weights=None, giou_loss_coef=2, gpu=0, hidden_dim=256, 
label_noise_scale=0.5, local_rank=0, lr=0.0001, lr_backbone=1e-05, lr_drop=10, 
mask_loss_coef=1, masks=False, modelname='dn_dab_deformable_detr', nheads=8, note='', 
num_feature_levels=4, num_patterns=0, num_queries=900, num_results=300, num_select=300, 
num_workers=10, output_dir='exps/r50_dn_dab_deformable_detr_two_stage_refactor_12epochs', 
pe_temperatureH=20, pe_temperatureW=20, position_embedding='sine', pre_norm=False, 
pretrain_model_path=None, random_refpoints_xy=False, rank=0, remove_difficult=False, 
return_interm_layers=False,
save_checkpoint_interval=10, save_log=False, save_results=False, scalar=200, seed=42, 
set_cost_bbox=5, set_cost_class=2,set_cost_giou=2, 
start_epoch=0, transformer_activation='relu', tsst=False, two_stage=True, 
use_dn=True, weight_decay=0.0001, world_size=8)

from dn-detr.

FengLi-ust avatar FengLi-ust commented on June 8, 2024

I just have a quick look. The lr_drop should be set to 11.

from dn-detr.

FengLi-ust avatar FengLi-ust commented on June 8, 2024

@Vallum Hey, you can pull request so I can merge your code. We can also have discussions if you meet problems.

from dn-detr.

Vallum avatar Vallum commented on June 8, 2024

@FengLi-ust Thank you for the response. Let me just check the result and do the code prepared.

from dn-detr.

seungyonglee0802 avatar seungyonglee0802 commented on June 8, 2024

@Vallum Thank you for your nice work. I just want to ask a tiny question.

In the code to choose Top K proposals from encoder output class, you wrote the code like below
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
But, I think enc_outputs_class[..., 0] is only considering class0. (class number 0 could be different according to a Dataset)
In my opinion, topk_proposals = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] could be better which can choose proposal with considering all classes. (refer to DINO)
I want to hear you opinion and thanks in advance :)

from dn-detr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.