pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 455 forks source link

NormFreeNets are not working in pytorch XLA #2776

Closed mobassir94 closed 3 years ago

mobassir94 commented 3 years ago

i was trying to use nf_resnet50 model in this notebook : https://www.kaggle.com/mobassir/faster-pytorch-tpu-baseline-for-cld-cv-0-9 from this repository : https://github.com/rwightman/pytorch-image-models

like other models,i followed these steps for using this model :

i include this updated dataset : https://www.kaggle.com/mobassir/timm2021 in my notebook(latest version of that timm repository), in this notebook : https://www.kaggle.com/mobassir/faster-pytorch-tpu-baseline-for-cld-cv-0-9

and then replace these imports :

package_paths = [
    '../input/pytorch-image-models/pytorch-image-models-master', #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
    '../input/image-fmix/FMix-master'
]
for pth in package_paths:
    sys.path.append(pth)

with this :

package_paths = [
    '../input/timm2021/pytorch-image-models-master', #'../input/efficientnet-pytorch-07/efficientnet_pytorch-0.7.0'
    '../input/image-fmix/FMix-master'
]
for pth in package_paths:
    sys.path.append(pth)

then i use for example,

kernel_type = 'nf_resnet50'

net_type = 'nf_resnet50'

and model class that i used is like this :

class cassavamodel(nn.Module):
    def __init__(self, model_arch, n_class, pretrained=True):
        super().__init__()
        self.model = timm.create_model(model_arch, pretrained=pretrained)
        n_features = self.model.head.fc.in_features
        self.model.head.fc = nn.Linear(n_features, n_class)

    def forward(self, x):
        x = self.model(x)
        return x

but when i try to train the model the code hangs here :


def _mp_fn(rank, flags):
    global acc_list
    torch.set_default_tensor_type('torch.FloatTensor')
    res = train_model()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

i waited almost 1 hour and nothing got printing and didn't get any error as well, all other models are working on tpu but don't know why i face this issue with nf_resnet50 or any other #NormFreeNets.

i've tested the same baseline on gpu and it works fine there, how can i solve this issue?i am working on kaggle kernel with tpu v3-8, thank you

taylanbil commented 3 years ago

AFAIU the newly introduced code for these is in the optimizer. Lack of batchnorm should not lead to xla-problems such as dynamic shapes, frequent context switches from device to cpu etc. If you wanted to test this, you could switch your optimizer to something that is known to be fast, such as torch.optim.SGD (for debugging purposes only).

If that doesn't work, could you collect some more debug info per the troubleshooting doc?

taylanbil commented 3 years ago

I was incidentally listening to https://www.youtube.com/watch?v=rNkHjZtH0RQ and around this section they talk about NF nets, and seems like the custom operation they do is in the model, not in the optimizer. This is very likely to be the culprit, the way they implemented in the repo you're using may not be compatible w/ XLA.

rwightman commented 3 years ago

@taylanbil I can't speak to the OP specific impl, but there is no special optimizer with these models. A gradient clipping impl is suggested by the authors but it is optional and can be used with standard grad clipping.

The only slightly unusual aspect of these models is weight standardization in convs, but is fairly straightforward PyTorch. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L68

taylanbil commented 3 years ago

right, the optimizer comment was me confusing the NFNets w/ Normalizer Free Resnets (recent paper by Google).

@rwightman I took a brief look @ the code, thanks for linking. I indeed quickly found the issue; I had a hunch that torch.std_mean wasn't lowered to xla because I'm seeing this for the first time. I put together the minimal test:

import torch.nn as nn
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met

layer = nn.Linear(100, 100)
d = xm.xla_device()
layer = layer.to(d)
std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
rep = met.metrics_report()
print(rep)

and observed that the aten counter is there in the metrics report:

...
Counter: XrtSubTuple_Empty
  Value: 128
Counter: aten::std_mean
  Value: 1
Counter: xla::_copy_from
  Value: 2
...

So this op need to be lowered before we can proceed with this workload.

rwightman commented 3 years ago

@taylanbil thanks for the quick response, for @mobassir94' benefit, I imagine var_mean (or at least separate var/std + mean) is lowered? So if he changed to code to my 'alternate' impl or something similar at https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L124 ... things should work?

EDIT: or possibly even my F.layer_norm hijack (just change the bool) if F.layer_norm is lowered

taylanbil commented 3 years ago

Unfortunately var_mean is not lowered either, Counter: aten::var_mean appears in the metrics report if I change to torch.var_mean.

taylanbil commented 3 years ago

separating works

d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)

results in

Counter: xla::mean
  Value: 1
Counter: xla::std
  Value: 1

so you could separate and it'll work. We should lower std_mean regardless though. @mobassir94 If youwant to unblock yourself and not wait for the lowering, you can separate as above.

taylanbil commented 3 years ago

Curious, @mobassir94 did you try separating std and mean? Are there any other issues you've encountered?

mobassir94 commented 3 years ago

@taylanbil i spend all my kaggle tpu quota for a kaggle competition, so i couldn't check this but tomorrow i will get back new 50 hour weekly tpu quota from kaggle,i will try this and will let you know then,thanks

mobassir94 commented 3 years ago

sorry for this silly question, i see @rwightman updated this file 2 days ago : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L124

so i am currently using the last updated version of timm and still it blocks me from using nf_resnet50 on torch xla,

i do not understand where to use this code :

d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)
rwightman commented 3 years ago

@mobassir94

If you are using the dm_ model variants with weights that require SAME padding support, you change this line https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L136 ... for the other models, this one https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L94

You could also try setting the use_layernorm=True in the layer args as I suspect layernorm is supported if people are using transformers w/ PyTorch + TPU. The layernorm option is almost identitcal (it's basically being used there to do the standardization in one kernel. Only differs in the handling of eps (inside sqrt w/ var). I haven't spent enough time validating it to use as default though.

mobassir94 commented 3 years ago

hi @rwightman inside ScaledStdConv2d class,i modified get_weight function like this :

from :

std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)

to :

std= torch.std(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
mean = torch.mean(self.weight.data, dim=[1, 2, 3], keepdim=True)
weight = (self.weight - mean) / (std + self.eps)

ALSO USING use_layernorm=True

also doing the same in ScaledStdConv2dSame class,am i right?

rwightman commented 3 years ago

@mobassir94 yes, that is correct, I suggest trying each approach and seeing what works better (faster). Don't think you want the .data there though for the mean.

mobassir94 commented 3 years ago

so should i replace mean = torch.mean(self.weight, dim=[1, 2, 3], keepdim=True) wtih :

mean = torch.mean(self.weight, dim=[1, 2, 3], keepdim=True)

right?

mobassir94 commented 3 years ago

i have updated std_conv.py file like this : https://pastebin.com/Yix7zdP2

then i tried dm_nfnet_f0 model and got this error :

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<timed exec> in <module>

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
    381   if not _is_xla_config():
    382     # If this is not an XLA setup, jump to normal multi-processing.
--> 383     return _run_direct(fn, args, nprocs, join, daemon, start_method)
    384 
    385   pf_cfg = _pre_fork_setup(nprocs)

/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in _run_direct(fn, args, nprocs, join, daemon, start_method)
    345   else:
    346     return torch.multiprocessing.spawn(
--> 347         fn, args=args, nprocs=nprocs, join=join, daemon=daemon)
    348 
    349 

/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
    197                ' torch.multiprocessing.start_process(...)' % start_method)
    198         warnings.warn(msg)
--> 199     return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')

/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
    155 
    156     # Loop on join until it returns True or raises an exception.
--> 157     while not context.join():
    158         pass
    159 

/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    110                 raise Exception(
    111                     "process %d terminated with exit code %d" %
--> 112                     (error_index, exitcode)
    113                 )
    114 

Exception: process 3 terminated with exit code 1

@taylanbil @rwightman

taylanbil commented 3 years ago

havent read the thread fully, but answering last comment:

that is not the real error, that s the xmp.spawn quitting because child processes crashed, see if you have another stack trace in stderr output.

taylanbil commented 3 years ago

when you are in dev mode like this,it makes sense to use one core instead of eight, you will see things more clearly and also you can use pdb.

mobassir94 commented 3 years ago

problem solved after restarting notebook,it works fine with the separation change @taylanbil suggested thanks @taylanbil and @rwightman for your time

taylanbil commented 3 years ago

could you file an issue to lower std_mean and var_mean so it doesn't get lost? Thanks @mobassir94

mobassir94 commented 3 years ago

@taylanbil should i create another xla issue for lowering std_mean and var_mean?

taylanbil commented 3 years ago

yes, thanks