approx-ml / approx

Automatic quantization library
https://approx-ml.github.io/approx/
Apache License 2.0
11 stars 1 forks source link

`approx.compare()` API #20

Closed bushshrub closed 2 years ago

bushshrub commented 2 years ago

Describe the solution you'd like

Public API approx.compare

approx.compare(original_model, quantized_model, *, eval_loop: Optional[Callable], metrics: Optional[list[Metric]])

class Metric(enum.Enum):
    LOSS = 0
    ACCURACY = 1

Where eval_loop is some function taking model and returning list[list[float]] or tuple[list[list[float]], list[list[float]], depending on what metrics are enabled.

Example eval loop:

def eval(model: torch.nn.Module) -> list[list[float]]:
    loss_fn = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(net.parameters(), lr=1e-4)
    epochs = 10
    loss_history: list[list[float]] = [[]]
    # ...
    return loss_history

Eval loop can be automatically generated too (for now, only PyTorch users).

Either eval_loop or metrics must be specified.

Additional context

We can also use inspect to look at the eval loop and ensure it is valid before running the eval loop and failing. For example, an eval loop that is some partial function or has default parameters can also be used.

Obviously, compare will run with the selected backend and device. Work may be done in the future to add support for setting a different device, but for now, batteries included.

cc @sudomaze

bushshrub commented 2 years ago

Resolved by #29