Closed nku-zhichengzhang closed 2 months ago
We have a Online BootsTAPIR checkpoint in PyTorch available, is that good for you? Or do you really need the specific Online TAPIR checkpoint in PyTorch?
https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt
We have a Online BootsTAPIR checkpoint in PyTorch available, is that good for you? Or do you really need the specific Online TAPIR checkpoint in PyTorch?
https://storage.googleapis.com/dm-tapnet/bootstap/causal_bootstapir_checkpoint.pt
Thanks for your feedback and checkpoint. We have addressed the issue by using a Jax-version checkpoint and a demo of Online TAPIR.
BTW, is there any toolkit that can transfer a Jax checkpoint into the Torch one? We find the parameters of operations such as Conv are different in Torch [N_out, N_in, kernel, kernel] and Jax [kernel, kernel, N_in, N_out].
Here is an examplar code for how we convert jax tapir to pytorch, it is quite manual and specific:
def load_weights(model, params):
# copy weights from Haiku to Torch data structure and load checkpoint into ResNetTorch class
state_dict = dict(model.state_dict())
for k, val in state_dict.items():
if 'conv' in k:
hk_key = k.replace('.weight', '')
hk_key = hk_key.replace('s.', '_')
hk_key = hk_key.replace('.','/~/')
hk_key = hk_key.replace('proj_conv','shortcut_conv')
hk_key = "tapir/~/resnet/~/" + hk_key
new_val = torch.tensor(params[hk_key]['w']).permute(3,2,0,1)
# print(k, hk_key)
assert np.allclose(np.array(val.shape), np.array(new_val.shape))
state_dict[k] = new_val
else:
hk_key = k.replace('.weight', '')
hk_key = hk_key.replace('.bias', '')
hk_key = hk_key.replace('s.', '_')
hk_key = hk_key.replace('.','/~/')
hk_key = hk_key.replace('/bn_','/instancenorm_')
hk_key = "tapir/~/resnet/~/" + hk_key
if 'weight' in k:
new_val = torch.tensor(params[hk_key]['scale'])
else:
new_val = torch.tensor(params[hk_key]['offset'])
assert np.allclose(np.array(val.shape), np.array(new_val.shape))
state_dict[k] = new_val
model.load_state_dict(state_dict)
model.eval()
return model
def load_mixer_weights(params, model, net_name):
state_dict = dict(model.state_dict())
for ok, val in state_dict.items():
k = ok.split('.')
if 'blocks' in k:
n = k[1]
if n == '0':
n = ''
else:
n = f'_{n}'
ln = f'{net_name}block{n}/{k[-2]}'
else:
ln = net_name + k[0]
jp = params[ln]
if 'layer_norm' in k[-2]:
weightkey = 'scale'
else:
weightkey = k[-1][0]
assert len(jp[weightkey].shape) == len(val.shape)
if weightkey == 'w':
if len(jp[weightkey].shape) == 3:
new_val = torch.tensor(np.array(jp[weightkey])).permute(2,1,0)
elif len(jp[weightkey].shape) == 2:
new_val = torch.tensor(np.array(jp[weightkey])).transpose(1,0)
else:
new_val = torch.tensor(np.array(jp[weightkey]))
state_dict[ok] = new_val
# print('l', ln,
# jp[weightkey].sum() - state_dict[ok].detach().numpy().sum(),
# jp[weightkey].var() - state_dict[ok].detach().numpy().var(),
# jp[weightkey].std() - state_dict[ok].detach().numpy().std())
# # val.shape, jp[weightkey].shape, weightkey,
model.load_state_dict(state_dict)
model.eval()
return model
self.resnet_torch = load_weights(self.resnet_torch, params)
self.torch_pips_mixer = load_mixer_weights(params, self.torch_pips_mixer, 'tapir/~/pips_mlp_mixer/')
mod_names = {'hid1': 'cost_volume_regression_1',
'hid2': 'cost_volume_regression_2',
'hid3': 'cost_volume_occlusion_1',
'hid4': 'cost_volume_occlusion_2',
'occ_out': 'occlusion_out'}
for torch_name, jax_name in mod_names.items():
jax_params = params[f'tapir/~/{jax_name}']
torch_mod = self.torch_cost_volume_track_mods[torch_name]
torch_mod.bias.data = torch.tensor(jax_params['b'])
if isinstance(torch_mod, nn.Conv2d):
torch_mod.weight.data = torch.tensor(jax_params['w']).permute(3, 2, 0, 1)
elif isinstance(torch_mod, nn.Linear):
torch_mod.weight.data = torch.tensor(jax_params['w']).permute(1, 0)
torch.save(self.torch_cost_volume_track_mods.state_dict(), 'torch_cost_volume_track_mods.pt')
torch.save(self.resnet_torch.state_dict(), 'resnet_torch.pt')
torch.save(self.torch_pips_mixer.state_dict(), 'torch_pips_mixer.pt')
Here is an examplar code for how we convert jax tapir to pytorch, it is quite manual and specific:
def load_weights(model, params): # copy weights from Haiku to Torch data structure and load checkpoint into ResNetTorch class state_dict = dict(model.state_dict()) for k, val in state_dict.items(): if 'conv' in k: hk_key = k.replace('.weight', '') hk_key = hk_key.replace('s.', '_') hk_key = hk_key.replace('.','/~/') hk_key = hk_key.replace('proj_conv','shortcut_conv') hk_key = "tapir/~/resnet/~/" + hk_key new_val = torch.tensor(params[hk_key]['w']).permute(3,2,0,1) # print(k, hk_key) assert np.allclose(np.array(val.shape), np.array(new_val.shape)) state_dict[k] = new_val else: hk_key = k.replace('.weight', '') hk_key = hk_key.replace('.bias', '') hk_key = hk_key.replace('s.', '_') hk_key = hk_key.replace('.','/~/') hk_key = hk_key.replace('/bn_','/instancenorm_') hk_key = "tapir/~/resnet/~/" + hk_key if 'weight' in k: new_val = torch.tensor(params[hk_key]['scale']) else: new_val = torch.tensor(params[hk_key]['offset']) assert np.allclose(np.array(val.shape), np.array(new_val.shape)) state_dict[k] = new_val model.load_state_dict(state_dict) model.eval() return model def load_mixer_weights(params, model, net_name): state_dict = dict(model.state_dict()) for ok, val in state_dict.items(): k = ok.split('.') if 'blocks' in k: n = k[1] if n == '0': n = '' else: n = f'_{n}' ln = f'{net_name}block{n}/{k[-2]}' else: ln = net_name + k[0] jp = params[ln] if 'layer_norm' in k[-2]: weightkey = 'scale' else: weightkey = k[-1][0] assert len(jp[weightkey].shape) == len(val.shape) if weightkey == 'w': if len(jp[weightkey].shape) == 3: new_val = torch.tensor(np.array(jp[weightkey])).permute(2,1,0) elif len(jp[weightkey].shape) == 2: new_val = torch.tensor(np.array(jp[weightkey])).transpose(1,0) else: new_val = torch.tensor(np.array(jp[weightkey])) state_dict[ok] = new_val # print('l', ln, # jp[weightkey].sum() - state_dict[ok].detach().numpy().sum(), # jp[weightkey].var() - state_dict[ok].detach().numpy().var(), # jp[weightkey].std() - state_dict[ok].detach().numpy().std()) # # val.shape, jp[weightkey].shape, weightkey, model.load_state_dict(state_dict) model.eval() return model self.resnet_torch = load_weights(self.resnet_torch, params) self.torch_pips_mixer = load_mixer_weights(params, self.torch_pips_mixer, 'tapir/~/pips_mlp_mixer/') mod_names = {'hid1': 'cost_volume_regression_1', 'hid2': 'cost_volume_regression_2', 'hid3': 'cost_volume_occlusion_1', 'hid4': 'cost_volume_occlusion_2', 'occ_out': 'occlusion_out'} for torch_name, jax_name in mod_names.items(): jax_params = params[f'tapir/~/{jax_name}'] torch_mod = self.torch_cost_volume_track_mods[torch_name] torch_mod.bias.data = torch.tensor(jax_params['b']) if isinstance(torch_mod, nn.Conv2d): torch_mod.weight.data = torch.tensor(jax_params['w']).permute(3, 2, 0, 1) elif isinstance(torch_mod, nn.Linear): torch_mod.weight.data = torch.tensor(jax_params['w']).permute(1, 0) torch.save(self.torch_cost_volume_track_mods.state_dict(), 'torch_cost_volume_track_mods.pt') torch.save(self.resnet_torch.state_dict(), 'resnet_torch.pt') torch.save(self.torch_pips_mixer.state_dict(), 'torch_pips_mixer.pt')
Thanks for the script, I'll try it.
Hi, Thank you for the outstanding work on the Online TAPIR model. I am currently working on testing this model for a downstream benchmark, and I've encountered an issue with the provided checkpoint and test demo scripts, which are currently in Jax. Unfortunately, there doesn't seem to be a Torch-version Online TAPIR checkpoint available. Could you kindly provide an Online TAPIR checkpoint compatible with Torch?