marcoaversa / diffinfinite

DiffInfinite Official Code
MIT License
23 stars 3 forks source link

how to train my own dataset #2

Closed chinayb closed 4 months ago

chinayb commented 6 months ago

Congratulations on your incredibly impressive work. I'm interested in training the model on my own dataset and was wondering if you could provide any specific instructions or recommendations on how to best approach this? Is there a problem with the code "from classifier_free_guidance import Unet, GaussianDiffusion, Trainer“ in train_masks.py?

marcoaversa commented 6 months ago

Thanks a lot! I'm glad that you like our work! We provided the end-to-end model training with our code. The only step left for you is to adjust the code to fit your custom dataset. For our experiments we prepared a dataset with the Pytorch ImageFolder data format, but you can prepare your own dataset with the format that you prefer the most. The dataloader is prepared in the Trainer class that you can find in both dm.py and dm_mask.py. We forgot to update the code, classifier_free_guidance was the old name for dm_masks, thanks for noticing it! I hope this will helpful.

chinayb commented 6 months ago

Thank you for your help. However, the problem arises again after loading my data? ┌───────────────────── Traceback (most recent call last) ─────────────────────┐ │ G:\download\diffinfinite-master\train.py:72 in │ │ │ │ 69 │ trainer.train() │ │ 70 │ │ 71 if name=='main': │ │ > 72 │ Fire(main) │ │ 73 │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\fire\core.py:141 in Fire │ │ │ │ 138 │ context.update(caller_globals) │ │ 139 │ context.update(caller_locals) │ │ 140 │ │ > 141 component_trace = _Fire(component, args, parsed_flag_args, context, │ │ 142 │ │ 143 if component_trace.HasError(): │ │ 144 │ _DisplayError(component_trace) │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\fire\core.py:475 in _Fire │ │ │ │ 472 │ is_class = inspect.isclass(component) │ │ 473 │ │ │ 474 │ try: │ │ > 475 │ │ component, remaining_args = _CallAndUpdateTrace( │ │ 476 │ │ │ component, │ │ 477 │ │ │ remaining_args, │ │ 478 │ │ │ component_trace, │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\fire\core.py:691 in │ │ _CallAndUpdateTrace │ │ │ │ 688 │ loop = asyncio.get_event_loop() │ │ 689 │ component = loop.run_until_complete(fn(*varargs, kwargs)) │ │ 690 else: │ │ > 691 │ component = fn(*varargs, *kwargs) │ │ 692 │ │ 693 if treatment == 'class': │ │ 694 │ action = trace.INSTANTIATED_CLASS │ │ │ │ G:\download\diffinfinite-master\train.py:69 in main │ │ │ │ 66 │ if milestone: │ │ 67 │ │ trainer.load(milestone) │ │ 68 │ │ │ > 69 │ trainer.train() │ │ 70 │ │ 71 if name=='main': │ │ 72 │ Fire(main) │ │ │ │ G:\download\diffinfinite-master\dm.py:790 in train │ │ │ │ 787 │ │ │ │ │ │ │ 788 │ │ │ │ │ with self.accelerator.accumulate(self.model): │ │ 789 │ │ │ │ │ │ # print('ss', data.shape, masks.shape) │ │ > 790 │ │ │ │ │ │ loss = self.train_loop(data, masks) │ │ 791 │ │ │ │ │ │ total_loss += loss.item() │ │ 792 │ │ │ │ │ │ 793 │ │ │ │ total_loss/=self.gradient_accumulate_every │ │ │ │ G:\download\diffinfinite-master\dm.py:733 in train_loop │ │ │ │ 730 │ │ │ imgs=self.vae.encode(imgs).latent_dist.sample()/50 #torc │ │ 731 │ │ with self.accelerator.autocast(): │ │ 732 │ │ │ print('self.model', masks.shape, self.model) │ │ > 733 │ │ │ loss = self.model(img=imgs,classes=masks) │ │ 734 │ │ │ │ 735 │ │ self.accelerator.backward(loss) │ │ 736 │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\torch\nn\modules\module.py:1 │ │ 501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_ho │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hook │ │ > 1501 │ │ │ return forward_call(args, kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ G:\download\diffinfinite-master\dm.py:595 in forward │ │ │ │ 592 │ │ t = torch.randint(0, self.num_timesteps, (b,), device=device) │ │ 593 │ │ │ │ 594 │ │ img = normalize_to_neg_one_to_one(img) │ │ > 595 │ │ return self.p_losses(img, t, *args, kwargs) │ │ 596 │ │ 597 class Trainer(object): │ │ 598 │ def init( │ │ │ │ G:\download\diffinfinite-master\dm.py:570 in p_losses │ │ │ │ 567 │ │ │ │ 568 │ │ # predict and take gradient step │ │ 569 │ │ │ │ > 570 │ │ model_out = self.model(x, t, classes) │ │ 571 │ │ │ │ 572 │ │ if self.objective == 'pred_noise': │ │ 573 │ │ │ target = noise │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\torch\nn\modules\module.py:1 │ │ 501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_ho │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hook │ │ > 1501 │ │ │ return forward_call(*args, *kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ G:\download\diffinfinite-master\dm.py:192 in forward │ │ │ │ 189 │ │ │ │ 190 │ │ # unet │ │ 191 │ │ print('s1', x.shape) │ │ > 192 │ │ x = self.init_conv(x) │ │ 193 │ │ r = x.clone() │ │ 194 │ │ │ │ 195 │ │ t = self.time_mlp(time) │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\torch\nn\modules\module.py:1 │ │ 501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_ho │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hook │ │ > 1501 │ │ │ return forward_call(args, kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\torch\nn\modules\conv.py:463 │ │ in forward │ │ │ │ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │ │ 461 │ │ │ 462 │ def forward(self, input: Tensor) -> Tensor: │ │ > 463 │ │ return self._conv_forward(input, self.weight, self.bias) │ │ 464 │ │ 465 class Conv3d(_ConvNd): │ │ 466 │ doc = r"""Applies a 3D convolution over an input signal comp │ │ │ │ D:\Anaconda3\envs\pytorch2.0\lib\site-packages\torch\nn\modules\conv.py:459 │ │ in _conv_forward │ │ │ │ 456 │ │ │ return F.conv2d(F.pad(input, self._reversed_padding_repe │ │ 457 │ │ │ │ │ │ │ weight, bias, self.stride, │ │ 458 │ │ │ │ │ │ │ _pair(0), self.dilation, self.groups) │ │ > 459 │ │ return F.conv2d(input, weight, bias, self.stride, │ │ 460 │ │ │ │ │ │ self.padding, self.dilation, self.groups) │ │ 461 │ │ │ 462 │ def forward(self, input: Tensor) -> Tensor: │ └─────────────────────────────────────────────────────────────────────────────┘ RuntimeError: Given groups=1, weight of size [256, 3, 7, 7], expected input[32, 4, 64, 64] to have 3 channels, but got 4 channels instead

Process finished with exit code 1

marcoaversa commented 6 months ago

In the latent space, the diffusion model expects 4 channels. This is because the Latent Diffusion Model is composed by a VQ-VAE which maps a Bx3xCxW --> Bx3x512x512 images into Bx4x(C//8)x(W//8) --> Bx4x64x64. We prepared the data with the VQ-VAE.encoder in the Trainer class. You are free to prepare your data in the latent wherever you want.