Kwai-Kolors / Kolors

Kolors Team
Apache License 2.0
3.86k stars 268 forks source link

Is there a plan to release finetune model code #8

Open kigy1 opened 4 months ago

kigy1 commented 4 months ago

Is there a plan to release finetune code for model

zcdliuwei commented 4 months ago

+1

maybleMyers commented 4 months ago

I think so. But they are probably going to do LORA first according to the readme.

dcfucheng commented 4 months ago

Is there a plan to release finetune code for model

Maybe you can refer to the finetune code of SDXL lora and full.

The main difference: Change the text_encoder from two CLIP to GLM.

The pseudo-code is as follows:

# load text encoder
tokenizer = ChatGLMTokenizer.from_pretrained(
        args.text_encoder_id,
        revision=args.revision,
        trust_remote_code=True)
text_encoder = ChatGLMModel.from_pretrained(
    args.text_encoder_id,
    trust_remote_code=True)

# inputs 
image, text_ids = batch['pixel_values'], batch['input_ids']

# text_encoder 
output = text_encoder(
    input_ids=text_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    output_hidden_states=True)

# glm output
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() 

# get latents
latents = vae.encode(image.to(torch.float32)).latent_dist.sample()
latents = latents.to(dtype=(weight_dtype))
latents = latents * vae.config.scaling_factor

add_time_ids = batch['add_time_ids'].to(latents.device)
encoded_text = {"text_embeds": text_proj, "time_ids": add_time_ids}

# get noise
noise_pred = unet(
    noisy_model_input,
    start_timesteps,
    timestep_cond=None,
    encoder_hidden_states=prompt_embeds,
    added_cond_kwargs=encoded_text,
).sample

Then, the training phase is similar to SDXL training.

cugzhengzhimin commented 4 months ago

是否有计划发布模型微调代码

也许你可以参考SDXL lora的微调代码和完整版

主要区别:将text_encoder从两个CLIP更改为GLM。

伪代码如下:

# load text encoder
tokenizer = ChatGLMTokenizer.from_pretrained(
        args.text_encoder_id,
        revision=args.revision,
        trust_remote_code=True)
text_encoder = ChatGLMModel.from_pretrained(
    args.text_encoder_id,
    trust_remote_code=True)

# inputs 
image, text_ids = batch['pixel_values'], batch['input_ids']

# text_encoder 
output = text_encoder(
    input_ids=text_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    output_hidden_states=True)

# glm output
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() 

# get latents
latents = vae.encode(image.to(torch.float32)).latent_dist.sample()
latents = latents.to(dtype=(weight_dtype))
latents = latents * vae.config.scaling_factor

add_time_ids = batch['add_time_ids'].to(latents.device)
encoded_text = {"text_embeds": text_proj, "time_ids": add_time_ids}

# get noise
noise_pred = unet(
    noisy_model_input,
    start_timesteps,
    timestep_cond=None,
    encoder_hidden_states=prompt_embeds,
    added_cond_kwargs=encoded_text,
).sample

然后,训练阶段与 SDXL 训练类似。

这个经过实验验证了吗?结果好像无法和推理对齐

maybleMyers commented 4 months ago

是否有计划发布模型微调代码

也许你可以参考SDXL lora的微调代码和完整版。 主要区别:将text_encoder从两个CLIP更改为GLM。 伪代码如下:

# load text encoder
tokenizer = ChatGLMTokenizer.from_pretrained(
        args.text_encoder_id,
        revision=args.revision,
        trust_remote_code=True)
text_encoder = ChatGLMModel.from_pretrained(
    args.text_encoder_id,
    trust_remote_code=True)

# inputs 
image, text_ids = batch['pixel_values'], batch['input_ids']

# text_encoder 
output = text_encoder(
    input_ids=text_ids,
    attention_mask=attention_mask,
    position_ids=position_ids,
    output_hidden_states=True)

# glm output
prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
text_proj = output.hidden_states[-1][-1, :, :].clone() 

# get latents
latents = vae.encode(image.to(torch.float32)).latent_dist.sample()
latents = latents.to(dtype=(weight_dtype))
latents = latents * vae.config.scaling_factor

add_time_ids = batch['add_time_ids'].to(latents.device)
encoded_text = {"text_embeds": text_proj, "time_ids": add_time_ids}

# get noise
noise_pred = unet(
    noisy_model_input,
    start_timesteps,
    timestep_cond=None,
    encoder_hidden_states=prompt_embeds,
    added_cond_kwargs=encoded_text,
).sample

然后,训练阶段与 SDXL 训练类似。

这个经过实验验证了吗?结果好像无法和推理对齐

It is just pseudo code and the noise scheduler for kolors is different so it needs refinement.

Kyushik commented 1 month ago

I forked this repo and changed train_dreambooth_lora.pyfor lora and full finetuning. I am still testing it's performance, but I think it works well.

https://github.com/Kyushik/Kolors/tree/master/training

run train.sh for full finetuning and run train_lora.sh for lora finetuning

Deng-Xian-Sheng commented 1 week ago

I forked this repo and changed train_dreambooth_lora.pyfor lora and full finetuning. I am still testing it's performance, but I think it works well.

https://github.com/Kyushik/Kolors/tree/master/training

run train.sh for full finetuning and run train_lora.sh for lora finetuning

80G memory fail run.

this code issue?

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 9.61 GiB. GPU 0 has a total capacity of 79.20 GiB of which 8.37 GiB is free. Including non-PyTorch memory, this process has 70.82 GiB memory in use. Of the allocated memory 50.75 GiB is allocated by PyTorch, and 19.01 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Deng-Xian-Sheng commented 1 week ago

I forked this repo and changed train_dreambooth_lora.pyfor lora and full finetuning. I am still testing it's performance, but I think it works well.

https://github.com/Kyushik/Kolors/tree/master/training

run train.sh for full finetuning and run train_lora.sh for lora finetuning

You test this is full train?

I next test,but no one more time big GPU

Kyushik commented 1 week ago

@Deng-Xian-Sheng

I tested it with 80G memory GPU. However if you have OOM error, you should change batch size to 1