stefanonardo / pytorch-esn

An Echo State Network module for PyTorch.
MIT License
205 stars 43 forks source link

How to train the ESN in online mode? #21

Open ruitang-git opened 6 months ago

ruitang-git commented 6 months ago

Using svd to train mackey-glass data, the error can decrease to 3.2e-11.

While I change to gradient descent, after 100 epochs, the error becomes about 1e-4. I have no idea about how to tune paramters to make the error drop to that low. Can anybody helps?

    ## transform into dataset class
    train_dataset = torch.utils.data.dataset.TensorDataset(trX, trY)
    test_dataset = torch.utils.data.dataset.TensorDataset(tsX, tsY)

    ## transform dataset into dataloader
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1024, shuffle=False)

    model = ESN(input_size, hidden_size, output_size, readout_training='gd')
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0)
    epochs = 100
    for epoch in range(epochs):
        hidden = None
        train_loss = 0
        for batch in train_dataloader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            output, hidden = model(x, washout, hidden)
            loss = loss_fcn(output, y[washout[0]:])
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item()
        print("Training error:", train_loss/len(train_dataloader))