bowang-lab / BIONIC

Biological Network Integration using Convolutions
MIT License
59 stars 4 forks source link

Inference Pass OOM #55

Closed addiewc closed 1 year ago

addiewc commented 1 year ago

Although the model is able to train successfully for a set of 6 networks, the model OOMs during the inference pass, and simplifying the model architecture does not prevent this. Also, the model file has not been saved before the inference pass, so we can't try the pretrained_model_file input in the config to separate training and inference.

For reference, the stack trace at the end of model training looks like this: ` Loaded best model from epoch 2984 with loss 5289.523550

Traceback (most recent call last): File "/homes/gws/addiewc/anaconda3/envs/bionic/bin/bionic", line 8, in sys.exit(main()) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/cli.py", line 29, in main app() File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/typer/main.py", line 214, in call return get_command(self)(*args, kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/click/core.py", line 829, in call return self.main(args, kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/click/core.py", line 782, in main rv = self.invoke(ctx) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/click/core.py", line 1066, in invoke return ctx.invoke(self.callback, ctx.params) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/click/core.py", line 610, in invoke return callback(args, kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/typer/main.py", line 497, in wrapper return callback(useparams) # type: ignore File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/cli.py", line 23, in train trainer.forward() File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/train.py", line 399, in forward , emb, _, learned_scales, label_preds = self.model(data_flows, mask, evaluate=True) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, *kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/model/model.py", line 205, in forward x = self.encodersnet_idx File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(input, kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/model/model.py", line 70, in forward x = self.gat((x, x[: size[1]]), edge_index, size=size, edge_weights=weights) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/bionic/model/layers.py", line 94, in forward out = self.propagate(edge_index, x=(x_l, x_r), alpha=(alpha_l, alpha_r), size=size) File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 233, in propagate coll_dict = self.collect(self.user_args__, edge_index, size, File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 157, in collect data = self.lift__(data, edge_index, File "/homes/gws/addiewc/anaconda3/envs/bionic/lib/python3.8/site-packages/torch_geometric/nn/conv/message_passing.py", line 127, in lift return src.index_select(self.node_dim, index) RuntimeError: CUDA out of memory. Tried to allocate 4.92 GiB (GPU 0; 15.74 GiB total capacity; 8.65 GiB already allocated; 3.21 GiB free; 11.01 GiB reserved in total by PyTorch) `

duncster94 commented 1 year ago

Hi, thanks for opening an issue.

Your training loss is very large. Could you provide details about the sizes of your 6 networks (number of nodes and number of edges)? It sounds like you might have very dense networks which GNNs generally don't handle well, see here. Also, it would be helpful if you provided the hyperparameters you used for the model and the GPU you are running on.

In BIONIC's forward pass the embedding for each node (one at a time) is computed using the node's full neighborhood, rather than a sampled neighborhood (as done during training) to ensure the embedding contains information from all neighbors. If you have very dense networks and many encoder layers (gat_shapes.n_layers) this could result in many nodes being required to compute the single node's embedding. In the extreme example of a complete graph (where every node connects to every other node) all nodes would be required to compute the embedding for a single node. This problem can be rectified by sparsifying your networks to only include the strongest edges.

addiewc commented 1 year ago

Thanks for the feedback. I've been using all of the default hyperparameters for BIONIC with an A4000 gpu. My 6 networks all have 22.2K nodes, and have 492K, 23K, 76K, 7.79M, 7.78M, and 3.82M edges respectively.

I'll try sparsifying the networks to avoid the neighborhood issue.

duncster94 commented 1 year ago

Yes, those networks are too dense. I would also set gat_shapes.n_layers to 2 rather than 3 (or even 1 if you still have OOM issues). This will reduce the effective receptive field and memory usage without much change in performance (in my experience).