Open jinmin527 opened 1 year ago
https://github.com/OpenBMB/CPM-Bee/pull/100 这是qlora的pr,还没有合,然后bmtrain不知道修复没有int8相关问题,未修复的时候需要本地魔改下BMTrain.blocklayer下面这处就可以了: (blocklayer中不能很好地传requires_grad,需要手动判断dtype类型并设置requires_grad = False)
if dtype == torch.uint8:
storage_param = torch.nn.Parameter(
torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer),
requires_grad = False,
)
else:
storage_param = torch.nn.Parameter(
torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer),
)
我git clone bmtrain 0.2.2版本的代码,用bmtrain_qlora目录替换BMTrain中的bmtrain文件夹,然后进行python setup.py develop操作。测试的时候结果报错
不确定是不是我将bmtrain目录覆盖导致的问题,能否提供一个类似于BMTrain完整的工程