huggingface / safetensors

Simple, safe way to store and distribute tensors
https://huggingface.co/docs/safetensors
Apache License 2.0
2.9k stars 200 forks source link

ModuleNotFoundError: No module named 'torch._higher_order_ops' #482

Closed guotong1988 closed 5 months ago

guotong1988 commented 5 months ago

System Info

torch==2.0.0    
torchtune==0.1.1    
transformers==4.41.1    
safetensors==0.4.3

Information

Reproduction

from torchtune.utils import FullModelHFCheckpointer
from torchtune.models import convert_weights
import torch

checkpointer = FullModelHFCheckpointer(
    checkpoint_dir="pythonProject/llama3_main/meta-llama-3-8b-instruct/",
    checkpoint_files=["model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors",
                      "model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors"],
    output_dir="./tmp",
    model_type='LLAMA3'
)

print("loading checkpoint")
sd = checkpointer.load_checkpoint()
sd = convert_weights.tune_to_meta(sd['model'])
print("saving checkpoint")
torch.save(sd, "./tmp/checkpoint.pth")

Expected behavior

ERROR INFO:

Traceback (most recent call last):
  File "pythonProject/convert.py", line 1, in <module>
    from torchtune.utils import FullModelHFCheckpointer
  File "python3.8/site-packages/torchtune/__init__.py", line 9, in <module>
    from torchtune import datasets, models, modules, utils
  File "python3.8/site-packages/torchtune/datasets/__init__.py", line 7, in <module>
    from torchtune.datasets._alpaca import alpaca_cleaned_dataset, alpaca_dataset
  File "python3.8/site-packages/torchtune/datasets/_alpaca.py", line 10, in <module>
    from torchtune.datasets._instruct import InstructDataset
  File "python3.8/site-packages/torchtune/datasets/_instruct.py", line 12, in <module>
    from torchtune.config._utils import _get_instruct_template
  File "python3.8/site-packages/torchtune/config/__init__.py", line 7, in <module>
    from ._instantiate import instantiate
  File "python3.8/site-packages/torchtune/config/_instantiate.py", line 12, in <module>
    from torchtune.config._utils import _get_component_from_path, _has_component
  File "python3.8/site-packages/torchtune/config/_utils.py", line 16, in <module>
    from torchtune.utils import get_logger, get_world_size_and_rank
  File "python3.8/site-packages/torchtune/utils/__init__.py", line 7, in <module>
    from ._checkpointing import (  # noqa
  File "python3.8/site-packages/torchtune/utils/_checkpointing/__init__.py", line 7, in <module>
    from ._checkpointer import (  # noqa
  File "python3.8/site-packages/torchtune/utils/_checkpointing/_checkpointer.py", line 17, in <module>
    from torchtune.models import convert_weights
  File "python3.8/site-packages/torchtune/models/__init__.py", line 7, in <module>
    from torchtune.models import convert_weights, gemma, llama2, mistral  # noqa
  File "python3.8/site-packages/torchtune/models/gemma/__init__.py", line 7, in <module>
    from ._component_builders import gemma  # noqa
  File "python3.8/site-packages/torchtune/models/gemma/_component_builders.py", line 9, in <module>
    from torchtune.modules import (
  File "python3.8/site-packages/torchtune/modules/__init__.py", line 8, in <module>
    from .common_utils import reparametrize_as_dtype_state_dict_post_hook
  File "python3.8/site-packages/torchtune/modules/common_utils.py", line 12, in <module>
    from torchao.dtypes.nf4tensor import NF4Tensor
  File "python3.8/site-packages/torchao/__init__.py", line 2, in <module>
    from .quantization.quant_api import apply_dynamic_quant
  File "python3.8/site-packages/torchao/quantization/__init__.py", line 7, in <module>
    from .smoothquant import *  # noqa: F403
  File "python3.8/site-packages/torchao/quantization/smoothquant.py", line 18, in <module>
    import torchao.quantization.quant_api as quant_api
  File "python3.8/site-packages/torchao/quantization/quant_api.py", line 22, in <module>
    from .dynamic_quant import DynamicallyPerAxisQuantizedLinear
  File "python3.8/site-packages/torchao/quantization/dynamic_quant.py", line 10, in <module>
    from .quant_primitives import (
  File "python3.8/site-packages/torchao/quantization/quant_primitives.py", line 9, in <module>
    from torch._higher_order_ops.out_dtype import out_dtype
ModuleNotFoundError: No module named 'torch._higher_order_ops'
guotong1988 commented 5 months ago

Solved at https://github.com/pytorch/torchtune/issues/1038