pytorch / ignite

High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.
https://pytorch-ignite.ai
BSD 3-Clause "New" or "Revised" License
4.51k stars 612 forks source link

Introduce a variable skip_unrolling in class Metric #3258

Closed simeetnayan81 closed 3 months ago

simeetnayan81 commented 3 months ago

Fixes #2940

Description: Introduce a variable skip_unrolling in class Metric as discussed here https://discord.com/channels/831462531327328276/1110662056622964860/1253769540710567977

Check list:

simeetnayan81 commented 3 months ago

Tests should be added to end of the test_metric.py file?

vfdev-5 commented 3 months ago

Yes, you can add it in the end of the file

simeetnayan81 commented 3 months ago

skip_unrolling = False is already covered by all the prior tests. I have added a test for when skip_unrolling = True. Kindly review and let me know the changes.

simeetnayan81 commented 3 months ago

@vfdev-5 Before adding the example in the docstring, I wanted to confirm, to make skip_unrolling effective for the loss function, we might also need to change this. https://github.com/pytorch/ignite/blob/master/ignite/metrics/loss.py#L77 Prev:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
    ):
        super(Loss, self).__init__(output_transform, device=device)

Change to:

def __init__(
        self,
        loss_fn: Callable,
        output_transform: Callable = lambda x: x,
        batch_size: Callable = len,
        device: Union[str, torch.device] = torch.device("cpu"),
        skip_unrolling=False
    ):
        super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
vfdev-5 commented 3 months ago

@simeetnayan81 yes, you are right, we need to add this new arg to all metrics defining a constructor. Let's update Loss metric here and update other metrics in a follow-up PR.

simeetnayan81 commented 3 months ago

Things to do in a follow-up PR.

vfdev-5 commented 3 months ago

Thanks for the updates and the TODO. Can we do this point here ?

Add test for updated Loss class

simeetnayan81 commented 3 months ago

Alright @vfdev-5

simeetnayan81 commented 3 months ago

Have made the changes, the new test works locally.

simeetnayan81 commented 3 months ago

The test is failing because list[torch.Tensor, torch.Tensor] is supported on python 3.9 and above. Let me modify it a bit.