ToluClassics / mlx-transformers

MLX Transformers is a library that provides model implementation in MLX. It uses a similar model interface as HuggingFace Transformers and provides a way to load and run models in Apple Silicon devices.
Apache License 2.0
47 stars 4 forks source link

Implemented `RobertaForSequenceClassification` #1

Closed Seun-Ajayi closed 4 months ago

Seun-Ajayi commented 4 months ago

Proposed changes

Implemented RobertaForSequenceClassification in the roberta.py module. Added RobertaClassificationHead along the line and added SequenceClassifierOutput class to the modelling_outputs.py script since it is the output for the sequence classifier.

Types of changes

What types of changes does your code introduce? Put an x in the boxes that apply

For multi_label_classification problem type, BCEWithLogitsLoss is used but this is not present in mlx-core module

            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

so I had to use binary_cross_entropy

            elif self.config.problem_type == "multi_label_classification":
                loss_fct = nn.losses.binary_cross_entropy()
                loss = loss_fct(logits, labels)

How else could I have handled this?

ToluClassics commented 4 months ago

Can you please add a test case?

Seun-Ajayi commented 4 months ago

Implemented RobertaForTokenClassification and RobertaForQuestionAnswering too.

CrossEntropyLoss method takes ignore_index as an argument but this parameter is not included in mlx's version of it, so I just did without the parameter. This is in the RobertaForQuestionAnswering class

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)

My workaround

            loss_fct = nn.losses.cross_entropy()
Seun-Ajayi commented 4 months ago

Can you please add a test case?

I would

Seun-Ajayi commented 4 months ago

We are good now

ToluClassics commented 4 months ago

We are good now

Thank you!