google-deepmind / tapnet

Tracking Any Point (TAP)
https://deepmind-tapir.github.io/blogpost.html
Apache License 2.0
1.28k stars 120 forks source link

Torch-version Online TAPIR Checkpoint Required for Testing #106

Closed nku-zhichengzhang closed 2 months ago

nku-zhichengzhang commented 3 months ago

Hi, Thank you for the outstanding work on the Online TAPIR model. I am currently working on testing this model for a downstream benchmark, and I've encountered an issue with the provided checkpoint and test demo scripts, which are currently in Jax. Unfortunately, there doesn't seem to be a Torch-version Online TAPIR checkpoint available. Could you kindly provide an Online TAPIR checkpoint compatible with Torch?

yangyi02 commented 2 months ago

We have a Online BootsTAPIR checkpoint in PyTorch available, is that good for you? Or do you really need the specific Online TAPIR checkpoint in PyTorch?

https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt

nku-zhichengzhang commented 2 months ago

We have a Online BootsTAPIR checkpoint in PyTorch available, is that good for you? Or do you really need the specific Online TAPIR checkpoint in PyTorch?

https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt

Thanks for your feedback and checkpoint. We have addressed the issue by using a Jax-version checkpoint and a demo of Online TAPIR.

nku-zhichengzhang commented 2 months ago

BTW, is there any toolkit that can transfer a Jax checkpoint into the Torch one? We find the parameters of operations such as Conv are different in Torch [N_out, N_in, kernel, kernel] and Jax [kernel, kernel, N_in, N_out].

yangyi02 commented 2 months ago

Here is an examplar code for how we convert jax tapir to pytorch, it is quite manual and specific:

def load_weights(model, params):
  # copy weights from Haiku to Torch data structure and load checkpoint into ResNetTorch class
  state_dict = dict(model.state_dict())
  for k, val in state_dict.items():
    if 'conv' in k:
      hk_key = k.replace('.weight', '')
      hk_key = hk_key.replace('s.', '_')
      hk_key = hk_key.replace('.','/~/')
      hk_key = hk_key.replace('proj_conv','shortcut_conv')
      hk_key = "tapir/~/resnet/~/" + hk_key
      new_val = torch.tensor(params[hk_key]['w']).permute(3,2,0,1)
      # print(k, hk_key)
      assert np.allclose(np.array(val.shape), np.array(new_val.shape))
      state_dict[k] = new_val
    else:
      hk_key = k.replace('.weight', '')
      hk_key = hk_key.replace('.bias', '')
      hk_key = hk_key.replace('s.', '_')
      hk_key = hk_key.replace('.','/~/')
      hk_key = hk_key.replace('/bn_','/instancenorm_')
      hk_key = "tapir/~/resnet/~/" + hk_key
      if 'weight' in k:
        new_val = torch.tensor(params[hk_key]['scale'])
      else:
        new_val = torch.tensor(params[hk_key]['offset'])
      assert np.allclose(np.array(val.shape), np.array(new_val.shape))
      state_dict[k] = new_val
  model.load_state_dict(state_dict)
  model.eval()
  return model

def load_mixer_weights(params, model, net_name):
  state_dict = dict(model.state_dict())
  for ok, val in state_dict.items():
    k = ok.split('.')
    if 'blocks' in k:
      n = k[1]
      if n == '0':
        n = ''
      else:
        n = f'_{n}'
      ln = f'{net_name}block{n}/{k[-2]}'
    else:
      ln = net_name + k[0]

    jp = params[ln]
    if 'layer_norm' in k[-2]:
      weightkey = 'scale'
    else:
      weightkey = k[-1][0]

    assert len(jp[weightkey].shape) == len(val.shape)
    if weightkey == 'w':
      if len(jp[weightkey].shape) == 3:
        new_val = torch.tensor(np.array(jp[weightkey])).permute(2,1,0)
      elif len(jp[weightkey].shape) == 2:
        new_val = torch.tensor(np.array(jp[weightkey])).transpose(1,0)
    else:
      new_val = torch.tensor(np.array(jp[weightkey]))

    state_dict[ok] = new_val
    # print('l', ln,
    #       jp[weightkey].sum() - state_dict[ok].detach().numpy().sum(),
    #       jp[weightkey].var() - state_dict[ok].detach().numpy().var(),
    #       jp[weightkey].std() - state_dict[ok].detach().numpy().std())
    #       # val.shape, jp[weightkey].shape, weightkey,

  model.load_state_dict(state_dict)
  model.eval()
  return model

self.resnet_torch = load_weights(self.resnet_torch, params)
self.torch_pips_mixer = load_mixer_weights(params, self.torch_pips_mixer, 'tapir/~/pips_mlp_mixer/')
mod_names = {'hid1': 'cost_volume_regression_1',
'hid2': 'cost_volume_regression_2',
'hid3': 'cost_volume_occlusion_1',
'hid4': 'cost_volume_occlusion_2',
'occ_out': 'occlusion_out'}

for torch_name, jax_name in mod_names.items():
  jax_params = params[f'tapir/~/{jax_name}']
  torch_mod = self.torch_cost_volume_track_mods[torch_name]
  torch_mod.bias.data = torch.tensor(jax_params['b'])
  if isinstance(torch_mod, nn.Conv2d):
    torch_mod.weight.data = torch.tensor(jax_params['w']).permute(3, 2, 0, 1)
  elif isinstance(torch_mod, nn.Linear):
    torch_mod.weight.data = torch.tensor(jax_params['w']).permute(1, 0)

torch.save(self.torch_cost_volume_track_mods.state_dict(), 'torch_cost_volume_track_mods.pt')
torch.save(self.resnet_torch.state_dict(), 'resnet_torch.pt')
torch.save(self.torch_pips_mixer.state_dict(), 'torch_pips_mixer.pt')
nku-zhichengzhang commented 2 months ago

Here is an examplar code for how we convert jax tapir to pytorch, it is quite manual and specific:

def load_weights(model, params):
  # copy weights from Haiku to Torch data structure and load checkpoint into ResNetTorch class
  state_dict = dict(model.state_dict())
  for k, val in state_dict.items():
    if 'conv' in k:
      hk_key = k.replace('.weight', '')
      hk_key = hk_key.replace('s.', '_')
      hk_key = hk_key.replace('.','/~/')
      hk_key = hk_key.replace('proj_conv','shortcut_conv')
      hk_key = "tapir/~/resnet/~/" + hk_key
      new_val = torch.tensor(params[hk_key]['w']).permute(3,2,0,1)
      # print(k, hk_key)
      assert np.allclose(np.array(val.shape), np.array(new_val.shape))
      state_dict[k] = new_val
    else:
      hk_key = k.replace('.weight', '')
      hk_key = hk_key.replace('.bias', '')
      hk_key = hk_key.replace('s.', '_')
      hk_key = hk_key.replace('.','/~/')
      hk_key = hk_key.replace('/bn_','/instancenorm_')
      hk_key = "tapir/~/resnet/~/" + hk_key
      if 'weight' in k:
        new_val = torch.tensor(params[hk_key]['scale'])
      else:
        new_val = torch.tensor(params[hk_key]['offset'])
      assert np.allclose(np.array(val.shape), np.array(new_val.shape))
      state_dict[k] = new_val
  model.load_state_dict(state_dict)
  model.eval()
  return model

def load_mixer_weights(params, model, net_name):
  state_dict = dict(model.state_dict())
  for ok, val in state_dict.items():
    k = ok.split('.')
    if 'blocks' in k:
      n = k[1]
      if n == '0':
        n = ''
      else:
        n = f'_{n}'
      ln = f'{net_name}block{n}/{k[-2]}'
    else:
      ln = net_name + k[0]

    jp = params[ln]
    if 'layer_norm' in k[-2]:
      weightkey = 'scale'
    else:
      weightkey = k[-1][0]

    assert len(jp[weightkey].shape) == len(val.shape)
    if weightkey == 'w':
      if len(jp[weightkey].shape) == 3:
        new_val = torch.tensor(np.array(jp[weightkey])).permute(2,1,0)
      elif len(jp[weightkey].shape) == 2:
        new_val = torch.tensor(np.array(jp[weightkey])).transpose(1,0)
    else:
      new_val = torch.tensor(np.array(jp[weightkey]))

    state_dict[ok] = new_val
    # print('l', ln,
    #       jp[weightkey].sum() - state_dict[ok].detach().numpy().sum(),
    #       jp[weightkey].var() - state_dict[ok].detach().numpy().var(),
    #       jp[weightkey].std() - state_dict[ok].detach().numpy().std())
    #       # val.shape, jp[weightkey].shape, weightkey,

  model.load_state_dict(state_dict)
  model.eval()
  return model

self.resnet_torch = load_weights(self.resnet_torch, params)
self.torch_pips_mixer = load_mixer_weights(params, self.torch_pips_mixer, 'tapir/~/pips_mlp_mixer/')
mod_names = {'hid1': 'cost_volume_regression_1',
'hid2': 'cost_volume_regression_2',
'hid3': 'cost_volume_occlusion_1',
'hid4': 'cost_volume_occlusion_2',
'occ_out': 'occlusion_out'}

for torch_name, jax_name in mod_names.items():
  jax_params = params[f'tapir/~/{jax_name}']
  torch_mod = self.torch_cost_volume_track_mods[torch_name]
  torch_mod.bias.data = torch.tensor(jax_params['b'])
  if isinstance(torch_mod, nn.Conv2d):
    torch_mod.weight.data = torch.tensor(jax_params['w']).permute(3, 2, 0, 1)
  elif isinstance(torch_mod, nn.Linear):
    torch_mod.weight.data = torch.tensor(jax_params['w']).permute(1, 0)

torch.save(self.torch_cost_volume_track_mods.state_dict(), 'torch_cost_volume_track_mods.pt')
torch.save(self.resnet_torch.state_dict(), 'resnet_torch.pt')
torch.save(self.torch_pips_mixer.state_dict(), 'torch_pips_mixer.pt')

Thanks for the script, I'll try it.