helmholtz-analytics / heat

Distributed tensors and Machine Learning framework with GPU and MPI acceleration in Python
https://heat.readthedocs.io/
MIT License
210 stars 53 forks source link

[Bug]: factories.array illegal parameter combinations copy & split #1321

Closed mtar closed 6 months ago

mtar commented 9 months ago

What happened?

ht.array creates a new DNDarray on the CPU when the arguments copy = False and split >= 0 are passed. This should not happen. In addition, copy = None returns the same DNDarray which doesn't infer the device.

Code snippet triggering the error

import torch
import heat as ht

m,n = 40,20

a = torch.randn(m, n, dtype=torch.double, device='cuda')
b = ht.array(a, copy=False, split=0, dtype=ht.double)

print('a.device =', a.device)
print('b.device =', b.device)
print('b.split =', b.split)

Error message or erroneous outcome

a.device = cuda:0
b.device = cpu:0
b.split = 0

a.device = cuda:0
b.device = cpu:0
b.split = 0

Version

1.3.x

Python version

None

PyTorch version

None

MPI version

No response

github-actions[bot] commented 8 months ago

Branch bugs/1321-_Bug_factories_array_illegal_parameter_combinationscopy&_split created!