kohya-ss / sd-scripts

Apache License 2.0
5.31k stars 880 forks source link

Having training issue with FLUX-schnell #1676

Closed JEFFSEVENTHSENSE closed 1 month ago

JEFFSEVENTHSENSE commented 1 month ago

CUDA_VISIBLE_DEVICES=7 accelerate launch --mixed_precision bf16 --num_cpu_threads_per_process 1 flux_train_network.py --pretrained_model_name_or_path flux1-schnell.safetensors --clip_l sd3/clip_l.safetensors --t5xxl sd3/t5xxl_fp16.safetensors --ae ae.safetensors --save_model_as safetensors --sdpa --persistent_data_loader_workers --max_data_loader_n_workers 2 --seed 42 --gradient_checkpointing --mixed_precision bf16 --save_precision bf16 --network_module networks.lora_flux --network_dim 4 --optimizer_type adamw8bit --learning_rate 1e-4 --cache_text_encoder_outputs --highvram --max_train_epochs 4 --save_every_n_epochs 1 --dataset_config dataset_1024_bs2.toml --output_dir /home/dluser/development/Jeff/LoRA --output_name flux-lora-jeff --timestep_sampling shift --discrete_flow_shift 3.1582 --model_prediction_type raw --guidance_scale 1.0

this is the script i am running on but i am facing issue at this point when it is training:

File "/home/sd-scripts/train_network.py", line 1096, in train latents = self.encode_images_to_latents(args, accelerator, vae, batch["images"].to(vae_dtype)) File "/home/sd-scripts/flux_train_network.py", line 326, in encode_images_to_latents return vae.encode(images) File "/home/sd-scripts/library/flux_models.py", line 347, in encode z = self.reg(self.encoder(x)) File "/home/.virtualenvs/flux/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(*input, *kwargs) File "/home/sd-scripts/library/flux_models.py", line 210, in forward h = self.mid.attn_1(h) File "/home/.virtualenvs/flux/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl return forward_call(input, **kwargs) File "/home/sd-scripts/library/flux_models.py", line 88, in forward return x + self.proj_out(self.attention(x)) File "/home/sd-scripts/library/flux_models.py", line 83, in attention

RuntimeError: Expected (head_size % 8 == 0) && (head_size <= 128) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

This error occurs in

kohya-ss commented 1 month ago

This error occurs in AE, so schnell doesn't seem to be related. It could be that the bucket_reso_steps value in the dataset config is too small or the max_bucket_reso value is too large. Could you please share the .toml of the dataset config?

JEFFSEVENTHSENSE commented 1 month ago

INFO [Dataset 0] config_util.py:570 batch_size: 4
resolution: (256, 256)
enable_bucket: False
network_multiplier: 1.0

                           [Subset 0 of Dataset 0]                                                                                     
                             image_dir: "/home/dluser/development/Jeff/test"                                                           
                             image_count: 11                                                                                           
                             num_repeats: 1                                                                                            
                             shuffle_caption: False                                                                                    
                             keep_tokens: 2                                                                                            
                             keep_tokens_separator:                                                                                    
                             caption_separator: ,                                                                                      
                             secondary_separator: None                                                                                 
                             enable_wildcard: False                                                                                    
                             caption_dropout_rate: 0.0                                                                                 
                             caption_dropout_every_n_epoches: 0                                                                        
                             caption_tag_dropout_rate: 0.0                                                                             
                             caption_prefix: None                                                                                      
                             caption_suffix: None                                                                                      
                             color_aug: False                                                                                          
                             flip_aug: False                                                                                           
                             face_crop_aug_range: None                                                                                 
                             random_crop: False                                                                                        
                             token_warmup_min: 1,                                                                                      
                             token_warmup_step: 0,                                                                                     
                             alpha_mask: False,                                                                                        
                             is_reg: False                                                                                             
                             class_tokens: J3FF                                                                                        
                             caption_extension: .txt     

I am currently running the script with | NVIDIA-SMI 510.47.03 Driver Version: 510.47.03 CUDA Version: 11.6 | torch 1.13.1+cu116 torchaudio 0.13.1+cu116 torchmetrics 1.4.2 torchvision 0.14.1+cu116

The above settings cant be changed due to some reasons

kohya-ss commented 1 month ago

resolution: (256, 256) seems to be too small for FLUX.1. Could you please try 512,512?

JEFFSEVENTHSENSE commented 1 month ago

thanks for the reply , i have re clone the whole repository and is currently only facing issues as mentioned in https://github.com/kohya-ss/sd-scripts/issues/1679