jmschrei / bpnet-lite

This repository hosts a minimal version of a Python API for BPNet.
MIT License
32 stars 14 forks source link

Practical limit on # input peaks #7

Closed gregorydonahue closed 5 months ago

gregorydonahue commented 5 months ago

Hi,

I'm running chrombpnet on some ATAC-seq data, and I was able to train the model without issue (chrombpnet fit) and also run predictions (chrombpnet predict). However, the following step (chrombpnet attribute) - which I gather calculates the SHAP scores and maybe runs TF-MoDISco - fails with the following:

$ chrombpnet attribute -p test.pipeline.json
/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/tangermeme/ersatz.py:448: NumbaDeprecationWarning: The keyword argument 'nopython=False' was supplied. From Numba 0.59.0 the default is being changed to True and use of 'nopython=False' will raise a warning as the argument will have no effect. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
@numba.jit(params, nopython=False)
/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/tangermeme/ersatz.py:448: NumbaDeprecationWarning: The keyword argument 'nopython=False' was supplied. From Numba 0.59.0 the default is being changed to True and use of 'nopython=False' will raise a warning as the argument will have no effect. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.
@numba.jit(params, nopython=False)
Loading Loci: 100%|███████████████████| 100933/100933 [01:08<00:00, 1477.14it/s]
0%|                                 | 255/2017880 [00:48<107:02:15,  5.24it/s]
Traceback (most recent call last):
  File "/home/gdonahue/miniconda3/envs/bpnetlite/bin/bpnet", line 512, in <module>
    X_attr = deep_lift_shap(wrapper.type(dtype), X.type(dtype),
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/tangermeme/deep_lift_shap.py", line 444, in deep_lift_shap
    raise(e)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/tangermeme/deep_lift_shap.py", line 424, in deep_lift_shap
    y = model(X_)[:, target]
  File "/home/gdonahue/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/bpnetlite/bpnet.py", line 128, in forward
    return self.model(X, X_ctl, **kwargs)[1]
  File "/home/gdonahue/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/bpnetlite/bpnet.py", line 46, in forward
    return self.model(X)
  File "/home/gdonahue/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/bpnetlite/chrombpnet.py", line 111, in forward
    acc_profile, acc_counts = self.accessibility(X)
  File "/home/gdonahue/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/bpnetlite/bpnet.py", line 293, in forward
    X_conv = self.rrelus[i](self.rconvs[i](X))
  File "/home/gdonahue/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
    hook_result = hook(self, args, result)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/site-packages/tangermeme/deep_lift_shap.py", line 107, in _f_hook
    module.output = outputs.clone().detach()
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 530.00 MiB (GPU 0; 14.62 GiB total capacity; 13.49 GiB already allocated; 194.94 MiB free; 13.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
Traceback (most recent call last):
  File "/home/gdonahue/miniconda3/envs/bpnetlite/bin/chrombpnet", line 435, in <module>
    subprocess.run(["bpnet", "attribute", "-p", args.parameters], check=True)
  File "/home/gdonahue/miniconda3/envs/bpnetlite/lib/python3.8/subprocess.py", line 516, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['bpnet', 'attribute', '-p', 'test.pipeline.json']' returned non-zero exit status 1.

...the relevant bit being the torch.cuda.OutOfMemoryError. As you can see, I'm loading ~100k ATAC-seq OCRs and then chrombpnet tries to load > 2 million something-or-others. This may be too much...I have tried setting the PYTORCH_CUDA_ALLOC_CONF environment variable to values higher or lower than the requested 530 MB, but nothing works. Am I just out of luck here? I also tried editing the JSON to restrict the interpret 'chroms' parameter to just chr10, thinking that this might limit the loaded sequences, but that also failed (same error).

Best, Greg

jmschrei commented 5 months ago

You need to set the batch size to be smaller. You can set the batch size separately for each step in the JSONs. Internally, all sequences initially live on the CPU and are moved to the GPU in batches and the results are moved back to the CPU after calculations are done. The number of sequences doesn't matter -- just the number that are moved over to the GPU at a time.

jmschrei commented 5 months ago

Please re-open if issues persist.