mihirp1998 / AlignProp

AlignProp uses direct reward backpropogation for the alignment of large-scale text-to-image diffusion models. Our method is 25x more sample and compute efficient than reinforcement learning methods (PPO) for finetuning Stable Diffusion
https://align-prop.github.io/
MIT License
242 stars 8 forks source link

accelerate config #5

Closed SkylerZheng closed 3 months ago

SkylerZheng commented 1 year ago

can you help share the accelerate config, it's not working from my side. Also, when i do evaluate with aesthetic reward model trained checkpoint, I got OOM issue, even though I have a P4 with 8A100s.

Do you have a separate inference code that can specify the image resolution to be generated?

mihirp1998 commented 1 year ago

Is only evaluation not working? or both training and evaluation not working?

SkylerZheng commented 1 year ago

evaluation is not working due to OOM issue. For training, I cannot use accelerate for 4 GPU model training, I can only use 1 GPU. Do you mind sharing your accelerate config file?

mihirp1998 commented 1 year ago

this is my default file, when i run accelerate config default: { "compute_environment": "LOCAL_MACHINE", "deepspeed_config": {}, "distributed_type": "MULTI_GPU", "downcast_bf16": false, "dynamo_config": {}, "fsdp_config": {}, "machine_rank": 0, "main_training_function": "main", "megatron_lm_config": {}, "mixed_precision": "no", "num_machines": 1, "num_processes": 4, "rdzv_backend": "static", "same_network": false, "tpu_use_cluster": false, "tpu_use_sudo": false, "use_cpu": false }

SkylerZheng commented 1 year ago

thank you very much for the sharing! I need to finetune my model and generate non square images, to do this, how shall i change the inference code?

mihirp1998 commented 1 year ago

Maybe changing the dimensions of this might be sufficient? https://github.com/mihirp1998/AlignProp/blob/858f4dc7f0833a5ef2c423bbd2bf590790e01c74/main.py#L440

SkylerZheng commented 1 year ago

@mihirp1998 Changing the dimensions can help generate non square images. But it does not work if I change the base model to sd2.1, resulting in nan in the loss.