Hello, @YuliaRubanova!
I found a small bug here. This line
idx_not_nan = 1 - torch.isnan(mortality_label)
causes RuntimeError with PyTorch 1.4.0 because torch.isnan returns torch.BoolTensor, which cannot be subtracted from int.
Everything works if substitute this line with the following:
idx_not_nan = ~torch.isnan(mortality_label)
Hello, @YuliaRubanova! I found a small bug here. This line
idx_not_nan = 1 - torch.isnan(mortality_label)
causes RuntimeError with PyTorch 1.4.0 becausetorch.isnan
returnstorch.BoolTensor
, which cannot be subtracted from int. Everything works if substitute this line with the following:idx_not_nan = ~torch.isnan(mortality_label)