Closed aaronenyeshi closed 2 years ago
These are the 19 failing tests without an implemention of get_module()
:
For reference: https://app.circleci.com/pipelines/github/pytorch/benchmark/2724/workflows/179ce0f4-0754-4c4d-8977-c46c278f921e/jobs/2807
After some thinking, I believe get_module()
should be an optional interface because in the future we may add benchmarks whose owners don't want to implement this interface and the stakeholder of this interface may not be interested in supporting it (if they do have interest, they can add them incrementally). What do you think, @jansel ?
I think we need get_module()
and should implement it for all models. get_module() is the only way to 1) check the correctness of transformations, 2) apply custom transformations to models
I think we need
get_module()
and should implement it for all models. get_module() is the only way to 1) check the correctness of transformations,
Can you elaborate on how to use get_module()
to check the correctness of transformations? For example, if there are three nn.module
networks are used in train()
or eval()
and all of them are transformed, how to guarantee the correctness via get_module()
? I think instead we should return the output tensor from train()
or eval()
to cross-validate the correctness. I think @wconstab has similar proposals.
2) apply custom transformations to models
Sure it can be used for experimental purpose, but is it of people's interest to test the custom transformations on all the models?
Here are a couple examples of get_module()
being used to check correctness of custom transformations:
Accomplishing the same thing with train()
/eval()
would not be possible with the current design. Especially the custom sync logic in the first example.
"is it of people's interest to test the custom transformations on all the models?"
Yes this is the primary thing I use torchbench for.
Here are a couple examples of
get_module()
being used to check correctness of custom transformations:
- Add lazy_bench.py to measure trace overhead and compute efficiency pytorch#68563
- https://github.com/jansel/torchdynamo/blob/main/torchbench.py
Accomplishing the same thing with
train()
/eval()
would not be possible with the current design. Especially the custom sync logic in the first example."is it of people's interest to test the custom transformations on all the models?"
Yes this is the primary thing I use torchbench for.
Thank you for the explanation. Could you please help add get_module()
interface to the models that are missing (listed by @aaronenyeshi in the comment above)? After that, we could fix a unit test that makes sure every model needs to implement this interface.
Also, the implementation of _set_mode()
is incorrect because the module used for some models are not the same for train and inference. For example: get_module()
always return self.model
but the model used in eval is self.eval_model
. https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/timm_resnest/__init__.py#L71
The name of get_module()
is unclear whether the model it needs is for train or inference.
Also, the implementation of
_set_mode()
is incorrect because the module used for some models are not the same for train and inference. For example:get_module()
always returnself.model
but the model used in eval isself.eval_model
. https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/timm_resnest/__init__.py#L71The name of
get_module()
is unclear whether the model it needs is for train or inference.
You're right, we should address that in another patch as well.
When a model in TorchBenchmark raises
NotImplementedError
in theget_module()
function, it will skip all the unit tests!This happens because test.py will accept and skip the exception, NotImplementedError. These are the reasons why every test is skipped:
model.get_module()
, which may raiseNotImplementedError
.set_train()
andset_eval()
respectively before running the methodset_train/eval()
will call_set_mode()
, where(model, _) = self.get_module()
is called and may raiseNotImplementedError
.For 1 and 2, it may make sense to skip them since their functionality requires
get_module()
. However for 3 and 4, we may choose:_set_mode()
to something that works (or pass)._set_mode()
and directly run train and eval.