timeseriesAI / tsai

Time series Timeseries Deep Learning Machine Learning Python Pytorch fastai | State-of-the-art Deep Learning library for Time Series and Sequences in Pytorch / fastai
https://timeseriesai.github.io/tsai/
Apache License 2.0
5.07k stars 633 forks source link

can not use TST in summary #710

Closed YYKKKKXX closed 1 year ago

YYKKKKXX commented 1 year ago

hello, i have a problem using TST my code is: tst = TST(1,1,66,d_model=8,n_heads=4).cuda() summary(tst,(1,66),batch_size=2) and it will raise error:an't multiply sequence by non-int of type 'list'

i will appreciate your help!

oguiza commented 1 year ago

Hi @YYKKKKXX, I don't know where the summary function is coming from. The way to use summary in tsai is this:

from tsai.basics import *
from tsai.models.TST import TST

X = np.random.rand(16, 1, 66)
y = np.random.rand(X.shape[0])
splits = TSSplitter(show_plot=False)(y)
tfms = [None, TSClassification()]
batch_tfms = TSStandardize(by_sample=True)
dls = get_ts_dls(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms)
tst = TST(1,1,66,d_model=8,n_heads=4)
learn = ts_learner(dls, tst, metrics=accuracy, cbs=[ShowGraph()])
learn.summary()

or the equivalent:

from tsai.basics import *

X = np.random.rand(16, 1, 66)
y = np.random.rand(X.shape[0])
splits = TSSplitter(show_plot=False)(y)
tfms = [None, TSClassification()]
batch_tfms = TSStandardize(by_sample=True)
arch_config = dict(d_model=8,n_heads=4)
tscls = TSClassifier(X, y, splits=splits, tfms=tfms, batch_tfms=batch_tfms, 
                      arch="TST", arch_config=arch_config, metrics=accuracy, cbs=[ShowGraph()])
tscls.summary()
YYKKKKXX commented 1 year ago

hi,thank you for explain the summary function is import from torchsummary, and what i import is: 'from torchsummary import summary' the summary function can represent the output shape of each layer in a model, and i want to use it to show my TST model output now i know that i can use learn.summary()! Thank you very much