Closed nisshimura closed 1 year ago
Hi there, please note that imitation_loss
internally implements the cross-entropy loss. You could define your own dict class and implement it based on torch.nn.functional.cross_entropy
.
Hello, have you trained vima successfully? For example, the replement result is similar to that in paper
Thank you for the snippet you provided #28 ; it has been immensely helpful. I am truly grateful for your remarkable contributions and research.
Using your snippet as a reference, I have crafted my training code. However, it seems that the imitation_loss you previously mentioned hasn't been implemented.
Could you provide guidance on implementing the imitation_loss or suggest another way to compute it? Additionally, if you notice any ambiguities or potential issues in my training code, I would greatly appreciate your insights.
Here's the error I encountered:
Exception has occurred: AttributeError (note: full exception trace is shown but execution is paused at: _run_module_as_main) 'dict' object has no attribute 'imitation_loss' File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 610, in train imitation_loss = dist_dict.imitation_loss(actions=tar_action) File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 633, in main train(policy, dataloader, optimizer, epochs=cfg.epochs) File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 642, in <module> main(args) File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame) return _run_code(code, main_globals, None, AttributeError: 'dict' object has no attribute 'imitation_loss'
And here is the relevant training code snippet:
def iteration(policy, batch): obs = batch['obs'] action = batch['action'] prompt_token_type = batch['prompt_token_type'] word_batch = batch['word_batch'] image_batch = batch['image_batch'] prompt_tokens, prompt_masks = policy.forward_prompt_assembly( (prompt_token_type, word_batch, image_batch) ) obs_tokens, obs_masks = policy.forward_obs_token(obs) action = policy.discretize_action(action) # cache target action tar_action = {k: v.clone() for k, v in action.items()} # slice action sequence up to the last one action_tokens = policy.forward_action_token(action) action_tokens = action_tokens.transpose(1,0) obs_tokens = obs_tokens.transpose(1,0) obs_masks = obs_masks.transpose(1,0) pred_action_tokens = policy.forward( obs_token=obs_tokens, action_token=action_tokens, prompt_token=prompt_tokens, prompt_token_mask=prompt_masks, obs_mask=obs_masks, )# (L, B, E) # pred_action_tokens = pred_action_tokens[-2:].contiguous()# (2, B, E) dist_dict = policy.forward_action_decoder(pred_action_tokens) tar_action = policy._de_discretize_actions(tar_action) return dist_dict, tar_action def train(policy, dataloader, optimizer, validation_dataloader=None,epochs=10): wandb.init(project="VIMA", name=f"VIMA_{NOW}") wandb.watch(policy, log_freq=100) # モデルのパラメータと勾配をログします。 policy.train() # モデルを学習モードに設定 best_val_loss = float('inf') no_improve_count = 0 for epoch in range(epochs): total_epoch_loss = 0 for batch in tqdm(dataloader,desc=f"Epoch {epoch + 1}/{epochs}"): dist_dict, tar_action = iteration(policy, batch) total_loss = 0 # pred_actions = {k: v.mode().detach().clone().requires_grad_() for k, v in dist_dict.items()} total_loss = compute_cross_entropy_loss(dist_dict, tar_action) # imitation_loss = dist_dict.imitation_loss(actions=tar_action) # imitation_loss.backward() total_loss.backward() optimizer.step() total_epoch_loss += total_loss.item() def get_pred(pred_actions, key, time, index): return pred_actions[key]._dists[index].probs[time] def get_true(tar_action, key, time, index): return tar_action[key][:, time, index] def compute_cross_entropy_loss(pred_actions, tar_action): keys = ['pose0_position', 'pose1_position', 'pose0_rotation', 'pose1_rotation'] indices = { 'pose0_position': [0, 1], 'pose1_position': [0, 1], 'pose0_rotation': [0, 1, 2, 3], 'pose1_rotation': [0, 1, 2, 3] } times = [0, 1] # 0 for t2, 1 for t1 total_loss = 0 for key in keys: for time in times: for index in indices[key]: pred = get_pred(pred_actions, key, time, index) true = get_true(tar_action, key, time, index).long() total_loss += criterion(pred, true) return total_loss
It would be beneficial for people like me ..if you can drop some info about the training code you wrote or share the training snippet.
Thank you for the snippet you provided #28 ; it has been immensely helpful. I am truly grateful for your remarkable contributions and research.
Using your snippet as a reference, I have crafted my training code. However, it seems that the imitation_loss you previously mentioned hasn't been implemented.
Could you provide guidance on implementing the imitation_loss or suggest another way to compute it? Additionally, if you notice any ambiguities or potential issues in my training code, I would greatly appreciate your insights.
Here's the error I encountered:
And here is the relevant training code snippet: