Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.08k stars 311 forks source link

转换pytorch代码时显示jittor尚不支持 #127

Open githubyaww opened 4 years ago

githubyaww commented 4 years ago

以下是转换出错的代码: pytorch_code="""

def conv_params(in_size, out_size): filters = [3,2,5,4] strides = [1,2,3] # max_stride = 3 pads = [0,1,2,3] # max pad

if out_size == 1:
    return 1, 0, in_size

for filter_size in filters:
    for pad in pads:
        for stride in strides:
            if ((out_size - 1) * stride == (in_size - filter_size) + 2 * pad):
                return stride, pad, filter_size
return None, None, None

class StdConv(nn.Module): def init(self, nin, nout, filter_size=3, stride=2, padding=1, drop=0.1): super().init() self.conv = nn.Conv2d(nin, nout, filter_size, stride=stride, padding=padding) self.bn = nn.BatchNorm2d(nout) self.drop = nn.Dropout(drop)

def forward(self, x): 
    return self.drop(self.bn(F.relu(self.conv(x))))

def flatten_conv(x,k): bs,nf,gx,gy = x.size() x = x.permute(0,2,3,1).contiguous() #permute转换维度 contiguous转换到一整块内存 return x.view(bs,-1,nf//k)

class OutConv(nn.Module): def init(self, k, nin, num_classes, bias): super().init() self.k = k self.oconv1 = nn.Conv2d(nin, (numclasses)*k, 3, padding=1) self.oconv1.bias.data.zero().add_(bias) self.oconv2 = nn.Conv2d(nin, 4*k, 3, padding=1)

def forward(self, x):
    return [flatten_conv(self.oconv1(x), self.k),
            flatten_conv(self.oconv2(x), self.k)]

class SSDHead(nn.Module): def init(self, grids, anchors_per_cell, num_classes, drop=0.3, bias=-4.,nin=2048): super().init() self.bn = nn.BatchNorm2d(nin) self.drop = nn.Dropout(drop)

    self.sconvs = nn.ModuleList([])
    self.oconvs = nn.ModuleList([])

    self.anc_grids = grids

    self._k = anchors_per_cell

    self.sconvs.append(StdConv(nin, 256, stride=1, drop=drop))

    for i in range(len(grids)):

        if i == 0:
            stride, pad, filter_size = conv_params(7, grids[i]) # get '7' by base model
        else:
            stride, pad, filter_size = conv_params(grids[i-1], grids[i])

        if stride is None:
            print(grids[i-1], ' --> ', grids[i])
            raise Exception('cannot create model for specified grids')

        self.sconvs.append(StdConv(256, 256, filter_size, stride=stride, padding=pad, drop=drop))
        self.oconvs.append(OutConv(self._k, 256, num_classes=num_classes, bias=bias))

def forward(self, x):
    x = self.drop(self.bn(F.relu(x)))
    x = self.sconvs[0](x)
    out_classes = []
    out_bboxes = []
    for sconv, oconv in zip(self.sconvs[1:], self.oconvs):
        x = sconv(x)
        out_class, out_bbox = oconv(x)
        out_classes.append(out_class)
        out_bboxes.append(out_bbox)

    return [torch.cat(out_classes, dim=1),
            torch.cat(out_bboxes, dim=1)]

def one_hot_embedding(labels, num_classes): return torch.eye(num_classes)[labels.data.cpu()]

class BCE_Loss(nn.Module): def init(self, num_classes): super().init() self.num_classes = num_classes

def forward(self, pred, targ): 
    t = one_hot_embedding(targ, self.num_classes)
    t = torch.Tensor(t[:,1:].contiguous()).cuda()
    x = pred[:,1:]
    #x = x.sigmoid().clamp(min=0.0001,max=1.0).detach()
    w = self.get_weight(x,t)
    return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)/(self.num_classes-1)

def get_weight(self,x,t): return None

class FocalLoss(BCE_Loss): def get_weight(self,x,t): alpha,gamma = 0.25,2 p = x.sigmoid().clamp(min=0.0001,max=1.0) p = x pt = pt + (1-p)(1-t) w = alphat + (1-alpha)(1-t) w = w * (1-pt).pow(gamma) return w.detach()

""" jittor_code = convert(pytorch_code) print(jittor_code)

报错提示

RuntimeError Traceback (most recent call last)

in 117 118 """ --> 119 jittor_code = convert(pytorch_code) 120 print(jittor_code) ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in convert(code) 417 ''' 418 a = ast.parse(code) --> 419 dfs(a) 420 a.body.insert(0, ast.parse('import jittor as jt').body[0]) 421 if 'init' not in import_flag: ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in dfs(a) 609 delete_flag = [] 610 for i,a_ in enumerate(a.__dict__[k]): --> 611 ret = dfs(a_) 612 if ret is 'delete': 613 delete_flag.append(True) ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in dfs(a) 609 delete_flag = [] 610 for i,a_ in enumerate(a.__dict__[k]): --> 611 ret = dfs(a_) 612 if ret is 'delete': 613 delete_flag.append(True) ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in dfs(a) 609 delete_flag = [] 610 for i,a_ in enumerate(a.__dict__[k]): --> 611 ret = dfs(a_) 612 if ret is 'delete': 613 delete_flag.append(True) ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in dfs(a) 619 a.__dict__[k] = tmp 620 else: --> 621 ret = dfs(a.__dict__[k]) 622 if ret is not None: 623 a.__dict__[k] = ret ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in dfs(a) 585 func_name = func[-1] 586 if func_name in unsupport_ops: --> 587 raise_unsupport(func_name) 588 if func_name in pjmap.keys(): 589 ags = [astunparse.unparse(ag).strip('\n') for ag in a.args] ~/python/envs/jittor/lib/python3.8/site-packages/jittor/utils/pytorch_converter.py in raise_unsupport(name) 375 376 def raise_unsupport(name): --> 377 raise RuntimeError(f'{name} is not supported in Jittor yet. We will appreciate it if you provide an implementation of {name} and make pull request at https://github.com/Jittor/jittor.') 378 379 def replace(a): RuntimeError: ModuleList is not supported in Jittor yet. We will appreciate it if you provide an implementation of ModuleList and make pull request at https://github.com/Jittor/jittor.
Jittor commented 4 years ago

感谢您的反馈。您可以试一下下面的代码~

import jittor as jt
from jittor import init
from jittor import nn
import numpy as np

def conv_params(in_size, out_size):
    filters = [3, 2, 5, 4]
    strides = [1, 2, 3]
    pads = [0, 1, 2, 3]
    if (out_size == 1):
        return (1, 0, in_size)
    for filter_size in filters:
        for pad in pads:
            for stride in strides:
                if (((out_size - 1) * stride) == ((in_size - filter_size) + (2 * pad))):
                    return (stride, pad, filter_size)
    return (None, None, None)

class StdConv(nn.Module):

    def __init__(self, nin, nout, filter_size=3, stride=2, padding=1, drop=0.1):
        super().__init__()
        self.conv = nn.Conv(nin, nout, filter_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm(nout)
        self.drop = nn.Dropout(drop)

    def execute(self, x):
        return self.drop(self.bn(nn.relu(self.conv(x))))

def flatten_conv(x, k):
    (bs, nf, gx, gy) = x.shape
    x = x.permute((0, 2, 3, 1))
    return x.view((bs, (- 1), (nf // k)))

class OutConv(nn.Module):

    def __init__(self, k, nin, num_classes, bias):
        super().__init__()
        self.k = k
        self.oconv1 = nn.Conv(nin, (num_classes * k), 3, padding=1)
        init.constant_(self.oconv1.bias, bias)
        self.oconv2 = nn.Conv(nin, (4 * k), 3, padding=1)

    def execute(self, x):
        return [flatten_conv(self.oconv1(x), self.k), flatten_conv(self.oconv2(x), self.k)]

class SSDHead(nn.Module):

    def __init__(self, grids, anchors_per_cell, num_classes, drop=0.3, bias=(- 4.0), nin=2048):
        super().__init__()
        self.bn = nn.BatchNorm(nin)
        self.drop = nn.Dropout(drop)
        self.sconvs = nn.ModuleList([])
        self.oconvs = nn.ModuleList([])
        self.anc_grids = grids
        self._k = anchors_per_cell
        self.sconvs.append(StdConv(nin, 256, stride=1, drop=drop))
        for i in range(len(grids)):
            if (i == 0):
                (stride, pad, filter_size) = conv_params(7, grids[i])
            else:
                (stride, pad, filter_size) = conv_params(grids[(i - 1)], grids[i])
            if (stride is None):
                print(grids[(i - 1)], ' --> ', grids[i])
                raise Exception('cannot create model for specified grids')
            self.sconvs.append(StdConv(256, 256, filter_size, stride=stride, padding=pad, drop=drop))
            self.oconvs.append(OutConv(self._k, 256, num_classes=num_classes, bias=bias))

    def execute(self, x):
        x = self.drop(self.bn(nn.relu(x)))
        x = self.sconvs[0](x)
        out_classes = []
        out_bboxes = []
        for (sconv, oconv) in zip(self.sconvs[1:], self.oconvs):
            x = sconv(x)
            (out_class, out_bbox) = oconv(x)
            out_classes.append(out_class)
            out_bboxes.append(out_bbox)
        return [jt.contrib.concat(out_classes, dim=1), jt.contrib.concat(out_bboxes, dim=1)]

def one_hot_embedding(labels, num_classes):
    return np.eye(num_classes)[labels.data]

class BCE_Loss(nn.Module):

    def __init__(self, num_classes):
        super().__init__()
        self.num_classes = num_classes

    def execute(self, pred, targ):
        t = one_hot_embedding(targ, self.num_classes)
        t = jt.array(t[:, 1:])
        x = pred[:, 1:]
        w = self.get_weight(x, t)
        return (nn.binary_cross_entropy_with_logits(x, t, w, size_average=False) / (self.num_classes - 1))

    def get_weight(self, x, t):
        return None

class FocalLoss(BCE_Loss):

    def get_weight(self, x, t):
        (alpha, gamma) = (0.25, 2)
        p = x.sigmoid().clamp(0.0001, 1.0)
        p = x
        pt = ((p * t) + ((1 - p) * (1 - t)))
        w = ((alpha * t) + ((1 - alpha) * (1 - t)))
        w = (w * (1 - pt).pow(gamma))
        return w.detach()