Spijkervet / CLMR

Official PyTorch implementation of Contrastive Learning of Musical Representations
https://arxiv.org/abs/2103.09410
Apache License 2.0
309 stars 48 forks source link

moules.linear_evaluation is missing parameters when calling Accuracy and AveragePrecision #24

Open fducau opened 1 year ago

fducau commented 1 year ago

When running the LinearEvaluation module I get the following error:

TypeError: Accuracy.__new__() missing 1 required positional argument: 'task'

and similar when calling torchmetrics.AveragePrecision without the task argument

Proposed solution

        self.accuracy = torchmetrics.Accuracy(
            task="multilabel",
            num_labels=output_dim
        )
        self.average_precision = torchmetrics.AveragePrecision(
            task='multilabel',
            num_labels=output_dim,
            pos_label=1
        )