dmlc / dgl

Python package built to ease deep learning on graph, on top of existing DL frameworks.
http://dgl.ai
Apache License 2.0
13.19k stars 2.99k forks source link

[GraphBolt] modify logic for `HeteroItemSet` indexing #7428

Open Skeleton003 opened 1 month ago

Skeleton003 commented 1 month ago

Description

First let's take a look at the current code for indexing a HeteroItemSet (occurs in HeteroItemSet.__getitem__):

        elif isinstance(index, Iterable):
            if not isinstance(index, torch.Tensor):
                index = torch.tensor(index)
            assert torch.all((index >= 0) & (index < self._length))
            key_indices = (
                torch.searchsorted(self._offsets, index, right=True) - 1
            )
            data = {}
            for key_id, key in enumerate(self._keys):
                mask = (key_indices == key_id).nonzero().squeeze(1)
                if len(mask) == 0:
                    continue
                data[key] = self._itemsets[key][
                    index[mask] - self._offsets[key_id]
                ]
            return data

Say the length of indices is N and the number of etypes/ntyeps is K, then the time complexity of current implementation of indexing a dictionary is O(N * K), which is mainly introduced by the line

mask = (key_indices == key_id).nonzero().squeeze(1)

If there are a lot of etypes, this line could easily become the bottleneck.

This draft PR intends to propose an alternative to current logic:

        elif isinstance(index, Iterable):
            if not isinstance(index, torch.Tensor):
                index = torch.tensor(index)
            sorted_index, indices = index.sort()
            assert sorted_index[0] >= 0 and sorted_index[-1] < self._length
            index_offsets = torch.searchsorted(sorted_index, self._offsets)
            data = {}
            for key_id, key in enumerate(self._keys):
                if index_offsets[key_id] == index_offsets[key_id + 1]:
                    continue
                current_indices, _ = indices[
                    index_offsets[key_id] : index_offsets[key_id + 1]
                ].sort()
                data[key] = self._itemsets[key][
                    index[current_indices] - self._offsets[key_id]
                ]
            return data

whose time complexity is O(N * logN) where the log is introduced by the sorting operation.

This will imporve the performance when there are many etypes, but might cause more time consuming when there are few etypes. A thoughtful consideration lies in striking a balance between the two approaches.

Update on June 18

Benchmark: https://docs.google.com/document/d/1Bbmp8gMekiGIYYxEMVbmXSANRZlZ_nTNbhpWul4RaKA/edit?usp=sharing

The results show that the original algorithm is faster than the new algorithm (theoretical time complexity N*logN) for almost all batch_size and num_types.

Checklist

Please feel free to remove inapplicable items for your PR.

Changes

dgl-bot commented 1 month ago

To trigger regression tests:

dgl-bot commented 1 month ago

Commit ID: 6c3a7f203cb4b94964ea2cd4e1e172190f488ad1

Build ID: 1

Status: ✅ CI test succeeded.

Report path: link

Full logs path: link

Skeleton003 commented 3 weeks ago

@dgl-bot

dgl-bot commented 3 weeks ago

Commit ID: b4c0ea46dfefd4426afc83d6a6098b4495a653b5

Build ID: 2

Status: ✅ CI test succeeded.

Report path: link

Full logs path: link

mfbalin commented 3 weeks ago

Do you have a benchmark comparing the new approach to the old one for different K values?

dgl-bot commented 3 weeks ago

Commit ID: 50da7181836ef6ecedd78cbd58e32bd9c9d5003b

Build ID: 3

Status: ❌ CI test failed in Stage [Distributed Torch CPU Unit test].

Report path: link

Full logs path: link

Skeleton003 commented 2 weeks ago

@mfbalin @frozenbugs See benchmark results in the description. The new implementation does not seem to be as efficient as we thought. Maybe we should keep it as is?

dgl-bot commented 2 weeks ago

Commit ID: 591c122323236759ec8e2df4021308331e93cf6b

Build ID: 4

Status: ✅ CI test succeeded.

Report path: link

Full logs path: link

mfbalin commented 2 weeks ago

@mfbalin @frozenbugs See benchmark results in the description. The new implementation does not seem to be as efficient as we thought. Maybe we should keep it as is?

Let me take a look at the code to see if we missed anything. Thank you for the benchmark.