Open Seanny123 opened 1 year ago
I had the same issue with EfficientAD, but it seems the bug has not been fixed yet.
yes , I had the same issue with PatchCore when use multi-GPU training
This is something we need to address. We move the metrics to CPU for computation to avoid OOM on GPU. But this breaks multi-gpu training. For now the solution is to remove all the calls that move outputs and metrics to CPU. We plan to make this configurable in the future.
@ashwinvaidya17 as I note in my original post, removing all the calls that move outputs to the CPU does not resolve the problem:
Getting around the
RuntimeError: Tensors must be CUDA and dense
error by removing all.cpu()
calls insrc/anomalib/models/components/base/anomaly_module.py
results in image_F1Score being 0.0 during both testing and validation.
We also have these lines here https://github.com/openvinotoolkit/anomalib/blob/main/src/anomalib/utils/callbacks/post_processing_configuration.py#L75. Did you try removing these as well?
Yes, when I remove these, the metric stops computing. image_F1Score
is forever 0.0 which is unhelpful
FYI, a different way to get around the RuntimeError: Tensors must be CUDA and dense
is to add:
from pytorch_lightning.utilities.distributed import gather_all_tensors
def all_gather_on_cuda(tensor: torch.Tensor, *args: Any, **kwargs: Any) -> list[torch.Tensor]:
original_device = tensor.device
return [
_tensor.to(original_device)
for _tensor in gather_all_tensors(tensor.to("cuda"), *args, **kwargs)
]
class AnomalyScoreThreshold(PrecisionRecallCurve):
to anomaly_score_threshold.py
.
this is still not fixed. I am still getting the same error. any suggestions please. I already removed all the "to_cpu()" command. And I added the above-mentioned lines of code.
work = group.allgather([tensor_list], [tensor])
RuntimeError: No backend type associated with device type cpu
FYI, a different way to get around the
RuntimeError: Tensors must be CUDA and dense
is to add:from pytorch_lightning.utilities.distributed import gather_all_tensors def all_gather_on_cuda(tensor: torch.Tensor, *args: Any, **kwargs: Any) -> list[torch.Tensor]: original_device = tensor.device return [ _tensor.to(original_device) for _tensor in gather_all_tensors(tensor.to("cuda"), *args, **kwargs) ] class AnomalyScoreThreshold(PrecisionRecallCurve):
to
anomaly_score_threshold.py
.
I tried to use this method, although no errors were reported, the result could be produced, but the F1 and and Auc accuracy became very low
That's interesting, because it worked for me on two datasets, but then had a major accuracy drop on a third dataset, so I assumed there was something wrong with my third, new, larger dataset. :thinking:
Describe the bug
Trying to do multi-GPU training of
fastflow
by setting the configstrategy: ddp
andoptimizer: gpu
, I get the error:Getting around the
RuntimeError: Tensors must be CUDA and dense
error by removing all.cpu()
calls insrc/anomalib/models/components/base/anomaly_module.py
results inimage_F1Score
being0.0
during both testing and validation.Why is
AnomalyScoreThreshold
incompatible multi-GPU training and how could it be modified to be compatible?Dataset
Other (please specify in the text field below)
Model
FastFlow
Steps to reproduce the behavior
See bug description.
OS information
OS information:
Expected behavior
I expected to be able to do multi-GPU training using Fastflow and for the F1 score to be non-zero.
Screenshots
No response
Pip/GitHub
GitHub
What version/branch did you use?
main
Configuration YAML
Logs
Code of Conduct