Open d710055071 opened 6 months ago
@FUNCTION_REWRITER.register_rewriter(func_name='copy.deepcopy')
def copy__default(tensor: Tensor, *args, **kwargs) -> Tensor:
"""Rewrite `copy.deepcopy` for default backend.
Replace it with tensor.clone(), or may raise `NYI: Named tensors are not
supported with the tracer`
"""
ctx = FUNCTION_REWRITER.get_context()
# if isinstance(tensor, Tensor) and args == () and kwargs == {}:
if isinstance(tensor, Tensor):
return tensor.clone()
elif isinstance(tensor, dict):
# from copy import deepcopy
def deepcopy_dict(obj,memo={}):
if isinstance(obj, dict):
# 如果obj是字典,则创建新的空字典并递归拷贝其中的值
copied_obj = {}
memo[id(obj)] = copied_obj # 存储已拷贝的字典引用
for key, value in obj.items():
copied_obj[deepcopy_dict(key, memo)] = deepcopy_dict(value, memo)
return copied_obj
elif isinstance(obj, list):
# 如果obj是列表,则创建新的空列表并递归拷贝其中的元素
copied_obj = []
memo[id(obj)] = copied_obj # 存储已拷贝的列表引用
for item in obj:
copied_obj.append(deepcopy_dict(item, memo))
return copied_obj
elif isinstance(obj, set):
# 如果obj是集合,则创建新的空集合并递归拷贝其中的元素
copied_obj = set()
memo[id(obj)] = copied_obj
for item in obj:
copied_obj.add(deepcopy_dict(item, memo))
return copied_obj
elif isinstance(obj, (int, float, complex, str, bytes, tuple, frozenset, type(None))):
# 如果obj是不可变类型,则直接返回
return obj
elif id(obj) in memo:
# 如果obj已经被拷贝过,则直接返回其拷贝
return memo[id(obj)]
else:
# 对于其他类型,尝试使用copy模块的deepcopy(如果需要)
try:
# import copy
return copy__default(obj, memo)
except Exception as e:
raise TypeError(f"Unsupported type {type(obj)} in deepcopy") from e
return deepcopy_dict(tensor, *args, **kwargs)
else:
pass
return ctx.origin_func(tensor, *args, **kwargs)
@RunningLeon
hi, sorry for the issue. This project is not actively maintained. Welcome to PR us to fix any bugs. Thanks for your understanding.
@RunningLeon Thank you, that worked. Colud you do a PR?
@RunningLeon Thank you, that worked. Colud you do a PR?
这个代码没有严格经过测试,网上找的 只能做为临时方案,问题的原因是当是字典时如果不处理还是会调用对象重载的深拷贝函数导致
@RunningLeon I see. tha's true.
@d710055071 🐮🍺
Checklist
Describe the bug
NYI: Named tensors are not supported with the tracer
Reproduction
Environment
Error traceback