Closed johankit closed 1 year ago
It definitely looks like you should have enough memory to preprocess this dataset. Let's get to the bottom of this.
The OOM error is on an array with 167 million elements and takes up 612GiB. So this means that each element (node_id) takes up about ~4KB on average. That seems far too high for a simple node_id, especially since your triple file only has 46GB (which implies only 46 bytes per triple with 1B triples).
What does the tsv look like? Worth checking if the delimiter is actually a tab. Does it look something like:
src_label rel_label dst_label
"node_src_0" "rel_0" "node_dst_0"
"node_src_1" "rel_1" "node_dst_1"
"node_src_2" "rel_2" "node_dst_2"
"node_src_3" "rel_3" "node_dst_3"
I'm wondering if numpy is doing something strange in the backend when handling strings, like allocating too much memory per element in this operation: np.concatenate([unique_src.astype(str), unique_dst.astype(str)])
We have two options for a workaround here:
TorchEdgeListConverter
similar to ogbl_citation2. Where a 2D numpy array for your edge list is passed into the TorchEdgeListConverter
class, and the raw string ids are converted into integer ids manually before input. Using np.loadtxt("/mount/data/works.tsv")
with the proper arguments should work to read in your edge list. This will go down a different code path than the line that throws the error and should bypass the problem.The usage of the TorchEdgeListConverter
will look like:
import numpy as np
from marius.tools.preprocess.converters.torch_converter import TorchEdgeListConverter
train_list = np.loadtxt("/mount/data/works.tsv", delimiter="/t", skiprows=header_length)
# convert string ids to integer ids using custom code written by you
train_list = convert_to_integers(train_list)
converter = TorchEdgeListConverter(
output_dir="/mount/marius_preprocessed/works/",
train_edges=train_list,
num_partitions=4,
columns=[0, 1, 2],
format="numpy",
splits=[.8, .1, .1])
converter.convert()
Full list of arguments for the TorchEdgeListConverter located here
--spark
to the preprocessor. This will be a bit slower than the in-memory preprocessing but should be able to handle strings better.Thanks for the quick response!
I tried the conversion using Spark just now, it gives me the same error unfortunately. See the full log below. Interestingly, the allocated memory peaks at approx. 160GB during reading and remapping, then declines at one point until numpy's OOM error crashes the execution.
My .tsv has the following structure. According to Notepad++, real tabs are used as delimiters (not whitespaces).
node1 label node2
n1:W2294004008 n2:hasRelatedWork n1:W2258524308
n1:W2026714175 n2:citedByCount 1
n1:W3157275129 n2:isRetracted False
I'll look into writing the custom script for manual preprocessing.
Do I understand correctly, that the TorchEdgeListConverter
class requires an edge list like the below form?
Meaning I'd go through the .tsv to assign each unique entry an integer ID. Just as an example, not related to my tsv above:
node1 label node2
0 1 2
0 3 4
2 1 5
The error log:
Preprocess custom dataset
Reading edges
Remapping Edges
Traceback (most recent call last):
File "/usr/local/bin/marius_preprocess", line 11, in <module>
load_entry_point('marius==0.0.2', 'console_scripts', 'marius_preprocess')()
File "/usr/local/lib/python3.6/dist-packages/marius/tools/marius_preprocess.py", line 160, in main
columns=args.columns,
File "/usr/local/lib/python3.6/dist-packages/marius/tools/preprocess/custom.py", line 66, in preprocess
converter.convert()
File "/usr/local/lib/python3.6/dist-packages/marius/tools/preprocess/converters/torch_converter.py", line 521, in convert
sequential_deg_nodes=self.sequential_deg_nodes,
File "/usr/local/lib/python3.6/dist-packages/marius/tools/preprocess/converters/torch_converter.py", line 148, in map_edge_lists
return map_edge_list_dfs(edge_lists, known_node_ids, sequential_train_nodes, sequential_deg_nodes)
File "/usr/local/lib/python3.6/dist-packages/marius/tools/preprocess/converters/torch_converter.py", line 88, in map_edge_list_dfs
unique_nodes = np.unique(np.concatenate([unique_src.astype(str), unique_dst.astype(str)]))
numpy.core._exceptions.MemoryError: Unable to allocate 612. GiB for an array with shape (167416627,) and data type <U981
Hmm I will investigate the spark issue, sorry for the problems!
Do I understand correctly, that the TorchEdgeListConverter class requires an edge list like the below form?
Yes that is correct, however, the edge types (labels) and nodes don't need to be considered together when assigning ids. So all node ids are in the range [0, num_nodes - 1] and all edge_types are in the range [0, num_edge_types - 1]. So your example edge list will look like:
node1 label node2
0 0 1
0 1 3
1 0 4
One other thing, I noticed that your edge list contains edges where the destination nodes are literal values. E.g. n1:W2026714175 n2:citedByCount 1
and n1:W3157275129 n2:isRetracted False
. You may want to consider removing those edges, as I'm not sure how well a graph embedding model will be able to learn embeddings for literals, and it might adversely impact the quality of the embeddings for the entities in your graph. If you don't remove these edges, you will learn embeddings for each literal value that appears in your edge list (0, 1, 2, ..., True, False, etc.), but these embeddings are unlikely to be useful or meaningful.
Okay, thanks for clearing that up!
I managed to convert my .tsv to an integer-only format like the one you described. I'm saving the transformed mapping as a .csv file for re-use, since the graph I'm working with is larger than the file described above (only a subset for initial testing).
This .csv I'm handing over as the input file for the marius_preprocessing
command with the additional parameters --no_remap_ids -d ','
I further modified a version of custom.py
, which is used by marius_preprocessed.py
as the module to work with my custom dataset I assume. In that custom.py
, I'm passing the explicit arguments num_nodes
and num_rels
to the TorchEdgeListConverter
.
The setup works as far as I can tell and partitions the given edges without any expensive remapping.
It still requires a lot of memory though, roughly 4-fold the file size if not more, but that may be due to the high number of unique entities. With spark, memory allocation was about the size of the file at first, however very slow and aborted in phase 10 due to a java.lang.OutOfMemoryError: Java heap space
error.
Anyways I was able to run pre-processing without spark and am now in the process of training.
For this, I am now wondering how the required GPU memory is determined.
Are all embeddings loaded into GPU mem. when the config file is set to embeddings: type: DEVICE
or just partially? Since the 24GB my GPU has to offer kept giving me CUDA memory allocation errors, I then switched to embeddings: type: HOST_MEMORY
.
Also, does the CPU memory scale linearly with the number of edges / inversely with the number of partitions? In other words, I am thinking how to reduce both GPU and CPU memory requirements by working with more partitions or modifying the config file.
Plus, regarding the literals. I am removing them for now but planning to incorporate them into the embeddings later on as additional information describing the entities. Thanks for the hint nevertheless!
Storage configuration
For this, I am now wondering how the required GPU memory is determined. Are all embeddings loaded into GPU mem. when the config file is set to
embeddings: type: DEVICE
or just partially? Since the 24GB my GPU has to offer kept giving me CUDA memory allocation errors, I then switched toembeddings: type: HOST_MEMORY
.
embeddings.type: DEVICE_MEMORY
should not be used as that will attempt to load all embeddings into GPU memory.
In order to use the partition buffer your storage should be set to embeddings.type: PARTITION_BUFFER
. In this configuration, buffer_capacity
partitions are stored in CPU memory during training.
Example storage configuration:
storage:
device_type: cuda
dataset:
dataset_dir: <your_path>
edges:
type: HOST_MEMORY
options:
dtype: int
embeddings:
type: PARTITION_BUFFER
options:
dtype: float
num_partitions: 4
buffer_capacity: 2
In your case, I actually don't think you need to use partitioning to train since you likely have enough CPU memory (500 GB). I recommend storing your embeddings in CPU memory as follows.
storage:
device_type: cuda
dataset:
dataset_dir: <your_path>
edges:
type: HOST_MEMORY
options:
dtype: int
embeddings:
type: HOST_MEMORY
options:
dtype: float
You do not need to preprocess the graph again to use this configuration
Memory usage
Also, does the CPU memory scale linearly with the number of edges / inversely with the number of partitions? In other words, I am thinking how to reduce both GPU and CPU memory requirements by working with more partitions or modifying the config file.
Memory usage for GPU and CPU depends on the exact configuration, but it is largely driven by the number of embeddings stored in memory.
3 * num_edges * 4 bytes
2 * embedding_dim * num_nodes * 4 bytes
So if the storage.embeddings
or storage.edges
is set to DEVICE_MEMORY
, the system will attempt to allocate that many bytes in the GPU.
If storage.embeddings
or storage.edges
is set to HOST_MEMORY
, the system will attempt to allocate that many bytes in the CPU. The GPU memory usage in this case is driven by the batch size. This is likely the best option for you unless you are training very large (d > 400) embeddings
For partitioned training (storage.embeddings
set to PARTITION_BUFFER
) the partitions are stored in CPU memory. The CPU memory used by the embeddings is d * num_nodes * (buffer_capacity / num_partitions) * 4 bytes
Configuring the pipeline To improve training throughput, you can use asynchronous training (although this may slightly degrade model quality). Here's a pipeline configuration which should work well for your machine:
pipeline:
sync: false
staleness_bound: 16
batch_loader_threads: 8
gradient_update_threads: 8
Alternatively, you can use synchronous training (this is the default), which will be slower but won't have any loss in quality :
pipeline:
sync: true
Thanks for the detailed response! It very much helped me to get the initial training going.
However, I encountered an error regarding the Partition Buffer for the embedding storage.
First I tried training with the HOST_MEMORY
option for both edges and embeddings, the process ran out of memory though.
I then set out to partition the embeddings using the PARTITION_BUFFER
setting, in accordance to the number of partitions I created during preprocessing (40 partitions).
The excerpt of the config .yaml I used:
edges:
type: HOST_MEMORY
embeddings:
type: PARTITION_BUFFER
options:
dtype: float
num_partitions: 40
buffer_capacity: 10
save_model: false
The preprocessing command used:
marius_preprocess --output_dir /mount/all_objects_v8_40/ --edges /mount/all_objects_v8_mapped.csv -d ',' --dataset_split 0.99 0.005 0.005 --columns 0 1 2 --num_partitions 40 --no_remap_ids --overwrite
With these two steps the training process returned the following error and aborted training:
$ marius_train /mcon/lp_50dim_partitioned.yaml
[2022-10-12 18:11:34.322] [info] [marius.cpp:41] Start initialization
[10/12/22 18:17:18.597] Initialization Complete: 344.291s
[10/12/22 18:17:18.605] Generating New Beta Ordering
[10/12/22 18:18:53.705] ################ Starting training epoch 1 ################
terminate called after throwing an instance of 'c10::IndexError'
what(): index out of range in self
Exception raised from operator() at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:849 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x14ead0cd9a22 in /usr/local/lib/python3.6/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0x128fd5a (0x14ead23e7d5a in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #2: GOMP_parallel + 0x3f (0x14eb7720d05f in /usr/local/lib/python3.6/dist-packages/torch/lib/libgomp-a34b3233.so.1)
frame #3: <unknown function> + 0x128e314 (0x14ead23e6314 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #4: at::native::index_select_out_cpu_(at::Tensor const&, long, at::Tensor const&, at::Tensor&) + 0x223a (0x14ead23fa3fa in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #5: at::native::index_select_cpu_(at::Tensor const&, long, at::Tensor const&) + 0x60 (0x14ead23faec0 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x1a17fe2 (0x14ead2b6ffe2 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #7: at::redispatch::index_select(c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&) + 0xb4 (0x14ead29f23a4 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0x30280c1 (0x14ead41800c1 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #9: <unknown function> + 0x3028515 (0x14ead4180515 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #10: at::Tensor::index_select(long, at::Tensor const&) const + 0x155 (0x14ead2eb96b5 in /usr/local/lib/python3.6/dist-packages/torch/lib/libtorch_cpu.so)
frame #11: PartitionBuffer::indexRead(at::Tensor) + 0x46 (0x14e9e42c4086 in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #12: PartitionBufferStorage::indexRead(at::Tensor) + 0x42 (0x14e9e42dca62 in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #13: GraphModelStorage::getNodeEmbeddings(at::Tensor) + 0x5e (0x14e9e42ce80e in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #14: DataLoader::loadCPUParameters(std::shared_ptr<Batch>) + 0x79 (0x14e9e423bba9 in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #15: DataLoader::getBatch(c10::optional<c10::Device>, bool) + 0xdd (0x14e9e423fdfd in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #16: LoadBatchWorker::run() + 0x15c (0x14e9e42ae34c in /usr/local/lib/python3.6/dist-packages/marius/libmarius.so)
frame #17: <unknown function> + 0xd44c0 (0x14eb76ec64c0 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #18: <unknown function> + 0x76db (0x14eb7a2a66db in /lib/x86_64-linux-gnu/libpthread.so.0)
frame #19: clone + 0x3f (0x14eb7a5df61f in /lib/x86_64-linux-gnu/libc.so.6)
Aborted (core dumped)
How can I interpret this error?
Right now I am training with the entire set of embeddings in HOST_MEMORY
and the edges as FLAT_FILE
, it is running smooth but fairly slow.
First I tried training with the HOST_MEMORY option for both edges and embeddings, the process ran out of memory though.
This is good to know, we are aware of some issues with the memory overhead of the edges and have identified some fixes.
Glad to hear that the configuration with embeddings in HOST_MEMORY
and edges in FLAT_FILE
is working. When you say it's slow do you have an estimate of the number of edges per second? Throughput depends on the model, batch size, negative sampling, and pipeline configuration used. For well-optimized configurations of simple models (e.g. DistMult, TransE, ComplEx), I would expect about 500k to 1 million edges per second. If you post your full config file I can give some pointers for optimizations.
With these two steps the training process returned the following error and aborted training:
This is a tough bug. It is saying that the node ids used to read the embeddings are not in the expected range. I'm not sure what the cause is exactly but my initial guess would be the preprocessing is not correct. What are the contents of /mount/all_objects_v8_40/dataset.yaml
? I'm guessing that the --no_remap_ids
option is not able to detect the correct number of nodes for the graph, causing the expected range of node_ids to be incorrect. If this is the case then the fix should be simple.
Thanks for the response!
Glad to hear that the configuration with embeddings in HOST_MEMORY and edges in FLAT_FILE is working. When you say it's slow do you have an estimate of the number of edges per second?
My training log shows 6610355ms runtime
and 452470.78 edges per second
. With the same config but active pipelining, this rate went down to around 170321 edges per second
.
Below is the config I used for the run without pipelining:
model:
learning_task: LINK_PREDICTION
encoder:
layers:
- - type: EMBEDDING
output_dim: 100
decoder:
type: DISTMULT
options:
input_dim: 50
loss:
type: SOFTMAX_CE
options:
reduction: SUM
dense_optimizer:
type: ADAM
options:
learning_rate: 0.1
sparse_optimizer:
type: ADAGRAD
options:
learning_rate: 0.1
storage:
device_type: cuda
dataset:
dataset_dir: /mount/all_objects_v8_40/
edges:
type: FLAT_FILE
options:
dtype: int
embeddings:
type: HOST_MEMORY
options:
dtype: float
save_model: false
training:
batch_size: 16000
negative_sampling:
num_chunks: 10
negatives_per_positive: 500
degree_fraction: 0.0
filtered: false
num_epochs: 3
pipeline:
sync: false
staleness_bound: 16
batch_loader_threads: 8
gradient_update_threads: 8
epochs_per_shuffle: 1
logs_per_epoch: 20
evaluation:
batch_size: 1000
negative_sampling:
filtered: false
pipeline:
sync: false
staleness_bound: 16
batch_loader_threads: 8
gradient_update_threads: 8
When trying to run a very similar config as above but with 3-layer GAT or Graph Sage as with the configs suggested in your documentation page instead, I encountered CUDA OOM errors. For these runs, I initially reduced the training batch size to 3000.
Is it correct to assume the error may be due to the three layer networks trying to load all 3-hop neighbors into GPU memory? I tried reducing the training batch size further, however this didn't significantly change the required GPU memory. Next, I'll decrease the network size to two layers. The CUDA OOM error is below.
[10/16/22 19:29:01.673] ################ Starting training epoch 1 ################
[W BucketizationUtils.h:20] Warning: input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible (function operator())
terminate called after throwing an instance of 'c10::CUDAOutOfMemoryError'
what(): CUDA out of memory. Tried to allocate 73.36 GiB
Is there perhaps any way to make training of the GNNs more slim without dramatically reducing the dimensions?
What are the contents of /mount/all_objects_v8_40/dataset.yaml? I'm guessing that the --no_remap_ids option is not able to detect the correct number of nodes for the graph, causing the expected range of node_ids to be incorrect. If this is the case then the fix should be simple.
The contents of the .yaml
file are:
dataset_dir: /mount/mmm/mdat/all_objects_v8_40/
num_edges: 2990992480
num_nodes: 430973178
num_relations: 6
num_train: 2990992480
num_valid: 2996986
num_test: 2996987
node_feature_dim: -1
rel_feature_dim: -1
num_classes: -1
initialized: false
And in the custom.py
script I passed the same number of nodes and relations for preprocessing to TorchEdgeListConverter
.
I'll re-run the preprocessing again with a lower number of partitions and double check the node count to see if this makes a difference!
Thank you!
Sorry for the delay!
My training log shows 6610355ms runtime and 452470.78 edges per second. With the same config but active pipelining, this rate went down to around 170321 edges per second.
Nice, these are expected throughputs for DistMult.
Is it correct to assume the error may be due to the three layer networks trying to load all 3-hop neighbors into GPU memory? I tried reducing the training batch size further, however this didn't significantly change the required GPU memory. Next, I'll decrease the network size to two layers. The CUDA OOM error is below.
Yes that is correct. In order to train a GNN at this scale you will need to perform neighbor sampling instead of getting all neighbors.
Here is a GNN configuration I typically use on large-scale graphs. It has a single graph sage layer where neighbors are sampled. This will train somewhere around 2x slower than DistMult. If you wanted to use GAT, this may give better quality embeddings but will train much slower as it is more computationally expensive (maybe around 10x slower than DistMult).
model:
learning_task: LINK_PREDICTION
encoder:
train_neighbor_sampling:
- type: UNIFORM
options:
max_neighbors: 10
layers:
- - type: EMBEDDING
output_dim: 50
bias: true
- - type: GNN
input_dim: 50
output_dim: 50
options:
type: GRAPH_SAGE
bias: true
decoder:
type: DISTMULT
options:
input_dim: 50
loss:
type: SOFTMAX_CE
options:
reduction: SUM
dense_optimizer:
type: ADAM
options:
learning_rate: 0.1
sparse_optimizer:
type: ADAGRAD
options:
learning_rate: 0.1
You can add more GNN layers if needed, but you may need to reduce the batch size or the number of neighbors sampled at each layer. Here's a 2 layer GNN encoder as an example.
encoder:
train_neighbor_sampling:
- type: UNIFORM
options:
max_neighbors: 15
- type: UNIFORM
options:
max_neighbors: 5
layers:
- - type: EMBEDDING
output_dim: 50
bias: true
- - type: GNN
input_dim: 50
output_dim: 50
options:
type: GRAPH_SAGE
bias: true
- - type: GNN
input_dim: 50
output_dim: 50
options:
type: GRAPH_SAGE
bias: true
Thanks for the two suggested GNN configurations. Using these, I was able to train a few networks!
With the option pipeline: snyc: false
the evaluation of GNNs did not work, as it generated an error with the comment terminate called recursively (memory was not . It then worked using pipeline: sync: true
for evaluation
I'm planning to compare the GNNs to more conventional approaches such as TransE or ComplEx on the full data set. However for both these models, the training goes through but the evaluation fails due to OOM.
How can this be explained? Since the number of embeddings stored in memory does not increase when evaluating?
I stored the edges at FLAT_FILE
and the embeddings in HOST_MEMORY
. Perhaps the evaluation set is loaded into memory as one large chunk as well?
Is there any way to reduce the additional CPU memory load that the evaluation causes? I already tried modfying the batch size and the parameters regarding the negative sampling, without noticeable effect.
Thank you!
One thing to make sure evaluation.negative_sampling.filtered
is set to false. If enabled, this will cause the evaluation to use a large amount of memory, as data structures are maintained in memory for fast negative filtering over the entire graph. This works fine on small graphs but not large ones. If you need to evaluate with filtered negatives, I have a python script lying around which is able to compute filtered MRR on large-scale graphs. If needed I can clean it up and share it.
Do you see this issue with the non-GNN models? (TransE, ComplEx, etc.). GNNs require additional data structures in memory, which might be the root cause of this OOM. On our end, we have been working on optimizations to reduce the memory overhead.
As a workaround you can try the following:
You can disable evaluation during training and then evaluate once all training epochs are complete with marius_eval
To do this, set evaluation.epochs_per_eval
> training.num_epochs
. So if you are training 3 epochs, then set evaluation.epochs_per_eval
to at least 4. This will cause the evaluation to be skipped during training.
Once training is complete, you can run marius_eval <config_file>
. This will return the MRR on the test set.
If you want to evaluate every epoch. You can perform the above process one epoch at a time by setting training. resume_training=true
Thank you for the explanation. I was able to train and evaluate separately using that workaround!
With this, I'm closing this issue as it has been resolved for me. Thanks again for the support.
Problem description
When preprocessing my data set in the .tsv format to prepare it for training and splitting into four partitions, I receive an out-of-memory error.
I used the command:
marius_preprocess --output_dir /mount/marius_preprocessed/works/ --edges /mount/data/works.tsv --dataset_split 0.8 0.1 0.1 --columns 0 1 2 --num_partitions 4
However, during preprocessing I encounter the error:
unique_nodes = np.unique(np.concatenate([unique_src.astype(str), unique_dst.astype(str)])) numpy.core._exceptions.MemoryError: Unable to allocate 612GiB for an array with shape (167416627, ).
The input file is 46GB in size and contains about 1 billion lines (triples). And the instance I'm using has 500GB of memory. It seems the array has the 'length' equal to the number of unique entities in the input file.
The error occurs after the remapping of edges step has started. Changing the number of partitions did not help. I am running the tool in a Docker container, the behavior without container was similar though.
I understand Marius was built for efficient embeddings generation on lower-capacity machines when training. Is there any way to reduce the resource needs during preproccesing as well? Perhaps any modifications in the .tsv file from my side that could support the preprocessing?
Expected behavior
Preprocessing of raw input (.nt or .tsv files) into ready-to-train partitions, with comparable resource requirements as during the embeddings training.
Environment
Thank you!