ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
415 stars 79 forks source link

[Feature Request] Does Floyd-Warshall algorithm perform better on tmat_class ATSP instance generation? #225

Open abcdhhhh opened 1 day ago

abcdhhhh commented 1 day ago

Motivation

The ATSPGenerator._generate function seems to have high complexity in both time and space.

Solution

I have tried to implement Floyd-Warshall algorithm as an alternative, and it seems to work faster on both small and large scale problems. Meanwhile, the space complexity is $O(n^2)$ while the original is $O(n^3)$.

My code is:

from rl4co.envs.routing.atsp.generator import ATSPGenerator
import torch
from tensordict.tensordict import TensorDict
import time
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)

class FloydGenerator(ATSPGenerator):
    def _generate(self, batch_size) -> TensorDict:
        # Generate distance matrices inspired by the reference MatNet (Kwon et al., 2021)
        # We satifsy the triangle inequality (TMAT class) in a batch
        batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
        dms = (
            self.dist_sampler.sample((batch_size + [self.num_loc, self.num_loc]))
            * (self.max_dist - self.min_dist)
            + self.min_dist
        )
        dms[..., torch.arange(self.num_loc), torch.arange(self.num_loc)] = 0
        log.info("Using TMAT class (triangle inequality): {}".format(self.tmat_class))
        if self.tmat_class:
            for i in range(self.num_loc):
                dms = torch.minimum(dms, dms[..., :, [i]] + dms[..., [i], :])
        return TensorDict({"cost_matrix": dms}, batch_size=batch_size)

for num_loc, batch_size in [(20, 10000), (50, 10000), (100, 10000), (200, 128), (500, 128), (1000, 4)]:
    print(f'num_loc: {num_loc}, batch_size: {batch_size}')
    gen_0 = ATSPGenerator(num_loc=num_loc, tmat_class=True)
    gen_1 = FloydGenerator(num_loc=num_loc, tmat_class=True)

    torch.manual_seed(2024)
    t = time.time()
    data_0 = gen_0._generate(batch_size=batch_size)["cost_matrix"]
    print(f'Original Generator: {time.time() - t}')

    torch.manual_seed(2024)
    t = time.time()
    data_1 = gen_1._generate(batch_size=batch_size)["cost_matrix"]
    print(f'Floyd Generator: {time.time() - t}')

    assert (data_0 == data_1).all()

    print()

and the output is:

num_loc: 20, batch_size: 10000
Original Generator: 0.34020304679870605
Floyd Generator: 0.032858848571777344

num_loc: 50, batch_size: 10000
Original Generator: 4.837661981582642
Floyd Generator: 1.722236156463623

num_loc: 100, batch_size: 10000
Original Generator: 38.64327096939087
Floyd Generator: 14.100801706314087

num_loc: 200, batch_size: 128
Original Generator: 4.392717361450195
Floyd Generator: 0.3182206153869629

num_loc: 500, batch_size: 128
Original Generator: 54.52213501930237
Floyd Generator: 19.21849036216736

num_loc: 1000, batch_size: 4
Original Generator: 16.113645792007446
Floyd Generator: 0.48308658599853516

Alternatives

I notice that the networkx package also has an implementation of Floyd-Warshall algorithm, but I don't know whether it can be integrated here and bring better performance.

Additional context

I have only tested the performance on my own device. Could you please check whether it still performs better in more situations?

Checklist

fedebotu commented 1 day ago

This looks awesome @abcdhhhh ! Could you submit a pull request with your implementation? We will benchmark the generation speed then ~

abcdhhhh commented 1 day ago

Hi @fedebotu , I have submitted a pull request #226