SalesforceAIResearch / uni2ts

Unified Training of Universal Time Series Forecasting Transformers
Apache License 2.0
868 stars 94 forks source link

Update README and config files with Moirai-1.1-R and fix Moirai-1.0-R model weights #111

Closed gorold closed 2 months ago

gorold commented 2 months ago

Made PRs on huggingface to fix the regression of Moirai-1.0-R model weights with the recent code changes https://huggingface.co/Salesforce/moirai-1.0-R-small/discussions/7 https://huggingface.co/Salesforce/moirai-1.0-R-base/discussions/6 https://huggingface.co/Salesforce/moirai-1.0-R-large/discussions/6

Please help to test, I only managed to run some small tests on my local :)

The changes to model weights were just:

mm = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small")
mm.param_proj.proj.weights_logits.weight[:] = torch.roll(mm.param_proj.proj.weights_logits.weight, 2, dims=0)
mm.param_proj.proj.weights_logits.bias[:] = torch.roll(mm.param_proj.proj.weights_logits.bias, 2, dims=0)
mm.push_to_hub("Salesforce/moirai-1.0-R-small")
liu-jc commented 2 months ago

Hi @gorold,

Sorry, I don't quite understand. How these changes with torch.roll can help to fix the weights? We still have / dim in our codebase? https://github.com/SalesforceAIResearch/uni2ts/blob/27616d912b632fc77731b207e0b74970bed5870b/src/uni2ts/module/ts_embed.py#L178

gorold commented 2 months ago
  1. The only weights affected are those where self.dim > 1, which in this case is just param_proj.proj.weights_logits.
  2. What happened during training of Moirai 1.0 is that we had weights_logits.out_features_ls = (32, 64, 128, 256, 512), and did not have feat_size // self.dim leading to this layer learning the weights for patch size 32 in the first position, 64 in the second position, and so on.
  3. This meant that weights_logits.weight[0:3] contain the correct weights for patch sizes 32, 64, 128, whereas the weights_logits.weight[3:5] are useless.
  4. So we use torch.roll to push the weights into the correct positions for patch sizes 32, 64, 128, and leave 256, 512 in the first two positions as dummies.
gorold commented 2 months ago

Ran it on the PF benchmark ex weather:

Small

with fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 3846122.0 4047497.75 219.01437377929688 0.9806302189826965 0.1304870843887329 0.13351434469223022 8.014351844787598 1961.1531982421875 0.8221926093101501 0.09181945025920868 0.07224071025848389
1 solar-energy 1229.113525390625 1439.4451904296875 19.23911476135254 1.4672776460647583 2.4206011295318604 1.444728970527649 8.45895767211914 35.0587158203125 1.138300895690918 0.6246635317802429 0.4719245433807373
2 walmart 27342336.0 19882204.0 2114.254150390625 0.9929503202438354 0.24787983298301697 0.17269687354564667 8.747377395629883 5228.990234375 0.29910993576049805 0.12094007432460785 0.09698692709207535
3 istanbul_traffic 135.40859985351562 154.15487670898438 8.99282455444336 1.0580122470855713 0.5734052658081055 0.3754161596298218 5.677515506744385 11.636520385742188 0.31070834398269653 0.24011866748332977 0.17352095246315002
4 turkey_power 755125.1875 758453.9375 358.3751525878906 0.9450287818908691 0.5046748518943787 0.3893551230430603 7.002157211303711 868.9793701171875 0.14972686767578125 0.061748750507831573 0.048235006630420685

without fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 23829644.0 22212912.0 560.96923828125 2.2790398597717285 0.33599036931991577 0.25199955701828003 12.980779647827148 4881.5615234375 2.0465428829193115 0.23518039286136627 0.17356589436531067
1 solar-energy 1965.3592529296875 2181.220458984375 23.293292999267578 1.7744004726409912 2.6420609951019287 1.4974911212921143 13.509282112121582 44.33237075805664 1.4394018650054932 0.7562963366508484 0.6452235579490662
2 walmart 31941660.0 22301972.0 2262.530517578125 1.0623531341552734 0.2722017765045166 0.1812155842781067 9.110936164855957 5651.6953125 0.3232896327972412 0.12942181527614594 0.10336102545261383
3 istanbul_traffic 202.15194702148438 205.9799346923828 11.197728157043457 1.317776083946228 1.6119800806045532 0.4647684693336487 7.559725761413574 14.21801471710205 0.379637211561203 0.2989920973777771 0.21395711600780487
4 turkey_power 5094314.5 2785156.5 762.2886352539062 2.0671753883361816 2.91452693939209 0.4795369803905487 25.759414672851562 2257.058837890625 0.388895720243454 0.13134384155273438 0.10627298802137375

Feel free to test on other sizes and datasets. You can directly download the model in the PR branch with the revision argument:

mm = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small", revision="pr/7")

Note that the PR number might be different for the other model sizes.

gorold commented 2 months ago

moirai-1.0-R-Base

with fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 1708712.5 1711201.625 164.1307373046875 0.7915405631065369 0.10031454265117645 0.11076250672340393 6.184873580932617 1307.17724609375 0.5480201840400696 0.06881006807088852 0.054687876254320145
1 solar-energy 1011.0943603515625 1108.408935546875 16.981399536132812 1.2911229133605957 2.296311855316162 1.4095485210418701 7.017038345336914 31.797710418701172 1.0324212312698364 0.5513591766357422 0.41874560713768005
2 walmart 26299296.0 19072352.0 2049.69384765625 0.9657745957374573 0.23114198446273804 0.1677016019821167 8.415294647216797 5128.28369140625 0.29334932565689087 0.117247074842453 0.09353204816579819
3 istanbul_traffic 37.16828918457031 40.77793884277344 4.562923431396484 0.5369675755500793 0.2586391270160675 0.2553446292877197 3.8279149532318115 6.0965800285339355 0.1627856343984604 0.12183525413274765 0.09833786636590958
4 turkey_power 473797.125 474377.5625 295.6066589355469 0.8949130177497864 0.16863861680030823 0.37849825620651245 6.532022476196289 688.3292236328125 0.11860048770904541 0.050933610647916794 0.040024567395448685

without fix:

index dataset test_metrics/MSE[mean] test_metrics/MSE[0.5] test_metrics/MAE[0.5] test_metrics/MASE[0.5] test_metrics/MAPE[0.5] test_metrics/sMAPE[0.5] test_metrics/MSIS test_metrics/RMSE[mean] test_metrics/NRMSE[mean] test_metrics/ND[0.5] test_metrics/mean_weighted_sum_quantile_loss
0 electricity 7544522.5 6037268.0 300.3915710449219 1.4005271196365356 0.1897597312927246 0.17114804685115814 8.438410758972168 2746.7294921875 1.1515370607376099 0.12593597173690796 0.09828455746173859
1 solar-energy 6939.76318359375 3234.031982421875 46.51622772216797 3.5624678134918213 15.757192611694336 1.4608030319213867 40.661155700683594 83.30523681640625 2.704789161682129 1.5103082656860352 1.162361741065979
2 walmart 20425108.0 19084788.0 1990.737548828125 0.9465618133544922 0.22026395797729492 0.16590428352355957 8.317343711853027 4519.41455078125 0.25852063298225403 0.1138746440410614 0.09144116938114166
3 istanbul_traffic 122.82603454589844 160.76173400878906 8.947684288024902 1.0514392852783203 1.8842673301696777 0.375012069940567 5.328226566314697 11.082691192626953 0.2959204614162445 0.23891335725784302 0.17149078845977783
4 turkey_power 7537457.0 2658505.75 731.1760864257812 1.9673036336898804 5.040554523468018 0.4480009973049164 29.197540283203125 2745.44287109375 0.4730452299118042 0.1259830892086029 0.11150600761175156