RitaRamo / smallcap

SmallCap: Lightweight Image Captioning Prompted with Retrieval Augmentation
88 stars 17 forks source link

Questions about TrainDataset (src/utils.py L54) #16

Closed junha1125 closed 3 months ago

junha1125 commented 3 months ago

I love your work. Thank you for sharing the code.

I have a question about the input and label loading in the Class-TrainDataset. When self.rag = False, the outputs from the __getitem__ method were as follows:

prefix = 'This image shows'
prefix_ids = [1212, 2939, 2523, 220]
text = caption = 'people stand facing a person sitting on a horse'
text_ids = [15332, 1302, 6476, 257, 1048, 5586, 319, 257, 8223 ...]

input_ids = [1212, 2939, 2523, 220, 15332, 1302, 6476, 257, 1048, 5586, 319, 257, 8223, 0, ...]
label_ids = [-100, -100, -100, 15332, 1302, 6476, 257, 1048, 5586, 319, 257, 8223, 13, -100, ...]

However, according to modeling_gpt2.py#L1330:

loss = None
if labels is not None:
    # move labels to correct device to enable model parallelism
    labels = labels.to(lm_logits.device)
    # Shift so that tokens < n predict n
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

The loss is computed with shift_logits and shift_labels.

Therefore, as I understand it, the input_ids and label_ids should be as follows:

input_ids = [1212, 2939, 2523, 220, 15332, 1302, 6476, 257, 1048, 5586, 319, 257, 8223, 0, ...]
label_ids = [-100, -100, -100, -100, 15332, 1302, 6476, 257, 1048, 5586, 319, 257, 8223, -100, ...]

Note that the positions of '15332, 1302, 6476, ....' should be the same in both variables.

Could you please point out where I might have misunderstood?

junha1125 commented 3 months ago

I understand! It was because of src/vision_encoder_decoder.py L510. Sorry to bother you. Thank you 😊