arthurdouillard / incremental_learning.pytorch

A collection of incremental learning paper implementations including PODNet (ECCV20) and Ghost (CVPR-W21).
MIT License
388 stars 60 forks source link

where can I change the targets value to torch.long #43

Closed bilan-elaine-gao closed 3 years ago

bilan-elaine-gao commented 3 years ago

I keep having this: "tensors used as indices must be long, byte or bool tensors" lib/losses/base.py line 77

Guess I need to change the target tensor dtype from torch.int32 to torch.int64.

But I can find the origin place that variable targets was set...

arthurdouillard commented 3 years ago

You probably have this error because you're using a PyTorch version higher than 1.2 right? My code was tested with 1.2 and may not work as well with a later version.

Weird that the error comes from line 77 (https://github.com/arthurdouillard/incremental_learning.pytorch/blob/master/inclearn/lib/losses/base.py#L77), there is no indexing done at this place. It may arised from a few lines before here https://github.com/arthurdouillard/incremental_learning.pytorch/blob/master/inclearn/lib/losses/base.py#L73 .

Thus if you change this line 73 by:

margins[torch.arange(margins.shape[0]).long(), targets.long()] = margin

It should work, although it's a bit ugly. Can you confirm? :)

bilan-elaine-gao commented 3 years ago

Hi, sorry for the late response.

Yes, my PyTorch version is 1.8

And I tried with you code suggestion, emmm it didn't work... got same error on line 81 "targets] = similarities[torch.arange(len(similarities)), targets]"

And I added one line above inclearn/models/icarl.py line 307 (https://github.com/arthurdouillard/incremental_learning.pytorch/blob/0d25c2e12bde4a4a25f81d5e316751c90e6f789b/inclearn/models/icarl.py#L307)

"targets=torch.tensor(targets, dtype=torch.long)"

And it worked!

arthurdouillard commented 3 years ago

Great!

Yes PyTorch started in later versions to be more restrictive on the type of the tensors used for indexing, so in that case you'd just need a .long() or .bool().