while costs_argsort:
if subbatch_size == len(costs_argsort) or (
subbatch_size * costs[costs_argsort[subbatch_size]] > max_cost
):
subbatch_item_ids = costs_argsort[:subbatch_size]
yield subbatch_item_ids, *[
[items[i] for i in subbatch_item_ids] for items in data
]
costs_argsort = costs_argsort[subbatch_size:]
subbatch_size = 1
else:
subbatch_size += 1
I think I ran into some errors (though using the packages with the same version)
Probably happened in here: https://github.com/thomaslu2000/Incremental-Parsing-Representations/blob/main/src/benepar/subbatching.py#L33-L37