catalyst-team / catalyst

Accelerated deep learning R&D
https://catalyst-team.com
Apache License 2.0
3.29k stars 390 forks source link

FocalLossMultiClass does not work on TPU's #1347

Closed SupreethRao99 closed 3 years ago

SupreethRao99 commented 3 years ago

When I try and use FocalLossMultiClass as a loss function on TPU's I get the following message RuntimeError: Unimplemented backend XLA

The full error message is as follows

Exception in device=TPU:0: Unimplemented backend XLA
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 831, in _run_stage
    self._run_epoch()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 824, in _run_epoch
    self._run_loader()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 815, in _run_loader
    self._run_batch()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 805, in _run_batch
    self._run_event("on_batch_end")
Exception in device=TPU:5: Unimplemented backend XLA  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 786, in _run_event
    getattr(callback, event)(self)

  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/metric.py", line 207, in on_batch_end
    metrics = self._update_metric(metrics_inputs)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/metric.py", line 152, in _update_value_metric
    return self._metric_update_method(*value_inputs)
  File "/opt/conda/lib/python3.7/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
    fn(gindex, *args)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/_functional_metric.py", line 96, in update_key_value
    value = self.update(batch_size, *args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 831, in _run_stage
    self._run_epoch()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/_functional_metric.py", line 80, in update
    value = self.metric_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 824, in _run_epoch
    self._run_loader()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/criterion.py", line 108, in _metric_fn
    return self.criterion(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 815, in _run_loader
    self._run_batch()
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 805, in _run_batch
    self._run_event("on_batch_end")
  File "/opt/conda/lib/python3.7/site-packages/catalyst/core/runner.py", line 786, in _run_event
    getattr(callback, event)(self)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/contrib/nn/criterion/focal.py", line 95, in forward
    loss += self.loss_fn(cls_label_input, cls_label_target)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/metric.py", line 207, in on_batch_end
    metrics = self._update_metric(metrics_inputs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/functional/_focal.py", line 33, in sigmoid_focal_loss
    targets = targets.type(outputs.type())
  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/metric.py", line 152, in _update_value_metric
    return self._metric_update_method(*value_inputs)
RuntimeError: Unimplemented backend XLA
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/_functional_metric.py", line 96, in update_key_value
    value = self.update(batch_size, *args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/_functional_metric.py", line 80, in update
    value = self.metric_fn(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/callbacks/criterion.py", line 108, in _metric_fn
    return self.criterion(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/contrib/nn/criterion/focal.py", line 95, in forward
    loss += self.loss_fn(cls_label_input, cls_label_target)
  File "/opt/conda/lib/python3.7/site-packages/catalyst/metrics/functional/_focal.py", line 33, in sigmoid_focal_loss
    targets = targets.type(outputs.type())
RuntimeError: Unimplemented backend XLA

Checklist

FAQ

Please review the FAQ before submitting an issue:

github-actions[bot] commented 3 years ago

Hi! Thank you for your contribution! Please re-check all issue template checklists - unfilled issues would be closed automatically. And do not forget to join our slack for collaboration.

Scitator commented 3 years ago

Hi, yeah, that could be. Could you please try to submit an issue to pytorch/xla guys? Maybe, they could help with the XLA backend implementation ;)

SupreethRao99 commented 3 years ago

sure , will do, Thanks !