Implementation of Megaportrait using Claude Opus
All models / code is in model.py
(beware when working with claude if you paste copyright ResNet code - it will spit the dummy - especially on .)
memory debug
mprof run train.py
or just
python train.py
warp / crop / spline / remove background / transforms
for now - to simplify problem - use the 4 videos in junk folder. once models are validated - can point the video_dir to above torrent
# video_dir: '/Downloads/CelebV-HQ/celebvhq/35666'
video_dir: './junk'
the preprocessing is taking 1-2 mins for each video - I add some saving to npz format for faster reloading.
You can download the dataset via the provided magnet link or by visiting Academic Torrents.
magnet:?xt=urn:btih:843b5adb0358124d388c4e9836654c246b988ff4&dn=CelebV-HQ&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=https%3A%2F%2Fipv6.academictorrents.com%2Fannounce.php
Gbase
)Eapp
): Encodes the appearance of the source frame into volumetric features and a global descriptor.
class Eapp(nn.Module):
# Architecture details omitted for brevity
Emtn
): Encodes the motion from both source and driving images into head rotations, translations, and latent expression descriptors.
class Emtn(nn.Module):
# Architecture details omitted for brevity
Wsrc_to_can
and Wcan_to_drv
): Removes motion from the source and imposes driver motion onto canonical features.
class WarpGenerator(nn.Module):
# Architecture details omitted for brevity
G3D
): Processes canonical volumetric features.
class G3D(nn.Module):
# Architecture details omitted for brevity
G2D
): Projects 3D features into 2D and generates the output image.
class G2D(nn.Module):
# Architecture details omitted for brevity
GHR
)class EncoderHR(nn.Module):
# Architecture details omitted for brevity
class DecoderHR(nn.Module):
# Architecture details omitted for brevity
Student
)class ResNet18(nn.Module):
# Architecture details omitted for brevity
class SPADEGenerator(nn.Module):
# Architecture details omitted for brevity
class VGG16Backbone(nn.Module):
# Architecture details omitted for brevity
class KeypointNet(nn.Module):
# Architecture details omitted for brevity
class GazeHead(nn.Module):
# Architecture details omitted for brevity
class BlinkHead(nn.Module):
# Architecture details omitted for brevity
train_base(cfg, Gbase, Dbase, dataloader)
: Trains the base model using perceptual, adversarial, and cycle consistency losses.
def train_base(cfg, Gbase, Dbase, dataloader):
# Training code omitted for brevity
train_hr(cfg, GHR, Dhr, dataloader)
: Trains the high-resolution model using super-resolution objectives and adversarial losses.
def train_hr(cfg, GHR, Dhr, dataloader):
# Training code omitted for brevity
train_student(cfg, Student, GHR, dataloader)
: Distills the high-resolution model into a student model for faster inference.
def train_student(cfg, Student, GHR, dataloader):
# Training code omitted for brevity
Implementation:
def main(cfg: OmegaConf) -> None:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter()
])
dataset = EMODataset(
use_gpu=use_cuda,
width=cfg.data.train_width,
height=cfg.data.train_height,
n_sample_frames=cfg.training.n_sample_frames,
sample_rate=cfg.training.sample_rate,
img_scale=(1.0, 1.0),
video_dir=cfg.training.video_dir,
json_file=cfg.training.json_file,
transform=transform
)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4)
Gbase = model.Gbase()
Dbase = model.Discriminator()
train_base(cfg, Gbase, Dbase, dataloader)
GHR = model.GHR()
GHR.Gbase.load_state_dict(Gbase.state_dict())
Dhr = model.Discriminator()
train_hr(cfg, GHR, Dhr, dataloader)
Student = model.Student(num_avatars=100)
train_student(cfg, Student, GHR, dataloader)
torch.save(Gbase.state_dict(), 'Gbase.pth')
torch.save(GHR.state_dict(), 'GHR.pth')
torch.save(Student.state_dict(), 'Student.pth')
if __name__ == "__main__":
config = OmegaConf.load("./configs/training/stage1-base.yaml")
main(config)
rome/losses - cherry picked from https://github.com/SamsungLabs/rome
wget 'https://download.pytorch.org/models/resnet18-5c106cde.pth' extract to state_dicts
git clone https://github.com/Tobias-Fischer/rt_gene.git
cd rt_gene/rt_gene
pip install .