bsmith89 / StrainFacts

Factorize metagenotypes to infer strains and their abundances
MIT License
11 stars 1 forks source link

Subsampling sites and other considerations for faster fitting #14

Closed elsherbini closed 6 months ago

elsherbini commented 7 months ago

In the README it says that you used --num-positions 5000 to fit in the paper, but I don't see that argument as an option in sfacts fit.

Two questions:

  1. Are there any considerations you did for choosing how to subsample the sites, or did you just pick randomly (i.e. did you weight on depth at all?)

  2. I made a simulated dataset with 3 strains and ~120 samples with different proportions of the strains. There were 26000 sites it used to fit. Strainfacts fit in about 2 hours when I told it a priori there were 3 strains. After I went to ~50 strains, I think it was saying it'd take ~400 hours to fit. Are there ways you'd recommend to make it run faster besides down sampling sites or choosing smaller number of strains?

sfacts fit \
    --verbose \
    --strains-per-sample 1 --strain-sample-exponent 0.75 \
    --random-seed 0 \
    merged_synthetic.mgen.nc merged_synthetic_exp075.world.nc
bsmith89 commented 7 months ago

Hey @elsherbini , thanks for your questions!

In the README it says that you used --num-positions 5000 to fit in the paper, but I don't see that argument as an option in sfacts fit.

That's right. It's instead an argument to sfacts sample_mgen. See https://byronjsmith.com/StrainFacts/filter_data.ipynb.html for how I'm using it.

Are there any considerations you did for choosing how to subsample the sites, or did you just pick randomly (i.e. did you weight on depth at all?)

Sites are sampled with uniform probability, but after filtering. In the examples, I filter for sufficient polymorphism: i.e. the --min-minor-allele-freq 0.05 to sample_mgen. While in the StrainFacts paper I didn't filter sites by depth—because the GT-Pro metagenotyper only returns core genome positions—if you're using a different metagenotyper (MIDAS, StrainPhlAn, InStrain, etc.) you will probably want to do depth filtering to remove auxiliary genome positions, duplicates, contamination from other species, etc.

I made a simulated dataset with 3 strains and ~120 samples with different proportions of the strains. There were 26000 sites it used to fit.

Did you use the simulator built in to StrainFacts or your own setup? If the former, I expect most positions are polymorphic. If not, you might want to filter for polymorphic sites first (as described above) since these should not affect your strain tracking.

Strainfacts fit in about 2 hours when I told it a priori there were 3 strains. After I went to ~50 strains, I think it was saying it'd take ~400 hours to fit.

Ah, this is something I should have documented better. Thanks for raising it. The time estimate that shows up on the command-line when running sfacts fit is an estimate _if it runs for all MAX_ITER iterations_ (defaults to 1,000,000 iterations). However, when it's working well, fitting should never actually run for that long.

How long it actually takes depends on how fast it converges, and will be affected by a number of model fitting hyperparameters:

  --lag1 LAG1           Setting for `pyro.optim.ReduceLROnPlateau` 'cooldown' argument. (default: 50)
  --lag2 LAG2           Setting for `pyro.optim.ReduceLROnPlateau` 'patience' argument. (default: 100)
  --no-jit              Don't use the PyTorch JIT; much slower steps, but no warm-up; may be useful for debugging. (default: True)
  --optimizer {Adam,Adamax,Adadelta,Adagrad,AdamW,RMSprop,SGD}
                        Which Pyro optimizer to use. (default: Adamax)
  --optimizer-learning-rate OPTIMIZER_LEARNING_RATE
                        Set the optimizer learning rate; otherwise use the default set in `sfacts.estimation.OPTIMIZERS`. (default: None)
  --min-optimizer-learning-rate MIN_OPTIMIZER_LEARNING_RATE
                        Learning rate threshold in reduction 'schedule' to terminate optimization. (default: 1e-06)
  --optimizer-clip-norm OPTIMIZER_CLIP_NORM
                        Set the clip_norm for Pyro optimizer; otherwise default is None (default: None)
  --anneal-wait ANNEAL_WAIT
                        Number of steps before annealed hyperparameters start stepping. (default: 0)
  --anneal-steps ANNEAL_STEPS
                        Number of steps before annealed hyperparameters are at the their final values; includes `--anneal-wait` steps. (default: 0)

The default hyperparameters should generally perform okay. However, for my own work I now use the following for a first pass:

sfacts fit \
        --model-structure model4 \
        --strain-sample-exponent 0.85 \
        --hyperparameters gamma_hyper=1e-10 pi_hyper=1e-3 pi_hyper2=1e-3 rho_hyper=1.0 rho_hyper2=1.0 \
        --anneal-hyperparameters gamma_hyper=0.1 rho_hyper=10.0 rho_hyper2=10.0 \
        --anneal-steps 20_000 \
        --random-seed 0 \
        --optimizer-learning-rate 0.05 \
        --min-optimizer-learning-rate 1e-2

This seems to both fit quickly and perform quite well. When that doesn't produce reasonable-seeming results, tuning these parameters is definitely still a challenge. Please reach out if you're having trouble with this, I'll be happy to try and help.

Are there ways you'd recommend to make it run faster besides down sampling sites or choosing smaller number of strains?

So besides the above model hyperparameters (which also affect convergence), the key fitting hyperparameters in the above that you should try are these:

        --anneal-steps 20_000 \
        --optimizer-learning-rate 0.05 \
        --min-optimizer-learning-rate 1e-2

If you're using hyperparameter annealing, --anneal-steps sets how many iterations are run before the final values are fixed. It therefore determines the minimum number of iterations that the model has to fit for. I do sometimes find that much larger values (i.e. 200,000 annealing iterations) can yield better results, but that comes at a runtime cost.

--optimizer-learning-rate determines how big of a step is take at each iteration at the start of fitting. Larger values will converge faster, but too large and it will overshoot and find itself in a bad local minimum. Very large values will often have numerical issues (errors that say something about ELBO == NaN). The default value for the default optimizer (Adamax) is also 0.05.

--min-optimizer-learning-rate sets the early stopping / convergence criterion. When the parameters have stopped changing for some number of steps (determined by --lag1 and --lag2), the learning rate is halved and fitting continues. This flag sets the threshold learning rate where the model is considered to have converged. Very small values (it defaults to 1e-6) can sometimes mean the fitting procedure creeps along without making much progress. However, in practice, convergence is usually pretty fast after the first halving.

I hope this helps. Please let me know if you're still having trouble!

bsmith89 commented 7 months ago

Are there ways you'd recommend to make it run faster besides down sampling sites or choosing smaller number of strains?

Oh, and I should also mention that running it on a GPU is usually a big speed up. See the --device cuda option.

elsherbini commented 7 months ago

Thank you so much for your thorough and thoughtful answer. Great to hear that estimate is for max iterations and will almost never hit that - that makes total sense in retrospect.

I'll continue to work on this and report back here. I'll give it a go with your new first pass hyperparameters and also look into running it on a GPU. This was super helpful, thanks!

elsherbini commented 6 months ago

Just coming back to say that StrainFacts has run FAST on our cluster's GPUs . I'm going to open another issue to ask about hyper parameters.