dannyneil / pytorch_plstm

MIT License
21 stars 3 forks source link

when include peephole in Phase-LSTM seems performance drop to LSTM? #6

Open evileleven opened 4 months ago

evileleven commented 4 months ago

I used the example to distinguish the frequency of the sine wave.

When I include peephole in Phase-LSTM, it seems the performance of Phase-LSTM is the same as LSTM. Anyone tried peephole?

evileleven commented 4 months ago

I just changed code in plstm_cell.py as follow: `import torch from torch import nn import math

OFF_SLOPE=1e-3

def set_grad(var): def hook(grad): var.grad = grad return hook

class GradMod(torch.autograd.Function): """ We can implement our own custom autograd Functions by subclassing torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. """ @staticmethod def forward(ctx, input, other): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. You can cache arbitrary Tensors for use in the backward pass using the save_for_backward method. """ result = torch.fmod(input, other) ctx.save_for_backward(input, other)
return result

@staticmethod
def backward(ctx, grad_output):
    """
    In the backward pass we receive a Tensor containing the gradient of the loss
    with respect to the output, and we need to compute the gradient of the loss
    with respect to the input.
    """
    x, y = ctx.saved_variables
    return grad_output * 1, grad_output * torch.neg(torch.floor_divide(x, y))

class PLSTM(nn.Module): def init(self, input_sz, hidden_sz, peephole=False): super().init() self.input_sz = input_sz self.hidden_size = hidden_sz self.Periods = nn.Parameter(torch.Tensor(hidden_sz, 1)) self.Shifts = nn.Parameter(torch.Tensor(hidden_sz, 1)) self.On_End = nn.Parameter(torch.Tensor(hidden_sz, 1)) self.W = nn.Parameter(torch.Tensor(input_sz, hidden_sz 4)) self.U = nn.Parameter(torch.Tensor(hidden_sz, hidden_sz 4)) self.ci = nn.Parameter(torch.Tensor(hidden_sz)) self.cf = nn.Parameter(torch.Tensor(hidden_sz)) self.co = nn.Parameter(torch.Tensor(hidden_sz)) self.bias = nn.Parameter(torch.Tensor(hidden_sz * 4)) self.init_weights() self.peephole = peephole

def init_weights(self):
    stdv = 1.0 / math.sqrt(self.hidden_size)
    for weight in self.parameters():
        weight.data.uniform_(-stdv, stdv)
    # Phased LSTM
    # -----------------------------------------------------
    nn.init.constant_(self.On_End, 0.05) # Set to be 5% "open"
    nn.init.uniform_(self.Shifts, 0, 100) # Have a wide spread of shifts
    # Uniformly distribute periods in log space between exp(1, 3)
    self.Periods.data.copy_(torch.exp((3 - 1) *
        torch.rand(self.Periods.shape) + 1))
    # -----------------------------------------------------

def forward(self, x, ts,
            init_states=None):
    """Assumes x is of shape (batch, sequence, feature)"""
    bs, seq_sz, _ = x.size()
    hidden_seq = []
    if init_states is None:
        h_t, c_t = (torch.zeros(bs, self.hidden_size).to(x.device), 
                    torch.zeros(bs, self.hidden_size).to(x.device))
    else:
        h_t, c_t = init_states

    # PHASED LSTM
    # -----------------------------------------------------
    # Precalculate some useful vars
    shift_broadcast = self.Shifts.view(1, -1)
    period_broadcast = abs(self.Periods.view(1, -1))
    on_mid_broadcast = abs(self.On_End.view(1, -1)) * 0.5 * period_broadcast
    on_end_broadcast = abs(self.On_End.view(1, -1)) * period_broadcast                       

    def calc_time_gate(time_input_n):
        # Broadcast the time across all units
        t_broadcast = time_input_n.unsqueeze(-1)
        # Get the time within the period
        in_cycle_time = GradMod.apply(t_broadcast + shift_broadcast, period_broadcast)            

        # Find the phase
        is_up_phase = torch.le(in_cycle_time, on_mid_broadcast)
        is_down_phase = torch.gt(in_cycle_time, on_mid_broadcast)*torch.le(in_cycle_time, on_end_broadcast)

        # Set the mask
        sleep_wake_mask = torch.where(is_up_phase, in_cycle_time/on_mid_broadcast,
                            torch.where(is_down_phase,
                                (on_end_broadcast-in_cycle_time)/on_mid_broadcast,
                                    OFF_SLOPE*(in_cycle_time/period_broadcast)))
        return sleep_wake_mask
    # -----------------------------------------------------

    HS = self.hidden_size
    for t in range(seq_sz):
        old_c_t = c_t
        old_h_t = h_t
        x_t = x[:, t, :]
        t_t = ts[:, t]
        # batch the computations into a single matrix multiplication

        if self.peephole:
            gates = x_t @ self.W + h_t @ self.U + self.bias

            i_t, f_t, g_t  = (
                torch.sigmoid(gates[:, :HS] + c_t * self.ci),  # input
                torch.sigmoid(gates[:, HS:HS * 2] + c_t * self.cf),  # forget
                torch.tanh(gates[:, HS * 2:HS * 3]),
            )

            c_t = f_t * c_t + i_t * torch.sigmoid(gates[:, HS * 2:HS * 3])
            o_t = torch.sigmoid(gates[:, HS * 3:]+ c_t* self.co) # output
            h_t = o_t * torch.tanh(c_t)

        else:
            gates = x_t @ self.W + h_t @ self.U + self.bias

            i_t, f_t, g_t, o_t = (
                torch.sigmoid(gates[:, :HS]), # input
                torch.sigmoid(gates[:, HS:HS*2]), # forget
                torch.tanh(gates[:, HS*2:HS*3]),
                torch.sigmoid(gates[:, HS*3:]), # output
            )

            c_t = f_t * c_t + i_t * g_t
            h_t = o_t * torch.tanh(c_t)

        # PHASED LSTM
        # -----------------------------------------------------
        # Get time gate openness
        sleep_wake_mask = calc_time_gate(t_t)
        # Sleep if off, otherwise stay a bit on
        c_t = sleep_wake_mask*c_t + (1. - sleep_wake_mask)*old_c_t
        h_t = sleep_wake_mask*h_t + (1. - sleep_wake_mask)*old_h_t
        # -----------------------------------------------------
        hidden_seq.append(h_t.unsqueeze(0))
    hidden_seq = torch.cat(hidden_seq, dim=0)
    # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)
    hidden_seq = hidden_seq.transpose(0, 1).contiguous()
    return hidden_seq, (h_t, c_t)`
evileleven commented 4 months ago

I increased the epoch to 100, and it seems that the accuracy of phase-peephole lstm increased around after 20 epochs. image

But the accuracy of phase-lstm increases after 5 or 6 epochs. image

Is it because of the more parameters phase-peephole lstm need to train than phase-vanilla lstm?