poppinace / indexnet_matting

(ICCV'19) Indices Matter: Learning to Index for Deep Image Matting
Other
389 stars 90 forks source link

Reproducing results #17

Closed 983 closed 3 years ago

983 commented 4 years ago

I tried to reproduce the results on the Adobe 1k Dataset and got exactly the same numbers when using the pretrained model. Very good job with that :)

I also tried to train the model from scratch, but did not succeed yet. Do you have any tips?

What I got so far:

troll_retrained

What it should look like:

troll_original

As you can see, your model produces much sharper results.

My training procedure:

Model:

net = hlmobilenetv2(
        pretrained=True,
        freeze_bn=True, 
        output_stride=32,
        apply_aspp=True,
        conv_operator='std_conv',
        decoder='indexnet',
        decoder_kernel_size=5,
        indexnet='depthwise',
        index_mode='m2o',
        use_nonlinear=True,
        use_context=True
)

I've also tried:

I am not sure about first cropping and then resizing, as described in Deep Image Matting, because every batch it produces a few trimaps which have 100% unknown region. Also, it is impossible to crop a 640x640 image from some alpha mattes because they don't have unknown pixels to center the cropped region on.

poppinace commented 4 years ago

Hi, Your training details look good to me. Such a reminder that the corrdinate pairs should have alpha in [0,1], rather than (0,1).

983 commented 4 years ago

I know that your boss said that you can not release the training code, but can you maybe release an example of the cropped images, trimaps and alpha in a training batch? Maybe I can find a visual difference.

For example, this is a batch from my training code where I tried without resizing and discarding cropped alpha regions with mean alpha below 0.2 and above 0.8 to better focus on the unknown region:

| image | trimap | mask of unknown region for loss function | ground truth alpha | predicted alpha |

100000

the corrdinate pairs should have alpha in [0,1], rather than (0,1).

I don't understand how that would help because all alpha values are in [0, 1].

Xu et al. (Deep Image Matting) say:

First, we randomly crop 320×320 (image, trimap) pairs centered on pixels in the unknown regions.

As I understand, this means that they choose a random rectangle in the ground truth alpha matte and discard it if the center pixel is known, or in other words, the center pixel is 100% foreground or 100% background.

Maybe my English is not good, so here is some code to explain:

from PIL import Image
import numpy as np

def crop_centered(alpha):
    while True:
        # pick random rectangle with corner (x, y) and size 320x320
        x = np.random.randint(alpha.shape[1] - 320)
        y = np.random.randint(alpha.shape[0] - 320)

        cropped = alpha[y:y+320, x:x+320]

        center_pixel = cropped[160, 160]

        # found good rectangle if the center pixel is unknown
        if center_pixel != 0 and center_pixel != 255:
            return cropped

Image.fromarray(crop_centered(np.array(Image.open("GT04.png").convert("L")))).show()
poppinace commented 4 years ago

@983 I work at home now. When I return to my office, I'll share you with some pieces of code on how I crop images.

poppinace commented 4 years ago

@983 Here is the code I use to randomly crop images:

`class RandomCrop(object): """Crop randomly the image

Args:
    output_size (int): Desired output size. If int, square crop
        is made.
    scales (list): Desired scales
"""

def __init__(self, output_size, scales):
    assert isinstance(output_size, int)
    self.output_size = output_size
    self.scales = scales

def __call__(self, sample):
    image, alpha = sample['image'], sample['alpha']
    h, w = image.shape[:2]

    if min(h, w) < self.output_size:
        s = (self.output_size + 180) / min(h, w)
        nh, nw = int(np.floor(h * s)), int(np.floor(w * s))
        image, alpha = resize_image_alpha(image, alpha, nh, nw)
        h, w = image.shape[:2]

    crop_size = np.floor(self.output_size * np.array(self.scales)).astype('int')
    crop_size = crop_size[crop_size < min(h, w)]
    crop_size = int(random.choice(crop_size))

    c = int(np.ceil(crop_size / 2))
    mask = np.equal(image[:, :, 3], 128).astype(np.uint8)
    if mask[c:h-c+1, c:w-c+1].sum() != 0:
        mask_center = np.zeros((h, w), dtype=np.uint8)
        mask_center[c:h-c+1, c:w-c+1] = 1
        mask = (mask & mask_center)
        idh, idw = np.where(mask == 1)
        ids = random.choice(range(len(idh)))
        hc, wc = idh[ids], idw[ids]
        h1, w1 = hc-c, wc-c
    else:
        idh, idw = np.where(mask == 1)
        ids = random.choice(range(len(idh)))
        hc, wc = idh[ids], idw[ids]
        h1, w1 = np.clip(hc-c, 0, h), np.clip(wc-c, 0, w)
        h2, w2 = h1+crop_size, w1+crop_size
        h1 = h-crop_size if h2 > h else h1
        w1 = w-crop_size if w2 > w else w1

    image = image[h1:h1+crop_size, w1:w1+crop_size, :]
    alpha = alpha[h1:h1+crop_size, w1:w1+crop_size, :]

    if crop_size != self.output_size:
        nh = nw = self.output_size
        image, alpha = resize_image_alpha(image, alpha, nh, nw)

    return {'image': image, 'alpha': alpha}`
983 commented 4 years ago

Thank you very much for the training code, I'll update here once it is finished.

EDIT:

My results are:

SAD MSE Grad Conn
49.91 0.0155 31.07 49.49

The model_ckpt.pth.tar file size is 71 821 666 bytes, but the size of the pretrained model indexnet_matting.pth.tar is only 24 085 481 bytes, so I guess the training configuration is different? Do you still have the original somewhere?

yucornetto commented 4 years ago

Thanks for the great work! I have also tried using the provided training code to reproduce the results. The only things I changed are num_workers (4 to 16) to speed up the training. I also get a similar result with SAD = 49.26 and MSE = 0.0143. The results are good compared to DIM, yet there still exists a significant margin to the provided model (SAD = 45.8 and MSE = 0.013). I wonder if you have any clue about what leads to the difference?

Thank you very much for the training code, I'll update here once it is finished.

EDIT:

My results are:

SAD MSE Grad Conn 49.91 0.0155 31.07 49.49 The model_ckpt.pth.tar file size is 71 821 666 bytes, but the size of the pretrained model indexnet_matting.pth.tar is only 24 085 481 bytes, so I guess the training configuration is different? Do you still have the original somewhere?

poppinace commented 4 years ago

Hi all @983 @yucornetto, Unfortunately this is exactly the training setting I use. How many GPUs are used in training? I have tried training with multiple GPUs with sync_bn, but the results are worse than training with a single GPU and standard bn. Another difference I can think about may be the random seed? Can you guys try to retrain the model with a different seed? I also report the performance using the official matlab implementation. It is slightly better than the python evaluation code I implement. The reason why model_ckpt.pt.tar is larger than the indexnet_matting.pth.tar is that some other intermediate variables are saved such as optimizer. I only saved the state_dict in indexnet_matting.pth.tar.

983 commented 4 years ago

I also increased the number of workers (to 8), but made no changes otherwise. Maybe that makes a difference? It really shouldn't, but who knows. I'll try 4 this time.

poppinace commented 4 years ago

@983 I don't think that the number of workers is an issue.

But I suffered from a problem that, when I terminate the training halfway and return from the checkpoint, the final results are always worse than training without stopping. This suggests that how the images are sampled affects the performance. I have stuck with this sampling strategy just to match what is used in deep image matting for a fair comparison, but I think there must exist better way to do data augmentation realiably (e.g., crop 512x512 instead of 320x320).

Hope my experience helps.

983 commented 4 years ago

I also think that better data augmentation could improve results, but training takes a really long time, so it is hard to evaluate what works and what doesn't.

It might be interesting to train a smaller model on smaller images and evaluate to what extend the findings can be transferred to larger models. For example, Macro Forte et al. (FBA matting) did some work recently where they found that a batch size of 1 works really well, but training took weeks, therefore it is hard to isolate the exact reason why this works. If the model was faster to train, it would be much faster to run experiments.

yucornetto commented 4 years ago

@poppinace I trained the model with single GPU without stopping and resuming as suggested. Thanks for the advice, I will try modify the sampling strategy to see if it helps.

poppinace commented 4 years ago

@983 I know that paper. I reserve my opinion about the 1-batch strategy because it does not report performance when bs>=16. It is unfair to compare small batch sizes with 1-batch instance norm.

I agree that you should find a proxy task to validate your idea. I saw some papers use resized dataset such that the whole dataset can be loaded into the memory to speed up training. We also only composite fg with 2 or 3 bgs to construct a small dataset. The key is that, the small dataset should be representative enough as a replacement of the full dataset. You can think about it.

983 commented 4 years ago

Here are the results from the latest run. The SAD after 30 epochs is slightly worse than before (50.59) and after 20 epochs the error does not decrease much more. However, there are some better values in-between like 48.40 and 48.69, so the gap to 45.8 is quite small now, maybe it is just luck.

epoch: 1, test: 1000/1000, sad: 19.52, SAD: 80.48, MSE: 0.0405, Grad: 60.09, Conn: 83.77, frame: 0.34Hz/0.42Hz
epoch: 2, test: 1000/1000, sad: 17.35, SAD: 71.73, MSE: 0.0336, Grad: 52.21, Conn: 73.62, frame: 0.38Hz/0.43Hz
epoch: 3, test: 1000/1000, sad: 17.23, SAD: 69.67, MSE: 0.0315, Grad: 50.41, Conn: 71.07, frame: 0.38Hz/0.42Hz
epoch: 4, test: 1000/1000, sad: 15.57, SAD: 63.45, MSE: 0.0267, Grad: 46.24, Conn: 65.33, frame: 0.38Hz/0.43Hz
epoch: 5, test: 1000/1000, sad: 13.08, SAD: 56.47, MSE: 0.0229, Grad: 41.28, Conn: 57.25, frame: 0.36Hz/0.42Hz
epoch: 6, test: 1000/1000, sad: 13.03, SAD: 56.34, MSE: 0.0219, Grad: 39.06, Conn: 57.05, frame: 0.35Hz/0.43Hz
epoch: 7, test: 1000/1000, sad: 14.22, SAD: 55.98, MSE: 0.0208, Grad: 36.05, Conn: 55.41, frame: 0.38Hz/0.43Hz
epoch: 8, test: 1000/1000, sad: 13.33, SAD: 60.12, MSE: 0.0211, Grad: 38.11, Conn: 59.12, frame: 0.38Hz/0.42Hz
epoch: 9, test: 1000/1000, sad: 12.97, SAD: 51.39, MSE: 0.0187, Grad: 34.93, Conn: 50.75, frame: 0.38Hz/0.42Hz
epoch: 10, test: 1000/1000, sad: 13.06, SAD: 51.57, MSE: 0.0190, Grad: 30.95, Conn: 51.59, frame: 0.38Hz/0.42Hz
epoch: 11, test: 1000/1000, sad: 11.23, SAD: 52.72, MSE: 0.0187, Grad: 34.69, Conn: 52.89, frame: 0.37Hz/0.42Hz
epoch: 12, test: 1000/1000, sad: 10.77, SAD: 54.39, MSE: 0.0193, Grad: 36.05, Conn: 54.63, frame: 0.38Hz/0.41Hz
epoch: 13, test: 1000/1000, sad: 10.94, SAD: 50.74, MSE: 0.0179, Grad: 34.42, Conn: 50.85, frame: 0.38Hz/0.41Hz
epoch: 14, test: 1000/1000, sad: 10.47, SAD: 54.52, MSE: 0.0185, Grad: 41.60, Conn: 54.90, frame: 0.38Hz/0.42Hz
epoch: 15, test: 1000/1000, sad: 10.97, SAD: 54.40, MSE: 0.0182, Grad: 39.88, Conn: 54.42, frame: 0.39Hz/0.42Hz
epoch: 16, test: 1000/1000, sad: 12.35, SAD: 50.06, MSE: 0.0177, Grad: 30.85, Conn: 48.74, frame: 0.35Hz/0.42Hz
epoch: 17, test: 1000/1000, sad: 11.05, SAD: 54.01, MSE: 0.0180, Grad: 35.54, Conn: 53.58, frame: 0.35Hz/0.42Hz
epoch: 18, test: 1000/1000, sad: 9.95, SAD: 56.45, MSE: 0.0194, Grad: 39.32, Conn: 57.01, frame: 0.37Hz/0.42Hz
epoch: 19, test: 1000/1000, sad: 9.36, SAD: 48.69, MSE: 0.0166, Grad: 31.67, Conn: 48.02, frame: 0.37Hz/0.42Hz
epoch: 20, test: 1000/1000, sad: 9.34, SAD: 49.63, MSE: 0.0162, Grad: 31.99, Conn: 48.89, frame: 0.38Hz/0.42Hz
epoch: 21, test: 1000/1000, sad: 9.14, SAD: 50.50, MSE: 0.0167, Grad: 36.00, Conn: 50.08, frame: 0.37Hz/0.42Hz
epoch: 22, test: 1000/1000, sad: 9.33, SAD: 50.74, MSE: 0.0166, Grad: 35.40, Conn: 50.39, frame: 0.37Hz/0.42Hz
epoch: 23, test: 1000/1000, sad: 9.02, SAD: 51.57, MSE: 0.0170, Grad: 35.14, Conn: 51.21, frame: 0.37Hz/0.42Hz
epoch: 24, test: 1000/1000, sad: 9.19, SAD: 50.63, MSE: 0.0164, Grad: 34.33, Conn: 50.44, frame: 0.37Hz/0.42Hz
epoch: 25, test: 1000/1000, sad: 9.02, SAD: 49.01, MSE: 0.0163, Grad: 32.39, Conn: 48.51, frame: 0.37Hz/0.42Hz
epoch: 26, test: 1000/1000, sad: 9.12, SAD: 48.53, MSE: 0.0157, Grad: 32.38, Conn: 47.81, frame: 0.37Hz/0.42Hz
epoch: 27, test: 1000/1000, sad: 9.23, SAD: 48.40, MSE: 0.0159, Grad: 31.56, Conn: 47.59, frame: 0.35Hz/0.42Hz
epoch: 28, test: 1000/1000, sad: 9.24, SAD: 49.95, MSE: 0.0163, Grad: 34.01, Conn: 49.49, frame: 0.38Hz/0.42Hz
epoch: 29, test: 1000/1000, sad: 9.16, SAD: 49.65, MSE: 0.0162, Grad: 33.64, Conn: 49.25, frame: 0.37Hz/0.42Hz
epoch: 30, test: 1000/1000, sad: 9.16, SAD: 50.59, MSE: 0.0167, Grad: 33.59, Conn: 50.25, frame: 0.34Hz/0.42Hz

I saw some papers use resized dataset such that the whole dataset can be loaded into the memory to speed up training.

I think most of the training cost is decoding the PNG images. It is probably fine to store them as BMP instead since natural images don't compress well anyway. My own training code generated training data on the fly from the adobe dataset and ran in one day instead of three, but the server has lots of RAM, so I can cache all images on it. But a fast SSD is probably cheaper and almost as good.

I'll try proxy tasks now, maybe I can find something useful.

Hope my experience helps.

It helps a lot. Thank you very much for your time.

poppinace commented 4 years ago

Hi @983,

Here is my validation results per epoch. They are quite stable in the last a few epochs. Online composition is really a good idea to speed up training. Thank you for letting me know.

epoch: 20, test: 1000/1000, SAD: 48.01, MSE: 0.0143, Grad: 27.74, Conn: 46.84
epoch: 21, test: 1000/1000, SAD: 45.41, MSE: 0.0141, Grad: 24.27, Conn: 44.22
epoch: 22, test: 1000/1000, SAD: 45.99, MSE: 0.0138, Grad: 24.97, Conn: 44.71
epoch: 23, test: 1000/1000, SAD: 47.40, MSE: 0.0147, Grad: 25.58, Conn: 46.67
epoch: 24, test: 1000/1000, SAD: 46.79, MSE: 0.0147, Grad: 25.28, Conn: 45.92
epoch: 25, test: 1000/1000, SAD: 45.48, MSE: 0.0133, Grad: 26.20, Conn: 44.55
epoch: 26, test: 1000/1000, SAD: 45.51, MSE: 0.0136, Grad: 25.10, Conn: 44.32
epoch: 27, test: 1000/1000, SAD: 45.85, MSE: 0.0136, Grad: 25.73, Conn: 44.86
epoch: 28, test: 1000/1000, SAD: 45.63, MSE: 0.0139, Grad: 24.72, Conn: 44.60
epoch: 29, test: 1000/1000, SAD: 45.24, MSE: 0.0139, Grad: 24.10, Conn: 44.07
epoch: 30, test: 1000/1000, SAD: 45.79, MSE: 0.0138, Grad: 25.09, Conn: 44.86
hejm37 commented 4 years ago

Hi @poppinace, I've noticed the issue #11. I was wondering if the performance discrepancy is due to the difference of the number of the channels of the second convolutional layer in the index block? Is the performance SAD 45.8 reported on that model?

Thanks a lot!

poppinace commented 4 years ago

Hi @hejm37 , the performance is NOT reported on the model with the doubled number of channels. It was just a mistake when I compute the number of parameters.

hejm37 commented 4 years ago

Thanks for your response @poppinace! I've also tried to train the network, but the result I got is similar to what 983 got. The best SAD I got so far is 46.96 (train for three times). Maybe it is just because of the different random seed.

poppinace commented 4 years ago

@hejm37 I see. Maybe it is about the hardware platform. The model is trained on a supercomputer where it uses a different system. I have an experience where the same code (not deep learning) running on Windows and Mac produces different results.

I think such numerical differences should be normal, especially for deep learning. Your reproduced results look good to me.

983 commented 3 years ago

I found a solution.

NumPy produces the same "random" values for every worker thread and every epoch because torch.utils.data.DataLoader forks the entire process, including the state of the random number generator: https://github.com/pytorch/pytorch/issues/5059

The fix is to seed the RNG differently using worker_init_fn.

I get MSE 0.01286 and SAD 43.8 after just 23 epochs.

poppinace commented 3 years ago

@983 Awesome! I'll fix this.

983 commented 1 year ago

I think the fix could still be improved.

Currently, only np.random is seeded. However, the data loader also uses Python's random module in three places.

https://github.com/poppinace/indexnet_matting/blob/4beb06a47db2eecca87b8003a11f0b268506cea3/scripts/hldataset.py#L109

In addition, np.random is seeded with its own state, which does not do anything to help our problem (I think).

https://github.com/poppinace/indexnet_matting/blob/4beb06a47db2eecca87b8003a11f0b268506cea3/scripts/hltrainval.py#L161

Seeding with worker_id will produce different data for every worker, but the augmentations will still be the same for every epoch. Here is an example for demonstration.

import numpy as np
import torch
import torch.utils.data

torch.manual_seed(0)

class MyDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return np.random.randint(1000)

    def __len__(self):
        return 4

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = MyDataset()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=2,
    worker_init_fn=worker_init_fn)

for epoch in range(3):
    print("Epoch", epoch)
    for batch in dataloader:
        print(batch)
    print()

The output is the same for each epoch.

Epoch 0
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 1
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 2
tensor([282])
tensor([684])
tensor([4])
tensor([17])

The PyTorch documentation recommends the following worker_init_fn:

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)
poppinace commented 1 year ago

I think the fix could still be improved.

Currently, only np.random is seeded. However, the data loader also uses Python's random module in three places.

https://github.com/poppinace/indexnet_matting/blob/4beb06a47db2eecca87b8003a11f0b268506cea3/scripts/hldataset.py#L109

In addition, np.random is seeded with its own state, which does not do anything to help our problem (I think).

https://github.com/poppinace/indexnet_matting/blob/4beb06a47db2eecca87b8003a11f0b268506cea3/scripts/hltrainval.py#L161

Seeding with worker_id will produce different data for every worker, but the augmentations will still be the same for every epoch. Here is an example for demonstration.

import numpy as np
import torch
import torch.utils.data

torch.manual_seed(0)

class MyDataset(torch.utils.data.Dataset):
    def __getitem__(self, index):
        return np.random.randint(1000)

    def __len__(self):
        return 4

def worker_init_fn(worker_id):
    np.random.seed(np.random.get_state()[1][0] + worker_id)

dataset = MyDataset()

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    num_workers=2,
    worker_init_fn=worker_init_fn)

for epoch in range(3):
    print("Epoch", epoch)
    for batch in dataloader:
        print(batch)
    print()

The output is the same for each epoch.

Epoch 0
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 1
tensor([282])
tensor([684])
tensor([4])
tensor([17])

Epoch 2
tensor([282])
tensor([684])
tensor([4])
tensor([17])

The PyTorch documentation recommends the following worker_init_fn:

def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Hi, I appreciate your rigor. Can you submit a pull request? I think your contribution is valuable and should be included in this repository.