lucidrains / iTransformer

Unofficial implementation of iTransformer - SOTA Time Series Forecasting using Attention networks, out of Tsinghua / Ant group
MIT License
445 stars 36 forks source link

Is it possible implement a custom training loop? #33

Open Mikiya0515 opened 1 month ago

Mikiya0515 commented 1 month ago

Hi @lucidrains, Thank you for your work.

Is it possible to implement a custom training loop using the following code?

Model

model = itransformer(
    num_variates=10,
    lookback_len=360,
    dim=367,
    depth=9,
    heads=6,
    dim_head=46,
    pred_length=pred_len,
    num_tokens_per_variate=6
)

Loss_func and Optimizer

criterion = torch.nn.MSELoss()
lr = 0.001
weight = 0.003
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr, weight_decay=weight)

Training loop

for i in range(1, epochs+1):
    train_losses = []
    print(f"Epoch_{i}-------------------------------------")
    model.train()
    for batch_idx, (x_train, y_train) in enumerate(train_loader):
         x_train, y_train = x_train.to(device), y_train.to(device)

        optimizer.zero_grad()
        train_output = model(x_train)

        loss = criterion(train_output[pred_len], y_train)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
        if batch_idx % 50 == 0:
            print(f"train_loss: {loss.item()}[{batch_idx}/{len(train_loader)}]")

    val_losses = []
    model.eval()
    with torch.no_grad():
        for batch_idx, (x_val, y_val) in enumerate(val_loader):
            x_val, y_val = x_val.to(device), y_val.to(device)
            val_output = model(x_val)
            loss = criterion(val_output[pred_len], y_val)
            val_losses.append(loss.item())
            if batch_idx % 50 == 0:
                print(f"val_loss: {loss.item()}[{batch_idx}/{len(val_loader)}]")
ikhsansdqq commented 1 month ago

You can indeed use a custom training loop since iTransformer relies on PyTorch so you can modify as much as you like as it follows the PyTorch guidelines.

I used a custom training loop and tested with different criteria and it's perfectly fine on mine.

Mikiya0515 commented 1 month ago

@ikhsansdqq Thank you for your answer! I gonna try custom traningn loop!