Git Product home page Git Product logo

Comments (13)

MaratZakirov avatar MaratZakirov commented on July 21, 2024 3

@fmassa Thank you that is exactly I was asking for! So you provide information about real image via mask

from detr.

fmassa avatar fmassa commented on July 21, 2024 2

Hi again,

So, I think that the main thing that we need to take into account here is that a Transformer encoder is permutation-equivariant. This means that we can shuffle all the pixels in feature map, and the output of the encoder will be shuffled accordingly. In the same vein, the Transformer decoder is permutation-invariant wrt the feature maps that we feed, which means that the order in which we feed the input pixels doesn't matter.

With that in mind, the only way the transformer can predict relative coordinates is by feeding the positional encoding. I've described in my previous post that the positional encoding takes care for the objects inside the image. But I didn't describe what / how the mask is calculated, which I'm doing now.

If you look at

detr/util/misc.py

Lines 283 to 300 in be9d447

def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
# TODO make this more general
if tensor_list[0].ndim == 3:
# TODO make it support different-sized images
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape = [len(tensor_list)] + max_size
b, c, h, w = batch_shape
dtype = tensor_list[0].dtype
device = tensor_list[0].device
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
for img, pad_img, m in zip(tensor_list, tensor, mask):
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
m[: img.shape[1], :img.shape[2]] = False
else:
raise ValueError('not supported')
return NestedTensor(tensor, mask)
, which is used in the collate_fn that we pass to the DataLoader, you'll see that images of different sizes are padded so that they have the same size. But we also keep another tensor (named mask) around, which hold as information which pixels belong to the image and which ones are just padded and should not be taken into account.

This mask is used at a few places:

  • the positional encoding, in order to properly compute the coordinates inside the image without taking padding into account
  • the transformer, which does not take the padded regions into account

So while the feature maps from the CNN will indeed be different for different image sizes, the transformer will only look at the regions which does not correspond to padding. And for predicting the 0-1 relative coordinates of the boxes, it will also only look at the features inside the image.

This can be a bit hard to explain, but please let us know if this isn't clear and we will try to explain it better.

from detr.

fmassa avatar fmassa commented on July 21, 2024 1

@MaratZakirov about the boxes:
This is needed because we perform the box normalization (to be between 0-1) as the last step of our transformation. Doing it this way is easier because handling crops and some other transformations becomes simpler. Because of that, and in order to keep the transformations correct, we need to scale the boxes as well, so if one day we decide to predict everything in absolute coordinates we just need to change one line, see

boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)

For the area, this is not actually used anymore in the codebase. We would use it to be able to separate between small-medium-large objects. In fact, we might not need to scale it even when it was needed, but for consistency, it was preferable to perform the data transformation over all fiels.

Let us know if you have further questions!

from detr.

alcinos avatar alcinos commented on July 21, 2024

Hi @MaratZakirov
Let me try to clarify the situation with an example. Imagine your input consists of two images, img1 of size (200, 300) and img2 of size (300, 200)

  • "orig_size" corresponds to the size of each image in the batch, before any pre-processing. In this situation, the tensor will contain [[200, 300], [300, 200]]. See for example how the postprocessor uses it to compute the boxes respective to the original size:

    detr/models/detr.py

    Lines 279 to 282 in be9d447

    # and from relative [0, 1] to absolute [0, height] coordinates
    img_h, img_w = target_sizes.unbind(1)
    scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
    boxes = boxes * scale_fct[:, None, :]

  • "size" corresponds to the size of each image in the batch, after pre-processing but before batching. In our validation pre-processing for example, all images are resized so that their small side is 800. The tensor will thus contain [[800, 1200], [1200, 800]]. This field is mostly used in panoptic segmentation, if you're only interested in detection you can ignore it.

  • The actual tensor size will be obtained after batching, which implies padding. In our example, the final size will be something like [2, 3, 1200, 1200]

DETR's predictions, are, by design, always made with respect to the original image size, disregarding potential padding. That's why one can't use the size of the batched tensor as H,W, and must rely instead on the "orig_size" field.

Hope this helps. I think I have answered your question, and as such I'm closing this, but feel free to reach out if anything remains unclear.

from detr.

MaratZakirov avatar MaratZakirov commented on July 21, 2024

!!!! @alcinos code fragment you provide work ONLY in evaluation stage in training problem which I have been described remains.

Although it seems that we understood each other let me describe it more distinctly

Suppose we have batch_size=2 and 3 images in our dataset

  1. with size (500, 500)
  2. with size (1000, 500)
  3. with size (500, 1000)

Lets consider that somehow we switched off preprocessing/augmentations (resizes, crops etc) so "orig_size" will be always equal to "size". But due to random batching we could obtain following batches:

Batch one
with image 1 orig_size=500,500 size=500, 500 as a part of tensor with size = 2, 3, 1000, 500
and image 2 orig_size=1000,500 size=1000, 500 as a part of tensor with size = 2, 3, 1000, 500

Batch two
with image 1 orig_size=500,500 size=500, 500 as a part of tensor with size = 2, 3, 500, 1000
and image 3 orig_size=500, 1000 size=500, 1000 as a part of tensor with size = 2, 3, 500, 1000

In these two cases boxes coordinates for image 1 will be ABSOLUTELY the same. So at training stage DETR will learn same Image 1 coordinates for DIFFERENT Image 1 tensors. This is do not seems to be normal. Is not it?

PS

Maybe DETR implicitly (as super duper strong CNN) clean off padding zeros in tensor and somehow account this fact and at the end refer to the same tensor features in both cases? ...
.

from detr.

MaratZakirov avatar MaratZakirov commented on July 21, 2024

I have made some investigation and put print just before criterion call where all modification to target already have been done and it goes to HungrianMatcher and other cool stuff as it is

    for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        <<<<<< I show image here see images bellow

        outputs = model(samples)
        loss_dict = criterion(outputs, targets)

This is what I get for batch size equal to 1
batch1
This is what I get for batch size equal to 2
batch2
I print bounding rectangles according Input tensor size, but please pay attention that despite tensor size are completely different in there two cases boxes coordinates absolutely the same at training stage!!!!

from detr.

fmassa avatar fmassa commented on July 21, 2024

Hi,

I believe the issue arises due to a bug in your visualization code.

Before pasting the boxes in the image, you need to crop the image so that it is of target['size'], in order to take into account the padding that was added by batching multiple images together.

Here is what I would add to your implementation:

H, W = target['size']
img = img[:, :H, :W]
boxes[:, 0::2] *= W
boxes[:, 1::2] *= H

Note that our predictions are made in 0-1, and we take extra care in our positional encoding to account for the regions which are only inside the image, see

not_mask = ~mask
y_embed = not_mask.cumsum(1, dtype=torch.float32)
x_embed = not_mask.cumsum(2, dtype=torch.float32)

which is basically a batched mesh-grid, but that takes the original image size into account to avoid the issues you mentioned.

from detr.

MaratZakirov avatar MaratZakirov commented on July 21, 2024

Hi @fmassa my question was not about visualization, visualization I made is needed just for illustration. And if you read my post carefully you would see following.

        H, W = target['size'] <<< Works well only with that
        W, H = img.width, img.height <<< But must works with this!!!

So I already know what you are trying tell me, but my question is about training, not visialization which is just an illustration, and you may forget about green rectangles look at white numbers at the left top of an image which I printed.

The point I tried to make clear as possible is that DETR is forced to learn SAME boxes coordinates for DIFFERENT tensor sizes (due to batch padding) which is as I believe is completely wrong. Look at two images above. They have different WxH but coupled with SAME coordinates which I printed (using white color).

So in one case (one batch) we asking DETR to output one set of coordinates (which are {0:1}) for phone for example.
And in batch two we asking DETR to output same set of coordinates as for previous case but for padded differently same image.

This scheme could be trained successfully only in case if DETR is smart enough to account paddings implicitly to pay attention to the same visual features in both batches (or you must provide 'size' filed explicitly as an NN input and hope for the best)

from detr.

MaratZakirov avatar MaratZakirov commented on July 21, 2024

So for example two batches of size 2 both has same image at position of 0
So formally we have:
B1.shape = (2, 3, H1, W1)
B2.shape = (2, 3, H2, W2)
H_c = min(H1, H2)
W_c = min(W1, W2)
B1[0, :, 0:H_c, 0:W_c] == B2[0, :, 0:H_c, 0:W_c]
B1.targets[0][boxes] == B2.targets[0].boxes

Suppose we have important 'phone' feature at address [0, :, i, j]
So DETR according to it is asked to produce same coordinates (a, b, h, w)

image

from detr.

linzzzzzz avatar linzzzzzz commented on July 21, 2024

Hi @MaratZakirov
Let me try to clarify the situation with an example. Imagine your input consists of two images, img1 of size (200, 300) and img2 of size (300, 200)

* "orig_size" corresponds to the size of each image in the batch, _before_ any pre-processing. In this situation, the tensor will contain [[200, 300], [300, 200]]. See for example how the postprocessor uses it to compute the boxes respective to the original size: https://github.com/facebookresearch/detr/blob/be9d447ea3208e91069510643f75dadb7e9d163d/models/detr.py#L279-L282

* "size" corresponds to the size of each image in the batch, _after_ pre-processing but _before_ batching. In our validation pre-processing for example, all images are resized so that their small side is 800. The tensor will thus contain [[800, 1200], [1200, 800]]. This field is mostly used in panoptic segmentation, if you're only interested in detection you can ignore it.

* The actual tensor size will be obtained after batching, which implies padding. In our example, the final size will be something like [2, 3, 1200, 1200]

DETR's predictions, are, by design, always made with respect to the original image size, disregarding potential padding. That's why one can't use the size of the batched tensor as H,W, and must rely instead on the "orig_size" field.

Hope this helps. I think I have answered your question, and as such I'm closing this, but feel free to reach out if anything remains unclear.

Hi @alcinos, I have a quick question on this last comment "DETR's predictions, are, by design, always made with respect to the original image size, disregarding potential padding".

I saw in transforms.py you are adjusting boxes and area of target when doing resizing. However, why would you need to adjust target if DETR's predictions always made with respect to the original image size?

detr/datasets/transforms.py

Lines 114 to 123 in 10a2c75

target = target.copy()
if "boxes" in target:
boxes = target["boxes"]
scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
target["boxes"] = scaled_boxes
if "area" in target:
area = target["area"]
scaled_area = area * (ratio_width * ratio_height)
target["area"] = scaled_area

from detr.

jiangwei221 avatar jiangwei221 commented on July 21, 2024

Hi DETR team, I have the same question as @MaratZakirov mentioned before.
I understand that Transformer encoder is permutation-equivariant, and the Transformer decoder is permutation-invariant.
But actually the black border will change the values of feature maps through convolutions, if I understand correctly.

I did following runs, and printed the values after this line:

outputs_coord = self.bbox_embed(hs).sigmoid()

python main.py --batch_size 1 --no_aux_loss --eval --resume https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --coco_path "d:/dataset/coco"

In [5]: outputs_coord[-1][0,0,:]
Out[5]: tensor([0.0791, 0.5397, 0.0183, 0.0526], device='cuda:0')

python main.py --batch_size 2 --no_aux_loss --eval --resume https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth --coco_path "d:/dataset/coco"

In [1]: outputs_coord[-1][0,0,:]
Out[1]: tensor([0.0787, 0.5397, 0.0188, 0.0531], device='cuda:0')

The only difference between these two runs is the batch size, and you can see the predicted values are not exactly the same.

Here is a paper I'd like to mention: https://arxiv.org/abs/2001.08248, But I'm not sure it's is related.

from detr.

fmassa avatar fmassa commented on July 21, 2024

@jiangwei221 correct, the padding will affect the convolution output, and thus potentially change the output of the model.

I would like to note thought that the same happens with other methods as well, such as Faster RCNN.

from detr.

jiangwei221 avatar jiangwei221 commented on July 21, 2024

I see, thanks for the clarification!

from detr.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.