aws / fmeval

Foundation Model Evaluations Library
http://aws.github.io/fmeval
Apache License 2.0
155 stars 40 forks source link

fix: replace add_column with map in _generate_prompt_column #161

Closed danielezhu closed 6 months ago

danielezhu commented 6 months ago

Issue #, if available:

Description of changes: This PR replaces the last instance of dataset.add_column (which I somehow missed when submitting PR #115) with dataset.map. Note that this PR is not necessary for correctness, but simply to make our code uniform (since it's kind of weird to have just one function that uses add_column while all of them use map).

Note: a question that likely arises is "How come the integration test that was added to test_factual_knowledge.py (which was supposed to validate the fix) was able to pass, given that _generate_prompt_column was still using add_column, which should've caused inconsistencies in the batch formats in the Ray task graph?"

The reason the test passed despite me not replacing every instance of add_column is that we don't strictly need to replace every instance; only calls to add_column that occur immediately before a dataset aggregation operation (for example, dataset.mean) need to be replaced with map. I should've explained this more clearly in my description for PR #115.

If you recall, errors like 'DataFrame' object has no attribute 'num_columns' and 'pyarrow.lib.Table' object has no attribute 'reset_index' occur when reducing mapped outputs (see this PR). If you drill deeper, you will see that only the batch format of the mapped outputs that are being directly consumed by the reduction/aggregation operation matters.

Thus, it is perfectly valid to do the following:

def identity(batch):
    return batch

xs = list(range(100))
ds = ray.data.from_items([{"A": (x % 3), "B": x} for x in xs]).repartition(num_parts)
grouped_ds = (
    ds.groupby("A")
    .map_groups(identity, batch_format="pandas") # first introduce the pandas batch format
    .map_batches(identity) # but then go back to using the default format
)

agg_ds = grouped_ds.groupby("A").max("B")
assert agg_ds.count() == 3
assert list(agg_ds.sort("A").iter_rows()) == [
    {"A": 0, "max(B)": 99},
    {"A": 1, "max(B)": 97},
    {"A": 2, "max(B)": 98},
]

Since PR #115 replaced all of the calls to add_column that occur immediately before aggregate_evaluation_scores (which calls dataset.map), it successfully got rid of the root cause of the 'DataFrame' object has no attribute 'num_columns' and 'pyarrow.lib.Table' object has no attribute 'reset_index' errors.

Replacing all other calls to add_column (i.e. the one that I missed in _generate_prompt_column) is purely a matter of style and not correctness.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.