non-tensor data are proper leaves in the tensorclass
data = MyData(X=X, y=y, z="a string", batch_size=batch_size)
assert "z" in data.keys() # used to break
non-tensor data will be compared if a tensorclass is compared to another tc / td
# Previously
z = "a striing!"
tensordict = TensorDict(
{
"X": X,
"y": y,
},
batch_size=[3, 4],
)
data = MyData(X=X, y=y, z=z, batch_size=batch_size)
assert (tensordict == data).all()
# Now z needs to be part of the tensordict as it won't be ignored during comparison
tensordict = TensorDict(
{
"X": X,
"y": y,
},
batch_size=[3, 4],
)
Non-tensor data following comparison is not None
data0 = MyData(X=X, y=y, z="a string", batch_size=batch_size)
data1 = MyData(X=X, y=y, z="another string", batch_size=batch_size)
(data0 == data1).z # used to be None, now a TD with boolean values
when setting non-tensor values in-place will now return a ValueError, not RuntimeError
This now works BUT it will convert any NonTensorData in data in a NonTensorStack (since values depend on their location in the batch):
data0 = MyData(X=X, y=y, z="a string", batch_size=batch_size)
data1 = MyData(X=X, y=y, z="another string", batch_size=batch_size)
data0[:2] = data1[:2]
data0.z # used to be a string, bc ignored by __setitem__, now a list
This is bc-breaking in the following ways:
non-tensor data are proper leaves in the tensorclass
non-tensor data will be compared if a tensorclass is compared to another tc / td
Non-tensor data following comparison is not None
when setting non-tensor values in-place will now return a
ValueError
, notRuntimeError
This now works BUT it will convert any
NonTensorData
indata
in aNonTensorStack
(since values depend on their location in the batch):