yoyohonyang / LearingFaceAgeProgression

Learning Face Age Progression: A Pyramid Architecture of GANs, CVPR 2018
74 stars 14 forks source link

Running on CPU #6

Open MaggieMC opened 5 years ago

MaggieMC commented 5 years ago

Hi,

Since the default testing script runs on GPU, is there a way to easily change this to run only on CPU?

Thank you

ccsvd commented 3 years ago

hi,do you know how to do it?

codingpy commented 2 years ago

Try this (at least you can run with PyTorch instead of Lua):

import inspect

import torch

import torch.nn as nn

# https://github.com/bshillingford/python-torchfile

import torchfile

import numpy as np

from PIL import Image

class ConcatTable(nn.Module):
    def forward(self, x):
        y = list()

        for module in self.children():
            y.append(module(x))

        return y

class ShaveImage(nn.Module):
    def __init__(self, size):
        super().__init__()

        self.size = size

    def forward(self, x):
        h = x.size(2)
        w = x.size(3)
        s = self.size

        x = x[..., s:h-s, s:w-s]
        return x

class CAddTable(nn.Module):
    def forward(self, x):
        return sum(x)

class MulConstant(nn.Module):
    def __init__(self, constant_scalar):
        super().__init__()

        self.constant_scalar = constant_scalar

    def forward(self, x):
        return x * self.constant_scalar

class Torch7:
    components = dict()

    components.update({
        'nn.Sequential': nn.Sequential,
        'nn.ConcatTable': ConcatTable,
    })

    @classmethod
    def register(cls, typename):
        def decorator(func):
            cls.components[typename] = func
            return func

        return decorator

    def convert(self, thobject):
        typename = thobject._typename
        typename = typename.decode()

        callable = self.components[typename]

        # Torch7 -> PyTorch
        if inspect.isclass(callable):
            net = callable()

            for tho in thobject.modules:
                module = self.convert(tho)
                string = str(hash(module))

                net.add_module(string, module)

            return net
        else:
            module = callable(thobject)
            return module

@Torch7.register('nn.SpatialReflectionPadding')
def func(thobject):
    return nn.ReflectionPad2d((
        thobject.pad_l,
        thobject.pad_r,
        thobject.pad_t,
        thobject.pad_b,
    ))

@Torch7.register('nn.SpatialConvolution')
def func(thobject):
    module = nn.Conv2d(
        thobject.nInputPlane,
        thobject.nOutputPlane,
        (thobject.kH, thobject.kW),
        stride=(thobject.dH, thobject.dW),
        padding=(thobject.padH, thobject.padW)
    )

    module.load_state_dict({
        'weight': torch.Tensor(thobject.weight),
        'bias': torch.Tensor(thobject.bias),
    })

    return module

@Torch7.register('nn.InstanceNormalization')
def func(thobject):
    module = nn.InstanceNorm2d(
        thobject.nOutput,
        eps=thobject.eps,
        affine=True,
        track_running_stats=True
    )

    prev_N = thobject.prev_N

    running_mean = thobject.bn.running_mean.reshape(prev_N, -1).mean(axis=0)
    running_var = thobject.bn.running_var.reshape(prev_N, -1).mean(axis=0)

    module.load_state_dict({
        'weight': torch.Tensor(thobject.weight),
        'bias': torch.Tensor(thobject.bias),
        'running_mean': torch.Tensor(running_mean),
        'running_var': torch.Tensor(running_var),
        'num_batches_tracked': torch.tensor(prev_N),
    })

    return module

@Torch7.register('nn.ReLU')
def func(thobject):
    return nn.ReLU(inplace=thobject.inplace)

@Torch7.register('nn.ShaveImage')
def func(thobject):
    return ShaveImage(thobject.size)

@Torch7.register('nn.CAddTable')
def func(thobject):
    return CAddTable()

@Torch7.register('nn.SpatialFullConvolution')
def func(thobject):
    module = nn.ConvTranspose2d(
        thobject.nInputPlane,
        thobject.nOutputPlane,
        (thobject.kH, thobject.kW),
        stride=(thobject.dH, thobject.dW),
        padding=(thobject.padH, thobject.padW),
        output_padding=(thobject.adjH, thobject.adjW)
    )

    module.load_state_dict({
        'weight': torch.Tensor(thobject.weight),
        'bias': torch.Tensor(thobject.bias),
    })

    return module

@Torch7.register('nn.Tanh')
def func(thobject):
    return nn.Tanh()

@Torch7.register('nn.MulConstant')
def func(thobject):
    return MulConstant(thobject.constant_scalar)

@Torch7.register('nn.TotalVariation')
def func(thobject):
    # only used for inference
    return nn.Identity()

class FaceAging:
    def __init__(self):
        # VGG
        self.mean = torch.Tensor([[[103.939]], [[116.779]], [[123.680]]])

    def generate(self, x):
        # evaluation mode
        self.net.eval()

        # numpy uint8 -> torch float
        x = torch.Tensor(x)

        # channels last -> channels first
        x = x.permute(0, 3, 1, 2)

        # RGB -> BGR
        x = x.flip(1)

        with torch.no_grad():
            x -= self.mean

            x = self.net(x)

            x += self.mean

        # BGR -> RGB
        x = x.flip(1)

        # channels first -> channels last
        x = x.permute(0, 2, 3, 1)

        # torch float -> numpy uint8
        x = x.numpy()

        x = np.uint8(x)
        return x

    def load(self, path):
        if path.endswith('.t7'):
            converter = Torch7()

            net = torchfile.load(path)
            net = net.netG
            net = converter.convert(net)
        else:
            net = torch.load(path)

        self.net = net

    def save(self, path):
        torch.save(self.net, path)

if __name__ == '__main__':
    model = FaceAging()

    model.load('CACD_Aging.t7')

    # 224 x 224 x 3
    image = \
        Image.open('LearingFaceAgeProgression/data/CACD/input/1151_0010_30.jpg')
    image = image.convert('RGB')

    image = np.array(image)

    x = np.expand_dims(image, axis=0)

    x = model.generate(x)

    image = np.squeeze(x)

    image = Image.fromarray(image)
    image.show()
ccsvd commented 2 years ago

这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。

ccsvd commented 2 years ago

这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。