kathrinse / be_great

A novel approach for synthesizing tabular data using pretrained large language models
MIT License
252 stars 41 forks source link

Adding Native Distributed Data Parallels Support #50

Open hiberfil opened 2 months ago

hiberfil commented 2 months ago

Hi, I was wondering if there were any efforts on great.py natively supporting Distributed Data Parallels? Currently I am doing a workaround by editing my own trainer file and saving it via torch save.

Below is how I invoke it.

torchrun --nproc_per_node=8 ddptest.py

import os
import pandas as pd
from be_great import GReaT
import torch.distributed as dist
import torch
from collections import OrderedDict

def main():
    # Set CUDA devices for each process
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)

    dataFile = "/edit/for/your/own/repo.csv"
    data = pd.read_csv(dataFile)

    great = GReaT("gpt2-xl",         
                      batch_size=8,
                      epochs=50,                           
                      fp16=True
                     )

   # Move the model to the appropriate GPU
    great.model.to(local_rank)  

    # Wrap the model for distributed training
    great.model = torch.nn.parallel.DistributedDataParallel(
        great.model, device_ids=[local_rank], output_device=local_rank
    )

    trainer = great.fit(data, data.columns.to_list())

        # Save the model only from rank 0 process
    if dist.get_rank() == 0:
        # Create a new state dict with corrected key names
        state_dict = great.model.state_dict()
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = k[7:]  # remove `module.`
            new_state_dict[name] = v

        # Save the model with the modified state dictl
        torch.save(new_state_dict, "/edit/for/your/own/model.pt")

if __name__ == "__main__":
    # Initialize the distributed process group
    dist.init_process_group(backend="nccl") 
    main()

Again thank you so much for this awesome framework.

unnir commented 2 months ago

Hi @hiberfil,

Thank you for choosing our framework :)

So far we do not have plans about adding native distributed data parallels support. However, it will be great to have, therefore any contributions are very welcome.

Also, thank you for providing a simple workaround script, it will be definitely useful for others!