Closed jweihe closed 1 month ago
dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) dummy_image_features_1 = self.mm_projector(dummy_image_features_1) dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1) use_im_start_end = True new_input_embeds = [] for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): if (cur_input_ids == im_patch_token).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() new_input_embeds.append(cur_input_embeds) continue