Closed hvgazula closed 5 months ago
shape of labels coming from tfrecords is already (*volume_shape, n_classes)
. see below
<_TakeDataset element_spec=(TensorSpec(shape=(1, 256, 256, 256, 1), dtype=tf.float32, name=None), TensorSpec(shape=(1, 256, 256, 256, 2), dtype=tf.float32, name=None))>
doing OHE in map_labels again leads to a shape mismatch (an extra dimension is appended, see below).
<_TakeDataset element_spec=(TensorSpec(shape=(1, 256, 256, 256, 1), dtype=tf.float32, name=None), TensorSpec(shape=(1, 256, 256, 256, 2, 2), dtype=tf.float32, name=None))>
Solution: The shape of labels should be (*volume_shape, 1)` until the dataset construction at which point, the n_classes should be factored in.
TODO: This should help address test for n_classes > 1
Reason: map_labels
is being called twice.
if we have to pass on the 6/50 class label mapping (dict
), it will become
dataset_train.map_labels(label_mapping).shuffle(NUM_GPUS)..
but this mapping is already happening once inside (of course without the label_mapping
. We need to pass the label mapping dictionary in a cleaner manner. Where do you think we should do this?
Proposal: add argument label_mapping: Dict = None
to both from_files
and from_tfrecords
in Dataset
.
For reference: see https://github.com/neuronets/nobrainer_training_scripts/blob/main/1.2.0/label_mapping.py
@satra 👍 or 👎 ?
Go for it and see how it plays out
On Tue, Apr 2, 2024, 5:25 PM H Gazula @.***> wrote:
Proposal: add argument label_mapping: Dict = None to both from_files and from_tfrecords in Dataset.
For reference: see https://github.com/neuronets/nobrainer_training_scripts/blob/main/1.2.0/label_mapping.py
@satra https://github.com/satra 👍 or 👎 ?
— Reply to this email directly, view it on GitHub https://github.com/neuronets/nobrainer/issues/317#issuecomment-2033126314, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABM574ZZFGX3N4NWFYXTADY3MO5NAVCNFSM6AAAAABFTUX7MSVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMZTGEZDMMZRGQ . You are receiving this because you were mentioned.Message ID: @.***>