Open X02cinnamondirty opened 8 months ago
By default, the dataloading step casts target_vars
to torch.float32
.
You can overwrite this using thetransforms
argument. Try modifying it to:
transforms={"ohe_seq": lambda x: x.swapaxes(1, 2), "id_x": lambda x: torch.tensor(x, dtype=torch.long))}
but
My target var is int ,why this error happen?