Closed mobassir94 closed 3 years ago
AFAIU the newly introduced code for these is in the optimizer. Lack of batchnorm should not lead to xla-problems such as dynamic shapes, frequent context switches from device to cpu etc. If you wanted to test this, you could switch your optimizer to something that is known to be fast, such as torch.optim.SGD
(for debugging purposes only).
If that doesn't work, could you collect some more debug info per the troubleshooting doc?
I was incidentally listening to https://www.youtube.com/watch?v=rNkHjZtH0RQ and around this section they talk about NF nets, and seems like the custom operation they do is in the model, not in the optimizer. This is very likely to be the culprit, the way they implemented in the repo you're using may not be compatible w/ XLA.
@taylanbil I can't speak to the OP specific impl, but there is no special optimizer with these models. A gradient clipping impl is suggested by the authors but it is optional and can be used with standard grad clipping.
The only slightly unusual aspect of these models is weight standardization in convs, but is fairly straightforward PyTorch. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L68
right, the optimizer comment was me confusing the NFNets w/ Normalizer Free Resnets (recent paper by Google).
@rwightman I took a brief look @ the code, thanks for linking. I indeed quickly found the issue; I had a hunch that torch.std_mean
wasn't lowered to xla because I'm seeing this for the first time. I put together the minimal test:
import torch.nn as nn
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
layer = nn.Linear(100, 100)
d = xm.xla_device()
layer = layer.to(d)
std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
rep = met.metrics_report()
print(rep)
and observed that the aten
counter is there in the metrics report:
...
Counter: XrtSubTuple_Empty
Value: 128
Counter: aten::std_mean
Value: 1
Counter: xla::_copy_from
Value: 2
...
So this op need to be lowered before we can proceed with this workload.
@taylanbil thanks for the quick response, for @mobassir94' benefit, I imagine var_mean (or at least separate var/std + mean) is lowered? So if he changed to code to my 'alternate' impl or something similar at https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L124 ... things should work?
EDIT: or possibly even my F.layer_norm hijack (just change the bool) if F.layer_norm is lowered
Unfortunately var_mean
is not lowered either, Counter: aten::var_mean
appears in the metrics report if I change to torch.var_mean
.
separating works
d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)
results in
Counter: xla::mean
Value: 1
Counter: xla::std
Value: 1
so you could separate and it'll work. We should lower std_mean
regardless though. @mobassir94 If youwant to unblock yourself and not wait for the lowering, you can separate as above.
Curious, @mobassir94 did you try separating std and mean? Are there any other issues you've encountered?
@taylanbil i spend all my kaggle tpu quota for a kaggle competition, so i couldn't check this but tomorrow i will get back new 50 hour weekly tpu quota from kaggle,i will try this and will let you know then,thanks
sorry for this silly question, i see @rwightman updated this file 2 days ago : https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L124
so i am currently using the last updated version of timm and still it blocks me from using nf_resnet50 on torch xla,
i do not understand where to use this code :
d = xm.xla_device()
layer = layer.to(d)
#std, mean = torch.var_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
#std, mean = torch.std_mean(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
std= torch.std(layer.weight, dim=[1, 0], keepdim=True, unbiased=False)
mean = torch.mean(layer.weight.data, dim=[1, 0], keepdim=True)
rep = met.metrics_report()
print(rep)
@mobassir94
If you are using the dm_
model variants with weights that require SAME padding support, you change this line https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L136
... for the other models, this one https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/std_conv.py#L94
You could also try setting the use_layernorm=True
in the layer args as I suspect layernorm is supported if people are using transformers w/ PyTorch + TPU. The layernorm option is almost identitcal (it's basically being used there to do the standardization in one kernel. Only differs in the handling of eps (inside sqrt w/ var). I haven't spent enough time validating it to use as default though.
hi @rwightman inside ScaledStdConv2d class,i modified get_weight function like this :
from :
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
to :
std= torch.std(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
mean = torch.mean(self.weight.data, dim=[1, 2, 3], keepdim=True)
weight = (self.weight - mean) / (std + self.eps)
ALSO USING use_layernorm=True
also doing the same in ScaledStdConv2dSame class,am i right?
@mobassir94 yes, that is correct, I suggest trying each approach and seeing what works better (faster). Don't think you want the .data
there though for the mean.
so should i replace mean = torch.mean(self.weight, dim=[1, 2, 3], keepdim=True) wtih :
mean = torch.mean(self.weight, dim=[1, 2, 3], keepdim=True)
right?
i have updated std_conv.py file like this : https://pastebin.com/Yix7zdP2
then i tried dm_nfnet_f0 model and got this error :
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<timed exec> in <module>
/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
381 if not _is_xla_config():
382 # If this is not an XLA setup, jump to normal multi-processing.
--> 383 return _run_direct(fn, args, nprocs, join, daemon, start_method)
384
385 pf_cfg = _pre_fork_setup(nprocs)
/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py in _run_direct(fn, args, nprocs, join, daemon, start_method)
345 else:
346 return torch.multiprocessing.spawn(
--> 347 fn, args=args, nprocs=nprocs, join=join, daemon=daemon)
348
349
/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in spawn(fn, args, nprocs, join, daemon, start_method)
197 ' torch.multiprocessing.start_process(...)' % start_method)
198 warnings.warn(msg)
--> 199 return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
155
156 # Loop on join until it returns True or raises an exception.
--> 157 while not context.join():
158 pass
159
/opt/conda/lib/python3.7/site-packages/torch/multiprocessing/spawn.py in join(self, timeout)
110 raise Exception(
111 "process %d terminated with exit code %d" %
--> 112 (error_index, exitcode)
113 )
114
Exception: process 3 terminated with exit code 1
@taylanbil @rwightman
havent read the thread fully, but answering last comment:
that is not the real error, that s the xmp.spawn quitting because child processes crashed, see if you have another stack trace in stderr output.
when you are in dev mode like this,it makes sense to use one core instead of eight, you will see things more clearly and also you can use pdb.
problem solved after restarting notebook,it works fine with the separation change @taylanbil suggested thanks @taylanbil and @rwightman for your time
could you file an issue to lower std_mean and var_mean so it doesn't get lost? Thanks @mobassir94
@taylanbil should i create another xla issue for lowering std_mean and var_mean?
yes, thanks
i was trying to use nf_resnet50 model in this notebook : https://www.kaggle.com/mobassir/faster-pytorch-tpu-baseline-for-cld-cv-0-9 from this repository : https://github.com/rwightman/pytorch-image-models
like other models,i followed these steps for using this model :
i include this updated dataset : https://www.kaggle.com/mobassir/timm2021 in my notebook(latest version of that timm repository), in this notebook : https://www.kaggle.com/mobassir/faster-pytorch-tpu-baseline-for-cld-cv-0-9
and then replace these imports :
with this :
then i use for example,
kernel_type = 'nf_resnet50'
net_type = 'nf_resnet50'
and model class that i used is like this :
but when i try to train the model the code hangs here :
i waited almost 1 hour and nothing got printing and didn't get any error as well, all other models are working on tpu but don't know why i face this issue with nf_resnet50 or any other #NormFreeNets.
i've tested the same baseline on gpu and it works fine there, how can i solve this issue?i am working on kaggle kernel with tpu v3-8, thank you