ShangtongZhang / reinforcement-learning-an-introduction

Python Implementation of Reinforcement Learning: An Introduction
MIT License
13.45k stars 4.81k forks source link

chapter06/random_wark.py #123

Closed ChenHuaYou closed 4 years ago

ChenHuaYou commented 4 years ago

def rms_error():

Same alpha value can appear in both arrays

td_alphas = [0.15, 0.1, 0.05]
mc_alphas = [0.01, 0.02, 0.03, 0.04]
episodes = 100 + 1
runs = 100
for i, alpha in enumerate(td_alphas + mc_alphas):
    total_errors = np.zeros(episodes)
    if i < len(td_alphas):
        method = 'TD'
        linestyle = 'solid'
    else:
        method = 'MC'
        linestyle = 'dashdot'
    for r in tqdm(range(runs)):
        errors = []
        current_values = np.copy(VALUES)
        for i in range(0, episodes):
            errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0))
            if method == 'TD':
                temporal_difference(current_values, alpha=alpha)
            else:
                monte_carlo(current_values, alpha=alpha)
        total_errors += np.asarray(errors)
    total_errors /= runs
    plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
plt.xlabel('episodes')
plt.ylabel('RMS')
plt.legend()

errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0)) shuld be as follow:

errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 7.0)),because N = 7

ShangtongZhang commented 4 years ago

Here I didn't consider the two terminal states.