Jittor / jittor

Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.
https://cg.cs.tsinghua.edu.cn/jittor/
Apache License 2.0
3.07k stars 307 forks source link

jt.Var.to does not deal with kargs #544

Closed zhc7 closed 3 months ago

zhc7 commented 3 months ago

Describe the bug

A clear and concise description of what the bug is. 使用中文也可以。

In jittor/misc.py, to is implemented as follow:

def to(x, *args, **kargs):
    if len(args) >= 1:
        s = args[0]
        if isinstance(s, jt.NanoString) or callable(s):
            return x.cast(s)
        s = str(s)
        if "cuda" in s:
            jt.flags.use_cuda = 1
        elif "cpu" in s:
            jt.flags.use_cuda = 0
    return x.clone()

It totally discards keyword arguments. so the following won't work as expected:

a = jt.randn(3, 4)
b = a.to(dtype=jt.float16)
print(b.dtype)

Full Log

Executing the above code results in

[i 0522 02:44:48.140803 84 compiler.py:956] Jittor(1.3.9.5) src: /root/anaconda3/envs/jdiffusion/lib/python3.9/site-packages/jittor
[i 0522 02:44:48.144123 84 compiler.py:957] g++ at /usr/bin/g++(9.4.0)
[i 0522 02:44:48.144212 84 compiler.py:958] cache_path: /root/.cache/jittor/jt1.3.9/g++9.4.0/py3.9.19/Linux-5.15.0-1x63/IntelRXeonRGolx7a/21e8/default
[i 0522 02:44:48.167781 84 __init__.py:412] Found /usr/local/cuda/bin/nvcc(12.0.140) at /usr/local/cuda/bin/nvcc.
[i 0522 02:44:48.172487 84 __init__.py:412] Found addr2line(2.34) at /usr/bin/addr2line.
[i 0522 02:44:48.331424 84 compiler.py:1011] cuda key:cu12.0.140_sm_86
[i 0522 02:44:49.040516 84 __init__.py:227] Total mem: 251.54GB, using 16 procs for compiling.
[i 0522 02:44:49.257756 84 jit_compiler.cc:28] Load cc_path: /usr/bin/g++
[i 0522 02:44:49.405901 84 init.cc:63] Found cuda archs: [86,]
float32

Minimal Reproduce

As shown above.

Expected behavior

Expect to to deal with keyword arguments properly. And based on the implementation code, maybe it should also be able to deal with multiple arguments.

If discarding keyword argument and only taking care of the first argument is intended, I think maybe the signature of the function should just change to def to(x, arg) in order to avoid misleading people.