RoboEden / Luxai-s2-Baseline

11 stars 1 forks source link

Memory leak in train.py 🐛 #5

Closed hugsclane closed 10 months ago

hugsclane commented 10 months ago

We logged the gpu and ram usage and found that ram usage was exploding into swap quickly, tested with 30gb RAM and 15 gb RAM capped at 1/2 RAM. We think that its a memory leak.

CLI command: python train.py --total-timesteps 1050 --num-envs 1 --save-interval=999 --train-num-collect=1024

Running in WSL, configured according to the documentation in this repo. here are the logs log.json

this is how we logged it.

memory_tracker = {}
last_check = process.memory_info().rss
def log_memory(name): 
    global last_check
    inc = process.memory_info().rss - last_check
    if name not in memory_tracker:
        memory_tracker[name] = 0
    memory_tracker[name] += inc
    last_check = process.memory_info().rss

def print_memory():
    memory_tracker_formatted = {}
    for k in memory_tracker:
        memory_tracker_formatted[k] = "{:.2f}".format(memory_tracker[k] / 10**9)
    memory_tracker_formatted["current"] = "{:.2f}".format(process.memory_info().rss / 10**9)
    print("current:", memory_tracker_formatted)

This code is allocating the memory. We think its something to do with torch models not being detached.

                # beginning of code block
for player_id, player in enumerate(['player_0', 'player_1']):
                obs[player] += envs.split(next_obs[player])
                dones[train_step] = next_done
                log_memory("-1")

                # ALGO LOGIC: action logic
                # use no_grad() context disables gradient calculation: https://pytorch.org/docs/stable/generated/torch.no_grad.html
                with torch.no_grad():
                    log_memory("0")
                    # under no_grad, any tensors created as a result
                    # of a computation will have their internal requires_grad
                    # state set to false. This means their gradient will not be
                    # calculated by torch. This avoids memory consumption for stuff
                    # that doesn't need gradient calculatios.
                    valid_action = envs.get_valid_actions(player_id)
                    # np2torch = lambda x, dtype: torch.tensor(x).type(dtype).to(device).detach()
                    log_memory("1")
                    # calling agent() like a function actually calls agent.forward() under the hood
                    # its defined in net.py. Here the observation space is being passed in.
                    # an action space which is not completely abstract BUT not completely lux either is
                    # returned. Lets call it intermediate actio space
                    global_feature = np2torch(next_obs[player]['global_feature'], torch.float32)
                    map_feature = np2torch(next_obs[player]['map_feature'], torch.float32)

                    log_memory("1.25")
                    # Note np2torch is a pytorch model in delcared in the folder
                    action_feature = tree.map_structure(lambda x: np2torch(x, torch.int16), next_obs[player]['action_feature'])
                    log_memory("1.5")
                    va = tree.map_structure(lambda x: np2torch(x, torch.bool), valid_action)
                    log_memory("1.75")

                    logprob, value, raw_action, _ = agent(
                        global_feature,
                        map_feature,
                        action_feature,
                        va
                    )
                    values[player][train_step] = value
                    log_memory("2")
                    # action space arrays are partitioned like so
                    # {"transfer_power": [10, 11, 12]} where 10, 11 and 12 represent the 
                    # actions to perform on environments 0, 1 and 2 (vectorized environments)
                    # This function splits the actions into independent trees per environment:
                    # [{'transfer_power': 10}, {'transfer_power': 11}, {'transfer_power': 12}]
                    valid_actions[player] += envs.split(valid_action)
                    log_memory("3")
                # see above comment, split raw_actions into independent trees per vectorized environment
                actions[player] += envs.split(raw_action)
                action[player_id] = raw_action
                logprobs[player][train_step] = logprob
                log_memory("4")
hugsclane commented 10 months ago
image image

This is the one of the culprits

map_feature = np2torch(next_obs[player]['map_feature'], torch.float32)

The 10mb per loop was coming from copy=True, changing it to False stops MOST of rapid increase in RAM.

How vital is this arg?

class LuxSyncVectorEnv(gym.vector.AsyncVectorEnv):
    def __init__(self, env_fns, observation_space=None, action_space=None, copy=False, shared_memory=False, worker=lux_worker):
Schopenhauer-loves-Hegel commented 10 months ago

Thank you for your report. I've attempted to reproduce the problem, but I also ran python train.py for an extended period without observing continuous memory usage growth.

As part of my investigation, I executed the following code snippet to check memory usage:

with torch.no_grad():
    log_memory('1')
    valid_action = envs.get_valid_actions(player_id)
    np2torch = lambda x, dtype: torch.tensor(x).type(dtype).to(device)
    log_memory('2')
    global_feature = np2torch(next_obs[player]['global_feature'], torch.float32)
    log_memory('test')
    a = np.random.randint(0, 4, size=(600, 32, 32))
    log_memory('3')
    map_feature = np2torch(next_obs[player]['map_feature'], torch.float32)
    log_memory('4')
    action_feature = tree.map_structure(lambda x: np2torch(x, torch.int16), next_obs[player]['action_feature'])
    log_memory('5')
    va = tree.map_structure(lambda x: np2torch(x, torch.bool), valid_action)
    log_memory('6')

The result is:

current: {'1': '1432.89', '2': '0.00', 'test': '0.00', '3': '11.36', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '1779.76'}
current: {'1': '3189.15', '2': '0.71', 'test': '0.00', '3': '29.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3555.28'}
current: {'1': '3189.73', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3574.87'}
current: {'1': '3190.61', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.75'}
current: {'1': '3190.61', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.75'}
current: {'1': '3190.61', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.75'}
current: {'1': '3190.61', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.75'}
current: {'1': '3190.61', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.75'}
current: {'1': '3190.77', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.91'}
current: {'1': '3190.77', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.91'}
current: {'1': '3190.77', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.91'}
current: {'1': '3190.78', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.92'}
current: {'1': '3190.78', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.92'}
current: {'1': '3190.78', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.92'}
current: {'1': '3190.78', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.92'}
current: {'1': '3190.78', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.92'}
current: {'1': '3190.81', '2': '0.71', 'test': '0.00', '3': '48.91', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3575.95'}
current: {'1': '3190.81', '2': '0.71', 'test': '0.00', '3': '67.64', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3594.67'}
current: {'1': '3190.81', '2': '0.71', 'test': '0.00', '3': '67.64', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3594.67'}
current: {'1': '3190.84', '2': '0.71', 'test': '0.00', '3': '67.64', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3594.70'}
current: {'1': '3190.84', '2': '0.71', 'test': '0.00', '3': '67.64', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3594.70'}
current: {'1': '3190.84', '2': '0.73', 'test': '0.00', '3': '67.64', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3594.71'}
current: {'1': '3190.84', '2': '0.73', 'test': '0.00', '3': '75.60', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3602.68'}
current: {'1': '3190.84', '2': '0.73', 'test': '0.00', '3': '75.60', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3602.68'}
current: {'1': '3190.84', '2': '0.73', 'test': '0.00', '3': '75.60', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3602.68'}
current: {'1': '3191.02', '2': '0.73', 'test': '0.00', '3': '79.00', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3606.26'}
current: {'1': '3191.02', '2': '0.73', 'test': '0.00', '3': '79.00', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3606.26'}
current: {'1': '3191.06', '2': '0.73', 'test': '0.00', '3': '79.00', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3606.30'}
current: {'1': '3191.07', '2': '0.73', 'test': '0.00', '3': '79.00', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3606.30'}
current: {'1': '3191.07', '2': '0.73', 'test': '0.00', '3': '79.00', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3606.31'}
current: {'1': '3191.07', '2': '0.73', 'test': '0.00', '3': '80.99', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3608.29'}
current: {'1': '3191.07', '2': '0.73', 'test': '0.00', '3': '80.99', '4': '1.02', '5': '2.59', '6': '0.00', 'current': '3608.29'}

The memory profiling results indicated that the lambda function, np2torch, does not seem to be the cause of the memory increase. Instead, it appears that the memory growth is associated with transition data stored in a buffer using a list.

One potential mitigation strategy is to change the copy argument to False when dealing with data to reduce excessive memory usage. This adjustment can help alleviate memory-related issues.

Please feel free to provide more details or context if needed, and we will continue to investigate and address this matter. Your feedback is valuable in improving the baseline.