idan-tankel / SemOOD

Apache License 2.0
0 stars 0 forks source link

Use the captioning loss instead of the conditionalGenerationLoss #7

Open idan-tankel opened 11 months ago

idan-tankel commented 11 months ago

Problem description

The model loss uses the caption as instruction (via input_ids). In that case, the model than generate some "blob" of text which is not making any sense, like ['</s> - stock image\n']. The loss is than calculated on that kind of sentence - which is not making any sense as a smart loss to use.

However, there is no BLIP2ForImageCaptioning, and If we pass only the image through the generate function to create a zero shot caption, the loss is not being calculated, only model out as tensor of tokens.

Questions / points to test

if labels is not None:
                labels = labels.to(logits.device)
                logits = logits[:, -labels.size(1) :, :]
                # Shift so that tokens < n predict n
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous().to(logits.device)

                # Flatten the tokens
                loss_fct = CrossEntropyLoss(reduction="mean")

                loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))

in model.forward() method of the huggingface model version

may we do this update manually to the zero shot out by generate? is that a good solution, even that the loss is only on outs and not embeddings?

another observations

when using the generate method, the loss is not intend to be calculated by the external BLIP2 wrapper. Within the internal BLIP2 generation method, the inputs are being modified by prepare_inputs_for_generation, and the labels are out!

idan-tankel commented 11 months ago

https://github.com/idan-tankel/SemOOD/blob/5660d8f90f57eb1c7fdf6939346f39ab8dd772c3/SEED-Bench/evaluator/BLIP2Models.py#L163-L164

Write loss wrapper for generate or to use score

The input_ids that are given to the generate ,instead of the regular forward, are random model parameters concatenated with the image, so trying to create model.forward instance that would not get input_ids is going to be harder than just wrapping the loss for generate

idan-tankel commented 11 months ago

the output by forward are conditioned on the input_ids since:

idan-tankel commented 11 months ago

That being said, what about the hidden states of the vision part? which are some of the sequence elements of the output hidden states? (on Blip2, they are 32 out of 51)