jantic / DeOldify

A Deep Learning based project for colorizing and restoring old images (and video!)
MIT License
18.01k stars 2.57k forks source link

fit_one_cycle issue with batch size #109

Closed miaoqiz closed 5 years ago

miaoqiz commented 5 years ago

Hi,

How are you?

Thanks for the update!

I think this line may cause trouble:

 learn_gen.**fit_one_cycle**(1, pct_start=pct_start, max_lr=slice(5e-8,5e-5))

if the "batch size"/bs is not set to 1.

Error:


cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 1 and 8 in dimension 0 at /pytorch/aten/src/THC/generic/THCTensorMath.cu:83


Please let me know if there is a way to specify the batch size in "fit_one_cycle"?

Thanks and have a great day!

bruce-cham commented 5 years ago

sorry

jantic commented 5 years ago

You actually have to set batch size when you retrieve your data (which would be the second cell under the 64px heading): data_gen = get_data(bs=bs, sz=sz, keep_pct=keep_pct)

Or the cell under 128px like this: learn_gen.data = get_data(sz=sz, bs=bs, keep_pct=keep_pct)

I'm assuming you're looking at one of the *Training.ipynb notebooks here.

That all being said...I suspect something else changed that's really the issue here. Did you happen to change image size? You can't go lower than 64px with the Unet....

miaoqiz commented 5 years ago

Hi,

Thanks for the quick feedback!

I did not change anything except "bs" value. For curiosity, I converted the notebook of "ColorizeTrainingVideo" to python code, but that should not affect anything.

jantic commented 5 years ago

Did you change bs to 1? For a number of reasons that would be problematic, even if you didn't run into this bug.

miaoqiz commented 5 years ago

Hi,

I had to change it to “1”; otherwise, the the size of two tensors in dimension m#1 would not match. :)

Thanks!

jantic commented 5 years ago

Something's amiss here.... You must have changed something else in the process of moving the code to Python. I've -never- used a batch size of 1.

miaoqiz commented 5 years ago

Hi,

I compared the original notebook and the converted python code. The code is the same.

To provide some details:


DeOldify/fasterai/unet.py:123: TracerWarning: Converting a tensor to a Python integer might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! print ( s.size(), up_in.size(), up_out.size() ) torch.Size([1, 1024, 12, 12]) torch.Size([1, 2048, 6, 6]) torch.Size([1, 512, 12, 12]) torch.Size([1, 512, 24, 24]) torch.Size([1, 512, 12, 12]) torch.Size([1, 512, 24, 24]) torch.Size([8, 256, 48, 48]) torch.Size([1, 512, 24, 24]) torch.Size([1, 512, 48, 48]) Error occurs, No graph saved torch.Size([1, 1024, 12, 12]) torch.Size([1, 2048, 6, 6]) torch.Size([1, 512, 12, 12]) torch.Size([1, 512, 24, 24]) torch.Size([1, 512, 12, 12]) torch.Size([1, 512, 24, 24]) torch.Size([1, 256, 48, 48]) torch.Size([1, 512, 24, 24]) torch.Size([1, 512, 48, 48]) torch.Size([1, 64, 96, 96]) torch.Size([1, 512, 48, 48]) torch.Size([1, 256, 96, 96]) torch.Size([1, 1024, 12, 12]) torch.Size([8, 2048, 6, 6]) torch.Size([8, 512, 12, 12])


This is inside "class UnetBlockWide(nn.Module)".

What does "hook" do exactly? registering previous graph information? Since the graph was not saved, does it affect "hook" then?

The error does not happen to "crit_data":

learn_critic = colorize_crit_learner(data=data_crit, nf=256).load(crit_old_checkpoint_name, with_opt=False)

Thanks,

jantic commented 5 years ago

Sorry, where are you seeing hook? Hook is called all over the place. Hooks == Callbacks.

miaoqiz commented 5 years ago

Hi,

In "unet.py" and "class UnetBlockWide(nn.Module)" at line#107

def forward(self, up_in:Tensor) -> Tensor: s = self.hook.stored up_out = self.shuf(up_in) ssh = s.shape[-2:] if ssh != up_out.shape[-2:]: up_out = F.interpolate(up_out, s.shape[-2:], mode='nearest')
cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1)) return self.conv(cat_x)

Also, what is the best practice to set up multiple epochs for 'Repeatable GAN Cycle"?

Thanks!

miaoqiz commented 5 years ago

The issue seems to be gone when I restructured the training set. Thanks!

cduguet commented 4 years ago

@miaoqiz How did you exactly restructure the training set? I'm currently having this problem as well (though I'm using a bs=44).

miaoqiz commented 4 years ago

@miaoqiz How did you exactly restructure the training set? I'm currently having this problem as well (though I'm using a bs=44).

Hi, it has been a long time. You can try various batch size.

BTW, using the latest "Pytorch" may help.