ai4co / rl4co

A PyTorch library for all things Reinforcement Learning (RL) for Combinatorial Optimization (CO)
https://rl4.co
MIT License
455 stars 84 forks source link

Major environment refactoring (draft version) #166

Closed cbhua closed 5 months ago

cbhua commented 7 months ago

[!IMPORTANT] The merge of this pull request is postponed because it contains sensitive modifications to the environment logic, which may cause hidden bugs. We should be careful to update them. Therefore, this full version of environment refactoring will be kept as a draft. We opened another base version refactor pull request: https://github.com/ai4co/rl4co/pull/169, which only touches the environment structure and adds the generator without changing any logic for a safe refactor in the current state. In the future, we will based on this draft's full version, go further refactor environments step by step.

Description

Together with Major modeling refactoring #165, this PR is for major, long-due refactoring to the RL4CO environments codebase.

Motivation and Context

This refactoring is driven by following motivations:

Changelog

Environment Structure Refactoring

The refactored structure for environments is as following:

rl4co
├── models/
└── envs/
    ├── eda/
    ├── scheduling/
    └── routing/
        ├── tsp/
        │   ├── env.py
        │   ├── generator.py
        │   └── render.py
        ├── cvrp/
        │   ├── env.py
        │   ├── generator.py
        │   └── render.py
        └── ...

We have restructured the organization of the environment files for improved modularity and clarity. Each environment has its own directory, comprising three components:

Data Generator Supporting

Each environment generator will be based on the base Generator() class with the following functions:

class Generator():
    def __init__(self, **kwargs):
        self.kwargs = kwargs

    def __call__(self, batch_size) -> TensorDict:
        batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
        return self._generate(batch_size)

    def _generate(self, batch_size, **kwargs) -> TensorDict:
        raise NotImplementedError

New get_sampler() function

This implementation mainly refers to @ngastzepeda's code. In the current version, we support the following distributions:

You can also use your own Callable function as the sampler. This function will take the batch_size: List[int] as input and return the sampled torch.Tensor.

Modification for RL4COEnvBase()

We move the checking for batch_size and device from every environment to the base class for clarity, as shown in

https://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L130-L138

We added a new _get_reward() function aside from the original get_reward() function and moved the check_solution_validity() from every environment to the base class for clarity, as shown in

https://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L175-L187

Standardization

We standardize the contents of env.py with the following functions:

class EnvName(RL4COEnvBase):
    name = "env_name"
    def __init__(self, generator: EnvGenerator, generator_params: dict): pass

    def _step(self, td: TensorDict) -> Tensordict: pass

    @staticmethod
    def get_action_mask(td: TensorDict) -> torch.Tensor: pass

    def _reset(self, td: Optional[TensorDict] = None, batch_size: Optional[list] = None) -> TensorDict: pass

    def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor: pass

    @staticmethod
    def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None: pass

    @staticmethod
    def render(td: TensorDict, actions: torch.Tensor = None, ax = None): pass

    def _make_spec(self, generator: EnvGenerator): pass

The order is considered to be natural and easy to follow, and we expected all environments to follow the same order for easier reference and matinees. In more detail, we have the following standardization:

  1. We changed the variable name available to visited for more intuitive understanding. In the step() and get_action_mask() calculation, visited records which nodes are visited, and the action_mask is based on it with environment constraints (e.g., capacity, time window, etc.). Separating these two variables would be clearer for the calculation logic.
  2. For some environments, change the _step() function to a nonstatic method. Follow the TorchRL style.
  3. Standardize the get_action_mask() calculation logic, which generally contains three parts: (a) initialize the action_mask based on visited; (b) update cities action_mask based on the state; (c) update the depot action_mask finally. Based on experience, this logic would cause fewer conflicts and mass.
  4. All 1-D features, e.g., i, capacity, used_capacity, etc., are initialized with the size of [*batch_size, 1] instead of [*batch_size, ]. The reason is that in many masking operations, we need to do logic calculations between this 1-D feature and 2-D features, e.g., capacity with demand. Also, stay consistent with TorchRL implementation.
  5. Rewrite comments on environments with descriptions of observations, constraints, finish conditions, rewards, and args so that a user can better understand the environment. Also, move data-related parameters (e.g., num_loc, min_loc, max_loc) to the generator for clarity.
  6. Add the cost variable to the get_reward function for an intuitive understanding. In this case, the return (reward) is -cost.

Other Fixes

  1. In CVRP, change the variable name vehicle_capacitycapacity, capacityunnorm_capacity to clarify.
  2. [⚠️ Sensitive Change] Now, the demand variable will also contain the depot. For example, in the previous CVRPEnv(), given num_loc=50, the td[”locs”] has the size of [batch_size, 51, 2] (with the depot), and the td[”demand”] has the size of [batch_size, 50, 2]. This causes index shifting in the get_action_mask() function, which requires a few padding operations.
  3. Fix the SDVRP environment action mask calculation bug.
  4. Adding numerical calculation error bound (01e-5), for example, in SDVRP done = ~(demand > 0).any(-1)done = ~(demand > 1e-5).any(-1) for better robustness to avoid edge cases.
  5. In CVRP, OP, and PCTSP environments, getting variables from tables with num_loc, e.g., CVRP CAPACITIES, if the given num_loc is not in the table, we will find the closest num_loc as replace and raise a warning to increase the running robustness.
  6. Fix the return type of get_reward().

Notes

  1. In The current version, we don’t support the distribution of int values, e.g., num_depot, num_agents. These values are initialized by torch.randint().
  2. In the reward calculation, for environments with the constraint starting and ending at the depot, actions should pad 0 to the start and end.
  3. In the current version, only routing environments have been refactored. We will also refactor the EDA and Scheduling environments soon.

Here is the summary of the refractory status for each environment:

  Decompose Training Checking Documentation Solution Validity Clean up Logic
TSP
CVRP
CVRPTW  
PCTSP  
OP    
SDVRP
SVRP    
ATSP
MTSP    
SPCTSP  
PDP
MPDP    
MDCPDP      

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

Checklist

Thanks, and need your help

Thanks for @ngastzepeda's base code for this refactoring!

If you have time, welcome to provide your ideas/feedback on this PR. CC: @Furffico @henry-yeh @bokveizen @LTluttmann

There are quite a few remaining works for this PR, and I will actively update them here.

fedebotu commented 7 months ago

Let's remember also to fix the shifts in the torch.roll distance calculation as @ngastzepeda noticed, e.g. here. These do not affect calculations in euclidean problems, but it's best to have it conceptually correct

fedebotu commented 7 months ago

Notice that we moved most of the above in here #169 (without modification to environment logic or variables)! We will address the comments and merge soon~

fedebotu commented 5 months ago

There have been too many changes to track recently, and it seems that several features have already been added.

I will be closing this for now and come back to this for a fresh PR if needed!