Closed RasmusOrsoe closed 1 year ago
I am not sure I follow this: If I train a model with the PositionReconstruction
task, MSELoss
, and target=["position_x", "positon_y", "position_z"]
the training runs without errors and the predictions look sensible at a first glance. Is there a pressing need to use the custom scaling, PassOutput3
task, custom target label, etc.?
Describe the bug
We cannot reconstruct vertex position. This is due to two issues.
Issue one
_validate_and_set_transforms
inTask
only supports transforms that doesn't slice input. If one uses a transformthe function fails because of dimensions of the mock data.
Assuming that the target doesn't require slicing is an issue that I believe was introduced since #97 ; it effectively removed our ability to reconstruct vertex position, as this requires scaling of the input data.
Issue two Because we chose not to pass the entire graph object to loss functions, we must pass a single "target" to the Task, which means that targets such as direction and interaction vertex must be stored in a single field, i.e
graph['direction'] = torch.cat([dir_x, dir_y, dir_z])
. This becomes an issue inTask
, because https://github.com/graphnet-team/graphnet/blob/main/src/graphnet/models/task/task.py#L134:L136 turns the [1,3]-dimensional direction (or vertex) truth variable into a [batch_size, 1, 3]-dimensional vector, which results in slicing errors in transform functions likef(x)
mentioned above.To Reproduce Steps to reproduce the behavior:
examples
folderTask
toPassOutput3
with the following settings:transform_target = scale_XYZ, transform_inference = unscale_XYZ
Dataset.py
(https://github.com/graphnet-team/graphnet/blob/e619034ed36768e27426a11c0b41f28a97c5b1db/src/graphnet/data/dataset.py#L606) and add the following labeltarget = 'vertex'
Expected behavior
_validate_and_set_transforms
should not assume that the truth variable is single row, single column (making slicing fail)Task
should not change the dimensions of the truth variable from [batch_size, d] to [batch_size, 1, d], as this complicates the transform functions. (or at the very least we need to have a big, fat red sign somewhere, because people might make bad mistakes)Full traceback Please include the full error message to allow for debugging
Additional context I think we should reconsider the decision to not pass entire graph object to loss functions.
I was very surprised to see the [batch_size, 1, d] dimensions of the direction and vertex variables. I was unsure if this had any impact on the direction reconstructions that I made for northeren tracks, so I've gone back and re-run those trials to see.