davidtvs / pytorch-lr-finder

A learning rate range test implementation in PyTorch
MIT License
921 stars 120 forks source link

How to use w/ LSTM #51

Closed phiweger closed 4 years ago

phiweger commented 4 years ago

Hi,

I would like to use the lr-finder with an LSTM. in the forward step of my model I do:

for epoch in range(100):
    model.train()
    hidden = model.init_hidden(batch_size)
    total_loss = 0

    for data, target in dataloader:
        hidden = repackage_hidden(hidden)
        output, hidden = model(data, hidden)
        loss = loss_fn(output, target.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

My dataloader yields x, y where x is a sequence and y is the next step in the sequence (think language model), e.g.

data = [1, 2, 3, 4]
target = [2, 3, 4, 5]

Now when I try to do:

lr_finder.range_test(dataloader, end_lr=100, num_iter=100)

... I get the following error:

TypeError: forward() missing 1 required positional argument: 'hidden'

How can I pass hidden to the model using lr-finder?

NaleRaphael commented 4 years ago

Hi @phiweger

Generally, it's recommended to move those code related to data processing into Dataset.__getitem__() or model.forward(), and that would make you be able to utilize your model and loss function easier. (you can check out this comment made by @davidtvs to understand further)

If it won't take you too much effort to change the architecture, you can try to rewrite your model to the following one:

class MyLSTM(nn.Module):
    def __init__(self, ...):
        self.lstm = nn.LSTM(self.input_dim, self.hidden_dim, self.layers)
        self.last_layer = nn.Sequential(nn.Liner(...), nn.Sigmoid())

        # Just initialize it to 1, we can change it later by `set_batch_size`.
        self.batch_size = 1

        # Since hidden state should be memorized for the whole epoch, we can
        # cache it here.
        self.hidden = None

    def set_batch_size(self, val):
        self.batch_size = val

    def init_hidden(self):
        # LSTM requires (h_n, c_n)
        self.hidden = (
            torch.rand((self.layers, self.batch_size, self.hidden_dim)),
            torch.rand((self.layers, self.batch_size, self.hidden_dim))
        )

    def forward(self, inputs):
        # Call your `repackage_hidden()` function here
        self.hidden = repackage_hidden(self.hidden)

        rnn_output, self.hidden = self.lstm(inputs, self.hidden)
        outputs = self.last_layer(rnn_output)
        return outputs

With the new model, you might also need to modify your training loop:

model = MyLSTM(...)
model.set_batch_size(...)

model.train()

for epoch in range(100):
    # Initialize hidden state at the beginning of each epoch
    model.init_hidden()
    total_loss = 0

    # Note that `target` is already returned in different shape from
    # `dataloader` in this example. This could simplified the content
    # of the following training loop.
    for data, target in dataloader:
        output = model(data)
        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

Then you can use LRFinder in this way:

lr_finder = LRFinder(model, criterion, optimizer)

# Initialize hidden state before running `range_test`
model.init_hidden()

lr_finder.range_test(dataloader, num_iter=100)

However, that is not a strict convention, and there are still many applications which didn't follow it. In this case, you can make some wrappers to make it work. (just like the example @davidtvs shows in this post)

class MyModelWrapper(nn.Module):
    def __init__(self, lstm_model):
        self.lstm_model = lstm_model
        self.hidden = None

    def init_hidden(self, batch_size):
        # Just call the original impl. of `init_hidden()`
        self.hidden = self.lstm_model.init_hidden(batch_size)

    def forward(self, inputs):
        assert self.hidden is not None, "hidden state is not initialized."

        # Call your `repackage_hidden()` function here
        self.hidden = repackage_hidden(self.hidden)

        output, self.hidden = self.lstm_model(inputs, self.hidden)
        return output

class MyLossFunctionWrapper(nn.Module):
    def __init__(self, loss_fn):
        self.loss_fn = loss_fn

    def forward(self, outputs, targets):
        return self.loss_fn(outputs, targets.view(-1))

With these wrappers, you can use LRFinder in this way:

# Wrap your original model and loss function
model_wrapper = MyModelWrapper(model)
loss_wrapper = MyLossFunctionWrapper(loss_fn)

lr_finder = LRFinder(model_wrapper, loss_wrapper, optimizer)

# Initialize hidden state before running `range_test`
model_wrapper.init_hidden(batch_size)

lr_finder.range_test(dataloader, num_iter=100)

The benefit of adopting this approach is that you won't need to modify your existing training loop.

But since I'm not familiar with developing a language model in LSTM, I'm not sure whether it's correct to initialize the hidden state just at the beginning of each training epoch. (I just found this post talking about it)

Anyway, just feel free to let me known if there is something I missed.

phiweger commented 4 years ago

AWESOME answer, thank you very much. I ended up rewriting the model class to integrate all the hidden init/ detach and wrapped the loss fn as you suggested. Works like a charm. Again, thank you for taking the time to formulate such a helpful and detailed answer.