NVlabs / Sana

SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformer
https://nvlabs.github.io/Sana
Other
962 stars 47 forks source link

OOM on train on 48GPU with batch 1 proc 1 #49

Open recoilme opened 4 hours ago

recoilme commented 4 hours ago

i waiting sana so long for training on potato, but its not working on A40 with 48GPU(

(sana) root@c88159d783a4:/workspace/sana# bash train_scripts/train.sh   configs/sana_config/1024ms/Sana_1600M_img1024.yaml   --data.data_dir="[asset/example_data]"   --data.type=SanaImgDataset   --model.multi_scale=false
2024-11-26 23:33:11 - [Sana] - INFO - Distributed environment: MULTI_GPU  Backend: nccl
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda:0

Mixed precision type: fp16

2024-11-26 23:33:11 - [Sana] - INFO - Config: 
{
    "data": {
        "data_dir": [
            "asset/example_data"
        ],
        "caption_proportion": {
            "prompt": 1
        },
        "external_caption_suffixes": [
            "",
            "_InternVL2-26B",
            "_VILA1-5-13B"
        ],
        "external_clipscore_suffixes": [
            "_InternVL2-26B_clip_score",
            "_VILA1-5-13B_clip_score",
            "_prompt_clip_score"
        ],
        "clip_thr_temperature": 0.1,
        "clip_thr": 25.0,
        "sort_dataset": false,
        "load_text_feat": false,
        "load_vae_feat": false,
        "transform": "default_train",
        "type": "SanaImgDataset",
        "image_size": 1024,
        "hq_only": false,
        "valid_num": 0,
        "data": null,
        "extra": null
    },
    "model": {
        "model": "SanaMS_1600M_P1_D20",
        "image_size": 1024,
        "mixed_precision": "fp16",
        "fp32_attention": true,
        "load_from": null,
        "resume_from": {
            "checkpoint": "latest",
            "load_ema": false,
            "resume_optimizer": true,
            "resume_lr_scheduler": true
        },
        "aspect_ratio_type": "ASPECT_RATIO_1024",
        "multi_scale": false,
        "pe_interpolation": 1.0,
        "micro_condition": false,
        "attn_type": "linear",
        "autocast_linear_attn": false,
        "ffn_type": "glumbconv",
        "mlp_acts": [
            "silu",
            "silu",
            null
        ],
        "mlp_ratio": 2.5,
        "use_pe": false,
        "qk_norm": false,
        "class_dropout_prob": 0.1,
        "linear_head_dim": 32,
        "cross_norm": false,
        "cfg_scale": 4,
        "guidance_type": "classifier-free",
        "pag_applied_layers": [
            8
        ],
        "extra": null
    },
    "vae": {
        "vae_type": "dc-ae",
        "vae_pretrained": "mit-han-lab/dc-ae-f32c32-sana-1.0",
        "scale_factor": 0.41407,
        "vae_latent_dim": 32,
        "vae_downsample_rate": 32,
        "sample_posterior": true,
        "extra": null
    },
    "text_encoder": {
        "text_encoder_name": "gemma-2-2b-it",
        "caption_channels": 2304,
        "y_norm": true,
        "y_norm_scale_factor": 0.01,
        "model_max_length": 300,
        "chi_prompt": [
            "Given a user prompt, generate an \"Enhanced prompt\" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:",
            "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.",
            "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.",
            "Here are examples of how to transform or refine prompts:",
            "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.",
            "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.",
            "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:",
            "User Prompt: "
        ],
        "extra": null
    },
    "scheduler": {
        "train_sampling_steps": 1000,
        "predict_v": true,
        "noise_schedule": "linear_flow",
        "pred_sigma": false,
        "learn_sigma": true,
        "vis_sampler": "flow_dpm-solver",
        "flow_shift": 3.0,
        "weighting_scheme": "logit_normal",
        "logit_mean": 0.0,
        "logit_std": 1.0,
        "extra": null
    },
    "train": {
        "num_workers": 1,
        "seed": 1,
        "train_batch_size": 1,
        "num_epochs": 100,
        "gradient_accumulation_steps": 1,
        "grad_checkpointing": true,
        "gradient_clip": 0.1,
        "gc_step": 1,
        "optimizer": {
            "betas": [
                0.9,
                0.999,
                0.9999
            ],
            "eps": [
                1e-30,
                1e-16
            ],
            "lr": 0.0001,
            "type": "CAMEWrapper",
            "weight_decay": 0.0
        },
        "lr_schedule": "constant",
        "lr_schedule_args": {
            "num_warmup_steps": 2000
        },
        "auto_lr": {
            "rule": "sqrt"
        },
        "ema_rate": 0.9999,
        "eval_batch_size": 16,
        "use_fsdp": false,
        "use_flash_attn": false,
        "eval_sampling_steps": 500,
        "lora_rank": 4,
        "log_interval": 1,
        "mask_type": "null",
        "mask_loss_coef": 0.0,
        "load_mask_index": false,
        "snr_loss": false,
        "real_prompt_ratio": 1.0,
        "save_image_epochs": 1,
        "save_model_epochs": 5,
        "save_model_steps": 500,
        "visualize": true,
        "null_embed_root": "output/pretrained_models/",
        "valid_prompt_embed_root": "output/tmp_embed/",
        "validation_prompts": [
            "dog",
            "portrait photo of a girl, photograph, highly detailed face, depth of field",
            "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece"
        ],
        "local_save_vis": true,
        "deterministic_validation": true,
        "online_metric": false,
        "eval_metric_step": 2000,
        "online_metric_dir": "metric_helper",
        "work_dir": "output/debug",
        "skip_step": 0,
        "loss_type": "huber",
        "huber_c": 0.001,
        "num_ddim_timesteps": 50,
        "w_max": 15.0,
        "w_min": 3.0,
        "ema_decay": 0.95,
        "debug_nan": false,
        "extra": null
    },
    "work_dir": "output/debug",
    "resume_from": "latest",
    "load_from": null,
    "debug": true,
    "caching": false,
    "report_to": "tensorboard",
    "tracker_project_name": "t2i-evit-baseline",
    "name": "tmp",
    "loss_report_name": "loss"
}
2024-11-26 23:33:11 - [Sana] - INFO - World_size: 1, seed: 1
2024-11-26 23:33:11 - [Sana] - INFO - Initializing: DDP for training
[DC-AE] Loading model from mit-han-lab/dc-ae-f32c32-sana-1.0
Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.42it/s]
2024-11-26 23:33:16 - [Sana] - INFO - vae type: dc-ae
2024-11-26 23:33:16 - [Sana] - INFO - Complex Human Instruct: Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt: 
2024-11-26 23:33:16 - [Sana] - INFO - v-prediction: True, noise schedule: linear_flow, flow shift: 3.0, flow weighting: logit_normal, logit-mean: 0.0, logit-std: 1.0
2024-11-26 23:33:28 - [Sana] - WARNING - use pe: False, position embed interpolation: 1.0, base size: 32
2024-11-26 23:33:28 - [Sana] - WARNING - attention type: linear; ffn type: glumbconv; autocast linear attn: false
2024-11-26 23:33:41 - [Sana] - INFO - SanaMS:SanaMS_1600M_P1_D20, Model Parameters: 1604.46M
2024-11-26 23:33:41 - [Sana] - INFO - Constructing dataset SanaImgDataset...
2024-11-26 23:33:41 - [Sana] - INFO - Dataset is repeat 2000 times for toy dataset
2024-11-26 23:33:41 - [Sana] - INFO - Dataset samples: 4000
2024-11-26 23:33:41 - [Sana] - INFO - Loading external caption json from: original_filename['', '_InternVL2-26B', '_VILA1-5-13B'].json
2024-11-26 23:33:41 - [Sana] - INFO - Loading external clipscore json from: original_filename['_InternVL2-26B_clip_score', '_VILA1-5-13B_clip_score', '_prompt_clip_score'].json
2024-11-26 23:33:41 - [Sana] - INFO - external caption clipscore threshold: 25.0, temperature: 0.1
2024-11-26 23:33:41 - [Sana] - INFO - T5 max token length: 300
2024-11-26 23:33:41 - [Sana] - INFO - Dataset SanaImgDataset constructed: time: 0.00 s, length (use/ori): 4000/4000
2024-11-26 23:33:41 - [Sana] - INFO - Automatically adapt lr to 0.00001 (using sqrt scaling rule).
2024-11-26 23:33:41 - [Sana] - INFO - CAMEWrapper Optimizer: total 316 param groups, 316 are learnable, 0 are fix. Lr group: 316 params with lr 0.00001; Weight decay group: 316 params with weight decay 0.0.
2024-11-26 23:33:41 - [Sana] - INFO - Lr schedule: constant, num_warmup_steps:2000.
2024-11-26 23:33:41 - [Sana] - WARNING - Basic Setting: lr: 0.00001, bs: 1, gc: True, gc_accum_step: 1, qk norm: False, fp32 attn: True, attn type: linear, ffn type: glumbconv, text encoder: gemma-2-2b-it, captions: {'prompt': 1}, precision: fp16
2024-11-26 23:33:58 - [Sana] - INFO - Epoch: 1 | Global Step: 1 | Local Step: 1 // 4000, total_eta: 71 days, 3:34:13, epoch_eta:17:04:17, time: all:15.368, model:14.333, data:0.201, lm:0.304, vae:0.529, lr:3.125e-09, Cap: VILA1-5-13B, s:(32, 32), loss:4.3361, grad_norm:61.1122
2024-11-26 23:33:58 - [Sana] - INFO - Running validation... 
[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 441, in model_fn
[rank0]:     noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 386, in noise_pred_fn
[rank0]:     output = model(x, t_input, cond, **model_kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 348, in forward_with_dpmsolver
[rank0]:     model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 327, in forward
[rank0]:     x = auto_grad_checkpoint(
[rank0]:   File "/workspace/sana/diffusion/model/utils.py", line 72, in auto_grad_checkpoint
[rank0]:     return checkpoint(module, *args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, preserve, *args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 162, in forward
[rank0]:     x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_blocks.py", line 160, in forward
[rank0]:     qkv = self.qkv(x).reshape(B, N, 3, C)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 44.34 GiB of which 20.81 MiB is free. Process 265335 has 44.31 GiB memory in use. Of the allocated memory 42.84 GiB is allocated by PyTorch, and 837.16 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[rank0]: During handling of the above exception, another exception occurred:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 974, in <module>
[rank0]:     main()
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/pyrallis/argparsing.py", line 158, in wrapper_inner
[rank0]:     response = fn(cfg, *args, **kwargs)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 959, in main
[rank0]:     train(
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 479, in train
[rank0]:     log_validation(
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 154, in log_validation
[rank0]:     image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler)
[rank0]:   File "/workspace/sana/train_scripts/train.py", line 127, in run_sampling
[rank0]:     denoised = dpm_solver.sample(
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 1529, in sample
[rank0]:     model_prev_list = [self.model_fn(x, t)]
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 689, in model_fn
[rank0]:     return self.data_prediction_fn(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 677, in data_prediction_fn
[rank0]:     noise = self.noise_prediction_fn(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 671, in noise_prediction_fn
[rank0]:     return self.model(x, t)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 616, in <lambda>
[rank0]:     self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 443, in model_fn
[rank0]:     noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
[rank0]:   File "/workspace/sana/diffusion/model/dpm_solver.py", line 386, in noise_pred_fn
[rank0]:     output = model(x, t_input, cond, **model_kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 348, in forward_with_dpmsolver
[rank0]:     model_out = self.forward(x, timestep, y, data_info=data_info, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 823, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/accelerate/utils/operations.py", line 811, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/amp/autocast_mode.py", line 43, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 327, in forward
[rank0]:     x = auto_grad_checkpoint(
[rank0]:   File "/workspace/sana/diffusion/model/utils.py", line 72, in auto_grad_checkpoint
[rank0]:     return checkpoint(module, *args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_compile.py", line 31, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 600, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 481, in checkpoint
[rank0]:     return CheckpointFunction.apply(function, preserve, *args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/autograd/function.py", line 574, in apply
[rank0]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 255, in forward
[rank0]:     outputs = run_function(*args)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_multi_scale.py", line 162, in forward
[rank0]:     x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/workspace/sana/diffusion/model/nets/sana_blocks.py", line 160, in forward
[rank0]:     qkv = self.qkv(x).reshape(B, N, 3, C)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 117, in forward
[rank0]:     return F.linear(input, self.weight, self.bias)
[rank0]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 28.00 MiB. GPU 0 has a total capacity of 44.34 GiB of which 20.81 MiB is free. Process 265335 has 44.31 GiB memory in use. Of the allocated memory 42.87 GiB is allocated by PyTorch, and 800.33 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
E1126 15:34:00.242000 124914645387072 torch/distributed/elastic/multiprocessing/api.py:833] failed (exitcode: 1) local_rank: 0 (pid: 18768) of binary: /root/miniconda3/envs/sana/bin/python
Traceback (most recent call last):
  File "/root/miniconda3/envs/sana/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 348, in wrapper
    return f(*args, **kwargs)
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/run.py", line 901, in main
    run(args)
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 133, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/root/miniconda3/envs/sana/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train_scripts/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2024-11-26_15:34:00
  host      : c88159d783a4
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 18768)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
(sana) root@c88159d783a4:/workspace/sana# 
lawrence-cj commented 3 hours ago

Seems my training is under 48GB

bash train_scripts/train.sh configs/sana_config/1024ms/Sana_1600M_img1024.yaml --data.data_dir="[asset/example_data]" --data.type=SanaImgDataset --model.multi_scale=false --data.load_vae_feat=false --train.train_batch_size=1

refer to:

image

Actually if you switch optimizer type in config file to AdamW, the GPU memory will be less. We will update a newer Came in the future, which will occupy even less than AdamW.

train:
  optimizer:
    lr: 1.0e-4
    type: AdamW
    weight_decay: 0.01
    eps: 1.0e-8
    betas: [0.9, 0.999]

refer to:

image
recoilme commented 2 hours ago

But how you do it? You have 32x VAE vs 8x in SDXL, less model size and need more then 2.5x against SDXL for train in 1024?

lawrence-cj commented 2 hours ago

The model you are using is 1.6B and all the VAE and Text Encoder are all extracting feature online.

FurkanGozukara commented 2 hours ago

it is all about training scripts

Currently with using Kohya we are able to fully fine to 12 billion parameters FLUX dev in 16 bit precision even on 6 GB GPUs via using block swapping :)

lawrence-cj commented 1 hour ago

Cool. Hh, I can't imagine the speed. lol : )

recoilme commented 1 hour ago

i'm very hope what you will add some minimum optimizations in the future.. its dead for full fine tuning for now, A100 is very expensive Thx for reply and good model!

lawrence-cj commented 1 hour ago

What do you mean it's dead?

FurkanGozukara commented 1 hour ago

Cool. Hh, I can't imagine the speed. lol : )

with latest improvements speeds are really decent

rtx 3090 is 7 second per sample image - batch size 1 rtx 4090 is like 5 second per sample image RTX A6000 is like 6 second per sample image

recoilme commented 1 hour ago

dead

it's dead for full fine-tunings from GPU poor guys We rent GPU for train. It's very expensive. 48GPU+ for train with batch 1 - its stop factor for most of us We need latent/TE caching and multi aspect ratio and probably slow optimizer like adafactor for train fine details like eyes with low LR