calico / basenji

Sequential regulatory activity predictions with deep convolutional neural networks.
Apache License 2.0
410 stars 126 forks source link

Question - Akita Test Set Loss Replication #108

Closed GreatArcStudios closed 2 years ago

GreatArcStudios commented 2 years ago

In the Akita paper, the MSE error is 0.14 on the test set. In my replication tests using Pytorch, the best test set error was ~0.21. The hyperparameters are the ones from params.json in the akita folder, e.g., batch norm momentum = 0.9265, sgd momentum = 0.997, etc.... I'm using the get_data.sh file. Is it safe to assume that the data from get_data.sh has been preprocessed? Otherwise what might be an issue/difference that's causing this discrepancy?

The one difference in parameter choice between our implementation and the paper is that we instead use GELU activation instead of relu.

gfudenberg commented 2 years ago

Hi Eric, TFrecords files should be preprocessed with akita_data.py. It would be interesting if the GeLu vs ReLu made that much of a difference, but the reported hyperparamters were found after tuning with dragonfly. Another thing to look into would be how much data augmentation you're doing in the pytorch replication test.

GreatArcStudios commented 2 years ago

Hi Geoff,

We downloaded the data from https://github.com/calico/basenji/blob/master/manuscripts/akita/get_data.sh. This the preprocessed data right? Could you clarify what you mean by the amount of data augmentation?

gfudenberg commented 2 years ago

yes, those should be the pre-processed tfrecords.

for data augmentation, I'm referring to reverse complementing & shifting: https://github.com/calico/basenji/blob/19f5345be2726600c62df55c223d9f7184a68e7a/basenji/seqnn.py#L114

GreatArcStudios commented 2 years ago

Hmm ok, we used a stochastic shift of 11.Should that be enough?

davek44 commented 2 years ago

Hi Eric, I did try gelu for Akita, and it was inferior to relu. However, I doubt it explains the substantial difference. Stochastic shift 11 should be fine; the reverse complement augmentation is much more important. When we measure test set accuracy, we also average the forward and reverse complement maps as an ensemble.

In general, I've found it to be very difficult to port complex models across frameworks exactly. I'd encourage you to work with our code on small examples to replicate its behavior in pytorch, and then slowly scale up and add more complexity.

GreatArcStudios commented 2 years ago

Hi David,

I see. By averaging across the forward and reverse complement maps, do you mean you calculated the MSE for the ~400 test sequences of both the reverse complement and forward maps and took the simple average?

Yeah, we started small, trying to port modules of this repo to our code. Framework implementation and such can definitely be quite finicky, so thanks for the replies!

davek44 commented 2 years ago

Yes, ensemble averaging is performed in-graph here: https://github.com/calico/basenji/blob/master/basenji/seqnn.py#L218 Akita was more complicated than the 1D Basenji, but careful upper triangular matrix bookkeeping allowed it to be done.

GreatArcStudios commented 2 years ago

Gotcha. I’m currently on my phone, so I haven’t gotten a chance to really closely look at that code portion, but it seems like it just takes the non-rc predictions, and then also generates the rc predictions and then takes the average of them? Is that correct?

GreatArcStudios commented 2 years ago

Also is the test set from the get_data shell script preprocessed too?

davek44 commented 2 years ago

Yes, all of the data downloaded by that script has been transformed as described in the paper, and is ready for training/testing.

GreatArcStudios commented 2 years ago

Ok thanks for the help!