lhw828 / timesfm

Google timesfm 实战部署/本地部署详细记录
71 stars 10 forks source link

checkpoint可以下载到本地来加载吗? #5

Open keochoi opened 5 months ago

keochoi commented 5 months ago

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?

tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

lhw828 commented 5 months ago

当然可以。从Hugging Face Hub预先下载模型到本地,然后从本地加载是一种常见做法,尤其在你可能需要多次使用该模型或网络不稳定的情况下。以下是大致步骤:

下载模型到本地 手动下载:你可以直接从Hugging Face Model Hub的网页上找到该模型页面,通常会有“Download”或“Model weights”之类的按钮,点击即可下载模型的文件或压缩包。对于google/timesfm-1.0-200m这样的模型,你可以在Hugging Face的网站上搜索该模型名,找到对应页面并下载。

使用CLI下载:另外,你也可以使用Hugging Face的命令行工具huggingface-cli来下载模型。首先确保你安装了这个工具,然后执行如下命令:

Bash huggingface-cli login # 登录你的Hugging Face账号,如果还没登录的话 huggingface-cli repo download google/timesfm-1.0-200m --cache_dir ./models # 将模型下载到本地的./models目录 这里--cache_dir指定了模型下载的本地目录,你可以根据需要修改。

从本地加载模型 下载完成后,你可以在代码中指定模型的本地路径来加载它。假设模型下载到了./models/google_timesfm-1.0-200m目录下,你可以这样修改加载模型的代码:

Python from transformers import AutoModel

指定本地路径加载模型

local_model_path = "./models/google_timesfm-1.0-200m" model = AutoModel.from_pretrained(local_model_path)

或者对于timesfm模型,如果你使用的是特定的加载方式,可能是这样的

注意:下面的代码是示意性的,具体加载方法取决于timesfm模型实际的加载逻辑

tfm = TimesFm(context_len=context_len, horizon_len=horizon_len, input_patch_len=32, output_patch_len=128, num_layers=20, model_dims=1280, backend='cpu') tfm.load_local(local_model_path) # 假设timesfm有类似load_local这样的方法来从本地路径加载模型

lhw828 commented 5 months ago

这是ai给的回复。

zhaokui001 commented 5 months ago

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗?

tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints" tfm.load_from_checkpoint(checkpoint_path=local_model_path) 我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

ham114 commented 1 month ago

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗? tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints" tfm.load_from_checkpoint(checkpoint_path=local_model_path) 我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'> WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024. WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume train_state is unpadded. ERROR:absl:For checkpoint version > 1.0, we require users to provide train_state_unpadded_shape_dtype_struct during checkpoint saving/restoring, to avoid potential silent bugs when loading checkpoints to incompatible unpadded shapes of TrainState. Restored checkpoint in 3.01 seconds. Jitting decoding. Jitted decoding in 52.32 seconds. 这些错误和提示信息怎么解决,影响使用吗

lhw828 commented 1 month ago

google/timesfm-1.0-200m这个可以提前从huggingface下载到本地,从本地加载吗? tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m") # 这句应该怎么改?

应该改为

local_model_path = "/home/dedong/huggingface/timesfm/checkpoints" tfm.load_from_checkpoint(checkpoint_path=local_model_path) 我将模型本地下载之后,按上面的方法导入是能够成功加载模型并进行预测的

WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'> WARNING:absl:Configured CheckpointManager using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024. WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume train_state is unpadded. ERROR:absl:For checkpoint version > 1.0, we require users to provide train_state_unpadded_shape_dtype_struct during checkpoint saving/restoring, to avoid potential silent bugs when loading checkpoints to incompatible unpadded shapes of TrainState. Restored checkpoint in 3.01 seconds. Jitting decoding. Jitted decoding in 52.32 seconds. 这些错误和提示信息怎么解决,影响使用吗

这个可以忽略