Closed gorold closed 2 months ago
Hi @gorold,
Sorry, I don't quite understand. How these changes with torch.roll can help to fix the weights? We still have / dim in our codebase? https://github.com/SalesforceAIResearch/uni2ts/blob/27616d912b632fc77731b207e0b74970bed5870b/src/uni2ts/module/ts_embed.py#L178
self.dim > 1
, which in this case is just param_proj.proj.weights_logits
.weights_logits.out_features_ls = (32, 64, 128, 256, 512)
, and did not have feat_size // self.dim
leading to this layer learning the weights for patch size 32 in the first position, 64 in the second position, and so on.weights_logits.weight[0:3]
contain the correct weights for patch sizes 32, 64, 128, whereas the weights_logits.weight[3:5]
are useless.torch.roll
to push the weights into the correct positions for patch sizes 32, 64, 128, and leave 256, 512 in the first two positions as dummies.Ran it on the PF benchmark ex weather:
index | dataset | test_metrics/MSE[mean] | test_metrics/MSE[0.5] | test_metrics/MAE[0.5] | test_metrics/MASE[0.5] | test_metrics/MAPE[0.5] | test_metrics/sMAPE[0.5] | test_metrics/MSIS | test_metrics/RMSE[mean] | test_metrics/NRMSE[mean] | test_metrics/ND[0.5] | test_metrics/mean_weighted_sum_quantile_loss |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | electricity | 3846122.0 | 4047497.75 | 219.01437377929688 | 0.9806302189826965 | 0.1304870843887329 | 0.13351434469223022 | 8.014351844787598 | 1961.1531982421875 | 0.8221926093101501 | 0.09181945025920868 | 0.07224071025848389 |
1 | solar-energy | 1229.113525390625 | 1439.4451904296875 | 19.23911476135254 | 1.4672776460647583 | 2.4206011295318604 | 1.444728970527649 | 8.45895767211914 | 35.0587158203125 | 1.138300895690918 | 0.6246635317802429 | 0.4719245433807373 |
2 | walmart | 27342336.0 | 19882204.0 | 2114.254150390625 | 0.9929503202438354 | 0.24787983298301697 | 0.17269687354564667 | 8.747377395629883 | 5228.990234375 | 0.29910993576049805 | 0.12094007432460785 | 0.09698692709207535 |
3 | istanbul_traffic | 135.40859985351562 | 154.15487670898438 | 8.99282455444336 | 1.0580122470855713 | 0.5734052658081055 | 0.3754161596298218 | 5.677515506744385 | 11.636520385742188 | 0.31070834398269653 | 0.24011866748332977 | 0.17352095246315002 |
4 | turkey_power | 755125.1875 | 758453.9375 | 358.3751525878906 | 0.9450287818908691 | 0.5046748518943787 | 0.3893551230430603 | 7.002157211303711 | 868.9793701171875 | 0.14972686767578125 | 0.061748750507831573 | 0.048235006630420685 |
index | dataset | test_metrics/MSE[mean] | test_metrics/MSE[0.5] | test_metrics/MAE[0.5] | test_metrics/MASE[0.5] | test_metrics/MAPE[0.5] | test_metrics/sMAPE[0.5] | test_metrics/MSIS | test_metrics/RMSE[mean] | test_metrics/NRMSE[mean] | test_metrics/ND[0.5] | test_metrics/mean_weighted_sum_quantile_loss |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | electricity | 23829644.0 | 22212912.0 | 560.96923828125 | 2.2790398597717285 | 0.33599036931991577 | 0.25199955701828003 | 12.980779647827148 | 4881.5615234375 | 2.0465428829193115 | 0.23518039286136627 | 0.17356589436531067 |
1 | solar-energy | 1965.3592529296875 | 2181.220458984375 | 23.293292999267578 | 1.7744004726409912 | 2.6420609951019287 | 1.4974911212921143 | 13.509282112121582 | 44.33237075805664 | 1.4394018650054932 | 0.7562963366508484 | 0.6452235579490662 |
2 | walmart | 31941660.0 | 22301972.0 | 2262.530517578125 | 1.0623531341552734 | 0.2722017765045166 | 0.1812155842781067 | 9.110936164855957 | 5651.6953125 | 0.3232896327972412 | 0.12942181527614594 | 0.10336102545261383 |
3 | istanbul_traffic | 202.15194702148438 | 205.9799346923828 | 11.197728157043457 | 1.317776083946228 | 1.6119800806045532 | 0.4647684693336487 | 7.559725761413574 | 14.21801471710205 | 0.379637211561203 | 0.2989920973777771 | 0.21395711600780487 |
4 | turkey_power | 5094314.5 | 2785156.5 | 762.2886352539062 | 2.0671753883361816 | 2.91452693939209 | 0.4795369803905487 | 25.759414672851562 | 2257.058837890625 | 0.388895720243454 | 0.13134384155273438 | 0.10627298802137375 |
Feel free to test on other sizes and datasets. You can directly download the model in the PR branch with the revision
argument:
mm = MoiraiModule.from_pretrained("Salesforce/moirai-1.0-R-small", revision="pr/7")
Note that the PR number might be different for the other model sizes.
index | dataset | test_metrics/MSE[mean] | test_metrics/MSE[0.5] | test_metrics/MAE[0.5] | test_metrics/MASE[0.5] | test_metrics/MAPE[0.5] | test_metrics/sMAPE[0.5] | test_metrics/MSIS | test_metrics/RMSE[mean] | test_metrics/NRMSE[mean] | test_metrics/ND[0.5] | test_metrics/mean_weighted_sum_quantile_loss |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | electricity | 1708712.5 | 1711201.625 | 164.1307373046875 | 0.7915405631065369 | 0.10031454265117645 | 0.11076250672340393 | 6.184873580932617 | 1307.17724609375 | 0.5480201840400696 | 0.06881006807088852 | 0.054687876254320145 |
1 | solar-energy | 1011.0943603515625 | 1108.408935546875 | 16.981399536132812 | 1.2911229133605957 | 2.296311855316162 | 1.4095485210418701 | 7.017038345336914 | 31.797710418701172 | 1.0324212312698364 | 0.5513591766357422 | 0.41874560713768005 |
2 | walmart | 26299296.0 | 19072352.0 | 2049.69384765625 | 0.9657745957374573 | 0.23114198446273804 | 0.1677016019821167 | 8.415294647216797 | 5128.28369140625 | 0.29334932565689087 | 0.117247074842453 | 0.09353204816579819 |
3 | istanbul_traffic | 37.16828918457031 | 40.77793884277344 | 4.562923431396484 | 0.5369675755500793 | 0.2586391270160675 | 0.2553446292877197 | 3.8279149532318115 | 6.0965800285339355 | 0.1627856343984604 | 0.12183525413274765 | 0.09833786636590958 |
4 | turkey_power | 473797.125 | 474377.5625 | 295.6066589355469 | 0.8949130177497864 | 0.16863861680030823 | 0.37849825620651245 | 6.532022476196289 | 688.3292236328125 | 0.11860048770904541 | 0.050933610647916794 | 0.040024567395448685 |
index | dataset | test_metrics/MSE[mean] | test_metrics/MSE[0.5] | test_metrics/MAE[0.5] | test_metrics/MASE[0.5] | test_metrics/MAPE[0.5] | test_metrics/sMAPE[0.5] | test_metrics/MSIS | test_metrics/RMSE[mean] | test_metrics/NRMSE[mean] | test_metrics/ND[0.5] | test_metrics/mean_weighted_sum_quantile_loss |
---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | electricity | 7544522.5 | 6037268.0 | 300.3915710449219 | 1.4005271196365356 | 0.1897597312927246 | 0.17114804685115814 | 8.438410758972168 | 2746.7294921875 | 1.1515370607376099 | 0.12593597173690796 | 0.09828455746173859 |
1 | solar-energy | 6939.76318359375 | 3234.031982421875 | 46.51622772216797 | 3.5624678134918213 | 15.757192611694336 | 1.4608030319213867 | 40.661155700683594 | 83.30523681640625 | 2.704789161682129 | 1.5103082656860352 | 1.162361741065979 |
2 | walmart | 20425108.0 | 19084788.0 | 1990.737548828125 | 0.9465618133544922 | 0.22026395797729492 | 0.16590428352355957 | 8.317343711853027 | 4519.41455078125 | 0.25852063298225403 | 0.1138746440410614 | 0.09144116938114166 |
3 | istanbul_traffic | 122.82603454589844 | 160.76173400878906 | 8.947684288024902 | 1.0514392852783203 | 1.8842673301696777 | 0.375012069940567 | 5.328226566314697 | 11.082691192626953 | 0.2959204614162445 | 0.23891335725784302 | 0.17149078845977783 |
4 | turkey_power | 7537457.0 | 2658505.75 | 731.1760864257812 | 1.9673036336898804 | 5.040554523468018 | 0.4480009973049164 | 29.197540283203125 | 2745.44287109375 | 0.4730452299118042 | 0.1259830892086029 | 0.11150600761175156 |
Made PRs on huggingface to fix the regression of Moirai-1.0-R model weights with the recent code changes https://huggingface.co/Salesforce/moirai-1.0-R-small/discussions/7 https://huggingface.co/Salesforce/moirai-1.0-R-base/discussions/6 https://huggingface.co/Salesforce/moirai-1.0-R-large/discussions/6
Please help to test, I only managed to run some small tests on my local :)
The changes to model weights were just: