johndpope / Emote-hack

Emote Portrait Alive - using ai to reverse engineer code from white paper. (abandoned)
https://github.com/johndpope/VASA-1-hack
173 stars 9 forks source link

freezing the FramesEncoderVAE - claudeai #25

Closed johndpope closed 8 months ago

johndpope commented 8 months ago

Screenshot from 2024-03-23 15-59-04 Screenshot from 2024-03-23 16-01-06

Screenshot from 2024-03-23 16-01-13

BEFORE


class FramesEncodingVAE(nn.Module):
    def __init__(self, config):
        super(FramesEncodingVAE, self).__init__()
        self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        self.vae.to(device)  # Move the model to the appropriate device (e.g., GPU)
        self.img_size = config.data.train_height

        self.speed_encoder = SpeedEncoder(config.num_speed_buckets, config.speed_embedding_dim)

        # Create a dummy input tensor to infer the number of latent channels
        dummy_input = torch.randn(1, 3, config.data.train_height, config.data.train_height).to(device)  # Move the dummy input to the same device as the VAE
        with torch.no_grad():
            latent_vector = self.vae.encode(dummy_input).latent_dist.sample()
        latent_channels = latent_vector.shape[1]  # Get the number of latent channels from the second dimension

        self.reference_net = ReferenceNet(self.vae, self.speed_encoder, config, latent_channels)

    def forward(self, reference_image, motion_frames, speed_value):
        # Encode reference and motion frames
        reference_latents = self.vae.encode(reference_image).latent_dist.sample()
        motion_latents = self.vae.encode(motion_frames).latent_dist.sample()

        # Scale the latent vectors (optional, depends on the VAE scaling factor)
        reference_latents = reference_latents * 0.18215
        motion_latents = motion_latents * 0.18215

        # Process reference features with ReferenceNet
        reference_features = self.reference_net(reference_latents, motion_latents, speed_value)

        # Embed speed value
        speed_embedding = self.speed_encoder(speed_value)

        # Combine features
        combined_features = torch.cat([reference_features, motion_latents, speed_embedding], dim=1)

        # Decode the combined features
        reconstructed_frames = self.vae.decode(combined_features).sample

        return reconstructed_frames

    def vae_loss(self, recon_frames, reference_image, motion_frames):
        # Compute VAE loss using the VAE's loss function
        loss = self.vae.loss_function(recon_frames, torch.cat([reference_image, motion_frames], dim=1))
        return loss["loss"]

REDUCED TO

class FramesEncodingVAE(nn.Module):
    def __init__(self, config):
        super(FramesEncodingVAE, self).__init__()
        self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
        self.vae.to(device)
        self.img_size = config.data.train_height

    def forward(self, reference_image, motion_frames):
        # Encode reference and motion frames
        reference_latents = self.vae.encode(reference_image).latent_dist.sample()
        motion_latents = self.vae.encode(motion_frames).latent_dist.sample()

        # Scale the latent vectors (optional, depends on the VAE scaling factor)
        reference_latents = reference_latents * 0.18215
        motion_latents = motion_latents * 0.18215

        # Combine the reference and motion latents
        combined_latents = torch.cat([reference_latents, motion_latents], dim=1)

        # Decode the combined latents
        reconstructed_frames = self.vae.decode(combined_latents).sample

        return reconstructed_frames

    def vae_loss(self, recon_frames, reference_image, motion_frames):
        # Compute VAE loss using the VAE's loss function
        loss = self.vae.loss_function(recon_frames, torch.cat([reference_image, motion_frames], dim=1))
        return loss["loss"]
johndpope commented 8 months ago

i think can throw away this class

in training_stage_1.py I will simply pass the reference image / motion frames through the vae (which is already frozen model) self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")

this reduces the size of images from 512 x 512 -> 64x64.

and then the throw through referencenet https://github.com/johndpope/Emote-hack/blob/main/train_stage_1_0.py#L144

https://github.com/huggingface/diffusers/issues/3726

 posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample