Closed cbhua closed 5 months ago
Let's remember also to fix the shifts in the torch.roll
distance calculation as @ngastzepeda noticed, e.g. here. These do not affect calculations in euclidean problems, but it's best to have it conceptually correct
Notice that we moved most of the above in here #169 (without modification to environment logic or variables)! We will address the comments and merge soon~
There have been too many changes to track recently, and it seems that several features have already been added.
I will be closing this for now and come back to this for a fresh PR if needed!
Description
Together with Major modeling refactoring #165, this PR is for major, long-due refactoring to the RL4CO environments codebase.
Motivation and Context
This refactoring is driven by following motivations:
Changelog
Environment Structure Refactoring
The refactored structure for environments is as following:
We have restructured the organization of the environment files for improved modularity and clarity. Each environment has its own directory, comprising three components:
env.py
: The core framework of the environment, managing functions such as_reset()
,_step()
, and others. For a comprehensive understanding, please refer to the documentation.generator.py
: Replace the previousgenerate_data()
function; this module works for randomly initializing instances within the environment. The updated version now supports custom data distributions. See the following sections for more details.render.py
: For visualization of the solution. Its separation from the main environment file enhances overall code readability.Data Generator Supporting
Each environment generator will be based on the base
Generator()
class with the following functions:__init_()
will record all the environment instance initialize parameters, for example,num_loc
,min_loc
,max_loc
, etc.Thus, you will see how the
__init__()
function for the environment (e.g.CVRPEnv.__init__(...)
) only takesgenerator
andgenerator_params
as input. Now, the environment initialize example would beVarious samplers will be initialized here. We provide the
get_sampler()
function to based on the input variables to return atorch.distributions
class. By default, we support distributionsUniform
,Normal
,Exponential
, andPoisson
for locations andcenter
,corner
, for depots. You can also pass your won distribution sampler. See the following sections for more details.__call__()
is a middle wrapper; at the moment, it is used to regularize thebatch_size
format supported by the TorchRL (i.e., in alist
format). Note that in this refactor version, we would finalize the dimension ofbatch_size
to be 1 for easier implementation and clearer understanding since even multi-batch-size dimensions can be easily transferred to a single dimension.__generate()
is the part you would like to implement for your own environment data generator.New
get_sampler()
functionThis implementation mainly refers to @ngastzepeda's code. In the current version, we support the following distributions:
center
: For depots. All depots will be initialized in the center of the space.corner
: For depots. All depots will be initialized in the bottom left corner of the space.Uniform
: Takesmin_val
andmax_val
as input.Exponential
andPoisson
: Takemean_val
andstd_val
as input.You can also use your own
Callable
function as the sampler. This function will take thebatch_size: List[int]
as input and return the sampledtorch.Tensor
.Modification for
RL4COEnvBase()
We move the checking for
batch_size
anddevice
from every environment to the base class for clarity, as shown inhttps://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L130-L138
We added a new
_get_reward()
function aside from the originalget_reward()
function and moved thecheck_solution_validity()
from every environment to the base class for clarity, as shown inhttps://github.com/ai4co/rl4co/blob/b70566bc2354ade45d249a8eb86c40f0e2b47230/rl4co/envs/common/base.py#L175-L187
Standardization
We standardize the contents of
env.py
with the following functions:The order is considered to be natural and easy to follow, and we expected all environments to follow the same order for easier reference and matinees. In more detail, we have the following standardization:
available
tovisited
for more intuitive understanding. In thestep()
andget_action_mask()
calculation,visited
records which nodes are visited, and theaction_mask
is based on it with environment constraints (e.g., capacity, time window, etc.). Separating these two variables would be clearer for the calculation logic._step()
function to a nonstatic method. Follow the TorchRL style.get_action_mask()
calculation logic, which generally contains three parts: (a) initialize theaction_mask
based onvisited
; (b) update citiesaction_mask
based on the state; (c) update the depotaction_mask
finally. Based on experience, this logic would cause fewer conflicts and mass.i
,capacity,
used_capacity,
etc., are initialized with the size of[*batch_size, 1]
instead of[*batch_size, ]
. The reason is that in many masking operations, we need to do logic calculations between this 1-D feature and 2-D features, e.g., capacity with demand. Also, stay consistent with TorchRL implementation.num_loc
,min_loc
,max_loc
) to the generator for clarity.cost
variable to theget_reward
function for an intuitive understanding. In this case, the return (reward) is-cost
.Other Fixes
vehicle_capacity
→capacity
,capacity
→unnorm_capacity
to clarify.demand
variable will also contain the depot. For example, in the previousCVRPEnv()
, givennum_loc=50
, thetd[”locs”]
has the size of[batch_size, 51, 2]
(with the depot), and thetd[”demand”]
has the size of[batch_size, 50, 2]
. This causes index shifting in theget_action_mask()
function, which requires a few padding operations.0
→1e-5
), for example, in SDVRPdone = ~(demand > 0).any(-1)
→done = ~(demand > 1e-5).any(-1)
for better robustness to avoid edge cases.num_loc,
e.g., CVRPCAPACITIES,
if the givennum_loc
is not in the table, we will find the closestnum_loc
as replace and raise a warning to increase the running robustness.get_reward()
.Notes
num_depot
,num_agents
. These values are initialized bytorch.randint()
.0
to the start and end.Here is the summary of the refractory status for each environment:
env.py
,generator.py
,render.py
; fix the__init__()
and_reset()
functions;check_solution_validity()
function;_step()
andget_action_maks()
function are cleaned up with the standard pipeline.Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Thanks, and need your help
Thanks for @ngastzepeda's base code for this refactoring!
If you have time, welcome to provide your ideas/feedback on this PR. CC: @Furffico @henry-yeh @bokveizen @LTluttmann
There are quite a few remaining works for this PR, and I will actively update them here.