Closed madhavajay closed 2 years ago
Few questions before I want to dive in
SyftTensor(FloatTensor(DataTensor(torch.FloatTensor([1,2,3]))))
I think it makes sense to write this down for all tensor types before implementing.So for a FloatTensor that would look something like this, is that the direction you were thinking of.
from syft.decorators import syft_decorator
import syft
import torch
from typing import Union
class DataTensor():
@syft_decorator(typechecking=True)
def __init__(self, child: Union[torch.FloatTensor, torch.IntTensor]):
self.child=child
def __add__(self, other):
return DataTensor(child=self.child + other.child)
class FloatTensor():
@syft_decorator(typechecking=True)
def __init__(self, child: DataTensor):
self.child=child
def __add__(self, other):
return FloatTensor(child=self.child + other.child)
class IntegerTensor():
@syft_decorator(typechecking=True)
def __init__(self, child: DataTensor):
self.child=child
def __add__(self, other):
return FloatTensor(child=self.child + other.child)
class SyftTensor():
def __init__(self, child: Union[FloatTensor, IntegerTensor]):
self.child = child
def __add__(self, other):
return SyftTensor(child=self.child + other.child)
@classmethod
def FloatTensor(cls, data):
if isinstance(data, list):
return cls(child=FloatTensor(child=DataTensor(child=torch.FloatTensor(data))))
# test
def get_children_types(t, l=None):
l = [] if l is None else l
if hasattr(t, "child"): return l + [type(t)] + get_children_types(t.child, l)
else: return l + [type(t)]
t = SyftTensor.FloatTensor([1,2,3])
assert get_children_types(t) == [SyftTensor, FloatTensor, DataTensor, torch.Tensor]
t2 = SyftTensor.FloatTensor([4,5,6])
t3 = t + t2
assert all(t3.child.child.child.numpy() == [5.0, 7.0, 9.0])
Hey - it looks good!
Some nitpicking stuff:
1) yes I agree that might be a better name 2) yes both could be possible, I don't have a good enough overview of the rest of the code and constraints to assess that currently
Hmm, thinking about it again, if we would use "tensor" as name, we would be using it as tensor.tensor, which is a bit confusing.
Description
Map out the initial basic custom tensor classes and their relationship with each other.
SyftTensor
-> InterfaceDataTensor
-> Framework Specific like TorchTensorFixedPrecisionTensor
IntegerTensor
FloatTensor
AutogradTensor
self.children -> Can only be Decimal / FloatTensorHETensor
-> Can only be IntegerTensorDefinition of Done
Basic hierarchy in place with some basic ops and examples of nested configurations with tests.