EleutherAI / pythia

The hub for EleutherAI's work on interpretability and learning dynamics
Apache License 2.0
2.16k stars 156 forks source link

cache_dir cannot be the same as model name #163

Open arunasank opened 1 month ago

arunasank commented 1 month ago
from transformers import GPTNeoXForCausalLM, AutoTokenizer

model = GPTNeoXForCausalLM.from_pretrained(
    "EleutherAI/pythia-12b",
    revision="step143000",
    cache_dir="./pythia-12b/step143000"
)

tokenizer = AutoTokenizer.from_pretrained(
    "EleutherAI/pythia-12b",
    cache_dir="./pythia-12b/step143000"
)

inputs = tokenizer("Hello, I am", return_tensors="pt")
tokens = model.generate(**inputs)
tokenizer.decode(tokens[0])

The above piece of code works for me. However, if I change the cache dir to be "EleutherAI/pythia-12b", the bin files are not downloaded, and the code errors. Shouldn't it work irrespective of what the cache_dir is?

Error Trace: ``` --------------------------------------------------------------------------- OSError Traceback (most recent call last) Cell In[30], line 3 1 from transformers import GPTNeoXForCausalLM, AutoTokenizer ----> 3 model = GPTNeoXForCausalLM.from_pretrained( 4 "EleutherAI[/pythia-12b](http://localhost:8001/pythia-12b)", 5 revision="step143000", 6 cache_dir="EleutherAI[/pythia-12b](http://localhost:8001/pythia-12b)" 7 ) 9 tokenizer = AutoTokenizer.from_pretrained( 10 "EleutherAI[/pythia-12b](http://localhost:8001/pythia-12b)", 11 cache_dir="EleutherAI[/pythia-12b](http://localhost:8001/pythia-12b)" 12 ) 14 inputs = tokenizer("Hello, I am", return_tensors="pt") File [/mnt/align4_drive/arunas/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/modeling_utils.py:3075](http://localhost:8001/lab/tree/rm-interp/rm-playground/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/modeling_utils.py#line=3074), in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs) 3073 if not isinstance(config, PretrainedConfig): 3074 config_path = config if config is not None else pretrained_model_name_or_path -> 3075 config, model_kwargs = cls.config_class.from_pretrained( 3076 config_path, 3077 cache_dir=cache_dir, 3078 return_unused_kwargs=True, 3079 force_download=force_download, 3080 resume_download=resume_download, 3081 proxies=proxies, 3082 local_files_only=local_files_only, 3083 token=token, 3084 revision=revision, 3085 subfolder=subfolder, 3086 _from_auto=from_auto_class, 3087 _from_pipeline=from_pipeline, 3088 **kwargs, 3089 ) 3090 else: 3091 # In case one passes a config to `from_pretrained` + "attn_implementation" 3092 # override the `_attn_implementation` attribute to `attn_implementation` of the kwargs (...) 3096 # we pop attn_implementation from the kwargs but this handles the case where users 3097 # passes manually the config to `from_pretrained`. 3098 config = copy.deepcopy(config) File [/mnt/align4_drive/arunas/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py:605](http://localhost:8001/lab/tree/rm-interp/rm-playground/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py#line=604), in PretrainedConfig.from_pretrained(cls, pretrained_model_name_or_path, cache_dir, force_download, local_files_only, token, revision, **kwargs) 601 kwargs["revision"] = revision 603 cls._set_token_in_kwargs(kwargs, token) --> 605 config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 606 if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 607 logger.warning( 608 f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 609 f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 610 ) File [/mnt/align4_drive/arunas/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py:634](http://localhost:8001/lab/tree/rm-interp/rm-playground/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py#line=633), in PretrainedConfig.get_config_dict(cls, pretrained_model_name_or_path, **kwargs) 632 original_kwargs = copy.deepcopy(kwargs) 633 # Get config dict associated with the base config file --> 634 config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) 635 if "_commit_hash" in config_dict: 636 original_kwargs["_commit_hash"] = config_dict["_commit_hash"] File [/mnt/align4_drive/arunas/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py:689](http://localhost:8001/lab/tree/rm-interp/rm-playground/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/configuration_utils.py#line=688), in PretrainedConfig._get_config_dict(cls, pretrained_model_name_or_path, **kwargs) 685 configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) 687 try: 688 # Load from local folder or from cache or download from model Hub and cache --> 689 resolved_config_file = cached_file( 690 pretrained_model_name_or_path, 691 configuration_file, 692 cache_dir=cache_dir, 693 force_download=force_download, 694 proxies=proxies, 695 resume_download=resume_download, 696 local_files_only=local_files_only, 697 token=token, 698 user_agent=user_agent, 699 revision=revision, 700 subfolder=subfolder, 701 _commit_hash=commit_hash, 702 ) 703 commit_hash = extract_commit_hash(resolved_config_file, commit_hash) 704 except EnvironmentError: 705 # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to 706 # the original exception. File [/mnt/align4_drive/arunas/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/utils/hub.py:356](http://localhost:8001/lab/tree/rm-interp/rm-playground/miniconda3/envs/rm-score/lib/python3.11/site-packages/transformers/utils/hub.py#line=355), in cached_file(path_or_repo_id, filename, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs) 354 if not os.path.isfile(resolved_file): 355 if _raise_exceptions_for_missing_entries: --> 356 raise EnvironmentError( 357 f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout " 358 f"'https://huggingface.co/{path_or_repo_id}/{[revision](https://huggingface.co/%7Bpath_or_repo_id%7D/%7Brevision)}' for available files." 359 ) 360 else: 361 return None OSError: EleutherAI[/pythia-12b](http://localhost:8001/pythia-12b) does not appear to have a file named config.json. Checkout 'https://huggingface.co/EleutherAI/pythia-12b/step143000' for available files. ```