rentruewang / koila

Prevent PyTorch's `CUDA error: out of memory` in just 1 line of code.
https://rentruewang.com/koila/
Apache License 2.0
1.82k stars 62 forks source link

cannot get "device" attribute from LazyTensor #6

Closed arpol closed 2 years ago

arpol commented 2 years ago

I have code that depends on getting the device on which the tensor is stored. The device is then used to initialize a new empty tensor that my model needs. Long story short, if tensor x is wrapped in LazyTensor then accessing x.device leads to an error.

Maybe you need to consider transparently exposing most (if not all) attributes of the wrapped tensor?

rentruewang commented 2 years ago

Thanks for the feedback!

The current implementation works as follows: The LazyTensor doesn't directly wrap a Tensor. It can be either some Tensor or an unevaluated Evaluation, which is basically a function with its arguments stored. While this prevents Tensors from being prematurely evaluated, thus using GPUs, I can't really 'expose' the underlying Tensor's attributes, as sometimes there really isn't a Tensor hidden in a LazyTensor to begin with.

The implementation also assumes that a member is a method that wraps a global function if it's not found in attributes. For example, LazyTensor.add directly looks up torch.add and wraps self in it. The benefit of this approach is: it saves time. This way instead of registering both Tensor.add and torch.add, I can register the name "add" and call it a day. Because of that reason, calling .device will not work in the current implementation.

Also, because this library work with both LazyTensor and Tensor, the current way of handling things is by calling isinstance to determine between types. But it leads to complicated code, and can lead to RecursionError if .device is defined as a property: Since if the property.device is defined on a protocol, checking if the tensor is a member of that protocol will evaluate the property, thus checking if the tensor is a member of that protocol ... You get the idea.

I'm thinking of also subclassing from Tensor, and creating an instance of that subclass whenever a Tensor instance is found. This will solve the above mentioned issues and simplify the code. So I'll give it a try when I can.

TL;DR: implementation issues lead to properties not allowed currently. Will solve it in the future with Tensor subclasses.

rentruewang commented 2 years ago

Tracking this with #18. I'll close this issue now.