fidelity / stoke

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.
https://fidelity.github.io/stoke/
Apache License 2.0
66 stars 3 forks source link

How to use a loss function which itself is a CNN!?? DDP training issue #32

Closed rushi-the-neural-arch closed 2 years ago

rushi-the-neural-arch commented 2 years ago

Describe the bug

This is not a bug exactly but I need some context regarding how to implement this. I am trying to implement a novel loss function, ref - https://github.com/gfxdisp/mdf. The gist of it is to use a pre-trained neural architecture for low-level vision tasks like Image Denoising, SR etc. So, we would be using the discriminator (CNN) itself as a loss function here (The CNN accepts input the model's perdiction and gives out some metrics). But the issue is I couldn't implement it in a compatible way with Stoke which leads me to the standard error: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! ...... Can you please suggest me a way to mitigate this or how to efficiently handle this task??

To Reproduce

Steps to reproduce the behavior:

  1. Code/Pseudo-code is -
from mdfloss import MDFLoss
path_disc = "mdf/weights/Ds_SISR.pth"
loss = MDFLoss(path_disc, cuda_available=True)

stoke_model = Stoke(
    model=model,
    verbose=True,    
    optimizer=optimizer,
    loss=loss,
    batch_size_per_device= opt.batchSize,   
    gpu=True,   
    fp16= None, #FP16Options.amp.value, 
    distributed=DistributedOptions.ddp.value,
    fairscale_oss=True, 
    fairscale_sddp=True, 
    grad_accum_steps=1,
    configs= [amp_config, ddp_config, oss_config],     
    grad_clip=ClipGradNormConfig(max_norm = opt.grad_clip, norm_type=2.0),
)

def train(train_dataloader, stoke_model: Stoke, scheduler1, scheduler2, epoch: int):

    example_ct = 0  # number of examples seen
    batch_ct = 0
    sum_loss = 0

    stoke_model.print_on_devices(f"Starting Epoch {epoch + 1}")
    stoke_model.model_access.train()

    for idx, (inputs, targets) in enumerate(train_dataloader):

        # call the model through the stoke onkect interface
        outputs = stoke_model.model(inputs)
        train_loss = stoke_model.loss(outputs, targets)

        stoke_model.print_ema_loss(prepend_msg=f"Step {idx+1} -- EMA Loss")

        # Call backward through the stoke object interface
        stoke_model.backward(loss=train_loss)

        # Call step through the stoke object interface
        stoke_model.step()
        scheduler1.step()
        scheduler2.step

        sum_loss += stoke_model.detach_and_sync_loss(loss=train_loss)

        example_ct +=  len(inputs)
        batch_ct += 1

        # Report metrics every 50th batch
        if ((batch_ct + 1) % 50) == 0:
            train_log(train_loss, example_ct, epoch)
            #print(train_loss,  example_ct, epoch)

    avg_loss = sum_loss / len(train_dataloader)

    return avg_loss
  1. Ran config as - env CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 Stoke-DDP.py --projectName "Stoke-4K-2X-DDP" --batchSize 18 --nEpochs 2 --lr 1e-3 --weight_decay 1e-4 --grad_clip 0.1

  2. Error produced is - RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

Environment:

Thanks!

ncilfone commented 2 years ago

hi @rushi-the-neural-arch. was away for a bit hence the lack of response. I'll look into this over the next few days. Hopefully it's a simple issue!

ncilfone commented 2 years ago

Looks like they are using torch.load to load something without any regards to devices (see here)...

You'll most likely have to handle all the placement manually :-/

rushi-the-neural-arch commented 2 years ago

Yeah I was afraid about the same, manual handling is a pain! No worries I am closing this issue as of now but let me know if there's any different way to handle this, thanks!