Closed deven-gqc closed 1 year ago
@deven-gqc can you print the shape of your y
. It seems like you are trying to train a model with more than 1
output. If that's the case, then you need to use models ending with plus
. Right now you are using TST
which can be used for one target prediction.
hey @HasnainKhanNiazi, both inputs and outputs are 3D tensors with the following shapes
It seems like you are trying to train a model with more than
1
output. If that's the case, then you need to use models ending withplus
. Right now you are usingTST
which can be used for one target prediction.
Where can I read about this? can you point me to some specific link in the documentation?
Hi @deven-gqc , I ran a few experiments to see what's the problem and as I said earlier that TST
can be used for one variable/point prediction but you have the size of (15,1)
so you need to replace TST
with TSTPlus
.
X = torch.rand((672, 15, 1))
y = torch.rand((672, 15, 1))
batch_tfms = TSStandardize(by_sample=True)
reg = TSRegressor(X, y, splits=splits, path='models', arch=TST, batch_tfms=batch_tfms, metrics=rmse, cbs=ShowGraph(), verbose=True)
The above code will give the same error as you are having so just replace TST
with TSTPlus
and it will work fine.
You can read more about TST
and TSTPlus
from the documentation, https://timeseriesai.github.io/tsai/models.tst.html
Hi @deven-gqc,
It seems you are trying to create a model for a time series with 15 features with 1 step only. You can try, but I don't think you'll get good results with tsai
for that type of data. tsai
is designed to work with time series and sequential data, but 1 step is not a sequence.
Thanks for the replies @oguiza and @HasnainKhanNiazi, I wanted to know, is a doc on the site which says custom_head argument is only supported by the Plus models?
@deven-gqc to the best of my knowledge, I don't think this is written anywhere in the documentation but if you have a look at the source code then you will find custom_head
param in every Plus
model.
@oguiza, @HasnainKhanNiazi
Hi, I am encountering a similar error - although mine is while attempting to use ROCKET on multivariant data.
If I input any multivariant custom data I receive the init() got an unexpected keyword argument 'custom_head' error message.
I have tried following along with the ROCKET Google Colab notebook and see that I can use multivariant data if it is pulled directly from the UCR datasets using the get_UCR_data function. As a test, I converted this downloaded data to a pandas dataframe, and then back to X, y using the df2xy function - if I do this then I receive the same error message init() got an unexpected keyword argument 'custom_head'
Any help with this would be greatly appreciated. I assume this is occurring as y.shape changes from (234, 0) before conversion to (234, 10) post conversion. If this is the case, is it due to me converting the data incorrectly, or is the data required to be fed into ROCKET in a different format?
@Zwayeh I think the error is coming as the constructor of the Rocket
class doesn't take the custom_head
argument.
To quote @HasnainKhanNiazi
@deven-gqc to the best of my knowledge, I don't think this is written anywhere in the documentation but if you have a look at the source code then you will find
custom_head
param in everyPlus
model.
You would need to use the MiniRocketPlus model available → here
@deven-gqc Thanks so much, I was playing around with MiniRocket before but didn't realise I should've been utilising MiniRocketPlus.
I've initially used it for some single variant data with excellent success so far (thanks to the excellent framework by @oguiza). Unfortunately I am still running into some issues when utilising multi-variant custom data.
I've been trying to work on this for a day or two but no success, so reaching out just in case there is something obvious I've missed that maybe you also ran into even though you were using a different model entirely.
I am trying to use
TSRegressor
with theTST
architecture. and I get aTypeError: __init__() got an unexpected keyword argument 'custom_head'
I tried the fix from #597, but it doesn't work, the issue too complex for my understanding. @oguiza can you please take a look at it?
Here is the entire stacktrace