Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.07k stars 307 forks source link

modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose; add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose #553

Closed fansunqi closed 2 months ago

fansunqi commented 2 months ago
  1. modify stride positive check in jt.nn.conv_transpose3d/jt.nn.conv_transpose:

Previous:

if stride <= 0:
        raise RuntimeError("non-positive stride is not supported")
stride = stride if isinstance(stride, tuple) else (stride, stride, stride)

will raise error when stride is a tuple.

After modification:

stride = stride if isinstance(stride, tuple) else (stride, stride, stride)
if stride[0] <= 0 or stride[1] <= 0 or stride[2] <= 0:
       raise RuntimeError("non-positive stride is not supported")

can handle both cases when stride is a tuple or a single number.

  1. add input shape check in jt.nn.conv_transpose3d/jt.nn.conv_transpose