Git Product home page Git Product logo

Comments (22)

poppinace avatar poppinace commented on September 8, 2024 1

@983 I work at home now. When I return to my office, I'll share you with some pieces of code on how I crop images.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024 1

@983 Here is the code I use to randomly crop images:

`class RandomCrop(object):
"""Crop randomly the image

Args:
    output_size (int): Desired output size. If int, square crop
        is made.
    scales (list): Desired scales
"""

def __init__(self, output_size, scales):
    assert isinstance(output_size, int)
    self.output_size = output_size
    self.scales = scales

def __call__(self, sample):
    image, alpha = sample['image'], sample['alpha']
    h, w = image.shape[:2]

    if min(h, w) < self.output_size:
        s = (self.output_size + 180) / min(h, w)
        nh, nw = int(np.floor(h * s)), int(np.floor(w * s))
        image, alpha = resize_image_alpha(image, alpha, nh, nw)
        h, w = image.shape[:2]

    crop_size = np.floor(self.output_size * np.array(self.scales)).astype('int')
    crop_size = crop_size[crop_size < min(h, w)]
    crop_size = int(random.choice(crop_size))

    c = int(np.ceil(crop_size / 2))
    mask = np.equal(image[:, :, 3], 128).astype(np.uint8)
    if mask[c:h-c+1, c:w-c+1].sum() != 0:
        mask_center = np.zeros((h, w), dtype=np.uint8)
        mask_center[c:h-c+1, c:w-c+1] = 1
        mask = (mask & mask_center)
        idh, idw = np.where(mask == 1)
        ids = random.choice(range(len(idh)))
        hc, wc = idh[ids], idw[ids]
        h1, w1 = hc-c, wc-c
    else:
        idh, idw = np.where(mask == 1)
        ids = random.choice(range(len(idh)))
        hc, wc = idh[ids], idw[ids]
        h1, w1 = np.clip(hc-c, 0, h), np.clip(wc-c, 0, w)
        h2, w2 = h1+crop_size, w1+crop_size
        h1 = h-crop_size if h2 > h else h1
        w1 = w-crop_size if w2 > w else w1

    image = image[h1:h1+crop_size, w1:w1+crop_size, :]
    alpha = alpha[h1:h1+crop_size, w1:w1+crop_size, :]

    if crop_size != self.output_size:
        nh = nw = self.output_size
        image, alpha = resize_image_alpha(image, alpha, nh, nw)

    return {'image': image, 'alpha': alpha}`

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

Hi, Your training details look good to me.
Such a reminder that the corrdinate pairs should have alpha in [0,1], rather than (0,1).

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

I know that your boss said that you can not release the training code, but can you maybe release an example of the cropped images, trimaps and alpha in a training batch? Maybe I can find a visual difference.

For example, this is a batch from my training code where I tried without resizing and discarding cropped alpha regions with mean alpha below 0.2 and above 0.8 to better focus on the unknown region:

| image | trimap | mask of unknown region for loss function | ground truth alpha | predicted alpha |

100000

the corrdinate pairs should have alpha in [0,1], rather than (0,1).

I don't understand how that would help because all alpha values are in [0, 1].

Xu et al. (Deep Image Matting) say:

First, we randomly crop 320×320 (image, trimap) pairs centered on pixels in the unknown regions.

As I understand, this means that they choose a random rectangle in the ground truth alpha matte and discard it if the center pixel is known, or in other words, the center pixel is 100% foreground or 100% background.

Maybe my English is not good, so here is some code to explain:

from PIL import Image
import numpy as np

def crop_centered(alpha):
    while True:
        # pick random rectangle with corner (x, y) and size 320x320
        x = np.random.randint(alpha.shape[1] - 320)
        y = np.random.randint(alpha.shape[0] - 320)

        cropped = alpha[y:y+320, x:x+320]

        center_pixel = cropped[160, 160]

        # found good rectangle if the center pixel is unknown
        if center_pixel != 0 and center_pixel != 255:
            return cropped

Image.fromarray(crop_centered(np.array(Image.open("GT04.png").convert("L")))).show()

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

Thank you very much for the training code, I'll update here once it is finished.

EDIT:

My results are:

SAD MSE Grad Conn
49.91 0.0155 31.07 49.49

The model_ckpt.pth.tar file size is 71 821 666 bytes, but the size of the pretrained model indexnet_matting.pth.tar is only 24 085 481 bytes, so I guess the training configuration is different? Do you still have the original somewhere?

from indexnet_matting.

yucornetto avatar yucornetto commented on September 8, 2024

Thanks for the great work! I have also tried using the provided training code to reproduce the results. The only things I changed are num_workers (4 to 16) to speed up the training. I also get a similar result with SAD = 49.26 and MSE = 0.0143. The results are good compared to DIM, yet there still exists a significant margin to the provided model (SAD = 45.8 and MSE = 0.013). I wonder if you have any clue about what leads to the difference?

Thank you very much for the training code, I'll update here once it is finished.

EDIT:

My results are:

SAD MSE Grad Conn
49.91 0.0155 31.07 49.49
The model_ckpt.pth.tar file size is 71 821 666 bytes, but the size of the pretrained model indexnet_matting.pth.tar is only 24 085 481 bytes, so I guess the training configuration is different? Do you still have the original somewhere?

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

Hi all @983 @yucornetto,
Unfortunately this is exactly the training setting I use.
How many GPUs are used in training? I have tried training with multiple GPUs with sync_bn, but the results are worse than training with a single GPU and standard bn.
Another difference I can think about may be the random seed? Can you guys try to retrain the model with a different seed?
I also report the performance using the official matlab implementation. It is slightly better than the python evaluation code I implement.
The reason why model_ckpt.pt.tar is larger than the indexnet_matting.pth.tar is that some other intermediate variables are saved such as optimizer. I only saved the state_dict in indexnet_matting.pth.tar.

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

I also increased the number of workers (to 8), but made no changes otherwise. Maybe that makes a difference? It really shouldn't, but who knows. I'll try 4 this time.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

@983 I don't think that the number of workers is an issue.

But I suffered from a problem that, when I terminate the training halfway and return from the checkpoint, the final results are always worse than training without stopping. This suggests that how the images are sampled affects the performance. I have stuck with this sampling strategy just to match what is used in deep image matting for a fair comparison, but I think there must exist better way to do data augmentation realiably (e.g., crop 512x512 instead of 320x320).

Hope my experience helps.

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

I also think that better data augmentation could improve results, but training takes a really long time, so it is hard to evaluate what works and what doesn't.

It might be interesting to train a smaller model on smaller images and evaluate to what extend the findings can be transferred to larger models. For example, Macro Forte et al. (FBA matting) did some work recently where they found that a batch size of 1 works really well, but training took weeks, therefore it is hard to isolate the exact reason why this works. If the model was faster to train, it would be much faster to run experiments.

from indexnet_matting.

yucornetto avatar yucornetto commented on September 8, 2024

@poppinace I trained the model with single GPU without stopping and resuming as suggested. Thanks for the advice, I will try modify the sampling strategy to see if it helps.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

@983 I know that paper. I reserve my opinion about the 1-batch strategy because it does not report performance when bs>=16. It is unfair to compare small batch sizes with 1-batch instance norm.

I agree that you should find a proxy task to validate your idea. I saw some papers use resized dataset such that the whole dataset can be loaded into the memory to speed up training. We also only composite fg with 2 or 3 bgs to construct a small dataset. The key is that, the small dataset should be representative enough as a replacement of the full dataset. You can think about it.

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

Here are the results from the latest run. The SAD after 30 epochs is slightly worse than before (50.59) and after 20 epochs the error does not decrease much more. However, there are some better values in-between like 48.40 and 48.69, so the gap to 45.8 is quite small now, maybe it is just luck.

epoch: 1, test: 1000/1000, sad: 19.52, SAD: 80.48, MSE: 0.0405, Grad: 60.09, Conn: 83.77, frame: 0.34Hz/0.42Hz
epoch: 2, test: 1000/1000, sad: 17.35, SAD: 71.73, MSE: 0.0336, Grad: 52.21, Conn: 73.62, frame: 0.38Hz/0.43Hz
epoch: 3, test: 1000/1000, sad: 17.23, SAD: 69.67, MSE: 0.0315, Grad: 50.41, Conn: 71.07, frame: 0.38Hz/0.42Hz
epoch: 4, test: 1000/1000, sad: 15.57, SAD: 63.45, MSE: 0.0267, Grad: 46.24, Conn: 65.33, frame: 0.38Hz/0.43Hz
epoch: 5, test: 1000/1000, sad: 13.08, SAD: 56.47, MSE: 0.0229, Grad: 41.28, Conn: 57.25, frame: 0.36Hz/0.42Hz
epoch: 6, test: 1000/1000, sad: 13.03, SAD: 56.34, MSE: 0.0219, Grad: 39.06, Conn: 57.05, frame: 0.35Hz/0.43Hz
epoch: 7, test: 1000/1000, sad: 14.22, SAD: 55.98, MSE: 0.0208, Grad: 36.05, Conn: 55.41, frame: 0.38Hz/0.43Hz
epoch: 8, test: 1000/1000, sad: 13.33, SAD: 60.12, MSE: 0.0211, Grad: 38.11, Conn: 59.12, frame: 0.38Hz/0.42Hz
epoch: 9, test: 1000/1000, sad: 12.97, SAD: 51.39, MSE: 0.0187, Grad: 34.93, Conn: 50.75, frame: 0.38Hz/0.42Hz
epoch: 10, test: 1000/1000, sad: 13.06, SAD: 51.57, MSE: 0.0190, Grad: 30.95, Conn: 51.59, frame: 0.38Hz/0.42Hz
epoch: 11, test: 1000/1000, sad: 11.23, SAD: 52.72, MSE: 0.0187, Grad: 34.69, Conn: 52.89, frame: 0.37Hz/0.42Hz
epoch: 12, test: 1000/1000, sad: 10.77, SAD: 54.39, MSE: 0.0193, Grad: 36.05, Conn: 54.63, frame: 0.38Hz/0.41Hz
epoch: 13, test: 1000/1000, sad: 10.94, SAD: 50.74, MSE: 0.0179, Grad: 34.42, Conn: 50.85, frame: 0.38Hz/0.41Hz
epoch: 14, test: 1000/1000, sad: 10.47, SAD: 54.52, MSE: 0.0185, Grad: 41.60, Conn: 54.90, frame: 0.38Hz/0.42Hz
epoch: 15, test: 1000/1000, sad: 10.97, SAD: 54.40, MSE: 0.0182, Grad: 39.88, Conn: 54.42, frame: 0.39Hz/0.42Hz
epoch: 16, test: 1000/1000, sad: 12.35, SAD: 50.06, MSE: 0.0177, Grad: 30.85, Conn: 48.74, frame: 0.35Hz/0.42Hz
epoch: 17, test: 1000/1000, sad: 11.05, SAD: 54.01, MSE: 0.0180, Grad: 35.54, Conn: 53.58, frame: 0.35Hz/0.42Hz
epoch: 18, test: 1000/1000, sad: 9.95, SAD: 56.45, MSE: 0.0194, Grad: 39.32, Conn: 57.01, frame: 0.37Hz/0.42Hz
epoch: 19, test: 1000/1000, sad: 9.36, SAD: 48.69, MSE: 0.0166, Grad: 31.67, Conn: 48.02, frame: 0.37Hz/0.42Hz
epoch: 20, test: 1000/1000, sad: 9.34, SAD: 49.63, MSE: 0.0162, Grad: 31.99, Conn: 48.89, frame: 0.38Hz/0.42Hz
epoch: 21, test: 1000/1000, sad: 9.14, SAD: 50.50, MSE: 0.0167, Grad: 36.00, Conn: 50.08, frame: 0.37Hz/0.42Hz
epoch: 22, test: 1000/1000, sad: 9.33, SAD: 50.74, MSE: 0.0166, Grad: 35.40, Conn: 50.39, frame: 0.37Hz/0.42Hz
epoch: 23, test: 1000/1000, sad: 9.02, SAD: 51.57, MSE: 0.0170, Grad: 35.14, Conn: 51.21, frame: 0.37Hz/0.42Hz
epoch: 24, test: 1000/1000, sad: 9.19, SAD: 50.63, MSE: 0.0164, Grad: 34.33, Conn: 50.44, frame: 0.37Hz/0.42Hz
epoch: 25, test: 1000/1000, sad: 9.02, SAD: 49.01, MSE: 0.0163, Grad: 32.39, Conn: 48.51, frame: 0.37Hz/0.42Hz
epoch: 26, test: 1000/1000, sad: 9.12, SAD: 48.53, MSE: 0.0157, Grad: 32.38, Conn: 47.81, frame: 0.37Hz/0.42Hz
epoch: 27, test: 1000/1000, sad: 9.23, SAD: 48.40, MSE: 0.0159, Grad: 31.56, Conn: 47.59, frame: 0.35Hz/0.42Hz
epoch: 28, test: 1000/1000, sad: 9.24, SAD: 49.95, MSE: 0.0163, Grad: 34.01, Conn: 49.49, frame: 0.38Hz/0.42Hz
epoch: 29, test: 1000/1000, sad: 9.16, SAD: 49.65, MSE: 0.0162, Grad: 33.64, Conn: 49.25, frame: 0.37Hz/0.42Hz
epoch: 30, test: 1000/1000, sad: 9.16, SAD: 50.59, MSE: 0.0167, Grad: 33.59, Conn: 50.25, frame: 0.34Hz/0.42Hz

I saw some papers use resized dataset such that the whole dataset can be loaded into the memory to speed up training.

I think most of the training cost is decoding the PNG images. It is probably fine to store them as BMP instead since natural images don't compress well anyway.
My own training code generated training data on the fly from the adobe dataset and ran in one day instead of three, but the server has lots of RAM, so I can cache all images on it.
But a fast SSD is probably cheaper and almost as good.

I'll try proxy tasks now, maybe I can find something useful.

Hope my experience helps.

It helps a lot. Thank you very much for your time.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

Hi @983,

Here is my validation results per epoch. They are quite stable in the last a few epochs.
Online composition is really a good idea to speed up training. Thank you for letting me know.

epoch: 20, test: 1000/1000, SAD: 48.01, MSE: 0.0143, Grad: 27.74, Conn: 46.84
epoch: 21, test: 1000/1000, SAD: 45.41, MSE: 0.0141, Grad: 24.27, Conn: 44.22
epoch: 22, test: 1000/1000, SAD: 45.99, MSE: 0.0138, Grad: 24.97, Conn: 44.71
epoch: 23, test: 1000/1000, SAD: 47.40, MSE: 0.0147, Grad: 25.58, Conn: 46.67
epoch: 24, test: 1000/1000, SAD: 46.79, MSE: 0.0147, Grad: 25.28, Conn: 45.92
epoch: 25, test: 1000/1000, SAD: 45.48, MSE: 0.0133, Grad: 26.20, Conn: 44.55
epoch: 26, test: 1000/1000, SAD: 45.51, MSE: 0.0136, Grad: 25.10, Conn: 44.32
epoch: 27, test: 1000/1000, SAD: 45.85, MSE: 0.0136, Grad: 25.73, Conn: 44.86
epoch: 28, test: 1000/1000, SAD: 45.63, MSE: 0.0139, Grad: 24.72, Conn: 44.60
epoch: 29, test: 1000/1000, SAD: 45.24, MSE: 0.0139, Grad: 24.10, Conn: 44.07
epoch: 30, test: 1000/1000, SAD: 45.79, MSE: 0.0138, Grad: 25.09, Conn: 44.86

from indexnet_matting.

hejm37 avatar hejm37 commented on September 8, 2024

Hi @poppinace, I've noticed the issue #11. I was wondering if the performance discrepancy is due to the difference of the number of the channels of the second convolutional layer in the index block? Is the performance SAD 45.8 reported on that model?

Thanks a lot!

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

Hi @hejm37 , the performance is NOT reported on the model with the doubled number of channels. It was just a mistake when I compute the number of parameters.

from indexnet_matting.

hejm37 avatar hejm37 commented on September 8, 2024

Thanks for your response @poppinace! I've also tried to train the network, but the result I got is similar to what 983 got. The best SAD I got so far is 46.96 (train for three times). Maybe it is just because of the different random seed.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

@hejm37 I see. Maybe it is about the hardware platform. The model is trained on a supercomputer where it uses a different system. I have an experience where the same code (not deep learning) running on Windows and Mac produces different results.

I think such numerical differences should be normal, especially for deep learning. Your reproduced results look good to me.

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

I found a solution.

NumPy produces the same "random" values for every worker thread and every epoch because torch.utils.data.DataLoader forks the entire process, including the state of the random number generator: pytorch/pytorch#5059

The fix is to seed the RNG differently using worker_init_fn.

I get MSE 0.01286 and SAD 43.8 after just 23 epochs.

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

@983 Awesome! I'll fix this.

from indexnet_matting.

983 avatar 983 commented on September 8, 2024

I think the fix could still be improved.

Currently, only np.random is seeded. However, the data loader also uses Python's random module in three places.

crop_size = int(random.choice(crop_size))

In addition, np.random is seeded with its own state, which does not do anything to help our problem (I think).

np.random.seed(np.random.get_state()[1][0] + worker_id)

Seeding with worker_id will produce different data for every worker, but the augmentations will still be the same for every epoch. Here is an example for demonstration.

import numpy as np
import torch
import torch.utils.data

torch.manual_seed(0)

class MyDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return np.random.randint(1000)

    def __len__(self):
        return 4

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = MyDataset()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=2,
    worker_init_fn=worker_init_fn)

for epoch in range(3):
    print("Epoch", epoch)
    for batch in dataloader:
        print(batch)
    print()

The output is the same for each epoch.

Epoch 0
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 1
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 2
tensor([282])
tensor([684])
tensor([4])
tensor([17])

The PyTorch documentation recommends the following worker_init_fn:

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

from indexnet_matting.

poppinace avatar poppinace commented on September 8, 2024

I think the fix could still be improved.

Currently, only np.random is seeded. However, the data loader also uses Python's random module in three places.

crop_size = int(random.choice(crop_size))

In addition, np.random is seeded with its own state, which does not do anything to help our problem (I think).

np.random.seed(np.random.get_state()[1][0] + worker_id)

Seeding with worker_id will produce different data for every worker, but the augmentations will still be the same for every epoch. Here is an example for demonstration.

import numpy as np
import torch
import torch.utils.data

torch.manual_seed(0)

class MyDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return np.random.randint(1000)

    def __len__(self):
        return 4

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = MyDataset()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=2,
    worker_init_fn=worker_init_fn)

for epoch in range(3):
    print("Epoch", epoch)
    for batch in dataloader:
        print(batch)
    print()

The output is the same for each epoch.

Epoch 0
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 1
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 2
tensor([282])
tensor([684])
tensor([4])
tensor([17])

The PyTorch documentation recommends the following worker_init_fn:

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Hi, I appreciate your rigor. Can you submit a pull request?
I think your contribution is valuable and should be included in this repository.

from indexnet_matting.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.