Open MaggieMC opened 5 years ago
hi,do you know how to do it?
Try this (at least you can run with PyTorch
instead of Lua
):
import inspect
import torch
import torch.nn as nn
# https://github.com/bshillingford/python-torchfile
import torchfile
import numpy as np
from PIL import Image
class ConcatTable(nn.Module):
def forward(self, x):
y = list()
for module in self.children():
y.append(module(x))
return y
class ShaveImage(nn.Module):
def __init__(self, size):
super().__init__()
self.size = size
def forward(self, x):
h = x.size(2)
w = x.size(3)
s = self.size
x = x[..., s:h-s, s:w-s]
return x
class CAddTable(nn.Module):
def forward(self, x):
return sum(x)
class MulConstant(nn.Module):
def __init__(self, constant_scalar):
super().__init__()
self.constant_scalar = constant_scalar
def forward(self, x):
return x * self.constant_scalar
class Torch7:
components = dict()
components.update({
'nn.Sequential': nn.Sequential,
'nn.ConcatTable': ConcatTable,
})
@classmethod
def register(cls, typename):
def decorator(func):
cls.components[typename] = func
return func
return decorator
def convert(self, thobject):
typename = thobject._typename
typename = typename.decode()
callable = self.components[typename]
# Torch7 -> PyTorch
if inspect.isclass(callable):
net = callable()
for tho in thobject.modules:
module = self.convert(tho)
string = str(hash(module))
net.add_module(string, module)
return net
else:
module = callable(thobject)
return module
@Torch7.register('nn.SpatialReflectionPadding')
def func(thobject):
return nn.ReflectionPad2d((
thobject.pad_l,
thobject.pad_r,
thobject.pad_t,
thobject.pad_b,
))
@Torch7.register('nn.SpatialConvolution')
def func(thobject):
module = nn.Conv2d(
thobject.nInputPlane,
thobject.nOutputPlane,
(thobject.kH, thobject.kW),
stride=(thobject.dH, thobject.dW),
padding=(thobject.padH, thobject.padW)
)
module.load_state_dict({
'weight': torch.Tensor(thobject.weight),
'bias': torch.Tensor(thobject.bias),
})
return module
@Torch7.register('nn.InstanceNormalization')
def func(thobject):
module = nn.InstanceNorm2d(
thobject.nOutput,
eps=thobject.eps,
affine=True,
track_running_stats=True
)
prev_N = thobject.prev_N
running_mean = thobject.bn.running_mean.reshape(prev_N, -1).mean(axis=0)
running_var = thobject.bn.running_var.reshape(prev_N, -1).mean(axis=0)
module.load_state_dict({
'weight': torch.Tensor(thobject.weight),
'bias': torch.Tensor(thobject.bias),
'running_mean': torch.Tensor(running_mean),
'running_var': torch.Tensor(running_var),
'num_batches_tracked': torch.tensor(prev_N),
})
return module
@Torch7.register('nn.ReLU')
def func(thobject):
return nn.ReLU(inplace=thobject.inplace)
@Torch7.register('nn.ShaveImage')
def func(thobject):
return ShaveImage(thobject.size)
@Torch7.register('nn.CAddTable')
def func(thobject):
return CAddTable()
@Torch7.register('nn.SpatialFullConvolution')
def func(thobject):
module = nn.ConvTranspose2d(
thobject.nInputPlane,
thobject.nOutputPlane,
(thobject.kH, thobject.kW),
stride=(thobject.dH, thobject.dW),
padding=(thobject.padH, thobject.padW),
output_padding=(thobject.adjH, thobject.adjW)
)
module.load_state_dict({
'weight': torch.Tensor(thobject.weight),
'bias': torch.Tensor(thobject.bias),
})
return module
@Torch7.register('nn.Tanh')
def func(thobject):
return nn.Tanh()
@Torch7.register('nn.MulConstant')
def func(thobject):
return MulConstant(thobject.constant_scalar)
@Torch7.register('nn.TotalVariation')
def func(thobject):
# only used for inference
return nn.Identity()
class FaceAging:
def __init__(self):
# VGG
self.mean = torch.Tensor([[[103.939]], [[116.779]], [[123.680]]])
def generate(self, x):
# evaluation mode
self.net.eval()
# numpy uint8 -> torch float
x = torch.Tensor(x)
# channels last -> channels first
x = x.permute(0, 3, 1, 2)
# RGB -> BGR
x = x.flip(1)
with torch.no_grad():
x -= self.mean
x = self.net(x)
x += self.mean
# BGR -> RGB
x = x.flip(1)
# channels first -> channels last
x = x.permute(0, 2, 3, 1)
# torch float -> numpy uint8
x = x.numpy()
x = np.uint8(x)
return x
def load(self, path):
if path.endswith('.t7'):
converter = Torch7()
net = torchfile.load(path)
net = net.netG
net = converter.convert(net)
else:
net = torch.load(path)
self.net = net
def save(self, path):
torch.save(self.net, path)
if __name__ == '__main__':
model = FaceAging()
model.load('CACD_Aging.t7')
# 224 x 224 x 3
image = \
Image.open('LearingFaceAgeProgression/data/CACD/input/1151_0010_30.jpg')
image = image.convert('RGB')
image = np.array(image)
x = np.expand_dims(image, axis=0)
x = model.generate(x)
image = np.squeeze(x)
image = Image.fromarray(image)
image.show()
这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。
这是来自QQ邮箱的假期自动回复邮件。 您好,我最近正在休假中,无法亲自回复您的邮件。我将在假期结束后,尽快给您回复。
Hi,
Since the default testing script runs on GPU, is there a way to easily change this to run only on CPU?
Thank you