vimalabs / VIMA

Official Algorithm Implementation of ICML'23 Paper "VIMA: General Robot Manipulation with Multimodal Prompts"
MIT License
778 stars 87 forks source link

Request for Guidance on Implementing imitation_loss #39

Closed nisshimura closed 1 year ago

nisshimura commented 1 year ago

Thank you for the snippet you provided #28 ; it has been immensely helpful. I am truly grateful for your remarkable contributions and research.

Using your snippet as a reference, I have crafted my training code. However, it seems that the imitation_loss you previously mentioned hasn't been implemented.

Could you provide guidance on implementing the imitation_loss or suggest another way to compute it? Additionally, if you notice any ambiguities or potential issues in my training code, I would greatly appreciate your insights.

Here's the error I encountered:

Exception has occurred: AttributeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'dict' object has no attribute 'imitation_loss'
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 610, in train
    imitation_loss = dist_dict.imitation_loss(actions=tar_action)
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 633, in main
    train(policy, dataloader, optimizer, epochs=cfg.epochs)
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 642, in <module>
    main(args)
  File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
AttributeError: 'dict' object has no attribute 'imitation_loss'

And here is the relevant training code snippet:

def iteration(policy, batch):
    obs = batch['obs']
    action = batch['action']
    prompt_token_type = batch['prompt_token_type']
    word_batch = batch['word_batch']
    image_batch = batch['image_batch']

    prompt_tokens, prompt_masks = policy.forward_prompt_assembly(
        (prompt_token_type, word_batch, image_batch)
    )
    obs_tokens, obs_masks = policy.forward_obs_token(obs)

    action = policy.discretize_action(action)
    # cache target action
    tar_action = {k: v.clone() for k, v in action.items()}

    # slice action sequence up to the last one
    action_tokens = policy.forward_action_token(action)

    action_tokens = action_tokens.transpose(1,0)
    obs_tokens = obs_tokens.transpose(1,0)
    obs_masks = obs_masks.transpose(1,0)

    pred_action_tokens = policy.forward(
        obs_token=obs_tokens,
        action_token=action_tokens,
        prompt_token=prompt_tokens,
        prompt_token_mask=prompt_masks,
        obs_mask=obs_masks,
    )# (L, B, E)
    # pred_action_tokens = pred_action_tokens[-2:].contiguous()# (2, B, E)
    dist_dict = policy.forward_action_decoder(pred_action_tokens)
    tar_action = policy._de_discretize_actions(tar_action)
    return dist_dict, tar_action

def train(policy, dataloader, optimizer, validation_dataloader=None,epochs=10):
    wandb.init(project="VIMA", name=f"VIMA_{NOW}")
    wandb.watch(policy, log_freq=100)  # モデルのパラメータと勾配をログします。

    policy.train()  # モデルを学習モードに設定
    best_val_loss = float('inf')
    no_improve_count = 0

    for epoch in range(epochs):
        total_epoch_loss = 0
        for batch in tqdm(dataloader,desc=f"Epoch {epoch + 1}/{epochs}"):
            dist_dict, tar_action = iteration(policy, batch)

            total_loss = 0
            # pred_actions = {k: v.mode().detach().clone().requires_grad_() for k, v in dist_dict.items()}

            total_loss = compute_cross_entropy_loss(dist_dict, tar_action)
            # imitation_loss = dist_dict.imitation_loss(actions=tar_action)

            # imitation_loss.backward()
            total_loss.backward()

            optimizer.step()
            total_epoch_loss += total_loss.item()

def get_pred(pred_actions, key, time, index):
    return pred_actions[key]._dists[index].probs[time]

def get_true(tar_action, key, time, index):
    return tar_action[key][:, time, index]

def compute_cross_entropy_loss(pred_actions, tar_action):
    keys = ['pose0_position', 'pose1_position', 'pose0_rotation', 'pose1_rotation']
    indices = {
        'pose0_position': [0, 1],
        'pose1_position': [0, 1],
        'pose0_rotation': [0, 1, 2, 3],
        'pose1_rotation': [0, 1, 2, 3]
    }
    times = [0, 1]  # 0 for t2, 1 for t1

    total_loss = 0
    for key in keys:
        for time in times:
            for index in indices[key]:
                pred = get_pred(pred_actions, key, time, index)
                true = get_true(tar_action, key, time, index).long()
                total_loss += criterion(pred, true)

    return total_loss
yunfanjiang commented 1 year ago

Hi there, please note that imitation_loss internally implements the cross-entropy loss. You could define your own dict class and implement it based on torch.nn.functional.cross_entropy.

aopolin-lv commented 1 year ago

Hello, have you trained vima successfully? For example, the replement result is similar to that in paper

odinprince commented 9 months ago

Thank you for the snippet you provided #28 ; it has been immensely helpful. I am truly grateful for your remarkable contributions and research.

Using your snippet as a reference, I have crafted my training code. However, it seems that the imitation_loss you previously mentioned hasn't been implemented.

Could you provide guidance on implementing the imitation_loss or suggest another way to compute it? Additionally, if you notice any ambiguities or potential issues in my training code, I would greatly appreciate your insights.

Here's the error I encountered:

Exception has occurred: AttributeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
'dict' object has no attribute 'imitation_loss'
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 610, in train
    imitation_loss = dist_dict.imitation_loss(actions=tar_action)
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 633, in main
    train(policy, dataloader, optimizer, epochs=cfg.epochs)
  File "/home/initial/workspace/VIMA2/VIMA/scripts/train.py", line 642, in <module>
    main(args)
  File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/initial/.pyenv/versions/3.9.16/lib/python3.9/runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,
AttributeError: 'dict' object has no attribute 'imitation_loss'

And here is the relevant training code snippet:

def iteration(policy, batch):
    obs = batch['obs']
    action = batch['action']
    prompt_token_type = batch['prompt_token_type']
    word_batch = batch['word_batch']
    image_batch = batch['image_batch']

    prompt_tokens, prompt_masks = policy.forward_prompt_assembly(
        (prompt_token_type, word_batch, image_batch)
    )
    obs_tokens, obs_masks = policy.forward_obs_token(obs)

    action = policy.discretize_action(action)
    # cache target action
    tar_action = {k: v.clone() for k, v in action.items()}

    # slice action sequence up to the last one
    action_tokens = policy.forward_action_token(action)

    action_tokens = action_tokens.transpose(1,0)
    obs_tokens = obs_tokens.transpose(1,0)
    obs_masks = obs_masks.transpose(1,0)

    pred_action_tokens = policy.forward(
        obs_token=obs_tokens,
        action_token=action_tokens,
        prompt_token=prompt_tokens,
        prompt_token_mask=prompt_masks,
        obs_mask=obs_masks,
    )# (L, B, E)
    # pred_action_tokens = pred_action_tokens[-2:].contiguous()# (2, B, E)
    dist_dict = policy.forward_action_decoder(pred_action_tokens)
    tar_action = policy._de_discretize_actions(tar_action)
    return dist_dict, tar_action

def train(policy, dataloader, optimizer, validation_dataloader=None,epochs=10):
    wandb.init(project="VIMA", name=f"VIMA_{NOW}")
    wandb.watch(policy, log_freq=100)  # モデルのパラメータと勾配をログします。

    policy.train()  # モデルを学習モードに設定
    best_val_loss = float('inf')
    no_improve_count = 0

    for epoch in range(epochs):
        total_epoch_loss = 0
        for batch in tqdm(dataloader,desc=f"Epoch {epoch + 1}/{epochs}"):
            dist_dict, tar_action = iteration(policy, batch)

            total_loss = 0
            # pred_actions = {k: v.mode().detach().clone().requires_grad_() for k, v in dist_dict.items()}

            total_loss = compute_cross_entropy_loss(dist_dict, tar_action)
            # imitation_loss = dist_dict.imitation_loss(actions=tar_action)

            # imitation_loss.backward()
            total_loss.backward()

            optimizer.step()
            total_epoch_loss += total_loss.item()

def get_pred(pred_actions, key, time, index):
    return pred_actions[key]._dists[index].probs[time]

def get_true(tar_action, key, time, index):
    return tar_action[key][:, time, index]

def compute_cross_entropy_loss(pred_actions, tar_action):
    keys = ['pose0_position', 'pose1_position', 'pose0_rotation', 'pose1_rotation']
    indices = {
        'pose0_position': [0, 1],
        'pose1_position': [0, 1],
        'pose0_rotation': [0, 1, 2, 3],
        'pose1_rotation': [0, 1, 2, 3]
    }
    times = [0, 1]  # 0 for t2, 1 for t1

    total_loss = 0
    for key in keys:
        for time in times:
            for index in indices[key]:
                pred = get_pred(pred_actions, key, time, index)
                true = get_true(tar_action, key, time, index).long()
                total_loss += criterion(pred, true)

    return total_loss

It would be beneficial for people like me ..if you can drop some info about the training code you wrote or share the training snippet.