Git Product home page Git Product logo

vicentevivan / geo-clip Goto Github PK

View Code? Open in Web Editor NEW
103.0 103.0 19.0 41.32 MB

This is an official PyTorch implementation of our NeurIPS 2023 paper "GeoCLIP: Clip-Inspired Alignment between Locations and Images for Effective Worldwide Geo-localization"

Home Page: https://arxiv.org/abs/2309.16020

License: MIT License

Python 100.00%
deep-learning geography geolocalization geolocation-estimation gps-embeddings machine-learning pytorch

geo-clip's People

Contributors

vicentevivan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

geo-clip's Issues

Performance Improvement: Precompute and Cache Location Features

Hello,

I've been using the GeoCLIP model and noticed potential for performance improvement by precomputing and caching the location features. This change can significantly speed up inference time by avoiding redundant computations and data transfers.

Proposed Changes

Current Implementation:

def forward(self, image, location):
    # Compute Features
    image_features = self.image_encoder(image)
    location_features = self.location_encoder(location)
    logit_scale = self.logit_scale.exp()

    # Normalize features
    image_features = F.normalize(image_features, dim=1)
    location_features = F.normalize(location_features, dim=1)

    # Cosine similarity (Image Features & Location Features)
    logits_per_image = logit_scale * (image_features @ location_features.t())

    return logits_per_image

@torch.no_grad()
def predict(self, image_path, top_k):
    image = Image.open(image_path)
    image = self.image_encoder.preprocess_image(image)
    image = image.to(self.device)

    gps_gallery = self.gps_gallery.to(self.device)

    logits_per_image = self.forward(image, gps_gallery)
    probs_per_image = logits_per_image.softmax(dim=-1).cpu()

    # Get top k predictions
    top_pred = torch.topk(probs_per_image, top_k, dim=1)
    top_pred_gps = self.gps_gallery[top_pred.indices[0]]
    top_pred_prob = top_pred.values[0]

    return top_pred_gps, top_pred_prob

Modified Implementation:

class GeoCLIP(nn.Module):
    def __init__(self, from_pretrained=True, queue_size=4096):
        super().__init__()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.image_encoder = ImageEncoder()
        self.location_encoder = LocationEncoder()

        self.gps_gallery = load_gps_data(os.path.join(file_dir, "gps_gallery", "coordinates_100K.csv"))
        self._initialize_gps_queue(queue_size)

        if from_pretrained:
            self.weights_folder = os.path.join(file_dir, "weights")
            self._load_weights()

        self.device = "cpu"
        self.tensors_gps_gallery = self.gps_gallery.to(self.device)
        self.location_features = self.location_encoder(self.tensors_gps_gallery)

    def forward(self, image, location):
        # Compute Features
        image_features = self.image_encoder(image)
        location_features = self.location_features
        logit_scale = self.logit_scale.exp()

        # Normalize features
        image_features = F.normalize(image_features, dim=1)
        location_features = F.normalize(location_features, dim=1)

        # Cosine similarity (Image Features & Location Features)
        logits_per_image = logit_scale * (image_features @ location_features.t())

        return logits_per_image

    @torch.no_grad()
    def predict(self, image_path, top_k):
        image = Image.open(image_path)
        image = self.image_encoder.preprocess_image(image)
        image = image.to(self.device)

        logits_per_image = self.forward(image, self.tensors_gps_gallery)
        probs_per_image = logits_per_image.softmax(dim=-1).cpu()

        # Get top k predictions
        top_pred = torch.topk(probs_per_image, top_k, dim=1)
        top_pred_gps = self.gps_gallery[top_pred.indices[0]]
        top_pred_prob = top_pred.values[0]

        return top_pred_gps, top_pred_prob

Benefits

  1. Reduced Redundant Computation: By precomputing the location_features during initialization, the modified code reduces redundant computations in every forward pass.
  2. Avoiding Redundant Data Transfer: The modified code avoids repeatedly transferring the gps_gallery to the device, which can further improve performance, especially if the gps_gallery is large.

Potential Issues

  • This change assumes the gps_gallery remains constant. If the gps_gallery changes dynamically, additional handling will be needed to update the cached location_features.

Conclusion

I believe this change can significantly improve the model's inference speed. I'm happy to discuss this further or assist with implementation and testing.

Thank you for your consideration!

Ten Crop benchmark

@VicenteVivan It is mentioned in the paper that a ten crop method is taken to evaluation where you average your prediction over all these 10 cropped images. How do you perform the averaging?

For example, you could average the predicted GPS coordinates or you could average the embeddings before you evaluate. Both methods will give very different results. Thankful for any answer ^^

Some questions

Dear author,
I appreciate your work and would like to get the details of Geoclip's training. Can you publish the complete training code? Thank you.

Also, in the geoclip code, self. opt is not defined in init. How can this be changed?

    def _dequeue_and_enqueue(self, gps):
        """ Update GPS queue

    Args:
        gps (torch.Tensor): GPS tensor of shape (batch_size, 2)
    """
        opt = self.opt
        gps_batch_size = gps.shape[0]
        batch_size = opt.batch_size

        gps_ptr = int(self.gps_queue_ptr)
        assert self.queue_size % batch_size == 0

        # Replace the GPS from ptr to ptr+batch_size (dequeue and enqueue)
        self.gps_queue[:, gps_ptr:gps_ptr + gps_batch_size] = gps.t()
        gps_ptr = (gps_ptr + batch_size) % self.queue_size  # move pointer
        self.gps_queue_ptr[0] = gps_ptr

In the following code, weather the self.gps of self._dequeue_and_enqueue(self.gps) should be modified to self.gps_queue?

    def append_gps_queue_features(self, gps_features):
        """ Compute the GPS queue features and append them to the given GPS features."""
        # Get the GPS queue features
        location_queue = self.gps_queue.t().detach()
        gps_queue_features = self.location_encoder(location_queue)
        gps_queue_features = F.normalize(gps_queue_features, dim=1)

        # Concatenate Features (GPS Features & GPS Queue Features)
        gps_features = torch.cat([gps_features, gps_queue_features], dim=0)

        # Update GPS queue
        self._dequeue_and_enqueue(self.gps) 

        return gps_features

Looking forward to your response, thank you.

Request for detailed GEO-CLIP training code

Thank you very much for the excellent work you are doing! I want to try to train it by myself, but I find that there is only a simple training loop python code on your github. Could you please realse your detailed and completed training code? Thank you very much!

Questions regarding the Loss function

Dear authors,

I am currently working on reproducing the results from your paper. It doesn't seem like you haven't included any code regarding the implementation of your loss function, and I therefore have some questions on the matter.

image

From my understanding of the loss, you have modified it in order to account for the dynamical queue (additional gps embeddings).
$P$ - corresponds to the different views from an image of a given batch, lets take it as being 1 view for simplicity.
$V$ - is the embedded image
$L$ - is the embedded GPS coordinate

This simplifies the Loss for a single view of a single image in a batch to the following:

$$L_i = - \log \frac{ \exp(V_i \cdot L_i / \tau)}{\sum_{i = 0} \exp(V_i \cdot L_i / \tau) + \sum_{i = 0} \exp(V_i \cdot \tilde{L}_i / \tau)}$$

Where in the denominator, the first sum is for a batch of length B, and the second sum is for the dynamic queue of length S.

My questions are the following:

    1. It seems like you are using the same index $i$ for both the $i^{th}$ sample of a batch, the sum over the batch, and the sum over the dynamical queue. Did you mean to take something like the loss below (index $i$ changed to $k$ in the denominator)?

$$L_i = - \log \frac{ \exp(V_i \cdot L_i / \tau)}{\sum_{k = 0} \exp(V_i \cdot L_k / \tau) + \sum_{k = 0} \exp(V_i \cdot \tilde{L}_k / \tau)}$$

By doing so, you do contrastive learning of each image over all other coordinates while keeping the same image $V_i$ in the denominator.

    1. If it is true that you do contrastive learning of each image over all other coordinates, why did you decide not to do contrastive learning of each GPS coordinate over all other images? In fact in the original CLIP paper, the Cross Entropy Loss is utilized both horizontally and vertically, yet you have chosen only to use it horizontally. Is there a specific reason for this decision?
    1. Going back to the $P$ augmented views, you mention in your paper that a benefit of using a frozen CLIP backbone is that one can pre-encode all images, making the training process faster. Yet if you perform $P$ augmentations for each image and for each batch, didn't you have to re-encode the augmented images again, thus not being able to take advantage from this benefit?

I look forward to hearing from you! Thanks.

Clarification of dataset splits

Thanks to authors for sharing their code.

Can the dataset splits be clarified? That is, can you provide the dataset(s) used for training and validation, as well as the specific split if training / validation are coming from the same dataset?

It is unclear what specifically comprises the validation dataset from the paper (e.g., whether it is a split of MP16, etc). This detail is important to ensure recreation of experiments and fair comparisons of future work.

Thanks!

Higher Resolution for GPS Coordinates

Thanks for the great work!

Did I understand correctly that the sigma parameter controls the resolution of the frequencies and if you need a higher resolution for the GPS coordinates you have to increase it? You use [20, 24, 2**8] for a resolution of up to one km, what about metre-level resolution?

Thank you very much

When will the testing code be released?

Dear author,
Great work! I have recently been developing a new model to solve the same task and would like to have your TESTING code to test the performance of my model.

I would appreciate your generosity. Looking forward to your response, thank you.

Loss function used for training

Hello, me and a partner are interested in finetuning the GeoCLIP model however we are unsure of the implementation of the loss function. Could you share the loss function you used or give any tips for implementing it?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.