Zhendong-Wang / Patch-Diffusion

Apache License 2.0
70 stars 13 forks source link

Ran training on my custom data but generation fails with CUDA memory error. #3

Open jbmaxwell opened 11 months ago

jbmaxwell commented 11 months ago

I appear to have successfully run training on my custom tensor data, but generate.py is crashing with a CUDA out of memory error. I'm running on a 2x A100-40 machine. I can see that the model loads, but it crashes when entering the actual sampler_fn.

I also tried with device='cpu', but that almost seems to hang, consuming a huge amount of resident memory (in top), but never really progressing (from what I can tell). My input tensors are (1, 32, 2048) and are actually latents from another autoencoder, so I'm not using the VAE in Patch-Diffusion. Any thoughts?

EDIT: I'm running the EDM sampler, btw.

Zhendong-Wang commented 11 months ago

Hi, if it can be trained, it should be able to generate. Generation doesn't need backprop and compute any gradients so it is more GPU meomory efficient. The edm sampler should be fine with my experiments. The input tensors is in (1, 32, 2048) share, so it is not image shape? I am not sure what's happening here, lol. But if it is able to be trained, it will be able to generate.

jbmaxwell commented 11 months ago

Yeah, it makes no sense to me... The last line reached in your code is: eps = net(x, t, pos, class_labels).to(torch.float64) in sample_with_cfg. I don't have class labels, so is that possibly an issue? I'm noticing that cfg and class_labels are both None when I hit that function.

Zhendong-Wang commented 11 months ago

In your case, if no class labels, then the cfg and class_labels are None, which are correct. Note in training eps = net(x, t, pos, class_labels).to(torch.float64) is called here https://github.com/Zhendong-Wang/Patch-Diffusion/blob/924f1ed4233e22a5917fc3fded8d65492cfc5c99/training/patch_loss.py#L76. Maybe you can check what is the difference in the inputs when you reach the two lines.

jbmaxwell commented 11 months ago

Okay, I see that during training the maximum yn/D_yn size is (32, 32), and it crashes before I get here during inference. These are presumably the patches, correct? So I think I misunderstood the generate.py code (when I hacked into it, I mean) and it's trying to use (32, 2048) patches, which might explain the memory crash. So, I commented out my changes to generate.py and it does run from my checkpoint, but it clearly generates (32, 32) outputs. Not sure, but suspect I accidentally trained on (32, 32) crops of my input... bummer. Where should I check to modify the actual image size/shape during training (I modified the dataset to remove the "squareness" assert, but that's all.)

jbmaxwell commented 11 months ago

Actually, printing the images_ shape during training, from here: https://github.com/Zhendong-Wang/Patch-Diffusion/blob/924f1ed4233e22a5917fc3fded8d65492cfc5c99/training/training_loop.py#L174 it is my expected shape of (32, 2048), so presumably it did train correctly. So if you could point me to what change needs to be made to generate (32, 2048), that would be awesome.

Zhendong-Wang commented 11 months ago

Yeah, the model could be trained on patches only. Here it shows how to allocate the probability of each patch size for efficiency requirement. https://github.com/Zhendong-Wang/Patch-Diffusion/blob/924f1ed4233e22a5917fc3fded8d65492cfc5c99/training/training_loop.py#L153

I guess in your case you are always using the 32 patch size for training, while in our original version, we do need a small ratio of full size images to ensure good generation quality. Training only on patches would bring downgraded performance. I am not sure whether this will happen in your case, since your image is a long horizontal shape.

As for generation, in our original setting, we directly sample through the whole image, since each patch location on the image has been trained. In your case, the whole image generation would cause OOM, so you need to slice sampling. You need to move the patch window along with its locations to sample the whole image. You could also allow some overlapping in the denoising process to provide better quality.

jbmaxwell commented 11 months ago

Thanks for the reply!

I did notice a range of sizes when printing during training... I don't remember exactly where that was printing, unfortunately, but it was giving a range of different shapes (in multiples of 2). Also, I did have to hack batch_mul_dict to get it running, by adding 8: 128: i.e., batch_mul_dict = {512: 1, 256: 2, 128: 4, 64: 16, 32: 32, 16: 64, 8: 128} I assumed that was to do with my unusual shape and the positioning of patches, but I wasn't certain. It did run and converge with that change though, so I'm guessing it was okay. Basically, I do think it trained in the way you originally intended, with varying patch sizes.

For generation all I really did was change the dimensions in the main batch loop:

        image_channel = 1  # single-channel images
        image_width = 2048  # Width of the image
        image_height = 32  # Height of the image, which is 'resolution' in this context

        # Assuming x_start and y_start are defined elsewhere in your code.
        x_pos = torch.arange(x_start, x_start + image_width).view(1, -1).repeat(image_height, 1)
        y_pos = torch.arange(y_start, y_start + image_height).view(-1, 1).repeat(1, image_width)
        x_pos = (x_pos / (image_width - 1) - 0.5) * 2.
        y_pos = (y_pos / (image_height - 1) - 0.5) * 2.

Just for some context, my inputs are latents from an audio encoder (hence the short/wide shape). I was interested in your model partly because of the need to improve training time and performance for small datasets, but also because it seems to me that patches could be good for music, where there tends to be a natural repetition structure that might benefit from the patch-based optimization.

Finally, yes, that makes sense what you're saying about the sampling; I didn't dig into the actual sampling code in detail, but I was surprised to hit the memory error, for sure... I get it now. If you have any tips on how to slice the sampling, that would be awesome, but otherwise thanks for the info.

FabianVald commented 1 month ago

I have a problem. RuntimeError: shape '[-1, 10]' is invalid for input of size 524288 When I try to run generate.py with your recomendations.

python generate.py --outdir=out --seeds=0-63 --batch=64 \ --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl