pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.29k stars 22.46k forks source link

[torch.compile] TypeError: object of type 'NoneType' has no len() on `eager` backend #101030

Closed shingjan closed 1 year ago

shingjan commented 1 year ago

🐛 Describe the bug

My repro, note that it works fine without @torch._dynamo.optimize("eager"):

import torch._dynamo
import torch
import torch.nn as nn

class Model(nn.Module):
    export = False

    def __init__(self, linear):
        super().__init__()
        self.m = nn.ModuleList([linear] * 3)

    def forward(self, x):
        for i, mod in enumerate(self.m):
            if not self.training:
                x = mod(x)
        return x

@torch._dynamo.optimize("eager")
def f(l, x):
    layers = []
    layers += [Model(l)]
    m = nn.Sequential(*layers)
    return m(x)

linear = nn.Linear(3,3)
x = torch.randn(1, 3)
t = f(linear, x)

stacktrace:

Traceback (most recent call last):
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 425, in _compile
    out_code = transform_code_object(code, transform)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
    transformations(instructions, code_options)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 410, in transform
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2010, in run
    super().run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
    return inner_fn(self, inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 554, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 153, in call_function
    return var.add_options(var.call_method(tx, "__init__", args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 760, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 272, in call_method
    return UserMethodVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 306, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 590, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call_
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
    return inner_fn(self, inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 554, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/torch.py", line 226, in call_function
    return variables.UserDefinedClassVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 153, in call_function
    return var.add_options(var.call_method(tx, "__init__", args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 760, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 272, in call_method
    return UserMethodVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 306, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 590, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call_
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 160, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 565, in call_function
    res = binop_handler(tx, args[0], args[1], options)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 234, in user_defined_inplace_handler
    return a.call_method(tx, forward_name, [b], {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 760, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 272, in call_method
    return UserMethodVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 306, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 590, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call_
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
    return inner_fn(self, inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 554, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/misc.py", line 418, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs).add_options(self)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 760, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 272, in call_method
    return UserMethodVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 306, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 590, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call_
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
    return inner_fn(self, inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 554, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 584, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 885, in call_len
    return args[0].call_method(tx, "__len__", args[1:], kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/nn_module.py", line 760, in call_method
    return super().call_method(tx, name, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/user_defined.py", line 272, in call_method
    return UserMethodVariable(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 306, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 269, in call_function
    return super().call_function(tx, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/functions.py", line 102, in call_function
    return tx.inline_user_function_return(
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 590, in inline_user_function_return
    result = InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2115, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 2193, in inline_call_
    tracer.run()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 703, in run
    and self.step()
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 663, in step
    getattr(self, inst.opname)(inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
    return inner_fn(self, inst)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 554, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 584, in call_function
    result = handler(tx, *args, **kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/builtin.py", line 885, in call_len
    return args[0].call_method(tx, "__len__", args[1:], kwargs)
  File "/home/yj/anaconda3/envs/dyna/lib/python3.8/site-packages/torch/_dynamo/variables/constant.py", line 129, in call_method
    return ConstantVariable(len(self.value), **options)
TypeError: object of type 'NoneType' has no len()

Versions

PyTorch version: 2.1.0.dev20230505+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.5 LTS (x86_64) GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0 Clang version: Could not collect CMake version: version 3.26.0 Libc version: glibc-2.31

Python version: 3.8.0 (default, Nov 6 2019, 21:49:08) [GCC 7.3.0] (64-bit runtime) Python platform: Linux-5.15.0-69-generic-x86_64-with-glibc2.10 Is CUDA available: True CUDA runtime version: 11.8.89 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3070 Nvidia driver version: 520.61.05 cuDNN version: Probably one of the following: /usr/lib/x86_64-linux-gnu/libcudnn.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.6.0 /usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.6.0 HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian Address sizes: 48 bits physical, 48 bits virtual CPU(s): 24 On-line CPU(s) list: 0-23 Thread(s) per core: 2 Core(s) per socket: 12 Socket(s): 1 NUMA node(s): 1 Vendor ID: AuthenticAMD CPU family: 25 Model: 33 Model name: AMD Ryzen 9 5900X 12-Core Processor Stepping: 0 Frequency boost: enabled CPU MHz: 2200.000 CPU max MHz: 3700.0000 CPU min MHz: 2200.0000 BogoMIPS: 7399.70 Virtualization: AMD-V L1d cache: 384 KiB L1i cache: 384 KiB L2 cache: 6 MiB L3 cache: 64 MiB NUMA node0 CPU(s): 0-23 Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Not affected Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm

Versions of relevant libraries: [pip3] clip-anytorch==2.5.2 [pip3] CoCa-pytorch==0.0.7 [pip3] dalle2-pytorch==1.12.4 [pip3] ema-pytorch==0.2.1 [pip3] functorch==1.14.0a0+408bcf1 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.21.2 [pip3] open-clip-torch==2.16.0 [pip3] pytorch-lightning==2.0.0 [pip3] pytorch-transformers==1.2.0 [pip3] pytorch-triton==2.1.0+7d1a95b046 [pip3] pytorch-warmup==0.1.1 [pip3] rotary-embedding-torch==0.2.1 [pip3] torch==2.1.0.dev20230505+cu118 [pip3] torch-fidelity==0.3.0 [pip3] torch-scatter==2.1.1+pt20cpu [pip3] torch-sparse==0.6.17+pt20cpu [pip3] torch-struct==0.5 [pip3] torchaudio==2.1.0.dev20230505+cu118 [pip3] torchdata==0.7.0.dev20230426 [pip3] torchmetrics==0.11.4 [pip3] torchrec-nightly==2023.3.16 [pip3] torchtext==0.16.0.dev20230426+cpu [pip3] torchvision==0.16.0.dev20230505+cu118 [pip3] vector-quantize-pytorch==1.1.2 [conda] clip-anytorch 2.5.2 pypi_0 pypi [conda] coca-pytorch 0.0.7 pypi_0 pypi [conda] dalle2-pytorch 1.12.4 pypi_0 pypi [conda] ema-pytorch 0.2.1 pypi_0 pypi [conda] functorch 1.14.0a0+408bcf1 pypi_0 pypi [conda] numpy 1.21.2 pypi_0 pypi [conda] open-clip-torch 2.16.0 pypi_0 pypi [conda] pytorch-lightning 2.0.0 pypi_0 pypi [conda] pytorch-transformers 1.2.0 pypi_0 pypi [conda] pytorch-triton 2.1.0+7d1a95b046 pypi_0 pypi [conda] pytorch-warmup 0.1.1 pypi_0 pypi [conda] rotary-embedding-torch 0.2.1 pypi_0 pypi [conda] torch 2.1.0.dev20230505+cu118 pypi_0 pypi [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torch-scatter 2.1.1+pt20cpu pypi_0 pypi [conda] torch-sparse 0.6.17+pt20cpu pypi_0 pypi [conda] torch-struct 0.5 pypi_0 pypi [conda] torchaudio 2.1.0.dev20230505+cu118 pypi_0 pypi [conda] torchdata 0.7.0.dev20230426 pypi_0 pypi [conda] torchmetrics 0.11.4 pypi_0 pypi [conda] torchrec-nightly 2023.3.16 pypi_0 pypi [conda] torchtext 0.16.0.dev20230426+cpu pypi_0 pypi [conda] torchvision 0.16.0.dev20230505+cu118 pypi_0 pypi [conda] vector-quantize-pytorch 1.1.2 pypi_0 pypi

cc @ezyang @soumith @msaroufim @wconstab @ngimel @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire

anijain2305 commented 1 year ago

I can take a look next week (on PTO). If you want the fix sooner, let me or @yanboliang know.

anijain2305 commented 1 year ago

@shingjan I can't repro the issue here. Can you please recheck on your end with latest main branch?

shingjan commented 1 year ago

@anijain2305 thanks for the info. I rebased and can confirm the repro no longer gives the above issue. Will close it for now. For your reference, ~the branch of this PR can still produce the same repro~ I accidentally did a rebase, this commit from the above branch should have the repro.