Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.04k stars 307 forks source link

实现WGAN时的问题 #52

Open Exusial opened 4 years ago

Exusial commented 4 years ago

实现WGAN时无法创建卷积层。环境是wsl ubuntu,采用pip+git方式安装。

import jittor as jt
import numpy as np
from jittor import nn,Module,init
from jittor.dataset.mnist import MNIST
import jittor.transform as transform

batch_size = 50
DIM =64
LAMBDA = 10
ITERS = 200000
OUTPUT_DIM = 784
transform = transform.Compose([
    transform.Resize(size=(28,28)),
    transform.Gray(),
    transform.ImageNormalize(mean=[0.5],std=[0.5])
])
train_loader = MNIST(train=True,transform=transform).set_attrs(batch_size=batch_size,shuffle=True)

class Generator(Module):
    def __init__(self,h_dim):
        super(Generator,self).__init__()
        self.preprocess = nn.Sequential(
            nn.Linear(128, 4 * 4 * 4 * DIM),
            nn.ReLU(),
        )
        self.block1 = nn.Sequential(
            nn.ConvTranspose(4*DIM, 2*DIM, 5),
            nn.ReLU(),
        )
        self.block2 = nn.Sequential(
            nn.ConvTranspose(2*DIM, DIM, 5),
            nn.ReLU(),
        )
        self.deconv_out = nn.ConvTranspose(DIM, 1, 8, stride=2)
        self.sigmoid = nn.Sigmoid()

    def execute(self,x):
        output = self.preprocess(x)
        output = output.reshape((-1,4*DIM,4,4))
        output = self.block1(output)
        output = output[:,:,:7,:7]
        output = self.block2(output)
        output = self.deconv_out(output)
        output = self.sigmoid(output)
        return output.reshape((-1,OUTPUT_DIM))

class Discriminator(Module):
    def __init__(self,h_dim):
        super(Discriminator,self).__init__()
        self.main = nn.Sequential(
            nn.Conv(1, DIM, 5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv(DIM, 2*DIM, 5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv(2*DIM, 4*DIM, 5, stride=2, padding=2),
            nn.ReLU(),
        )
        self.linear = nn.Linear(4*4*4*DIM,1)

    def execute(self,x):
        x = x.reshape((-1,1,28,28))
        output = self.main(x)
        output = output.reshape((-1,4*4*4*DIM))
        output = self.linear(output)
        return output.reshape((-1,))

def grad_penalty(D,xr,xf):
    xf = xf.detach()
    xr = xr.detach()
    alpha = jt.random((batch_size,1))
    alpha = alpha.reindex(xr.shape,['i0','0'])
    interpolates = alpha*xr + (1-alpha) * xf
    disc_interpolates = D(interpolates)
    grad = jt.grad(disc_interpolates,interpolates)
    N,P = grad.shape
    gp_norm = jt.zeros((N),'float')
    one = jt.ones((OUTPUT_DIM),'float')
    for idx in range(0,N):
        gp_norm[idx] = jt.sqrt(((grad[idx]-one)**2).sum())
    gp = gp_norm.mean()*LAMBDA
    return gp

G = Generator(DIM)
D = Discriminator(DIM)
G_optim = nn.Adam(G.parameters(),lr=1e-4,betas=(0.5,0.9))
D_optim = nn.Adam(D.parameters(),lr=1e-4,betas=(0.5,0.9))

def train(epoch):
    for batch_idx,(x_,target) in enumerate(train_loader):
        for i in range(0,5):
            mini_batch = x_.shape[0]
            x_ = x_.reshape((batch_size,OUTPUT_DIM))
            D_result = D(x_)
            loss_t = -D_result.mean()
            z = init.gauss((mini_batch,128),'float')
            G_result = G(z).detach()
            D_G_result = D(G_result)
            loss_f = D_G_result.mean()
            loss_D = loss_t+loss_f+grad_penalty(D,x_,G_result)
            loss_D.sync()
            D_optim.step(loss_D)
        z = init.gauss((mini_batch,128),'float')
        G_result = G(z)
        D_G_result = D(G_result)
        loss_G = - D_G_result.mean()
        loss_G.sync()
        G_optim.step(loss_G)
    if epoch%5 == 0:
        print("train:epoch",epoch)

for iter in range(0,ITERS):
    train(iter)
cjld commented 4 years ago

谢谢您的反馈! 问题可以复现,会尽快修复。

cjld commented 4 years ago

您好,该问题在最新的master分支已经修复。 修复commit:https://github.com/Jittor/jittor/commit/602d7056097d074d78ca52fc96e47fc101b9bb8c