open-mmlab / mmengine

OpenMMLab Foundational Library for Training Deep Learning Models
https://mmengine.readthedocs.io/
Apache License 2.0
1.17k stars 357 forks source link

[Bug] Error Encountered with mmengine Dependency Involving JSON and Time Modules #1523

Open Duguce opened 7 months ago

Duguce commented 7 months ago

Prerequisite

Environment

OrderedDict([('sys.platform', 'linux'), ('Python', '3.10.14 (main, Mar 21 2024, 16:24:04) [GCC 11.2.0]'), ('CUDA available', True), ('MUSA available', False), ('numpy_random_seed', 2147483648), ('GPU 0,1,2,3,4,5,6,7', 'NVIDIA A800-SXM4-40GB'), ('CUDA_HOME', '/usr/local/cuda-11.8'), ('NVCC', 'Cuda compilation tools, release 11.8, V11.8.89'), ('GCC', 'gcc (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)'), ('PyTorch', '2.1.2+cu121'), ('PyTorch compiling details', 'PyTorch built with:\n - GCC 9.3\n - C++ Version: 201703\n - Intel(R) oneAPI Math Kernel Library Version 2022.2-Product Build 20220804 for Intel(R) 64 architecture applications\n - Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)\n - OpenMP 201511 (a.k.a. OpenMP 4.5)\n - LAPACK is enabled (usually provided by MKL)\n - NNPACK is enabled\n - CPU capability usage: AVX512\n - CUDA Runtime 12.1\n - NVCC architecture flags: -gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90\n - CuDNN 8.9.2\n - Magma 2.6.1\n - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=12.1, CUDNN_VERSION=8.9.2, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=old-style-cast -Wno-invalid-partial-specialization -Wno-unused-private-field -Wno-aligned-allocation-unavailable -Wno-missing-braces -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.1.2, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=1, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, \n'), ('TorchVision', '0.16.2+cu121'), ('OpenCV', '4.9.0'), ('MMEngine', '0.10.3')])

Reproduces the problem - code sample

import torch
from datasets import load_dataset
from mmengine.dataset import DefaultSampler
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
                            LoggerHook, ParamSchedulerHook)
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
from peft import LoraConfig
from torch.optim import AdamW
from transformers import (AutoModelForCausalLM, AutoTokenizer,
                          BitsAndBytesConfig)

from xtuner.dataset import process_hf_dataset
from xtuner.dataset.collate_fns import default_collate_fn
from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
                                 VarlenAttnArgsToMessageHubHook)
from xtuner.engine.runner import TrainLoop
from xtuner.model import SupervisedFinetune
from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE

import time
from mmengine.visualization.vis_backend import WandbVisBackend
from mmengine.visualization.visualizer import Visualizer

Reproduces the problem - command or script

CUDA_VISIBLE_DEVICES=7 xtuner train /mnt/data61/qingchen/codes/OpenJudge/yqc/ft/qwen1_5_0_5b_chat_qlora/qwen1_5_0_5b_chat_qlora_alpaca_e3_copy.py --deepspeed deepspeed_zero2

Reproduces the problem - error message

[2024-04-01 11:00:04,308] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-01 11:00:11,840] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/xtuner/tools/train.py", line 307, in <module>
    main()
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/xtuner/tools/train.py", line 300, in main
    runner = RUNNERS.build(cfg)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/mmengine/registry/registry.py", line 570, in build
    return self.build_func(cfg, *args, **kwargs, registry=self)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/mmengine/registry/build_functions.py", line 196, in build_runner_from_cfg
    runner = runner_cls.from_cfg(args)  # type: ignore
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/mmengine/runner/_flexible_runner.py", line 422, in from_cfg
    cfg = copy.deepcopy(cfg)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/mmengine/config/config.py", line 1527, in __deepcopy__
    super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/site-packages/mmengine/config/config.py", line 144, in __deepcopy__
    other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
  File "/mnt/data61/qingchen/envs/xtuner-env/lib/python3.10/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'module' object

Additional information

When I use xtuner, I encounter an error related to mmengine that seems to involve built-in packages like json and time. For example, when I comment out import time in the code above, it seems that the error doesn't occur anymore.

h123c commented 5 months ago

I have the same problem, did you solve it?

wangzhen0518 commented 3 months ago

I met similar errors, and I find it is caused by manually adding imports for builtin modules. For example, I imported os here.

❯ xtuner train internlm2_7b_qlora_colorist_e5_copy.py
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
[2024-08-08 17:16:25,504] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, input, weight, bias=None):
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, grad_output):
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/optim/optimizer/zero_optimizer.py:11: DeprecationWarning: `TorchScript` support for functional optimizers is deprecated and will be removed in a future PyTorch release. Consider using the `torch.compile` optimizer instead.
  from torch.distributed.optim import \
[2024-08-08 17:16:27,826] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, input, weight, bias=None):
/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, grad_output):
Traceback (most recent call last):
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/xtuner/tools/train.py", line 360, in <module>
    main()
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/xtuner/tools/train.py", line 349, in main
    runner = Runner.from_cfg(cfg)
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/runner/runner.py", line 461, in from_cfg
    cfg = copy.deepcopy(cfg)
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/config/config.py", line 1531, in __deepcopy__
    super(Config, other).__setattr__(key, copy.deepcopy(value, memo))
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/copy.py", line 153, in deepcopy
    y = copier(memo)
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/config/config.py", line 144, in __deepcopy__
    other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo)
  File "/home/zhenwang/miniconda3/envs/tune/lib/python3.10/copy.py", line 161, in deepcopy
    rv = reductor(4)
TypeError: cannot pickle 'module' object
# internlm2_7b_qlora_colorist_e5_copy.py
import os

import torch
...

It seems that Config.fromfile requires modification.. I will update my solution when I solve the problem.

wangzhen0518 commented 3 months ago

I do not known whether this solution could cause other problems.

I have run successfully by deleting /path/to/conda/miniconda3/envs/tune/lib/python3.10/site-packages/mmengine/config/utils.py:446-448 in function _gather_abs_import_lazyobj (mmengine version 0.10.4) as following.

def _gather_abs_import_lazyobj(tree: ast.Module,
                               filename: Optional[str] = None):
    """Experimental implementation of gathering absolute import information."""
    if isinstance(filename, str):
        filename = filename.encode('unicode_escape').decode()
    imported = defaultdict(list)
    abs_imported = set()
    new_body: List[ast.stmt] = []
    # module2node is used to get lineno when Python < 3.10
    module2node: dict = dict()
    for node in tree.body:
        if isinstance(node, ast.Import):
            for alias in node.names:
                # Skip converting built-in module to LazyObject
                # ! LINES TO BE DELETED ! if _is_builtin_module(alias.name): 
                # ! LINES TO BE DELETED !    new_body.append(node)
                # ! LINES TO BE DELETED !    continue
                module = alias.name.split('.')[0]
                module2node.setdefault(module, node)
                imported[module].append(alias)
            continue
        new_body.append(node)

    for key, value in imported.items():
        names = [_value.name for _value in value]
        if hasattr(value[0], 'lineno'):
            lineno = value[0].lineno
        else:
            lineno = module2node[key].lineno
        lazy_module_assign = ast.parse(
            f'{key} = LazyObject({names}, location="{filename}, line {lineno}")'  # noqa: E501
        )  # noqa: E501
        abs_imported.add(key)
        new_body.insert(0, lazy_module_assign.body[0])
    tree.body = new_body
    return tree, abs_imported

Emphasize it again: I do not known whether this solution could cause other problems.

I also do not understand why we need to skip builtin modules here. It might be for performace consideration.

I suggest removing all special handling code for builtin modules in utils.py including line 158-182 (_is_builtin_module definition), line 318-328 in __init__ of class ImportTransformer, line 422-423 in visit_Import, and line 446-448 in _gather_abs_import_lazyobj (mmengine version 0.10.4).