Hello, this is my first issue so please bear with me.
Tldr: Should model.py use torch.tensor() instead of torch.Tensor()?
Explanation: I was adapting the reference_metric tutorial to use Pytorch when I encountered a compatibility issue between the passed tensor and the loss function when creating the Audit object. In particular, the loss function nn.CrossEntropyLoss() required a torch.int64 tensor while it was receiving a torch.float32 tensor. This is despite passing a dtype('int64') numpy array as the value to the y key in the Dataset object used in the InfromationSources in the Audit constructor. I traced back the error and it seems to stem from model.py. I noticed that model.py uses torch.Tensor() (which creates a torch.FloatTensor) instead of torch.tensor() (which infers the dtypeof the tensor automatically). After replacing the code with torch.tensor(), the program ran as expected.
Is there a different way of making this work without editing the source? Or is torch.Tensor() just supposed to be torch.tensor()? Thanks!
Hello, this is my first issue so please bear with me.
Tldr: Should
model.py
usetorch.tensor()
instead oftorch.Tensor()
?Explanation: I was adapting the reference_metric tutorial to use Pytorch when I encountered a compatibility issue between the passed tensor and the loss function when creating the Audit object. In particular, the loss function
nn.CrossEntropyLoss()
required atorch.int64
tensor while it was receiving atorch.float32
tensor. This is despite passing adtype('int64')
numpy array as the value to they
key in the Dataset object used in theInfromationSource
s in the Audit constructor. I traced back the error and it seems to stem frommodel.py
. I noticed thatmodel.py
usestorch.Tensor()
(which creates atorch.FloatTensor
) instead oftorch.tensor()
(which infers thedtype
of the tensor automatically). After replacing the code withtorch.tensor()
, the program ran as expected.Is there a different way of making this work without editing the source? Or is
torch.Tensor()
just supposed to betorch.tensor()
? Thanks!