Open tangh18 opened 2 months ago
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
This is awesome, @tangh18! 🎉
When running python convert_ckpt.py
, I'm seeing
jaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.
Checking if there are any ideas about how to resolve this? JAX version is 0.4.26
, CUDA version is 12.4
, and cuDNN version is 8.9.7.29
, which should be compatible.
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.
This is awesome, @tangh18! 🎉
When running
python convert_ckpt.py
, I'm seeingjaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.
Checking if there are any ideas about how to resolve this? JAX version is
0.4.26
, CUDA version is12.4
, and cuDNN version is8.9.7.29
, which should be compatible.
The code in convert_ckpt.py
about JAX only involves creating the model and loading the checkpoint. Have you successfully run timesfm in JAX? If so, the issue you're encountering should not occur.
This is awesome, @tangh18! 🎉 When running
python convert_ckpt.py
, I'm seeingjaxlib.xla_extension.XlaRuntimeError: FAILED PRECONDITION: DNN library initialization failed.
Checking if there are any ideas about how to resolve this? JAX version is0.4.26
, CUDA version is12.4
, and cuDNN version is8.9.7.29
, which should be compatible.The code in
convert_ckpt.py
about JAX only involves creating the model and loading the checkpoint. Have you successfully run timesfm in JAX? If so, the issue you're encountering should not occur.
Thanks for your super cool contribution! I think what @sebastianpinedaar means is performance in terms of accuracy of the model. Have you been able to reproduce the accuracy metrics reported in the original paper?
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.
Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.
Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.
Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!
'mse': np.float32(0.4324413),
'smape': np.float32(0.7251805),
'mae': np.float32(0.40476117),
'wape': np.float32(11.708796),
'nrmse': np.float32(19.022907),
'num_elements': 20160,
'abs_sum': np.float32(696.91077),
'dataset': 'etth1',
'freq': 'h',
'pred_len': 96,
'context_len': 512
Well, that is my misunderstanding. I may try etth1 when available as @TeddyHuang-00 did with the official pytorch implementation. For sanity check, I have checked several times that when there's no padding and use the same normalizing method, the output of jax and my implementation are almost same with the same input. Thanks for point out that. @sebastianpinedaar @melopeo
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.
Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.
Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!
'mse': np.float32(0.4324413), 'smape': np.float32(0.7251805), 'mae': np.float32(0.40476117), 'wape': np.float32(11.708796), 'nrmse': np.float32(19.022907), 'num_elements': 20160, 'abs_sum': np.float32(696.91077), 'dataset': 'etth1', 'freq': 'h', 'pred_len': 96, 'context_len': 512
I haven't learned how to build models with JAX, and I would like to ask how to convert the weights from the JAX checkpoints into a format that can be loaded into a PyTorch model, if I'm using the PyTorch version of the model provided on the official pytorch models.
nice! thanks for this! Did you have a chance to compare the performance between the original and the new implementation performance?
Thanks for your attention! I conducted tests on an A100 server, running the process 45 times with each run consisting of a batch size of 32 and a context length of 2880. The processing time with PyTorch was 3.89 seconds, while JAX completed the task in 1.61 seconds. However, this comparison may not be entirely fair. The PyTorch version did not handle paddings, and its implementation was also not optimized.
Thanks for the info @tangh18! As @melopeo pointed out, I was rather curious about the MAE and scaled MAE performance comparison between the original jax TimesFM and the pytorch version, at least in a couple of datasets. I mean this as a sanity check. Although there are some difference in the preprocessing, hopefully the difference is not too big.
Although using the official PyTorch implementation, I took the same weight conversion processing and was able to reproduce the results like this. Hopefully this helps!
'mse': np.float32(0.4324413), 'smape': np.float32(0.7251805), 'mae': np.float32(0.40476117), 'wape': np.float32(11.708796), 'nrmse': np.float32(19.022907), 'num_elements': 20160, 'abs_sum': np.float32(696.91077), 'dataset': 'etth1', 'freq': 'h', 'pred_len': 96, 'context_len': 512
I haven't learned how to build models with JAX, and I would like to ask how to convert the weights from the JAX checkpoints into a format that can be loaded into a PyTorch model, if I'm using the PyTorch version of the model provided on the official pytorch models.
https://gist.github.com/TeddyHuang-00/fc2238f6f5956a9906c8c206edef2603
You are welcome :smile:
I implement a PyTorch version of TimesFM here. It includes the essential components required to operate the model effectively. Hope it helps. :D