oke-aditya / quickvision

An Easy To Use PyTorch Computer Vision Library
Apache License 2.0
51 stars 4 forks source link

[RFC] Loss functions in quickvision #43

Open oke-aditya opened 3 years ago

oke-aditya commented 3 years ago

🚀 Feature

Frequently re-used losses that can be added.

Motivation

Writing Losses is quite repetitive. PyTorch supports losses which are written with deep interoperability with C++ API. But most research losses aren't.

These losses are building blocks for other complicates losses as well.

Pitch

A non-exhaustive and probable list of losses that are not PyTorch but used often.

Alternatives

Wait for them to reach into fvcore or PyTorch. Till then we keep duplicating these code for models.

Additional context

Note, if we are re-using implementation from any repo. Please cite them on top of code.

hassiahk commented 3 years ago

Do we include sample_weight and reduction parameters as well for the user to choose? Currently dice_loss does not consider sample_weight and takes reduction = sum as shown below. https://github.com/Quick-AI/quickvision/blob/c6bd28e8d77905c6e4f9538644df5d875af654d2/quickvision/losses/segmentation.py#L6 https://github.com/Quick-AI/quickvision/blob/c6bd28e8d77905c6e4f9538644df5d875af654d2/quickvision/losses/segmentation.py#L21

Below is a part of huber_loss from the implementation you mentioned which considers both sample_weight and reduction.

def huber_loss(input, target, delta: float = 1., weights: Optional[torch.Tensor] = None, size_average: bool = True):
    # rest of the code
    if weights is not None:
        loss *= weights
    return loss.mean() if size_average else loss.sum()

I will start working on this simultaneously, if it is up for grabs. :smile:

oke-aditya commented 3 years ago

Yes ! I was about to tell about these. We include weights and reduction.

Reduction can be none, sum or mean.

P.S. We also keep consistent with torchvision and if it gets a losses API we will support only other losses. We won't duplicate torchvision losses unless necessary.

P.P.S. Up For Grabs !!

hassiahk commented 3 years ago

What should be the structure for these losses? I can think of these three ways:

oke-aditya commented 3 years ago

@zhiqwang you are welcome to contribute 😄

zhiqwang commented 3 years ago

Hi @oke-aditya , I'm looking for anything I can do in this vigorous repo 🚀

oke-aditya commented 3 years ago

Join on Slack here

All the development talks go here ! You can freely communicate your ideas, RFCs and thoughts.

oke-aditya commented 3 years ago

Each loss in separate file is better This keeps abstraction minimal.

If We could create a folder called nn but that would interfere with torch which we can avoid.

Also, should we implement losses as Classes / functions ?

If we implement as classes we should inherit from nn.Module. Otherwise we can continue to use them as functions. This is slight confusion. I see no advantage of one over the other. PyTorch creates them as classes and gives a functional API too.

Eg.

loss_fn = nn.CrossEntropyLoss()
loss_v = F.CrossEntropy(inp, op)

I think we should follow that and provide a functional API too (if possible)

Losses tied to models can be implemented as mixture of both. Since their use will be only with the model. Preferabbly like detr_loss.

I propose to keep API Simple.

==>losses
====>functional
=======>loss_dice.py
====>__init__.py
====> loss_dice.py
====> loss_detr.py

Thoughts @zhiqwang @hassiahk @ramaneswaran ?

hassiahk commented 3 years ago

I agree with implementing losses as both classes and functions since it would be easier for people coming from torch to use them. We can just do this:

def loss_function(input, target):
    return loss

class LossFunction(nn.Module):
    def forward(self, input, target):
        return loss_function(input, target)

I did not get this part. How are we avoiding by naming the folder as nn?

If We could create a folder called nn but that would interfere with torch which we can avoid.

oke-aditya commented 3 years ago

If we call our folder as nn people will have to do from quickvision.nn import loss_fn While nn is something torch uses whose code is bonded to C++ API. So this doesn't serve this purpose. Safer side we can keep losses.

hassiahk commented 3 years ago

Based on whatever we discussed above, I will start working on this and let's see how it goes.

oke-aditya commented 3 years ago

@zhiqwang @hassiahk I think we can start with this for release 2.0, Let;s propsose a template API for this.

hassiahk commented 3 years ago

@oke-aditya, are we implementing all the losses mentioned in the initial comment?

oke-aditya commented 3 years ago

Yes 😀