RasmussenLab / vamb

Variational autoencoder for metagenomic binning
MIT License
244 stars 44 forks source link

Unable to run taxometer using `--cuda` #360

Open Prunoideae opened 1 week ago

Prunoideae commented 1 week ago

Python: 3.10.14 Vamb: 9810ef047ef41f986b7ddcfdd3dc06947ee0ab6c on GitHub

Looks like the

https://github.com/RasmussenLab/vamb/blob/bdd14d12855081dbe0ab0c42c3cd7d948f997943/vamb/taxvamb_encode.py#L853

needs to be on cuda.

Fixing it fixes the error, but it seems to have low training speed improvements to use cuda (9min/epoch -> 7min/epoch) on a 3090. NVTop showed a quite low GPU utilization (GPU utilization, GPU memory, GPU memory%, CPU utilization, Host memory):

image

Log:

2024-09-11 09:58:09.508 | INFO    | Starting Vamb version 4.1.4.dev134+g9810ef0
2024-09-11 09:58:09.509 | INFO    | Random seed is 21359552096367181
2024-09-11 09:58:09.509 | INFO    | Invoked with CLI args: 'f/home/-----/miniconda3/envs/meta/bin/vamb bin taxvamb --outdir taxvamb --fasta assembly.filtered.fa --bamdir filtered_bams --taxonomy extracted.taxa.tsv -m 1500 --cuda -p 16'
2024-09-11 09:58:09.509 | INFO    | Loading TNF
2024-09-11 09:58:09.509 | INFO    |     Minimum sequence length: 1500
2024-09-11 09:58:09.509 | INFO    |     Loading data from FASTA file assembly.filtered.fa
2024-09-11 10:03:07.536 | INFO    |     Kept 8111121845 bases in 2916808 sequences
2024-09-11 10:03:07.536 | INFO    |     Processed TNF in 298.03 seconds.

2024-09-11 10:03:07.536 | INFO    | Loading depths
2024-09-11 10:03:07.536 | INFO    |     Reference hash: ab5a09db778cb776e0d97da9ecfdb9ca
2024-09-11 10:03:07.536 | INFO    |     Parsing 7 BAM files with 16 threads
2024-09-11 10:03:07.536 | INFO    |     Min identity: 0.0
2024-09-11 10:30:01.329 | INFO    |     Order of columns is:
2024-09-11 10:30:01.335 | INFO    |          0: filtered_bams/dongjiang_rb.sorted.bam
2024-09-11 10:30:01.335 | INFO    |          1: filtered_bams/xijiang_rb.sorted.bam
2024-09-11 10:30:01.336 | INFO    |          2: filtered_bams/yujiang_rb.sorted.bam
2024-09-11 10:30:01.336 | INFO    |          3: filtered_bams/unclassified.sorted.bam
2024-09-11 10:30:01.336 | INFO    |          4: filtered_bams/nan-bei_pan_rb.sorted.bam
2024-09-11 10:30:01.336 | INFO    |          5: filtered_bams/beijiang_rb.sorted.bam
2024-09-11 10:30:01.336 | INFO    |          6: filtered_bams/hongliu_rb.sorted.bam
2024-09-11 10:30:01.337 | INFO    |     Processed abundance in 1613.8 seconds.

2024-09-11 10:30:01.337 | INFO    | Predicting missing values from taxonomy
2024-09-11 10:30:12.296 | INFO    | 20190 nodes in the graph
2024-09-11 10:30:31.696 | INFO    |     Created dataloader
2024-09-11 10:30:31.697 | INFO    | Starting training the taxonomy predictor
2024-09-11 10:30:31.697 | INFO    | Using threshold 0.5
2024-09-11 10:30:32.216 | INFO    |     Network properties:
2024-09-11 10:30:32.216 | INFO    |     CUDA: True
2024-09-11 10:30:32.216 | INFO    |     Hierarchical loss: flat_softmax
2024-09-11 10:30:32.216 | INFO    |     Alpha: 0.15
2024-09-11 10:30:32.217 | INFO    |     Beta: 200.0
2024-09-11 10:30:32.217 | INFO    |     Dropout: 0.2
2024-09-11 10:30:32.217 | INFO    |     N hidden: 512, 512, 512, 512
2024-09-11 10:30:32.217 | INFO    |     Training properties:
2024-09-11 10:30:32.217 | INFO    |     N epochs: 256
2024-09-11 10:30:32.217 | INFO    |     Starting batch size: 1024
2024-09-11 10:30:32.217 | INFO    |     Batchsteps: 25, 75, 150, 225
2024-09-11 10:30:32.217 | INFO    |     Learning rate: 0.001
2024-09-11 10:30:32.217 | INFO    |     N labels: torch.Size([2916808, 7])
2024-09-11 10:30:33.033 | ERROR   | An error has been caught in function 'main', process 'MainProcess' (3829388), thread 'MainThread' (140155813640000):
Traceback (most recent call last):

  File "/home/-----/miniconda3/envs/meta/bin/vamb", line 8, in <module>
    sys.exit(main())
    │   │    └ <function main at 0x7f773b61e710>
    │   └ <built-in function exit>
    └ <module 'sys' (built-in)>

> File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/__main__.py", line 2200, in main
    run(runner, opt.common.general)
    │   │       │   │      └ <vamb.__main__.GeneralOptions object at 0x7f7797337e20>
    │   │       │   └ <vamb.__main__.BinnerCommonOptions object at 0x7f773b6397e0>
    │   │       └ <vamb.__main__.BinTaxVambOptions object at 0x7f773b639ba0>
    │   └ functools.partial(<function run_vaevae at 0x7f773b61dea0>, <vamb.__main__.BinTaxVambOptions object at 0x7f773b639ba0>)
    └ <function run at 0x7f773b61cb80>

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/__main__.py", line 649, in run
    runner()
    └ functools.partial(<function run_vaevae at 0x7f773b61dea0>, <vamb.__main__.BinTaxVambOptions object at 0x7f773b639ba0>)

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/__main__.py", line 1420, in run_vaevae
    predict_taxonomy(
    └ <function predict_taxonomy at 0x7f773b61dd80>

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/__main__.py", line 1332, in predict_taxonomy
    model.trainmodel(
    │     └ <function VAMB2Label.trainmodel at 0x7f77fe160d30>
    └ VAMB2Label(
        (encoderlayers): ModuleList(
          (0): Linear(in_features=111, out_features=512, bias=True)
          (1-3): 3 x Linea...

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/taxvamb_encode.py", line 1047, in trainmodel
    dataloader = self.trainepoch(
                 │    └ <function VAMB2Label.trainepoch at 0x7f77fe160ca0>
                 └ VAMB2Label(
                     (encoderlayers): ModuleList(
                       (0): Linear(in_features=111, out_features=512, bias=True)
                       (1-3): 3 x Linea...

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/taxvamb_encode.py", line 968, in trainepoch
    labels_out = self(depths_in, tnf_in, abundances_in, weights)
                 │    │          │       │              └ tensor([[0.8345],
                 │    │          │       │                        [0.8702],
                 │    │          │       │                        [0.8553],
                 │    │          │       │                        ...,
                 │    │          │       │                        [0.8631],
                 │    │          │       │                        [1.8225],
                 │    │          │       │                        [1.1693]], dev...
                 │    │          │       └ tensor([[ 0.3633],
                 │    │          │                 [ 2.2363],
                 │    │          │                 [ 0.6754],
                 │    │          │                 ...,
                 │    │          │                 [-0.1923],
                 │    │          │                 [ 0.0105],
                 │    │          │                 [-0.2597]...
                 │    │          └ tensor([[-1.1825,  0.0716, -0.6795,  ...,  0.1214, -0.6343,  0.1674],
                 │    │                    [-0.9313, -0.4224,  0.4400,  ...,  0.4502, -0.1...
                 │    └ tensor([[0.0000, 0.4290, 0.1922,  ..., 0.0000, 0.0487, 0.2475],
                 │              [0.3522, 0.1935, 0.0540,  ..., 0.1046, 0.0135, 0.0644...
                 └ VAMB2Label(
                     (encoderlayers): ModuleList(
                       (0): Linear(in_features=111, out_features=512, bias=True)
                       (1-3): 3 x Linea...

  File "/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           │    │           │       └ {}
           │    │           └ (tensor([[0.0000, 0.4290, 0.1922,  ..., 0.0000, 0.0487, 0.2475],
           │    │                     [0.3522, 0.1935, 0.0540,  ..., 0.1046, 0.0135, 0.064...
           │    └ <function Module._call_impl at 0x7f784593bbe0>
           └ VAMB2Label(
               (encoderlayers): ModuleList(
                 (0): Linear(in_features=111, out_features=512, bias=True)
                 (1-3): 3 x Linea...
  File "/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           │             │       └ {}
           │             └ (tensor([[0.0000, 0.4290, 0.1922,  ..., 0.0000, 0.0487, 0.2475],
           │                       [0.3522, 0.1935, 0.0540,  ..., 0.1046, 0.0135, 0.064...
           └ <bound method VAMB2Label.forward of VAMB2Label(
               (encoderlayers): ModuleList(
                 (0): Linear(in_features=111, out_features=...

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/taxvamb_encode.py", line 872, in forward
    labels_out = self._predict(tensor)
                 │    │        └ tensor([[ 0.0000,  0.4290,  0.1922,  ..., -0.6343,  0.1674,  0.3633],
                 │    │                  [ 0.3522,  0.1935,  0.0540,  ..., -0.1319, -0.6...
                 │    └ <function VAMB2Label._predict at 0x7f77fe160940>
                 └ VAMB2Label(
                     (encoderlayers): ModuleList(
                       (0): Linear(in_features=111, out_features=512, bias=True)
                       (1-3): 3 x Linea...

  File "/mnt/nvme1n1/public/-----/projects/meta/vamb/vamb/taxvamb_encode.py", line 866, in _predict
    reconstruction = self.outputlayer(tensor)
                     │                └ tensor([[ 0.5725,  0.0408, -0.4522,  ...,  1.0374, -0.5149,  0.1081],
                     │                          [ 0.2728, -0.5565, -0.4375,  ..., -0.6076, -0.5...
                     └ VAMB2Label(
                         (encoderlayers): ModuleList(
                           (0): Linear(in_features=111, out_features=512, bias=True)
                           (1-3): 3 x Linea...

  File "/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           │    │           │       └ {}
           │    │           └ (tensor([[ 0.5725,  0.0408, -0.4522,  ...,  1.0374, -0.5149,  0.1081],
           │    │                     [ 0.2728, -0.5565, -0.4375,  ..., -0.6076, -0....
           │    └ <function Module._call_impl at 0x7f784593bbe0>
           └ Linear(in_features=512, out_features=18375, bias=True)
  File "/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           │             │       └ {}
           │             └ (tensor([[ 0.5725,  0.0408, -0.4522,  ...,  1.0374, -0.5149,  0.1081],
           │                       [ 0.2728, -0.5565, -0.4375,  ..., -0.6076, -0....
           └ <bound method Linear.forward of Linear(in_features=512, out_features=18375, bias=True)>
  File "/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           │ │      │      │            └ Linear(in_features=512, out_features=18375, bias=True)
           │ │      │      └ Linear(in_features=512, out_features=18375, bias=True)
           │ │      └ tensor([[ 0.5725,  0.0408, -0.4522,  ...,  1.0374, -0.5149,  0.1081],
           │ │                [ 0.2728, -0.5565, -0.4375,  ..., -0.6076, -0.5...
           │ └ <built-in function linear>
           └ <module 'torch.nn.functional' from '/home/-----/miniconda3/envs/meta/lib/python3.10/site-packages/torch/nn/functional.py'>

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)
jakobnissen commented 1 week ago

@sgalkina can you take a look?