ir-lab / alpha-MDF

Codebase for α-MDF at CoRL 2023
MIT License
8 stars 2 forks source link

How to get the visualization corresponding to attention values in UR5 task? #2

Open acceptallofall opened 9 months ago

acceptallofall commented 9 months ago

Hello,I haven't found the visualization.

Could you also release the way for getting Fig 5 in the paper (which predicted the joint angle trajectories and the corresponding accumulated attention values for each modality).

Thanks!! This repo will make a great contribution to my graduation.

liuxiao1468 commented 9 months ago

Hi there,

If you used this framework for training or testing, you may find in the engine_UR5.py, I collected the attention output at each timestep. You can use the following code block to calculate each the sum of the attention value for plotting, 'sum_1' ... 'sum_4' is the line plot variable for each modality.

            atten = data["attention"]
            attn_map = np.squeeze(np.array(atten))
            # print(attn_map.shape)
            attn_visual_1 = []
            attn_visual_2 = []
            attn_visual_3 = []
            attn_visual_4 = []
            for i in range(attn_map.shape[0]):
                attn_1 = attn_map[i][:, :256]
                attn_2 = attn_map[i][:, 256 : 256 * 2]
                attn_3 = attn_map[i][:, 256 * 2 : 256 * 3]
                attn_4 = attn_map[i][:, 256 * 3 :]
                attn_1 = np.diag(attn_1).reshape((16, 16))
                attn_2 = np.diag(attn_2).reshape((16, 16))
                attn_3 = np.diag(attn_3).reshape((16, 16))
                attn_4 = np.diag(attn_4).reshape((16, 16))
                attn_visual_1.append(attn_1)
                attn_visual_2.append(attn_2)
                attn_visual_3.append(attn_3)
                attn_visual_4.append(attn_4)
            attn_visual_1 = np.array(attn_visual_1)
            attn_visual_2 = np.array(attn_visual_2)
            attn_visual_3 = np.array(attn_visual_3)
            attn_visual_4 = np.array(attn_visual_4)

            sum_1 = np.sum(np.sum(attn_visual_1, axis=1), axis=1)
            sum_2 = np.sum(np.sum(attn_visual_2, axis=1), axis=1)
            sum_3 = np.sum(np.sum(attn_visual_3, axis=1), axis=1)
            sum_4 = np.sum(np.sum(attn_visual_4, axis=1), axis=1)