def load_from_torch_shard_ckpt(model, ckpt_dir):
"""
Load sharded checkpoints directly from huggingface dir.
"""
with open(os.path.join(ckpt_dir, 'pytorch_model.bin.index.json')) as fp:
ckpt_index = json.load(fp)
total_size = ckpt_index['metadata']['total_size']
weight_map = ckpt_index['weight_map']
file_weight_map = {}
for key, value in weight_map.items():
# key: param name; value: filename.
if value not in file_weight_map:
file_weight_map[value] = []
file_weight_map[value].append(key)
load_from_map(model, ckpt_dir, file_weight_map)
def load_from_torch_shard_ckpt(model, ckpt_dir): """ Load sharded checkpoints directly from huggingface dir. """ with open(os.path.join(ckpt_dir, 'pytorch_model.bin.index.json')) as fp: ckpt_index = json.load(fp)