Open JeanHung opened 1 year ago
You would need to extend the script in two places:
First, use something like
label = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(data['label'])[0]), (num_classes,))
to get multiple labels in one-hot format in the input processing
https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/input_pipeline.py#L215
Second, use a sigmoid loss instead of the cross-entropy loss here:
https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/train.py#L52-L58
Thank you so much!
You would need to extend the script in two places:
First, use something like
to get multiple labels in one-hot format in the input processing
https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/input_pipeline.py#L215
Second, use a sigmoid loss instead of the cross-entropy loss here:
https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/train.py#L52-L58