Closed sjgosai closed 9 months ago
Update this function:
def filter_state_dict(model, stashed_dict, fill_tensor=False):
results_dict = {
'filtered_state_dict': {},
'passed_keys' : [],
'removed_keys' : [],
'missing_keys' : [],
'unloaded_keys': []
}
old_dict = model.state_dict()
for m_key, m_value in old_dict.items():
try:
if old_dict[m_key].shape == stashed_dict[m_key].shape:
results_dict['filtered_state_dict'][m_key] = stashed_dict[m_key]
results_dict['passed_keys'].append(m_key)
print(f'Key {m_key} successfully matched', file=sys.stderr)
else:
check_str = 'Size mismatch for key: {}, expected size {}, got {}' \
.format(m_key, old_dict[m_key].shape, stashed_dict[m_key].shape)
if fill_tensor:
print("Filling")
"Do stuff to fill out tensor" # Add code here
results_dict['removed_keys'].append(m_key)
print(check_str, file=sys.stderr)
except KeyError:
results_dict['missing_keys'].append(m_key)
print(f'Missing key in dict: {m_key}', file=sys.stderr)
for m_key, m_value in stashed_dict.items():
if m_key not in old_dict.keys():
check_str = 'Skipped loading key: {} of size {}' \
.format(m_key, m_value.shape)
results_dict['unloaded_keys'].append(m_key)
print(check_str, file=sys.stderr)
return results_dict
Here's a reference for opening a new branch and making changes. https://git-scm.com/book/en/v2/Git-Branching-Basic-Branching-and-Merging
We need to load weights into the branched linear layers when the number of branches change. Need to update
/boda2/boda/graph/utils.py
.