KTH-FlowAI / DeepReinforcementLearning_RayleighBenard2D_Control

Control of 2D Rayleigh Benard Convection using Deep Reinforcement Learning with Tensorforce and Shenfun.
MIT License
16 stars 8 forks source link

Rewards on train_salr.py with Figure 7 parameters do not match. #11

Open dmitryshribak opened 2 months ago

dmitryshribak commented 2 months ago

#!/bin/env python
#
# DEEP REINFORCEMENT LEARNING FOR RAYLEIGH-BENARD CONVECTION
#
# Single-Agent Reinforcement Learning launcher
#
# train_sarl.py: main launcher for the SARL framework. 
#
# Pol Suarez, Francisco Alcantara, Colin Vignon & Joel Vasanth
#
# FLOW, KTH Stockholm | 09/04/2023

from __future__ import print_function, division

import os
os.environ["HDF5_DO_MPI_FILE_SYNC"] = 'False'
os.environ["HDF5_DO_MPI_FILE_SYNC"] = '0'

# https://github.com/h5py/h5py/issues/2330

import sys
import time

from tensorforce.agents import Agent
from tensorforce.execution import Runner

from env_utils import generate_node_list, read_node_list

#### Set up which case to run
training_case = "RB_2D_SARL"
simu_name = training_case

general_path = os.getcwd()
case_path = general_path + '/data/' + simu_name
sys.path.append(case_path)

os.system('rm -r ' + case_path)
os.mkdir(case_path)

os.system('cp ./parameters/parameters_{}.py '.format(training_case) + case_path + '/parameters.py')

from sarl_env import Environment2D
# from parameters import nb_actuations, num_episodes, num_servers, simu_name

'''
Input parameters here

'''

# case name - should be the same name as this file, without the prefix parameters_
case = 'RB_2D_SARL'
simu_name = case
dimension = '2D'
reward_function = 'Nusselt'

# Number of calculation processors
nb_proc = 1

# number of environment in parallel
num_servers = 1

# Number of segments (actuators) on the lower boundary
n_seg = 10

# Number of invariant parallel environments ('multi-agents' - set to one for single agent)
nb_inv_envs = 1

# Duration of baseline simulation (in nondimensional simulation time)
simulation_duration = 400
simulation_time_start = 0.0

# Duration of each actuation (in nondimensional simulation time)
delta_t_smooth = 1.5
delta_t_converge = 0.0
smooth_func = 'linear'

# post options
post_process_steps = 200

# Total number of episodes
num_episodes = 350  # 1

# Number of actuations per episode
nb_actuations = 200 # need this as 200 for the curve
nb_actuations_deterministic = nb_actuations * 4

# Probes
probes_location = 'cartesian_grid'
number_of_probes = (8, 32)

# Simulation parameters
simulation_params = {
    'simulation_duration': simulation_duration,
    'simulation_timeframe': [simulation_time_start, simulation_time_start + simulation_duration],
    'delta_t_smooth': delta_t_smooth,
    'delta_t_converge': delta_t_converge,
    'smooth_func': smooth_func,
    'post_process_steps': post_process_steps
}

# Variational input
variational_input = {
    'filename': 'RB_2D',
    'porous': False,
    "d": 0,
    "time": -0.25,
    "initial_time": None,
}

output_params = {
    'nb_probes': number_of_probes,
    'probe_type': 'u_T'
}

# Optimization
optimization_params = {
    "min_ampl_temp": -1.,
    "max_ampl_temp": 1.,
    # "norm_Temp":                 0.4,
    "norm_reward": 1.,
    # "norm_press":                    2,
    "offset_reward": 2.6788,
}

inspection_params = {
    "plot": False,
    "step": 50,
    "dump": 100,
    "show_all_at_reset": True,
    "single_run": False
}

'''
End of parameters

'''

#### Run
initial_time = time.time()

# Generate the list of nodes
generate_node_list(num_servers=num_servers)

# Read the list of nodes
nodelist = read_node_list()

print("\n\nDRL for 2D Rayleigh-Benard convection\n")
print("---------------------------------------\n")
print('Case: ' + simu_name + ' (Single-Agent RL)\n')
environment_base = Environment2D(simu_name=simu_name, path=general_path, node=nodelist[0])  # Baseline  #+simu_name

network = [dict(type='dense', size=512), dict(type='dense', size=512)]

agent = Agent.create(
    # Agent + Environment
    agent='ppo', environment=environment_base, max_episode_timesteps=nb_actuations,
    # Network
    network=network,
    # Optimization
    batch_size=20, learning_rate=1e-3, subsampling_fraction=0.2, multi_step=25,
    # Reward estimation
    likelihood_ratio_clipping=0.2, predict_terminal_values=True,
    baseline=network,
    baseline_optimizer=dict(
        type='multi_step', num_steps=5,
        optimizer=dict(type='adam', learning_rate=1e-3)
    ),
    # Regularization
    entropy_regularization=0.01,
    parallel_interactions=num_servers,
    saver=dict(directory=os.path.join(os.getcwd(), 'saver_data'), frequency=1, max_checkpoints=1),
    # parallel_interactions=number_servers,
)

environments = [Environment2D(simu_name=simu_name, path=general_path, do_baseline=False, ENV_ID=i,
                              host="environment{}".format(i + 1), node=nodelist[i + 1]) for i in range(num_servers)]

# start all environments at the same time
runner = Runner(agent=agent, environments=environments, remote='multiprocessing')

# now start the episodes and sync_episodes is very useful to update the DANN efficiently
runner.run(num_episodes=num_episodes, sync_episodes=False)
runner.close()

# saving all the model data in model-numpy format
agent.save(directory=os.path.join(os.getcwd(), 'model-numpy'), format='numpy', append='episodes')

agent.close()

end_time = time.time()

print(
    "DRL simulation :\nStart at : {}.\nEnd at {}\nDone in : {}".format(initial_time, end_time, end_time - initial_time))

EPISODE;REWARD 1;0.00491781547637915 2;0.00491781547637915 3;0.00491781547637915 4;0.00491781547637915 5;0.00491781547637915 6;0.00491781547637915 7;0.00491781547637915 8;0.00491781547637915 9;0.00491781547637915 10;0.00491781547637915 11;0.00491781547637915 12;0.00491781547637915 13;0.00491781547637915 14;0.00491781547637915 15;0.00491781547637915 16;0.00491781547637915 17;0.00491781547637915 18;0.00491781547637915 19;0.00491781547637915 20;0.00491781547637915 21;0.00491781547637915 22;0.00491781547637915 23;0.00491781547637915 24;0.00491781547637915 25;0.00491781547637915 26;0.00491781547637915 27;0.00491781547637915 28;0.00491781547637915 29;0.00491781547637915 30;0.00491781547637915 31;0.00491781547637915 32;0.00491781547637915 33;0.00491781547637915 34;0.00491781547637915 35;0.00491781547637915 36;0.00491781547637915 37;0.00491781547637915 38;0.00491781547637915 39;0.00491781547637915 40;0.00491781547637915 41;0.00491781547637915 42;0.00491781547637915 43;0.00491781547637915 44;0.00491781547637915 45;0.00491781547637915 46;0.00491781547637915 47;0.00491781547637915 48;0.00491781547637915 49;0.00491781547637915 50;0.00491781547637915 51;0.00491781547637915 52;0.00491781547637915 53;0.00491781547637915 54;0.00491781547637915 55;0.00491781547637915 56;0.00491781547637915 57;0.00491781547637915 58;0.00491781547637915 59;0.00491781547637915 60;0.00491781547637915 61;0.00491781547637915 62;0.00491781547637915 63;0.00491781547637915 64;0.00491781547637915 65;0.00491781547637915 66;0.00491781547637915 67;0.00491781547637915 68;0.00491781547637915 69;0.00491781547637915 70;0.00491781547637915 71;0.00491781547637915 72;0.00491781547637915 73;0.00491781547637915 74;0.00491781547637915 75;0.00491781547637915 76;0.00491781547637915 77;0.00491781547637915 78;0.00491781547637915 79;0.00491781547637915 80;0.00491781547637915 81;0.00491781547637915 82;0.00491781547637915 83;0.00491781547637915 84;0.00491781547637915 85;0.00491781547637915 86;0.00491781547637915 87;0.00491781547637915 88;0.00491781547637915 89;0.00491781547637915 90;0.00491781547637915 91;0.00491781547637915 92;0.00491781547637915 93;0.00491781547637915 94;0.00491781547637915 95;0.00491781547637915 96;0.00491781547637915 97;0.00491781547637915 98;0.00491781547637915 99;0.00491781547637915 100;0.00491781547637915 101;0.00491781547637915 102;0.00491781547637915 103;0.00491781547637915 104;0.00491781547637915 105;0.00491781547637915 106;0.00491781547637915 107;0.00491781547637915 108;0.00491781547637915 109;0.00491781547637915 110;0.00491781547637915 111;0.00491781547637915 112;0.00491781547637915 113;0.00491781547637915 114;0.00491781547637915 115;0.00491781547637915 116;0.00491781547637915 117;0.00491781547637915 118;0.00491781547637915 119;0.00491781547637915 120;0.00491781547637915 121;0.00491781547637915 122;0.00491781547637915 123;0.00491781547637915 124;0.00491781547637915 125;0.00491781547637915 126;0.00491781547637915 127;0.00491781547637915 128;0.00491781547637915 129;0.00491781547637915 130;0.00491781547637915 131;0.00491781547637915 132;0.00491781547637915 133;0.00491781547637915 134;0.00491781547637915 135;0.00491781547637915 136;0.00491781547637915 137;0.00491781547637915 138;0.00491781547637915 139;0.00491781547637915 140;0.00491781547637915 141;0.00491781547637915 142;0.00491781547637915 143;0.00491781547637915 144;0.00491781547637915 145;0.00491781547637915 146;0.00491781547637915 147;0.00491781547637915 148;0.00491781547637915 149;0.00491781547637915 150;0.00491781547637915 151;0.00491781547637915 152;0.00491781547637915 153;0.00491781547637915 154;0.00491781547637915 155;0.00491781547637915 156;0.00491781547637915 157;0.00491781547637915 158;0.00491781547637915 159;0.00491781547637915 160;0.00491781547637915 161;0.00491781547637915 162;0.00491781547637915 163;0.00491781547637915 164;0.00491781547637915 165;0.00491781547637915 166;0.00491781547637915 167;0.00491781547637915 168;0.00491781547637915 169;0.00491781547637915 170;0.00491781547637915 171;0.00491781547637915 172;0.00491781547637915 173;0.00491781547637915 174;0.00491781547637915 175;0.00491781547637915 176;0.00491781547637915 177;0.00491781547637915 178;0.00491781547637915 179;0.00491781547637915 180;0.00491781547637915 181;0.00491781547637915 182;0.00491781547637915 183;0.00491781547637915 184;0.00491781547637915 185;0.00491781547637915 186;0.00491781547637915 187;0.00491781547637915 188;0.00491781547637915 189;0.00491781547637915 190;0.00491781547637915 191;0.00491781547637915 192;0.00491781547637915 193;0.00491781547637915 194;0.00491781547637915 195;0.00491781547637915 196;0.00491781547637915 197;0.00491781547637915 198;0.00491781547637915 199;0.00491781547637915 200;0.00491781547637915 201;0.00491781547637915 202;0.00491781547637915 203;0.00491781547637915 204;0.00491781547637915 205;0.00491781547637915 206;0.00491781547637915 207;0.00491781547637915 208;0.00491781547637915 209;0.00491781547637915 210;0.00491781547637915 211;0.00491781547637915 212;0.00491781547637915 213;0.00491781547637915 214;0.00491781547637915 215;0.00491781547637915 216;0.00491781547637915 217;0.00491781547637915 218;0.00491781547637915 219;0.00491781547637915 220;0.00491781547637915 221;0.00491781547637915 222;0.00491781547637915 223;0.00491781547637915 224;0.00491781547637915 225;0.00491781547637915 226;0.00491781547637915 227;0.00491781547637915 228;0.00491781547637915 229;0.00491781547637915 230;0.00491781547637915 231;0.00491781547637915 232;0.00491781547637915 233;0.00491781547637915 234;0.00491781547637915 235;0.00491781547637915 236;0.00491781547637915 237;0.00491781547637915 238;0.00491781547637915 239;0.00491781547637915 240;0.00491781547637915 241;0.00491781547637915 242;0.00491781547637915 243;0.00491781547637915 244;0.00491781547637915 245;0.00491781547637915 246;0.00491781547637915 247;0.00491781547637915 248;0.00491781547637915 249;0.00491781547637915 250;0.00491781547637915 251;0.00491781547637915 252;0.00491781547637915 253;0.00491781547637915 254;0.00491781547637915 255;0.00491781547637915 256;0.00491781547637915 257;0.00491781547637915 258;0.00491781547637915 259;0.00491781547637915 260;0.00491781547637915 261;0.00491781547637915 262;0.00491781547637915 263;0.00491781547637915 264;0.00491781547637915 265;0.00491781547637915 266;0.00491781547637915 267;0.00491781547637915 268;0.00491781547637915 269;0.00491781547637915 270;0.00491781547637915 271;0.00491781547637915 272;0.00491781547637915 273;0.00491781547637915 274;0.00491781547637915 275;0.00491781547637915 276;0.00491781547637915 277;0.00491781547637915 278;0.00491781547637915 279;0.00491781547637915 280;0.00491781547637915 281;0.00491781547637915 282;0.00491781547637915 283;0.00491781547637915 284;0.00491781547637915 285;0.00491781547637915 286;0.00491781547637915 287;0.00491781547637915 288;0.00491781547637915 289;0.00491781547637915 290;0.00491781547637915 291;0.00491781547637915 292;0.00491781547637915 293;0.00491781547637915 294;0.00491781547637915 295;0.00491781547637915 296;0.00491781547637915 297;0.00491781547637915 298;0.00491781547637915 299;0.00491781547637915 300;0.00491781547637915 301;0.00491781547637915 302;0.00491781547637915 303;0.00491781547637915 304;0.00491781547637915 305;0.00491781547637915 306;0.00491781547637915 307;0.00491781547637915 308;0.00491781547637915 309;0.00491781547637915 310;0.00491781547637915 311;0.00491781547637915 312;0.00491781547637915 313;0.00491781547637915 314;0.00491781547637915 315;0.00491781547637915 316;0.00491781547637915 317;0.00491781547637915 318;0.00491781547637915 319;0.00491781547637915 320;0.00491781547637915 321;0.00491781547637915 322;0.00491781547637915 323;0.00491781547637915 324;0.00491781547637915 325;0.00491781547637915 326;0.00491781547637915 327;0.00491781547637915 328;0.00491781547637915 329;0.00491781547637915 330;0.00491781547637915 331;0.00491781547637915 332;0.00491781547637915 333;0.00491781547637915 334;0.00491781547637915 335;0.00491781547637915 336;0.00491781547637915 337;0.00491781547637915 338;0.00491781547637915 339;0.00491781547637915 340;0.00491781547637915 341;0.00491781547637915 342;0.00491781547637915 343;0.00491781547637915 344;0.00491781547637915 345;0.00491781547637915 346;0.00491781547637915 347;0.00491781547637915 348;0.00491781547637915 349;0.00491781547637915 350;0.00491781547637915

jerabaul29 commented 2 months ago

@dmitryshribak can you explain the issues you encounter in plain text too?

@joelvarunvasanth it looks like @dmitryshribak has trouble getting learning. Can you provide a bit of support / pointers to how to use the code? :) Is it possible that this is an old version of the code or there is another typo somewhere (?) / if so, can you push the version you were using to generate the figure:

Screenshot from 2024-09-10 15-13-56

? :)

dmitryshribak commented 2 months ago

My issue in plain text can be summarized as I ran "train_sarl.py" with the hyperparameters given in the paper (350 episodes, etc) and am getting no change in the reward curves over episodes.

joelvarunvasanth commented 1 month ago

Hi @dmitryshribak , can you share the piece of code that contains the reward function computation? And if you pass the right parameters to it - there should be some variation in the Nu from one step to the next.