THUDM / SwissArmyTransformer

SwissArmyTransformer is a flexible and powerful library to develop your own Transformer variants.
https://THUDM.github.io/SwissArmyTransformer
Apache License 2.0
951 stars 90 forks source link

如果想绕过deepspeed做finetune,可以在train的时候直接model.step()来实现吗? #172

Open cocoshe opened 6 months ago

cocoshe commented 6 months ago

image

或者有什么办法(或者需要注意修改哪些地方),才能实现解开对deepspeed的依赖呢?

1049451037 commented 6 months ago

如果不需要模型并行、zero优化器等技术,sat构造出来的model就可以当作一个正常的pytorch module来用。

from sat import AutoModel
model, args = AutoModel.from_pretrained("bert-base-uncased")
model = model.cuda()
inputs = {'input_ids': torch.LongTensor([[1, 2, 3]]).cuda(), 'position_ids': torch.LongTensor([[0, 1, 2]]).cuda(), 'token_type_ids': torch.LongTensor([[0, 0, 0]]).cuda(), 'attention_mask': torch.LongTensor([[[[1]]]]).cuda()}
output = model(**inputs)[0]
loss = output.sum()
loss.backward()
print(loss)