Closed martianbit closed 1 year ago
have you tried mixed precision fp32?
Thanks for your response.
No, I haven't tried it yet, but it seems to say that it's not a valid option.
train_network.py: error: argument --mixed_precision: invalid choice: 'fp32' (choose from 'no', 'fp16', 'bf16')
Just choice "no" instead of fp16 for mixed_precision
Hey I just stumbled on this thread with the same problem. I have a regular GTX 1660 6GB.
I also run into VRAM issues if I do not go with fp16
precision or disabling xformers
like you've described.
I see this thread was closed 5 days ago. Was there a resolution? I suppose "not planned" suggests that there wasn't. I just wanted to confirm.
In the meantime I've managed to just train some LoRAs on colab.
I got the same NAN problem on my GTX 1660 6GB.
I traced this problem, and found the source of NAN is the process of loading image latents.
It's in library/train_util.py
's BaseDataset.cache_latents
.
The returns of latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
contains NAN.
It has the same problem when I disable the lantens cache and load lantens directorly.
I traced it and found it is caused by the ResnetBlock2D
in venv/Lib/site-packages/diffusers/models.resnet.py
.
The returns of hidden_states = self.conv1(hidden_states)
contains NAN.
(self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
)
I ran it on A5000 and the returns are correct. I think there are some errors in CUDA or GTX 1660.
@kohya-ss
Just adding my voice to the mix to say I have this issue as well. I spoke to @bmaltais in https://github.com/bmaltais/kohya_ss/issues/722 about this. The link will lead you back eventually to the a1111 repo #4407 where a potential fix was mentioned involving setting torch.backends.cudnn.benchmark = True
inside a devices.py
file. No clue if this works with kohya_ss since I don't know where to find this file or what the equivalent fix for kohya_ss would be. If anyone figures out how to use this for a local fix at least, let me know!
Sorry for late reply. I think you can add immediately after train
method, in train_network.py` like this:
def train(args):
torch.backends.cudnn.benchmark = True
session_id = random.randint(0, 2**32)
If this fix works, I will add an option to enable this. Please let me know the result!
Yes, this works perfectly, thank you very much for your help! Have a great day!
That's good! I will add an option to enable it.
original author of the webui PR, it causes some noticable slowdown on non-turing cards; also holy cow you can lora on 6gb vram or is it a modded card?
Hey, I have a NVIDIA GeForce 1660 SUPER 6GB card, and I wanted to train LoRA models with it. This is my configuration:
accelerate launch --num_cpu_threads_per_process 4 train_network.py --network_module="networks.lora" --pretrained_model_name_or_path=/mnt/models/animefull-final-pruned.ckpt --vae=/mnt/models/animevae.pt --train_data_dir=/mnt/datasets/character --output_dir=/mnt/out --output_name=character --caption_extension=.txt --shuffle_caption --prior_loss_weight=1 --network_alpha=128 --resolution=512 --enable_bucket --min_bucket_reso=320 --max_bucket_reso=768 --train_batch_size=1 --gradient_accumulation_steps=1 --learning_rate=0.0001 --text_encoder_lr=0.00005 --max_train_epochs=20 --mixed_precision=fp16 --save_precision=fp16 --use_8bit_adam --xformers --save_every_n_epochs=1 --save_model_as=safetensors --clip_skip=2 --flip_aug --color_aug --face_crop_aug_range="2.0,4.0" --network_dim=128 --max_token_length=225 --lr_scheduler=constant
The train directory's name is 3_Concept1, so 3 repetitions are used. The script does not throw any errors, but loss=nan and corrupted unets are produced. I've tried setting mixed_precision to no, but then I've run out of VRAM. I've also tried disabling xformers, but again, I've run out of VRAM. I've compiled xformers myself, using
pip install ninja && MAX_JOBS=4 pip install -v .
Also tried several other xformers versions, like 0.0.16 and the one suggested in the README. Tried both CUDA 11.6 and 11.7.Python version: 3.10.6 PyTorch version: torch==1.12.1+cu116 torchvision==0.13.1+cu116
Any help is much appreciated! Thank you!