Git Product home page Git Product logo

Comments (15)

mathildecaron31 avatar mathildecaron31 commented on August 19, 2024 5

Hi @KeremTurgutlu , let me open a new issue :)

@enverfakhan I have incorporated your suggested fix for the floating point error and have also been trying to improve the forward logic in the vision_transformer.py code. Thanks a lot for your suggestion and feedback is appreciated if you do have some time :).

dino/vision_transformer.py

Lines 174 to 233 in 6687929

def interpolate_pos_encoding(self, x, w, h):
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
class_pos_embed = self.pos_embed[:, 0]
patch_pos_embed = self.pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_embed.patch_size
h0 = h // self.patch_embed.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def prepare_tokens(self, x):
B, nc, w, h = x.shape
x = self.patch_embed(x) # patch linear embedding
# add the [CLS] token to the embed patch tokens
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
# add positional encoding to each token
x = x + self.interpolate_pos_encoding(x, w, h)
return self.pos_drop(x)
def forward(self, x):
x = self.prepare_tokens(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0]
def get_last_selfattention(self, x):
x = self.prepare_tokens(x)
for i, blk in enumerate(self.blocks):
if i < len(self.blocks) - 1:
x = blk(x)
else:
# return attention of the last block
return blk(x, return_attention=True)
def get_intermediate_layers(self, x, n=1):
x = self.prepare_tokens(x)
# we return the output tokens from the `n` last blocks
output = []
for i, blk in enumerate(self.blocks):
x = blk(x)
if len(self.blocks) - i <= n:
output.append(self.norm(x))
return output

I'm closing this issue. Feel free to reopen is there is other problem related to the interpolation of the positional encodings.

from dino.

enverfakhan avatar enverfakhan commented on August 19, 2024 4

Hi @mathildecaron31 thanks for the response and also I appreciate the insight about the interpolation vs separate pos_embed a lot. I would be curios about how would that behave in the wild with completely different sizes. I actually tried the deit_small(patch_size=8) for retrieval task on a in-house data, it seems to be working on par with a supervised vgg imagenet, however I had to set the image sizes to [224, 224] because some of the images blow the memory during the attention computation.

About the workaround for the floating point error, I feel like incrementing the w0 and h0 a small amount is more legit than zero padding the pos_embed but it is probably not a big deal especially if the image size is relatively big.

Looking forward for the Dino on large, random, uncurated dataset.

from dino.

mathildecaron31 avatar mathildecaron31 commented on August 19, 2024 3

I actually tried the deit_small(patch_size=8) for retrieval task on a in-house data, it seems to be working on par with a supervised vgg imagenet,

That's slightly disappointing :/. Have you tried the other models ? For example ViT-Base/16 should be more manageable memorywise. As a matter of fact, on copy detection datasets, I've found the base models to perform clearly better than the small ones: I get better performance with Base16x16 than with Small8x8 though Small8x8 is better at k-NN ImNet.

About the workaround for the floating point error, I feel like incrementing the w0 and h0 a small amount is more legit than zero padding the pos_embed but it is probably not a big deal especially if the image size is relatively big.

Yes your solution is definitely better ! I'll update that in the code.

from dino.

steve-landers avatar steve-landers commented on August 19, 2024 3

Is there a reason the interpolate call doesn't set the output size directly to (w0,h0) using the size parameter, rather than using the scale_factor parameter?

from dino.

enverfakhan avatar enverfakhan commented on August 19, 2024 2

I realized this solution is prone to floating point error, an example for such an error would be the following

npatch = 3904
N = 784
w, h = 491, 532
self.patch_embed.patch_size = 8
w0 = w // self.patch_embed.patch_size  # 61
h0 = h // self.patch_embed.patch_size   # 64
pos_embed = nn.functional.interpolate(
            pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),  
            mode='bicubic',
        )
pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)  # pos_embed.shape[1] is 3840 (60 * 64) which is supposed to be 3904 (61 * 64)

This happens because w0 / math.sqrt(N) * math.sqrt(N) is not equal w0 and nn.functional.interpolate casts int on the scale_factor (I assume) and w0 / math.sqrt(N) * math.sqrt(N) is something like 60.999999999.. and it become 60 when int is casted on. Now a solution for this problem would be adding a small number (0.1) to w0 and h0.

However this brings another question, does interpolating the positional embedings makes sense?, I mean does this operation happen during training or all the images in the training are just in the right shape and interpolation never occur? Because in that case interpolating positional embeddings (learned positional embedding) would be illegitimate right?

from dino.

enverfakhan avatar enverfakhan commented on August 19, 2024 2

I actually tried the deit_small(patch_size=8) for retrieval task on a in-house data, it seems to be working on par with a supervised vgg imagenet

I guess I caused a misinformation unintentionally. The images were RGBA and I was treating them as RGB. I'm sorry if I caused any confusion. However after I accounted for that image format, the result still varies from query (image) to query. In some cases Dino outperforms a vgg_16 ImNet by far, but in some other cases they are almost on par or even worse. I haven't detected a consistent pattern for which images that Dino outperform or under-perform, but so far it seems like, anecdotally, Dino outperform in colorful images and it is on par (or worse) with vgg_16 ImNet for black and white images. By the way, I'm testing OpenAI's clip model with ViT too and Clip model seems to be the worst among the three (I was betting on the clip model that it would be the best, but couldn't be more wrong :) )

it's hard to come up with a quantitative evaluation. The images are multi--tagged and we are trying to retrieve similar images given a query. The tags are not reliable for evaluation because some similar images doesn't share any tag, or reverse is also the case, different images may share some common tag (instagram logo vs instagram app image). I'm planning to do a finetunining as multi class classification and try to get some numeric assessment out of that.

I wonder if finetuning DINO models on the in-house data you have might help?

I strongly believe it would help, but I wonder which model would perform better after finetuning with each. However the preliminary result that Dino is working better on colorful images is worth to pay attention.

from dino.

woctezuma avatar woctezuma commented on August 19, 2024 2

By the way, I'm testing OpenAI's clip model with ViT too and Clip model seems to be the worst among the three (I was betting on the clip model that it would be the best, but couldn't be more wrong :) )

Not surprised. :D cf. openai/CLIP#1 and https://openai.com/blog/multimodal-neurons/

from dino.

enverfakhan avatar enverfakhan commented on August 19, 2024 2

Thank you for the reference to the issue, it was super fun to read and check the examples (and of course eye opening :) ). So the Clip model is out of option.

from dino.

mathildecaron31 avatar mathildecaron31 commented on August 19, 2024 1

Hi @enverfakhan

Thanks for raising this issue. Yes I totally agree that something could be done to simplify/unify a bit the code there...

For the floating point error I've found that workaround:

dino/vision_transformer.py

Lines 235 to 240 in 1d06521

if w0 != patch_pos_embed.shape[-2]:
helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
if h0 != patch_pos_embed.shape[-1]:
helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)

I mean does this operation happen during training or all the images in the training are just in the right shape and interpolation never occur?

This operation actually happens during training. Indeed, with multi-crop the model is trained both with images of 224x224 and images of 96x96. So when forwarding a batch of 96^2 images we need to interpolate the encodings. As a matter of fact in my experiments I also tried having two sets of encodings. In practice that means that I was using differents encodings for the 224x224 and for the 96x96 inputs. This solution has exactly the same performance as when performing bicubic interpolation, which makes me think that the interpolation solution makes sense.

from dino.

enverfakhan avatar enverfakhan commented on August 19, 2024 1

I picked the Small 8x8 because it was shown that that performs better with k-NN ImNet and because I was going to try with zero shot for retrieval task this choice mad more sense at the time. The result was indeed slightly disappointing, however I haven't experimented exhaustively and I dont have quantitative result either, I only check qualitatively which you can only do it for a handful of query, so this result is not definitive at all. But I should say the in-house data is very different than the imagenet, so I wouldn't be very surprised if I got some weird result with either model.

from dino.

KeremTurgutlu avatar KeremTurgutlu commented on August 19, 2024 1

But I should say the in-house data is very different than the imagenet, so I wouldn't be very surprised if I got some weird result with either model.

I wonder if finetuning DINO models on the in-house data you have might help? But you mentioned, that results are on par with vgg pretrained on imagenet so I am not very sure. Probably still worth trying.

from dino.

KeremTurgutlu avatar KeremTurgutlu commented on August 19, 2024

@mathildecaron31 I have a question about copy detection. I am trying to evaluate the pretrained DINO models on a dataset for copy detection task and I am trying to follow the steps from the paper. Even with different image input sizes in Table 4 we see that final embedding dimension is 1536. I am not able to understand how we can get same embedding dimension after concatenating CLS embedding and GeM pooled output patch tokens for different input image sizes. Maybe I am missing a point here. Here is what I did:

Added the following method to VisionTransformer to return output patch tokens and cls output.

def forward_output_patch_tokens_cls(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        pos_embed = self.interpolate_pos_encoding(x, self.pos_embed)
        x = x + pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)
        if self.norm is not None:
            x = self.norm(x)

        return x

Using GeM module from here

def gem(x, p=3, eps=1e-6):
    "x: BS x num tokens x embed_dim"
    return F.avg_pool1d(x.clamp(min=eps).pow(p), (x.size(-1))).pow(1./p)
    
class GeM(nn.Module):

    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = nn.Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)
        
    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'

Collect embeddings (CLS + GeM Pooled Output Patch Tokens)

all_image_features = []
with torch.no_grad():
    for imgb in progress_bar(image_dl):
        outputs = model.forward_output_patch_tokens_cls(imgb.cuda())
        cls_token, output_patch_tokens = outputs[:,0],outputs[:,1:]
        
        cls_features   = cls_token   
        patch_features = gem_pooling(output_patch_tokens.permute(0,2,1)).squeeze(-1)
        concat_features = torch.cat([cls_features,patch_features],dim=-1)
        all_image_features.append(concat_features.cpu())

Following this and using an image size of 224 for dino_vitb8 my final embedding dimension is 1568 1536. Which can also be calculated as:

cls_feature_dim*2 = 768*2

Question
Also, during copy detection task do you learn the pooling parameter p or is it picked based on validation set? I didn't quite understand the whitening part is it same as regular unsupervised PCA?

Found this paper: https://hal.inria.fr/hal-00722622v2/document. I believe idea is coming from here.

Edit:

Figured out the 1536 dimension size. We need to pool across token positions, so this gives pooled embedding with same dimension as cls token embedding dimension.

from dino.

NightMachinery avatar NightMachinery commented on August 19, 2024

@mathildecaron31

This operation actually happens during training. Indeed, with multi-crop the model is trained both with images of 224x224 and images of 96x96. So when forwarding a batch of 96^2 images we need to interpolate the encodings. As a matter of fact in my experiments I also tried having two sets of encodings. In practice that means that I was using differents encodings for the 224x224 and for the 96x96 inputs. This solution has exactly the same performance as when performing bicubic interpolation, which makes me think that the interpolation solution makes sense.

patch_pos_embed = nn.functional.interpolate(
    patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
        0, 3, 1, 2
    ),
    scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
    mode="bicubic",
)

So when a 96x96 cropped image is fed to the model during its training, the positional embeddings of the original 224x224 model get scaled down to 96x96? (assuming a patch size of 8, pos embeddings of size 392x392 will get downsized to 72x72)

Couldn't we just use a subset of the original embeddings for the smaller images? Like in text models, we can give the model smaller sequences with no problems. Again assuming a patch size of 8, this means gettings the first 72x72 subset of the whole 392x392 positional embeddings.

I guess the current interpolation regime will make model more invariant to the scale of the images ...

from dino.

alexaatm avatar alexaatm commented on August 19, 2024

Hi!
I stumbled on the same issue when using dinov2, the code crashed on the same function when using rectangular input...

In the function to encode positions, this github issue was referenced:

def interpolate_pos_encoding(self, x, w, h):
        previous_dtype = x.dtype
        npatch = x.shape[1] - 1
        N = self.pos_embed.shape[1] - 1
        if npatch == N and w == h:
            return self.pos_embed
        pos_embed = self.pos_embed.float()
        class_pos_embed = pos_embed[:, 0]
        patch_pos_embed = pos_embed[:, 1:]
        dim = x.shape[-1]
        w0 = w // self.patch_size
        h0 = h // self.patch_size
        print(f'DEBUG dinov2 vision_trasnformer.py: w0={w0}, h0={h0}')
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
        print(f'DEBUG dinov2 vision_trasnformer.py: add small number w0={w0}, h0={h0}')

        sqrt_N = math.sqrt(N)
        sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
            scale_factor=(sx, sy),
            mode="bicubic",
            antialias=self.interpolate_antialias,
        )
        print(f'DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape={patch_pos_embed.shape}')

        assert int(w0) == patch_pos_embed.shape[-2]
        assert int(h0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)

I checked the output by feeding rectangular image and found out the small addition did not change the w and h, see output:

 image_crop.shape=torch.Size([1, 3, 434, 546])
DEBUG dinov2 vision_trasnformer.py: w0=31, h0=39
DEBUG dinov2 vision_trasnformer.py: add small number w0=31.0, h0=39.0
DEBUG dinov2 vision_trasnformer.py: patch_pos_embed.shape=torch.Size([1, 384, 31, 38])

Note that in the init, interpolate_offset=0.1.
Here are the errors I got:

File "home/.cache/torch/hub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py", line 204, in interpolate_pos_encoding
    assert int(w0) == patch_pos_embed.shape[-2]
AssertionError

Note: used pretrained dinov2_vits14_reg model.

from dino.

zshn25 avatar zshn25 commented on August 19, 2024

Is there a reason the interpolate call doesn't set the output size directly to (w0,h0) using the size parameter, rather than using the scale_factor parameter?

+1

Why cannot we just do

patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, h0, w0, dim).permute(0, 3, 1, 2),
            mode="bicubic",
            antialias=self.interpolate_antialias,
            **kwargs,
        )

from dino.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.