cswry / OSEDiff

121 stars 5 forks source link

when trying to adapt to sdxl I am unable to properly add time_ids to UNet2DConditionModel #14

Open saiFarmer opened 1 month ago

saiFarmer commented 1 month ago
    elif self.config.addition_embed_type == "text_time":
        # SDXL - style
        if "text_embeds" not in added_cond_kwargs:
            raise ValueError(
                f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
            )
        text_embeds = added_cond_kwargs.get("text_embeds")
        if "time_ids" not in added_cond_kwargs:
            raise ValueError(
                f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
            )
        time_ids = added_cond_kwargs.get("time_ids")
        time_embeds = self.add_time_proj(time_ids.flatten())
        time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
        add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
        add_embeds = add_embeds.to(emb.dtype)
        aug_emb = self.add_embedding(add_embeds)

        model_pred = self.unet(
            lq_latent,
            self.timesteps,
            encoder_hidden_states=prompt_embeds,
            added_cond_kwargs={
                "text_embeds": prompt_embeds,
                "time_ids": self.timesteps,
            },
        ).sample

        self.timesteps as placeholder as this is not the correct value
cswry commented 1 month ago

Hello, SDXL and SD2 UNets are distinct, and LORA cannot be directly applied.