JWFanggit / LOTVS-CAP

23 stars 4 forks source link

Is it possible that the hidden state h in a GRU network does not update as time t increases? #9

Open DaliMIT01 opened 4 days ago

DaliMIT01 commented 4 days ago

Hello, @JWFanggit There may be an error in the code. When feeding 150 consecutive frames into a GRU network, each time step should input the current frame (a new image frame) and the hidden state from the previous time step. But in codes, hh is 0 and not undate. model.py:

def forward(self,x,z,y,toa,w):
        #x:rgb、z:foucs、y:label(positive,negative)、toa:time to accident、w:word(text)
        losses = {'total_loss': 0}
        all_output=[]
        x_11 = x
        # hh is the initial hidden state
        hh = Variable(torch.zeros(self.n_layers, x_11.size(0), self.h_dim))
        hh=hh.to(device)
        for i in range(x.size(1)):
            x1 =x_11[:, i]
            x2 =z[:, i]
            tokens_tensor, input_masks_tensors=self.text(w)
            x= self.fusion(tokens_tensor,input_masks_tensors,x1)
            x = self.features(x)
            x=x.permute(0,2,1).contiguous()
            output1 = self.gru_net(x, hh)  # TODO: hh need to update?
            #also can output foucs_p
            foucs_p=self.deconv(x)
            L1 = self._exp_loss(output1, y,i, toa, fps=30.0)
            L2=self.kl_loss(x2,foucs_p )
            loss_sum=(5*L1+L2).mean()
            losses['total_loss'] += loss_sum
            all_output.append(output1)