Git Product home page Git Product logo

Comments (6)

mseitzer avatar mseitzer commented on August 20, 2024

Hi,

thank you for spotting this. This is indeed a bug.

I initially copied the normalization code from the Pytorch implementation (https://github.com/pytorch/vision/blob/f87a896f170a34502a60cdc358c19d8b55e72f54/torchvision/models/inception.py#L73-L76), naively assuming that it normalizes the input to the format the pretrained network expects.

But looking into it more, what this really does is transform from inputs that have been normalized using

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=std)])

to inputs in range [-1, 1], which is what Tensorflow's pretrained Inception network expects. So it reverses the zero mean, unit variance normalization.

I will submit a fix for this shortly. In the meantime, you could scale your images to range [-1, 1] from range [0, 1] using 2 * x - 1 and run with transform_input=false.

Best.
Max

from pytorch-fid.

XavierXiao avatar XavierXiao commented on August 20, 2024

I still has a question about normalization. As you said, the Tensorflow's pretrained Inception network expects input in [-1,1], but if you load pretrained pytorch models, it sees like it expects inputs that are in [0,1] and then normalized in a specific way. Quote from pytorch official site:

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225].

So if you are using the pretrained pytorch model, you should normalize the imgae in this way, rather than to [-1,1] right?

from pytorch-fid.

mseitzer avatar mseitzer commented on August 20, 2024

As I said above, Inception expects inputs to be in [-1, 1]. This is true for Pytorch's Inception as well, as they just ported Tensorflow's weights to Pytorch. So what the Pytorch implementation internally does is it transforms the input from mean/std normalized images to [-1, 1]. This is so that all torchvision models take the same input format, namely mean/std normalized images, as you quoted above.

from pytorch-fid.

AtlantixJJ avatar AtlantixJJ commented on August 20, 2024

Hi, I noticed you used [0,1] image scale in the code

        images = np.array([imread(str(f)).astype(np.float32)
                           for f in files[start:end]])

        # Reshape to (n_images, 3, height, width)
        images = images.transpose((0, 3, 1, 2))
        images /= 255

while TF's version is using [-1, 1] image scale. Is this a bug?

from pytorch-fid.

mseitzer avatar mseitzer commented on August 20, 2024

Scaling to range [-1, 1] happens inside the model:

x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)

from pytorch-fid.

GiangHLe avatar GiangHLe commented on August 20, 2024

@mseitzer hi, I know it's late, but I'm trying to implement a more flexible FID version based on your code. I knew that you scale the input from [0,1] to [-1,1] but as I know, the Inception model expects the input image is scaled by ImageNet normalized factor which are mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]. Therefore, it will be better if we modify the transform when we read data. Am I right?

from pytorch-fid.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.