VectorSpaceLab / OmniGen

OmniGen: Unified Image Generation. https://arxiv.org/pdf/2409.11340
MIT License
2.86k stars 223 forks source link

Please improve the fine-tuning script! #47

Open win10ogod opened 3 weeks ago

win10ogod commented 3 weeks ago

Please improve the fine-tuning script! After I solved this problem:

Traceback (most recent call last):
  File "E:\OmniGen\train.py", line 371, in <module>
    main(args)
  File "E:\OmniGen\train.py", line 239, in main
    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\distributed\c10d_logger.py", line 79, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\distributed\distributed_c10d.py", line 2277, in all_reduce
    group = _get_default_group()
            ^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\distributed\distributed_c10d.py", line 1025, in _get_default_group
    raise ValueError(
ValueError: Default process group has not been initialized, please make sure to call init_process_group.

New problem occurred:

[2024-10-28 13:54:25] (step=0000100) Train Loss: 0.4717, Train Steps/Sec: 0.57, Epoch: 0.44052863436123346, LR: 2e-05
Traceback (most recent call last):
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\peft\peft_model.py", line 529, in __getattr__
    return super().__getattr__(name)  # defer to nn.Module's logic
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'PeftModel' object has no attribute 'module'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\peft\tuners\lora\model.py", line 273, in __getattr__
    return super().__getattr__(name)  # defer to nn.Module's logic
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'LoraModel' object has no attribute 'module'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "E:\OmniGen\train.py", line 373, in <module>
    main(args)
  File "E:\OmniGen\train.py", line 266, in main
    model.module.save_pretrained(checkpoint_path)
    ^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\peft\peft_model.py", line 531, in __getattr__
    return getattr(self.base_model, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\peft\tuners\lora\model.py", line 275, in __getattr__
    return getattr(self.model, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\jmes1\AppData\Local\Programs\Python\Python311\Lib\site-packages\torch\nn\modules\module.py", line 1729, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'OmniGen' object has no attribute 'module'. Did you mean: 'modules'?
win10ogod commented 3 weeks ago

in addition... I have improved the compatibility of distributed training with a single card, so you can no longer only use multiple cards for training! train.zip

staoxiao commented 3 weeks ago

We haven't encountered such issues on our machines; this could be due to differences in library versions or the runtime environment. Anyway, I will check the fine-tuning script again.

win10ogod commented 3 weeks ago

We haven't encountered such issues on our machines; this could be due to differences in library versions or the runtime environment. Anyway, I will check the fine-tuning script again.

please check lora

staoxiao commented 3 weeks ago

@win10ogod , I am unable to reproduce your issue. You can refer to https://github.com/VectorSpaceLab/OmniGen/blob/main/requirements.txt, and try to install our environment.

werruww commented 3 weeks ago

!accelerate launch --num_processes=1 /content/OmniGen/train.py \ --model_name_or_path Shitao/OmniGen-v1 \ --batch_size_per_device 1 \ --condition_dropout_prob 0.01 \ --lr 1e-3 \ --use_lora \ --lora_rank 8 \ --json_file /content/OmniGen/toy_data/toy_subject_data.jsonl \ --image_path /content/OmniGen/toy_data/images \ --max_input_length_limit 18000 \ --keep_raw_resolution \ --max_image_size 1024 \ --gradient_accumulation_steps 1 \ --ckpt_every 10 \ --epochs 2 \ --log_every 1 \ --results_dir /content/results/toy_finetune_lora

The following values were not passed to accelerate launch and had defaults used instead: --num_machines was set to a value of 1 --mixed_precision was set to a value of 'no' --dynamo_backend was set to a value of 'no' To avoid this warning pass in values for each of the problematic parameters or run accelerate config. 2024-10-29 18:21:15.635201: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:95] Opening library: /usr/local/lib/python3.10/dist-packages/tensorflow/python/platform/../../libtensorflow_cc.so.2 2024-10-29 18:21:15.635390: I tensorflow/core/tpu/tpu_api_dlsym_initializer.cc:119] Libtpu path is: libtpu.so 2024-10-29 18:21:15.676883: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags. [2024-10-29 18:21:18] Experiment directory created at /content/results/toy_finetune_lora Fetching 10 files: 100% 10/10 [00:00<00:00, 122282.92it/s] [2024-10-29 18:21:18] Downloaded model to /root/.cache/huggingface/hub/models--Shitao--OmniGen-v1/snapshots/4636aebc43d6a56b512527e2bae26cdbc69337c2 Traceback (most recent call last): File "/content/OmniGen/train.py", line 373, in main(args) File "/content/OmniGen/train.py", line 71, in main model = OmniGen.from_pretrained(args.model_name_or_path) File "/content/OmniGen/OmniGen/model.py", line 197, in from_pretrained model = cls(config) File "/content/OmniGen/OmniGen/model.py", line 186, in init self.llm = Phi3Transformer(config=transformer_config) File "/usr/local/lib/python3.10/dist-packages/transformers/models/phi3/modeling_phi3.py", line 951, in init super().init(config) File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 1404, in init config = self._autoset_attn_implementation( File "/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py", line 1579, in _autoset_attn_implementation elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): File "/usr/local/lib/python3.10/dist-packages/transformers/utils/import_utils.py", line 634, in is_torch_xla_available import torch_xla File "/usr/local/lib/python3.10/dist-packages/torch_xla/init.py", line 20, in import _XLAC ImportError: /usr/local/lib/python3.10/dist-packages/_XLAC.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv Traceback (most recent call last): File "/usr/local/bin/accelerate", line 8, in sys.exit(main()) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/accelerate_cli.py", line 47, in main args.func(args) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 1023, in launch_command simple_launcher(args) File "/usr/local/lib/python3.10/dist-packages/accelerate/commands/launch.py", line 643, in simple_launcher raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) subprocess.CalledProcessError: Command '['/usr/bin/python3', '/content/OmniGen/train.py', '--model_name_or_path', 'Shitao/OmniGen-v1', '--batch_size_per_device', '1', '--condition_dropout_prob', '0.01', '--lr', '1e-3', '--use_lora', '--lora_rank', '8', '--json_file', '/content/OmniGen/toy_data/toy_subject_data.jsonl', '--image_path', '/content/OmniGen/toy_data/images', '--max_input_length_limit', '18000', '--keep_raw_resolution', '--max_image_size', '1024', '--gradient_accumulation_steps', '1', '--ckpt_every', '10', '--epochs', '2', '--log_every', '1', '--results_dir', '/content/results/toy_finetune_lora']' returned non-zer

werruww commented 3 weeks ago

colab tpu

yukiarimo commented 3 weeks ago

Don’t use TPU, switch to the GPU

win10ogod commented 2 weeks ago

@win10ogod , I am unable to reproduce your issue. You can refer to https://github.com/VectorSpaceLab/OmniGen/blob/main/requirements.txt, and try to install our environment.

When I use Windows system, I am always unable to use single card training.

Geruldas commented 2 weeks ago

@win10ogod , I am unable to reproduce your issue. You can refer to https://github.com/VectorSpaceLab/OmniGen/blob/main/requirements.txt, and try to install our environment.

When I use Windows system, I am always unable to use single card training.

hey, resolved with the unwrapping of the class for multigpu: `import json from time import time import argparse import logging import os from pathlib import Path import math from torch.distributed import init_process_group, destroy_process_group

import numpy as np from PIL import Image from copy import deepcopy

import torch import torch.distributed as dist from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from torchvision import transforms

from accelerate import Accelerator from accelerate.utils import ProjectConfiguration, set_seed from diffusers.optimization import get_scheduler from accelerate.utils import DistributedType from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model from peft.utils import get_peft_model_state_dict from huggingface_hub import snapshot_download from safetensors.torch import save_file

from diffusers.models import AutoencoderKL

from OmniGen import OmniGen, OmniGenProcessor from OmniGen.train_helper import DatasetFromJson, TrainDataCollator from OmniGen.train_helper import training_losses from OmniGen.utils import ( create_logger, update_ema, requires_grad, center_crop_arr, crop_arr, vae_encode, vae_encode_list )

def main(args):

Setup accelerator

accelerator = Accelerator(
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    mixed_precision=args.mixed_precision,
    log_with=args.report_to,
    project_dir=args.results_dir,
)
device = accelerator.device
accelerator.init_trackers("tensorboard_log", config=args.__dict__)

# Setup an experiment folder:
checkpoint_dir = os.path.join(args.results_dir, "checkpoints")  # Stores saved model checkpoints
logger = create_logger(args.results_dir)
if accelerator.is_main_process:
    os.makedirs(checkpoint_dir, exist_ok=True)
    logger.info(f"Experiment directory created at {args.results_dir}")
    json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w'), indent=4)

# Create model:
if not os.path.exists(args.model_name_or_path):
    cache_folder = os.getenv('HF_HUB_CACHE', "./hf_cache")
    args.model_name_or_path = snapshot_download(
        repo_id=args.model_name_or_path,
        cache_dir=cache_folder,
        ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']
    )
    logger.info(f"Downloaded model to {args.model_name_or_path}")

model = OmniGen.from_pretrained(args.model_name_or_path)
model.llm.config.use_cache = False
model.llm.gradient_checkpointing_enable()
model = model.to(device)

# Setup VAE
if args.vae_path is None:
    logger.info(f"Model path: {args.model_name_or_path}")
    vae_path = os.path.join(args.model_name_or_path, "vae")
    if os.path.exists(vae_path):
        vae = AutoencoderKL.from_pretrained(vae_path).to(device)
    else:
        logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF")
        logger.info("If you have VAE in local folder, please specify the path with --vae_path")
        vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device)
else:
    vae = AutoencoderKL.from_pretrained(args.vae_path).to(device)

# Set weight dtype based on mixed precision
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
    weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
    weight_dtype = torch.bfloat16
vae.to(dtype=torch.float32)
model.to(weight_dtype)

# Initialize processor
processor = OmniGenProcessor.from_pretrained(args.model_name_or_path)

# Freeze VAE parameters
requires_grad(vae, False)

# Setup LoRA if enabled
if args.use_lora:
    if accelerator.distributed_type == DistributedType.FSDP:
        raise NotImplementedError("FSDP does not support LoRA")
    requires_grad(model, False)
    transformer_lora_config = LoraConfig(
        r=args.lora_rank,
        lora_alpha=args.lora_rank,
        init_lora_weights="gaussian",
        target_modules=["qkv_proj", "o_proj"],
    )
    model.llm.enable_input_require_grads()
    model = get_peft_model(model, transformer_lora_config)
    model.to(weight_dtype)
    transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
    for name, param in model.named_parameters():
        print(f"{name}: requires_grad={param.requires_grad}")
    optimizer = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay)
else:
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay)

# Setup EMA if enabled
ema = None
if args.use_ema:
    ema = deepcopy(model).to(device)  # Create an EMA of the model for use after training
    requires_grad(ema, False)

# Setup data transformations
crop_func = center_crop_arr if not args.keep_raw_resolution else crop_arr
image_transform = transforms.Compose([
    transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
])

# Initialize dataset and dataloader
dataset = DatasetFromJson(
    json_file=args.json_file,
    image_path=args.image_path,
    processer=processor,
    image_transform=image_transform,
    max_input_length_limit=args.max_input_length_limit,
    condition_dropout_prob=args.condition_dropout_prob,
    keep_raw_resolution=args.keep_raw_resolution
)
collate_fn = TrainDataCollator(
    pad_token_id=processor.text_tokenizer.eos_token_id,
    hidden_size=model.llm.config.hidden_size,
    keep_raw_resolution=args.keep_raw_resolution
)

dataloader = DataLoader(
    dataset,
    collate_fn=collate_fn,
    batch_size=args.batch_size_per_device,
    shuffle=True,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

if accelerator.is_main_process:
    logger.info(f"Dataset contains {len(dataset):,} samples")

# Calculate training steps
num_update_steps_per_epoch = math.ceil(len(dataloader) / args.gradient_accumulation_steps)
max_train_steps = args.epochs * num_update_steps_per_epoch

# Setup learning rate scheduler
lr_scheduler = get_scheduler(
    args.lr_scheduler,
    optimizer=optimizer,
    num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
    num_training_steps=max_train_steps * args.gradient_accumulation_steps,
)

# Prepare models, optimizer, dataloader, and scheduler with accelerator
if ema is not None:
    model, ema = accelerator.prepare(model, ema)
else:
    model = accelerator.prepare(model)

optimizer, dataloader, lr_scheduler = accelerator.prepare(optimizer, dataloader, lr_scheduler)

# Initialize EMA if used
if ema is not None:
    update_ema(ema, model, decay=0)  # Ensure EMA is initialized with synced weights
    ema.eval()  # EMA model should always be in eval mode

# Set model to training mode
model.train()  # Important! This enables embedding dropout for classifier-free guidance

# Variables for monitoring/logging purposes
train_steps, log_steps = 0, 0
running_loss = 0
start_time = time()

if accelerator.is_main_process:
    logger.info(f"Training for {args.epochs} epochs...")

for epoch in range(args.epochs):
    if accelerator.is_main_process:
        logger.info(f"Beginning epoch {epoch + 1}/{args.epochs}...")

    for batch in dataloader:
        with accelerator.accumulate(model):
            # Encode images with VAE
            with torch.no_grad():
                output_images = batch['output_images']
                input_pixel_values = batch.get('input_pixel_values', None)

                if isinstance(output_images, list):
                    output_images = vae_encode_list(vae, output_images, weight_dtype)
                    if input_pixel_values is not None:
                        input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype)
                else:
                    output_images = vae_encode(vae, output_images, weight_dtype)
                    if input_pixel_values is not None:
                        input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype)

            # Prepare model inputs
            model_kwargs = {
                'input_ids': batch['input_ids'],
                'input_img_latents': input_pixel_values,
                'input_image_sizes': batch['input_image_sizes'],
                'attention_mask': batch['attention_mask'],
                'position_ids': batch['position_ids'],
                'padding_latent': batch['padding_images'],
                'past_key_values': None,
                'return_past_key_values': False
            }

            # Compute training losses
            loss_dict = training_losses(model, output_images, model_kwargs)
            loss = loss_dict["loss"].mean()

            # Accumulate loss
            running_loss += loss.item()

            # Backpropagation
            accelerator.backward(loss)

            # Gradient clipping
            if args.max_grad_norm is not None and accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            # Optimizer step
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

            # Update counters
            log_steps += 1
            train_steps += 1

            # Logging
            accelerator.log({"training_loss": loss.item()}, step=train_steps)
            if train_steps % args.gradient_accumulation_steps == 0:
                if accelerator.sync_gradients and ema is not None:
                    update_ema(ema, model)

            # Periodic logging
            if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
                torch.cuda.synchronize()
                end_time = time()
                steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time)

                # Reduce loss across all processes
                avg_loss = torch.tensor(running_loss / log_steps, device=device)
                if accelerator.num_processes > 1:
                    dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM)
                    avg_loss = avg_loss.item() / accelerator.num_processes
                else:
                    avg_loss = avg_loss.item()

                if accelerator.is_main_process:
                    current_lr = optimizer.param_groups[0]["lr"]
                    current_epoch = train_steps / len(dataloader)
                    logger.info(
                        f"(step={int(train_steps / args.gradient_accumulation_steps):07d}) "
                        f"Train Loss: {avg_loss:.4f}, "
                        f"Train Steps/Sec: {steps_per_sec:.2f}, "
                        f"Epoch: {current_epoch:.2f}, "
                        f"LR: {current_lr}"
                    )

                # Reset monitoring variables
                running_loss = 0
                log_steps = 0
                start_time = time()

        # Periodic checkpoint saving
        if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0:
            if accelerator.distributed_type == DistributedType.FSDP:
                state_dict = accelerator.get_state_dict(model)
                ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None
            else:
                if not args.use_lora:
                    state_dict = accelerator.unwrap_model(model).state_dict()
                    ema_state_dict = accelerator.unwrap_model(ema).state_dict() if ema is not None else None
                else:
                    # For LoRA, use the provided save_pretrained method
                    state_dict = None
                    ema_state_dict = None

            if accelerator.is_main_process:
                checkpoint_step = int(train_steps / args.gradient_accumulation_steps)
                checkpoint_path = os.path.join(checkpoint_dir, f"{checkpoint_step:07d}")
                os.makedirs(checkpoint_path, exist_ok=True)

                if args.use_lora:
                    # Save LoRA-specific checkpoints
                    model.save_pretrained(checkpoint_path)
                else:
                    # Save main model checkpoint
                    torch.save(state_dict, os.path.join(checkpoint_path, "model.pt"))
                    processor.text_tokenizer.save_pretrained(checkpoint_path)
                    model.llm.config.save_pretrained(checkpoint_path)

                    # Save EMA checkpoint if applicable
                    if ema_state_dict is not None:
                        ema_checkpoint_path = f"{checkpoint_path}_ema"
                        os.makedirs(ema_checkpoint_path, exist_ok=True)
                        torch.save(ema_state_dict, os.path.join(ema_checkpoint_path, "model.pt"))
                        processor.text_tokenizer.save_pretrained(ema_checkpoint_path)
                        model.llm.config.save_pretrained(ema_checkpoint_path)

                logger.info(f"Saved checkpoint to {checkpoint_path}")

            if accelerator.num_processes > 1:
                dist.barrier()

# Finalize training
accelerator.end_training()
model.eval()

if accelerator.is_main_process:
    logger.info("Training completed successfully!")

if name == "main": parser = argparse.ArgumentParser(description="OmniGen Training Script")

# Directory and Path Arguments
parser.add_argument("--results_dir", type=str, default="results", help="Directory to store results and checkpoints.")
parser.add_argument("--model_name_or_path", type=str, default="OmniGen", help="Path to the pretrained model or model identifier from huggingface.co.")
parser.add_argument("--json_file", type=str, required=True, help="Path to the JSON file containing dataset annotations.")
parser.add_argument("--image_path", type=str, default=None, help="Path to the directory containing images.")

# Training Hyperparameters
parser.add_argument("--epochs", type=int, default=1400, help="Number of training epochs.")
parser.add_argument("--batch_size_per_device", type=int, default=1, help="Batch size per device (GPU/CPU).")
parser.add_argument("--vae_path", type=str, default=None, help="Path to the VAE model.")
parser.add_argument("--num_workers", type=int, default=4, help="Number of worker threads for data loading.")
parser.add_argument("--log_every", type=int, default=100, help="Logging frequency (in steps).")
parser.add_argument("--ckpt_every", type=int, default=20000, help="Checkpoint saving frequency (in steps).")
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for clipping.")
parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--max_input_length_limit", type=int, default=1024, help="Maximum input sequence length.")
parser.add_argument("--condition_dropout_prob", type=float, default=0.1, help="Probability for condition dropout.")
parser.add_argument("--adam_weight_decay", type=float, default=0.0, help="Weight decay for Adam optimizer.")

# Resolution and Image Size
parser.add_argument("--keep_raw_resolution", action="store_true", help="Keep raw image resolutions.")
parser.add_argument("--max_image_size", type=int, default=1344, help="Maximum image size (must be divisible by 16).")

# LoRA Configuration
parser.add_argument("--use_lora", action="store_true", help="Enable LoRA for model training.")
parser.add_argument("--lora_rank", type=int, default=8, help="Rank for LoRA.")

# EMA Configuration
parser.add_argument("--use_ema", action="store_true", help="Enable Exponential Moving Average for model parameters.")

# Scheduler and Optimization
parser.add_argument(
    "--lr_scheduler",
    type=str,
    default="constant",
    choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
    help="Learning rate scheduler type."
)
parser.add_argument("--lr_warmup_steps", type=int, default=1000, help="Number of warmup steps for the scheduler.")

# Reporting and Logging
parser.add_argument(
    "--report_to",
    type=str,
    default="tensorboard",
    choices=["tensorboard", "wandb", "comet_ml", "all"],
    help="Integration to report logs to."
)

# Mixed Precision
parser.add_argument(
    "--mixed_precision",
    type=str,
    default="bf16",
    choices=["no", "fp16", "bf16"],
    help=(
        "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). "
        "Bf16 requires PyTorch >= 1.10 and an Nvidia Ampere GPU. Defaults to the value of accelerate config."
    ),
)

# Gradient Accumulation
parser.add_argument(
    "--gradient_accumulation_steps",
    type=int,
    default=1,
    help="Number of update steps to accumulate before performing a backward/update pass.",
)

args = parser.parse_args()

# Validate arguments
assert args.max_image_size % 16 == 0, "Image size must be divisible by 16."

main(args)

`

staoxiao commented 1 week ago

I have updated the train.py. You can try the latest code. Feel free to ask me if you have any questions.