Git Product home page Git Product logo

Comments (2)

AntiLibrary5 avatar AntiLibrary5 commented on July 20, 2024 11

A simple implementation is like this assuming the model with standard image and patch size:

def patch_cosSim(model, pos_i, pos_j):
    """
    Arguments:
    model: ViT
    pos_i: ith patch (int)
    pos_j: jth patch (int)

    returns: cos similarity of the pos emb of the (i,j) patch
    """
    cos = nn.CosineSimilarity(dim=0)
    s = model.embeddings.position_embeddings.shape
    pos_patch = model.embeddings.position_embeddings.view(*s[1:])[:-1].view(14,14,-1)
    n_rows = pos_patch.shape[0]
    n_cols = pos_patch.shape[1]
    patches = []
    for i in range(n_rows):
        for j in range(n_cols):
            patches.append(cos(pos_patch[pos_i][pos_j], pos_patch[i][j]).detach().cpu().numpy())
    return patches

Plot the cos similarity of the pos emb of the first patch:

i = 0
j = 1
patch = patch_cosSim(i, j)
ax = fig.add_subplot()
ax.set_axis_off()
ax.imshow(np.array(patch).reshape(14,14), cmap='hot', interpolation='nearest')

image

You can put it in a loop to get the cos similarity of the positional encoding for all the patches.
Hope it helps.

[Edit: added plot]

from vision_transformer.

lucasb-eyer avatar lucasb-eyer commented on July 20, 2024 9

Hi, will not share code, but it is basically just distance (I don't remember if dot product or euclidean) between all position embeddings to all other position embeddings.

from vision_transformer.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.