hyunjimoon / 24_transpo

https://web.mit.edu/1.041/www/schedule.html class codes
0 stars 0 forks source link

transfer learning notebook #2

Open hyunjimoon opened 2 months ago

hyunjimoon commented 2 months ago

episode 한번 policy update (Q-learning에서 Q matrix, ) , rollout = episode (

image

R is from env (Q is from agent)

def get_baseline_performance(data_transfer, num_transfer_steps):
    deltas = data_transfer.columns.values.astype(float)

    # Oracle transfer
    oracle_transfer = [data_transfer.max(axis=0).mean()] * num_transfer_steps

    # Exhaustive training
    data_transfer_diagonal = np.zeros(len(deltas))
    for i in range(len(deltas)):
        data_transfer_diagonal[i] = data_transfer.iloc[i][i]

    exhaustive_training = [data_transfer_diagonal.mean()] * num_transfer_steps

    # Sequential oracle training
    sequential_oracle_training = []
    sot_deltas = []

    # 1st step
    sot_deltas.append(data_transfer.mean(axis=1).argmax())
    sequential_oracle_training.append(data_transfer.iloc[data_transfer.mean(axis=1).argmax(),:].mean())
    for _ in range(num_transfer_steps-1):
        candidate_indices = [x for x in range(len(deltas)) if x not in sot_deltas]
        index_tmp = [data_transfer.T[sot_deltas+[i]].max(axis=1).mean() for i in candidate_indices].index(max([data_transfer.T[sot_deltas+[i]].max(axis=1).mean() for i in candidate_indices]))
        sot_deltas.append(candidate_indices[index_tmp])
        sequential_oracle_training.append(data_transfer.T[sot_deltas].max(axis=1).mean())

    return oracle_transfer, exhaustive_training, sequential_oracle_training
hyunjimoon commented 2 months ago
image

resolved

python transfer_main.py --speed 13.0 --model_num 1 --source_path_name "results/intersection_reward-waittime_flow1000_lane4.0_length750_speed13.89_left0.25/" --num_episodes 50

parser = argparse.ArgumentParser(description='Arguments')
parser.add_argument('--flow', type=int, default=1000, help='Flow of cars')
parser.add_argument('--lane', type=float, default=4.0, help='Number of lanes')
parser.add_argument('--length', type=float, default=750, help='Length of lanes')
parser.add_argument('--speed', type=float, default=13.89, help='Speed limit')
parser.add_argument('--left', type=float, default=0.25, help='Left turn ratio')
parser.add_argument('--model_num', type=int, default=1, help='Model number')
parser.add_argument('--source_path_name', type=str, default="intersection_flow1000_lane4.0_length750.0_speed13.89_left0.25/", help='pathname')
parser.add_argument('--num_episodes', type=int, default=50, help='Number of episodes')
parser.add_argument('--reward', type=str, default='waittime', help='We only support wait time reward for transferring now.')
args = parser.parse_args()