Closed ChenHuaYou closed 4 years ago
def rms_error():
td_alphas = [0.15, 0.1, 0.05] mc_alphas = [0.01, 0.02, 0.03, 0.04] episodes = 100 + 1 runs = 100 for i, alpha in enumerate(td_alphas + mc_alphas): total_errors = np.zeros(episodes) if i < len(td_alphas): method = 'TD' linestyle = 'solid' else: method = 'MC' linestyle = 'dashdot' for r in tqdm(range(runs)): errors = [] current_values = np.copy(VALUES) for i in range(0, episodes): errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0)) if method == 'TD': temporal_difference(current_values, alpha=alpha) else: monte_carlo(current_values, alpha=alpha) total_errors += np.asarray(errors) total_errors /= runs plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha)) plt.xlabel('episodes') plt.ylabel('RMS') plt.legend()
errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0)) shuld be as follow:
errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 7.0)),because N = 7
Here I didn't consider the two terminal states.
def rms_error():
Same alpha value can appear in both arrays
errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0)) shuld be as follow:
errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 7.0)),because N = 7