pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 476 forks source link

Simple LSTM model freezes on first iteration #1196

Closed gvamvou closed 5 years ago

gvamvou commented 5 years ago

Hello and thank you for providing a way to run pytorch on TPUs! I'm trying to test a simple example taken from https://github.com/pytorch/examples/tree/master/time_sequence_prediction I tried to follow the API Guide to make the necessary changes so it can run on a single TPU. These are the changes I made to the train.py script:


from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch_xla
import torch_xla.core.xla_model as xm

class Sequence(torch.nn.Module):
    def __init__(self):
            super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 51)
        self.linear = nn.Linear(51, 1)

    def forward(self, inp, device, future = 0):
        outputs = []
        h_t = torch.zeros(inp.size(0), 51, dtype=torch.double).to(device)
        c_t = torch.zeros(inp.size(0), 51, dtype=torch.double).to(device)
        h_t2 = torch.zeros(inp.size(0), 51, dtype=torch.double).to(device)
        c_t2 = torch.zeros(inp.size(0), 51, dtype=torch.double).to(device)

        for i, input_t in enumerate(inp.chunk(inp.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs

if __name__ == '__main__':
    device = xm.xla_device()

    # set random seed to 0
    np.random.seed(0)
    torch.manual_seed(0)
    # load data and make training set
    data = torch.load('traindata.pt')
    input = torch.from_numpy(data[3:, :-1]).to(device)
    target = torch.from_numpy(data[3:, 1:]).to(device)
    test_input = torch.from_numpy(data[:3, :-1]).to(device)
    test_target = torch.from_numpy(data[:3, 1:]).to(device)
    # build the model
    seq = Sequence().to(device)
    seq.double()
    criterion = nn.MSELoss()
    # use LBFGS as optimizer since we can load the whole data to train
    optimizer = optim.SGD(seq.parameters(), lr=0.8)
    #begin to train
    for i in range(15):
        print('STEP: ', i)
        def closure():
            optimizer.zero_grad()
            out = seq(input, device)
            loss = criterion(out, target)
            print('loss:', loss.item())
            loss.backward()
            print('after loss backward')
            return loss
        loss = closure()
        print('b4 step')
        xm.optimizer_step(optimizer, barrier=True)
        # begin to predict, no need to track gradient here
        with torch.no_grad():
            future = 1000
            pred = seq(test_input, device, future=future)
            loss = criterion(pred[:, :-future], test_target)
            print('test loss:', loss.item())
            y = pred.detach().numpy()

The program gets stuck right after loss is calculated. Is there something I might have missed? Thank you for your time!

dlibenzi commented 5 years ago

Hi! Can you try this version (I added a couple of debug options):

https://gist.github.com/dlibenzi/f021e40bb0524ce368b58aaf478d689c

Also, export these in the environment:

export XLA_IR_DEBUG=1
export XLA_HLO_DEBUG=1
export TF_CPP_VMODULE=tensor=5

Given that you gave us a full repro, if you could include a minimal traindata.pt file we could try on our side as well (or provide the required input shape so we can use our xu.SampleGenerator).

gvamvou commented 5 years ago

Hello and thank you for your answer. I tried your version with the exports you recommended and here's the log log.txt traindata.pt was generated with the generate_sine_wave.py in the repo I sent in the above message.

dlibenzi commented 5 years ago

With this, I was able to get out the metrics report:

https://gist.github.com/dlibenzi/a18ab4b6e675208bf8277a53a0f42be5

Which shows one op we need to lower to XLA (currently goes to pytorch CPU):

Counter: aten::mse_loss
  Value: 1

The model uses torch int64 (long) which is not necessary. So either change to int32 (torch int) or:

export XLA_USE_32BIT_LONG=1
dlibenzi commented 5 years ago

While we implement that (will do soon), maybe you can try some other loss function. If you are on nightly (Docker and TPU VM) we have L1Loss.

dlibenzi commented 5 years ago

I was able to get the source I attach below running on Cloud TPU.

TRIM_GRAPH_CHECK_FREQUENCY=100000 TRIM_GRAPH_SIZE=1000000 XLA_USE_32BIT_LONG=1 python ~/lstm.py

The first time it takes a while to compile, but once the compilation is cached (unless shapes or graph changes), the next runs should come pretty quickly. The ExecuteTime looks pretty small, almost too small 😄

Source:

https://gist.github.com/dlibenzi/0be1687922d99f55188f585ac7ac1534

I have disabled the future in this run. 1000 LSTM cells would definitely be too much.

Metrics:

https://gist.github.com/dlibenzi/54fd4433260ff62c715cb8cb9f5d7747

gvamvou commented 5 years ago

I ran your modified code and after a while it finally worked! I intend on running more complex models (encoder/decoder with attention) using these methods. Thank you very much for your support :smile:

dlibenzi commented 5 years ago

Notice that with pytorch/xla there can be, for some model (like the ones with explicit long loops), a long compilation time at the beginning. Unless the graph (or tensor shapes) change, the next time the compilation will be cached (unless you explicitly restart the TPU node).

gvamvou commented 5 years ago

I see. Would it be faster if the model was serialized with JIT?

dlibenzi commented 5 years ago

I honestly think native pytorch/GPU might be faster, but we are in the very early stages and we did not do any performance compare.

gvamvou commented 5 years ago

I understand. Thanks again for all the help :smiley: