google-research / vision_transformer

Apache License 2.0
10.52k stars 1.3k forks source link

How can I extend this to multi-label classification? #264

Open JeanHung opened 1 year ago

andsteing commented 1 year ago

You would need to extend the script in two places:

First, use something like

label = tf.scatter_nd(labels[:, None], tf.ones(tf.shape(data['label'])[0]), (num_classes,))

to get multiple labels in one-hot format in the input processing

https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/input_pipeline.py#L215

Second, use a sigmoid loss instead of the cross-entropy loss here:

https://github.com/google-research/vision_transformer/blob/85c4f53febd929c43e70e8ff598f9f00d52948b7/vit_jax/train.py#L52-L58

JeanHung commented 1 year ago

Thank you so much!