Closed simeetnayan81 closed 3 months ago
Tests should be added to end of the test_metric.py file?
Yes, you can add it in the end of the file
skip_unrolling = False
is already covered by all the prior tests. I have added a test for when skip_unrolling = True
. Kindly review and let me know the changes.
@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this. https://github.com/pytorch/ignite/blob/master/ignite/metrics/loss.py#L77 Prev:
def __init__(
self,
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
):
super(Loss, self).__init__(output_transform, device=device)
Change to:
def __init__(
self,
loss_fn: Callable,
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling=False
):
super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR.
Things to do in a follow-up PR.
skip_unrolling
arg as required, add tests and docstringThanks for the updates and the TODO. Can we do this point here ?
Add test for updated Loss class
Alright @vfdev-5
Have made the changes, the new test works locally.
The test is failing because list[torch.Tensor, torch.Tensor]
is supported on python 3.9 and above. Let me modify it a bit.
Fixes #2940
Description: Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977
Check list: