amazon-science / tgl

Apache License 2.0
192 stars 31 forks source link

fix a bug in train_dist for TGAT #28

Closed tedzhouhk closed 11 months ago

tedzhouhk commented 11 months ago

Issue #: None

Description of changes:

The script train_dist.py would throw out an error stating the input MFGs are None for non-memory-based TGNNs. This is because the DataPipelineThread class to prepare minibatch data in different CUDA streams does not store the result MFGs, due to an embarrassing bug of the if clause at train_dist.py:145. This PR fixes this bug.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.