ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.61k stars 5.71k forks source link

Having trouble extracting target network weights through callback #5704

Closed DavidVillero closed 4 years ago

DavidVillero commented 5 years ago

System information

Describe the problem

What am I trying to do? I'm training DQN/Rainbow agent with noisy nets and I'm noticing that as the model converges towards the ideal policy, exploration does not decline (maybe this is not a problem). Typically, having noisy actions is not a problem as long as the model converges to an optimal policy but I would like to find the source of this noise or understand better why this happens. I would like to create the same plots shown in the Noisy Networks for Exploration paper, where they compare the learning curves of the average noise parameter, sigma (an average of the weight values in each layer of the target network), image. How am I doing it? I managed to get the weight values for each layer of my target network and save them in a .csv by using a callback on_train_result. Which looks like this:

Source code / logs

def store_sigmas(metrics, global_engine=None):
    if global_engine:
        metrics.to_sql("sigma_metrics_table", global_engine, if_exists="append", index=False)
    else:
        with open('metrics.csv', 'a') as f:
            metrics.to_csv(f, header=False)

def on_train_result(info, dbarg=None):
    """
    This callback will calculate the ave, max, min and std for the target network's layers and b 
    vectors.
    :param info:
    :param dbarg:
    :return:
    """
    var_names = [
        'default/q_func/action_value/hidden_0_sigma_b',
        'default/target_q_func/action_value/hidden_0_sigma_b',
        'default/q_func/state_value/dueling_hidden_0_sigma_b',
        'default/target_q_func/state_value/dueling_hidden_0_sigma_b',
        'default/q_func/action_value/hidden_0_sigma_w',
        'default/target_q_func/action_value/hidden_0_sigma_w',
        'default/q_func/state_value/dueling_hidden_0_sigma_w',
        'default/target_q_func/state_value/dueling_hidden_0_sigma_w',
    ]
    policy = info['agent'].get_policy()
    sess = policy.sess

    metrics = {op.name + "_{}".format(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")):
                   {"_avg": float(np.mean(tf.cast(op.values(), tf.float64).eval(session=sess))),
                    "_std": float(np.std(tf.cast(op.values(), tf.float64).eval(session=sess))),
                    "_min": float(np.min(tf.cast(op.values(), tf.float64).eval(session=sess))),
                    "_max": float(np.max(tf.cast(op.values(), tf.float64).eval(session=sess))),
                    "variable": op.name,
                    "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
               for op in policy.sess.graph.get_operations()
               if op.name in var_names}
    metrics = pd.DataFrame(metrics).transpose()
    store_sigmas(metrics, global_engine=global_engine)
    info["result"].update(metrics)

The callback works, but policy.sess.graph.get_operations() gets bigger every episode, ( I don't understand why this happens) causing the iteration for op in policy.sess.graph.get_operations() to take longer and longer every episode. does anyone know what I'm doing wrong? and is there a better way of extracting the information I'm after?

Thank you

ericl commented 5 years ago

I think this is due to the tf.cast() calls, which add nodes to the graph on each iteration. I would try to avoid doing any TF operations inside the callback, or if you need to, do it once and save the result in a global variable for reuse. That way the Tensors don't keep building up.