fani-lab / Osprey

Online Predatory Conversation Detection
0 stars 0 forks source link

RNN classifier #25

Open rezaBarzgar opened 1 year ago

rezaBarzgar commented 1 year ago

I will write my process on bulding a Recurrent Neural Network (RNN) here.

Some useful links:

rezaBarzgar commented 1 year ago

A new issue related to RNN. torch.nn.RNN does not handle the last batch while the last batch size is less than the given size. The simplest way is to use drop_last = True in Dataloader but we need a better way to handle that

hamedwaezi01 commented 1 year ago

Hi I am fixing the codes for the RNN baseline. Most of the work has progressed, but there have been issues regarding handling different sequence lengths and nan predictions of the RNN module. I already addressed the former by padding the sequences in each batch to the longest sequence. For example, we have a batch of 8 sequences where each vector has a length of 1000. So each batch has a shape of (8, x, 1000) where x is the length of a sequence, and it is different for each of them. If the largest sequence is of length 10 (10 tokens), then the batch will be reshaped to (8, 10, 1000). (I haven't committed the changes yet) I am now trying to figure out the problem with RNN nan predictions.

hamedwaezi01 commented 1 year ago

I am still working on the RNN problem. Some suggest that it might be because of the exploding gradient problem. I started to use the BCEWithLogitsLoss, and removed the sigmoid from the last layer, as the loss function calls sigmoid in itself (it is numerically more stable); It did not help much since I got nan values from the RNN. I read a couple of issues, and they suggested decrease the learning rate significantly. Now I do not get any nan values though the model is predicting zeros only. If you got any suggestions, please let me know.

hosseinfani commented 1 year ago

@hamedwaezi01 sorry for late reply. I believe the issue is solved now.