Closed ccssu closed 1 year ago
_has_compatible_shallow_copy_type, # 判断张量是否可以进行浅拷贝
_has_compatible_shallow_copy_type
is a private function in PyTorch that checks whether two tensors have compatible types for shallow copy. It returns a boolean value indicating whether the two tensors have compatible types for shallow copy¹.
源: 与必应的对话, 2023/4/4(1) Function at::_has_compatible_shallow_copy_type. https://pytorch.org/cppdocs/api/function_namespaceat_1a4662355d238cb7049f75ac6cb287d32a.html 访问时间 2023/4/4. (2) pytorch/module.py at master · pytorch/pytorch · GitHub. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py 访问时间 2023/4/4. (3) How to force compiler error if struct shallow copy?. https://stackoverflow.com/questions/68183168/how-to-force-compiler-error-if-struct-shallow-copy 访问时间 2023/4/4.
参考代码 PyTorch: https://github.com/pytorch/pytorch
# torch/nn/modules/module.py
def _apply(self, fn):
for module in self.children():
module._apply(fn)
def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False
for key, param in self._parameters.items():
if param is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
...
OneFlow:
# python/oneflow/nn/modules/module.py
def _apply(self, fn):
# A dict to store tensors that has already been applied.
# There is no need to apply multiple times on a same tensor.
if self._oneflow_internal_module_tensor_applied_dict__ is None:
self._oneflow_internal_module_tensor_applied_dict__ = dict()
for module in self.children():
module._oneflow_internal_module_tensor_applied_dict__ = (
self._oneflow_internal_module_tensor_applied_dict__
)
module._apply(fn)
module._oneflow_internal_module_tensor_applied_dict__ = None
def can_use_assign_copy(tensor, tensor_applied):
return tensor.is_local == tensor_applied.is_local
for (key, param) in self._parameters.items():
.....
if can_use_assign_copy(param_applied, param):
if need_apply:
Tensor.char
is not a valid attribute of PyTorch's Tensor class¹. However, you can convert a tensor to a tensor of type torch.int8
using the to()
method¹. For example, tensor.to(torch.int8)
will convert the tensor to a tensor of type torch.int8
.
Let me know if you have any other questions!
源: 与必应的对话, 2023/4/4(1) torch.Tensor.char — PyTorch 2.0 documentation. https://pytorch.org/docs/stable/generated/torch.Tensor.char.html 访问时间 2023/4/4. (2) torch.Tensor — PyTorch 2.0 documentation. https://pytorch.org/docs/stable/tensors.html 访问时间 2023/4/4. (3) TensorCharts.com. https://tensorcharts.com/ 访问时间 2023/4/4.
PyTorch:
# torch/nn/utils/rnn.py
def char(self):
return self.to(dtype=torch.int8)
Tensor.short()
is equivalent to self.to(torch.int16)
. It returns a tensor with dtype torch.int16
. ¹²
源: 与必应的对话, 2023/4/4(1) torch.Tensor.short — PyTorch 2.0 documentation. https://pytorch.org/docs/stable/generated/torch.Tensor.short.html 访问时间 2023/4/4. (2) torch.Tensor — PyTorch 2.0 documentation. https://pytorch.org/docs/stable/tensors.html 访问时间 2023/4/4. (3) Introduction to Tensors | TensorFlow Core. https://www.tensorflow.org/guide/tensor 访问时间 2023/4/4.
PyTorch:
# torch/nn/utils/rnn.py
def short(self):
return self.to(dtype=torch.short)
遇到问题: 没有oneflow.int16数据类型
# python/oneflow/framework/dtype.py
_dtypes = [
oneflow.bool,
oneflow.float,
oneflow.float32,
oneflow.double,
oneflow.float64,
oneflow.float16,
oneflow.int8,
oneflow.int32,
oneflow.int64,
oneflow.uint8,
oneflow.record,
oneflow.tensor_buffer,
oneflow.bfloat16,
oneflow.complex64,
oneflow.cfloat,
oneflow.complex128,
oneflow.cdouble,
]
torch.Tensor.short 暂时无法绕过,原因: 缺失 oneflow.int16类型