NickleDave / visual-search-nets

neural network models of visual search behavior
BSD 3-Clause "New" or "Revised" License
6 stars 2 forks source link

`onehot_from_class_ints` raises RuntimeError about index/src dimension mismatch #89

Closed NickleDave closed 3 years ago

NickleDave commented 3 years ago

when trying to run an experiments with the VSD dataset

only when it runs the validation step, I get a RuntimeError

  File "/home/bart/Documents/repos/coding/L2M/visual-search-nets/src/searchnets/transforms/functional.py", line 229, in onehot_from_class_ints
    return onehot.scatter_(0, torch.LongTensor(class_ints), 1)
RuntimeError: Expected index [30] to be smaller than self [20] apart from dimension 0 and to be smaller size than src [20]

looking at the transforms used to refresh my memory, I see that for the VSD dataset, transforms.util.get_transforms returns the following:

            target_transform = vis_transforms.Compose([
                transforms.ParseVocXml(),  # 
                transforms.ClassIntsFromXml(),  # 
                transforms.OneHotFromClassInts(),
            ])

which does this:

with debug I see what's happening is that for a certain image annotation, there are multiple instances of a some classes, which causes the number of 'class_ints' for this particular image to be greater than the total number of classes`

(Pdb) class_ints
[6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 5, 13, 13, 13, 13, 14, 14, 13, 14, 14, 14, 14]
(Pdb) n_classes
20
(Pdb) n_classes
20
(Pdb) onehot.shape
torch.Size([20])
(Pdb) onehot
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
(Pdb) onehot.scatter_(0, torch.LongTensor(class_ints), 1)
*** RuntimeError: Expected index [30] to be smaller than self [20] apart from dimension 0 and to be smaller size than src [20]
(Pdb) class_ints.shape

i.e. for this particular image '/home/bart/Documents/data/voc/VOCdevkit/VOC2012/Annotations/2008_005139.xml' there are multiple cars, persons, etc., whatever the other classes are

when I re-run it's always this particular image that's causing the bug

so I get why the bug is happening, but why didn't it happen before?

one explanation would be that the functions I'm using to create the one-hot vector changed in newer versions of torch. I'm using scatter_ (the "in-place" version) to generate the Tensor. I dug through the torch CHANGELOG and didn't find any obvious change / bugfix that could account for this.

I can avoid the bug if I just remove duplicates from the class ints:

(Pdb) class_ints_uniq = list(set(class_ints))
(Pdb) onehot.scatter_(0, torch.LongTensor(class_ints_uniq), 1)
tensor([0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.,
        0., 0.])
(Pdb) class_ints_uniq
[13, 5, 6, 14]

which for the purposes of these experiments would be fine -- even for multi-label classification, all we care about is the presence or absence of the object that should be labeled, not how many there are

I think this was happening at some point? That we "uniquify" the class ints?

I haven't regenerated the dataset splits so this one file that causes the crash during the validation step should always have been causing the crash?

I kind of don't care since I'm only doing single-label classification for these final experiments and this transform only gets applied "incidentally" -- the actual target we use for training is "what is the largest object, call that the target" -- but if this causes crashes then nothing will run (all the transforms get computed for each batch ... accident of history)

so the quick fix is just to somehow uniquify the class ints

I guess I can add a boolean flag 'unique' to the transforms.ClassIntsFromXml class and corresponding function. At least that way if I need to later in theory I could make it possible for a user to specify whether or not to make the ints unique. The other way would be to just always do it inside the OneHotFromClassInts class -- since this transform requires that the ints be unique to not crash. I guess actually the latter way is less complicated for now. Doing it the latter way