aws / fmeval

Foundation Model Evaluations Library
http://aws.github.io/fmeval
Apache License 2.0
187 stars 42 forks source link

String enums not comparing as expected #186

Closed athewsey closed 3 months ago

athewsey commented 7 months ago

Hi team!

I was surprised today when the following didn't work as expected:

from fmeval.eval import get_eval_algorithm
from fmeval.eval_algorithms import EvalAlgorithm

get_eval_algorithm(EvalAlgorithm.QA_ACCURACY)
# Throws: "EvalAlgorithmClientError: Unknown eval algorithm QA ACCURACY"

The reason it seems, as discussed here on StackOverflow is that Python string enums require an additional parent class for their string values to work in comparison - So at the moment:

print(EvalAlgorithm.QA_ACCURACY == "qa_accuracy")  # 'False'
print(EvalAlgorithm.QA_ACCURACY == "QA_ACCURACY")  # Also 'False'
print(EvalAlgorithm.QA_ACCURACY == "QA ACCURACY")  # Also 'False' (despite the error msg above!)
print(EvalAlgorithm.QA_ACCURACY) # 'QA ACCURACY' because of the __str__ method

I propose editing fmeval.eval_algorithms.EvalAlgorithm to inherit from (str, Enum) instead of (Enum) (no other changes needed). From my testing it wouldn't break your custom __str__ method, but would allow logical comparisons to work:

print(EvalAlgorithm.QA_ACCURACY == "qa_accuracy")  # 'True'
print(EvalAlgorithm.QA_ACCURACY) # Still 'QA ACCURACY' because of the __str__ method

If so, I think the same refactor should also be applied to fmeval.eval_algorithms.ModelTask and fmeval.reporting.constants.ListType?

danielezhu commented 7 months ago

Hi, thanks for catching this and for finding a fix! I'll track this issue internally so that we can fix it when someone has bandwidth.

athewsey commented 5 months ago

Seems like this is still not fixed as of v1.0.0 - any updates?

danielezhu commented 4 months ago

Hi, taking a second look at this issue, I don't believe that there is any unexpected behavior occurring. If you want to access the underlying string, you should be using the enum's value attribute. For example, get_eval_algorithm(EvalAlgorithm.QA_ACCURACY.value). Sorry for not noticing this earlier.