Git Product home page Git Product logo

Comments (3)

xeroxM avatar xeroxM commented on May 27, 2024 8

Hey Guys,

I found a solution for this. As I found out, the name of the keys from the downloaded torchvision model and the trained Places365 model just don't match. I'm not the biggest python pro, so I came up with a pretty simple and barbaric solution (feel free to refactor it!).

model = models.__dict__[arch](num_classes=365)
checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)

state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
state_dict = {str.replace(k,'norm.','norm'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'conv.','conv'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normweight','norm.weight'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normrunning','norm.running'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'normbias','norm.bias'): v for k,v in state_dict.items()}
state_dict = {str.replace(k,'convweight','conv.weight'): v for k,v in state_dict.items()}

I am basically just replacing the key names of the Places365 model with the key names of the model downloaded from the torchvision.models package. I tested the model, and for me it works :)

@RiSaMa strict=False ignores all not matching keys. So the initial model keeps his initial weights when no matching key is found in the new state_dict (in our case i think it's all of them). As you most likely are loading an untrained model from the torchvision.models package this results in random performance.

from places365.

zhoubolei avatar zhoubolei commented on May 27, 2024

Due to the upgrading of pytorch0.4, the densenet model structure is also updated, thus the old model cannot be loaded properly. I don't have solution for that yet.

from places365.

RiSaMa avatar RiSaMa commented on May 27, 2024

Same problem here... any solution?

I have been able to load the model as follows:

    model = models.__dict__[arch](num_classes=365)
    checkpoint = torch.load(model_file, map_location=lambda storage, loc: storage)
    state_dict = {str.replace(k,'module.',''): v for k,v in checkpoint['state_dict'].items()}
    model.load_state_dict(state_dict,strict=False) #NOTICE THE strict=False
    model.cuda()
    model.eval()

But the performance is almost random...

Any idea?

from places365.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.