nanoporetech / bonito

A PyTorch Basecaller for Oxford Nanopore Reads
https://nanoporetech.com/
Other
394 stars 121 forks source link

CUDA implementation for training requires improvement #7

Closed kishwarshafin closed 4 years ago

kishwarshafin commented 4 years ago

Hi @iiSeymour ,

Over the weekend, I've re-designed the training implementation for Bonito in my fork. The basic improvement was to implement dataparallel class from PyTorch. I've also changed the parameters for my convenience.

You can use amp and dataparallel which allows the batch size from 64 to 384 on the 8GPU machine and the ETA per epoch is ~50mins as opposed to ~6hours without having dataparallel.

Runtime estimation: image

GPU Usage: image

I think the best idea of training would be to use amp distributedDataparallel to pump up the training speed more. But, at first, I want to make sure that this method works.

To that end. Could you please let me know which species are these reads from? So I can validate the models on a holdout species. Also, please feel free to close the issue.

iiSeymour commented 4 years ago

Hey @kishwarshafin

That's really great! I'll keep an eye on your fork. Multi-GPU training is/was on my todo list but I'm still parallelising ideas and doing hyperparameter searches across GPUs right now.

It's also worth noting that the best models I have managed to train came before 2d500126fc169c610b96b5bd70e0f3d86c74df6b where I moved from taking fixed-length chunks. I suspect that the default value of 3% for samples-per-read is too low.

The training data provided is a mix of Human, Yeast and E. coli.

HTH.

Chris,

kishwarshafin commented 4 years ago

Great, please share notes as you go forward. If you have a validation set from another species, that'd also be helpful to benchmark the models we train. The consensus accuracy you have implemented is good but I think it needs to go through something like MP-HELEN or racon-medaka to see if the errors are actually improving.

kishwarshafin commented 4 years ago

@iiSeymour ,

I was finally able to validate this basecaller. It looks like, you were right about the sample-to-read ratio. The model I trained performed very poorly against the model you uploaded here. Either the training got stuck at local minima, of which I have no evidence of (I checked gradient distribution over different timesteps of the training), or, you may be right, the 3% default is maybe too low.

However, details about the validation set: Chemistry: R941 Device: PromethION Species: Staphylococcus Aureus Total reads: 401,624

Details about analysis: I performed basecalling with Bonito -> Assembled the genome with Shasta in Modal consensus mode -> Compared against the truth using Pomoxis

I did the basecalling on a single GPU. And here are the results: image

You can see that the model you trained is very close to Guppy 305 than the one I trained, this big of a difference can't only be hyperparameters. Do you have any plans on how to tackle this? Also, there are significantly larger sets of haploid genomes available for training, but it needs some careful data curation. Let me know how are you planning to move forward, I will be happy to provide any help.

iiSeymour commented 4 years ago

Random subsampling really was a bad idea :man_facepalming: The idea was more chunks would help, clearly not!

I've moved to a variable sequence and signal approach with full (single) coverage, overlap, and longer sequence lengths (see b7be126cced99eccf199bed1a8a133856d8b1817). This is looking a lot better, in 24 hours on a DGX-2 I have a model that is beating the released model.

I haven't pushed my multi-gpu support yet as I have seen hanging on some machines (https://github.com/pytorch/pytorch/issues/1637#issuecomment-338268158) I want to investigate further.

Thanks for the analysis, it's really useful. I can open some issues if you want to contribute.

kishwarshafin commented 4 years ago

@iiSeymour, I wonder why the previous chunking strategy failed! I'll pull in the new changes you've made and train models over the break. I am personally interested in this basecaller because it gives a cleaner segway for us to integrate the basecaller observations to a pipeline like Shasta-marginPolish-HELEN for the genome assembly and polishing. This probably would mean that we will need a substantial amount of support to output intermediate results (I am happy to take on the responsibility). I will talk to others during our weekly meeting and see what the consensus is.

Regardless of that, I'd be happy to contribute as I think this is a much cleaner solution to a problem that we often over-complicate. I see your concerns regarding multi-GPU support, I'll wait for a satisfactory solution. Once you are convinced that you have a good training scheme, let me know and I'll try my best to help you with training and evaluation. I am also planning to incorporate hyperband at some point for automated hyperparameter tuning. I have had tremendous results with hyperband used in HELEN.

kishwarshafin commented 4 years ago

@iiSeymour ,

Is the training set you have provided by default here created using the basecalls from guppy? I mean, the reference per read as you have considered for the training are those guppy basecalled reads?

iiSeymour commented 4 years ago

@kishwarshafin yes, this is set taiyaki uses to produce the production models in guppy.

kishwarshafin commented 4 years ago

Ok, thanks, I think I see a bias toward the guppy calls when I am testing the model on a human genome. I was very surprised, but it makes sense that it's coming from the training set.

kishwarshafin commented 4 years ago

Hi @iiSeymour ,

So, I was debugging the training data generation pipeline and it seems that the model used for generating the training data is crucial to the training of the basecaller model. For example, I found a region in chr20 of HG002 where both Guppy and Bonito makes a wrong call of a homopolymer run in the majority of the reads. image In this case, the true sequence is 8Ts but most of the reads are predicted to be 5Ts. I was trying to see if I go back and train on this case, would the model fit for this observation. I generated a chr20 specific training data following these instructions: taiyaki-walkthough. But I suspect the model that defines the signal boundary is not sensitive enough? This is a read after labeling: image Although it's very hard to debug signal-level sequences, it's kind of clear that the boundaries defined for the 8Ts are not correct. One reason could be that the data is from promethION and the model was trained on minION. However, do you know of any other way the alignment can be generated, for example, for a new pore in which case you won't need a model for the alignment? Maybe an earlier HMM-based basecaller which has a better sense of signal boundaries?

iiSeymour commented 4 years ago

Hey @kishwarshafin

it seems that the model used for generating the training data is crucial to the training of the basecaller model

Right, I agree this is the route to producing better basecalling models and will be primary focus starting in January.