Closed nono-Sang closed 2 months ago
存在的问题:当 torch module 对象动态添加函数属性后,oneflow_compile 会失败。
oneflow_compile
例如:
正常情况,没有问题。
class PyTorchModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 3) def forward(self, x): return self.linear(x)
class OneFlowModel(flow.nn.Module): def init(self): super().init() self.linear = flow.nn.Linear(2, 3)
def forward(self, x): return self.linear(x)
cls_key = transform_mgr.get_transformed_entity_name(PyTorchModel) transform_mgr.update_class_proxies({cls_key: OneFlowModel})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pytorch_model = PyTorchModel().to(device)
of_model = oneflow_compile(pytorch_model)
x = torch.randn(1, 2).to(device) y_pt = pytorch_model(x) y_of = of_model(x)
* 添加一个函数属性后,失败。 ```py def torch_func(): pass pytorch_model.func = torch_func ## 添加函数属性 of_model = oneflow_compile(pytorch_model) x = torch.randn(1, 2).to(device) y_pt = pytorch_model(x) y_of = of_model(x)
本质是 ProxySubmodule 以及 torch2oflow 导致的,具体分析在:https://github.com/siliconflow/sd-team/issues/384
ProxySubmodule
torch2oflow
是的,动态添加属性还没支持。
存在的问题:当 torch module 对象动态添加函数属性后,
oneflow_compile
会失败。例如:
正常情况,没有问题。
class OneFlowModel(flow.nn.Module): def init(self): super().init() self.linear = flow.nn.Linear(2, 3)
cls_key = transform_mgr.get_transformed_entity_name(PyTorchModel) transform_mgr.update_class_proxies({cls_key: OneFlowModel})
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pytorch_model = PyTorchModel().to(device)
of_model = oneflow_compile(pytorch_model)
x = torch.randn(1, 2).to(device) y_pt = pytorch_model(x) y_of = of_model(x)
本质是
ProxySubmodule
以及torch2oflow
导致的,具体分析在:https://github.com/siliconflow/sd-team/issues/384