IntelLabs / MART

Modular Adversarial Robustness Toolkit
BSD 3-Clause "New" or "Revised" License
16 stars 0 forks source link

Add mart.nn.Get() to extract a value from the kwargs dict. #251

Closed mzweilin closed 2 months ago

mzweilin commented 2 months ago

What does this PR do?

This PR adds mart.nn.Get(key) to get a value from **kwargs by key.

One use case is when model.training_step() returns a dictionary with the loss value, we can use mart.nn.Get(key) to fetch the loss in the Gain function.

Type of change

Please check all relevant options.

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

Before submitting

Did you have fun?

Make sure you had fun coding 🙃

mzweilin commented 2 months ago

I thought we could use the DotDict to access things within a dict?

Good point. We can enhance the capability of mart.nn.Get() using DotDict, in case a model returns a multi-level dicts that hide loss in depth. But I don't think we can directly use DotDict as gain_fn:

https://github.com/IntelLabs/MART/blob/07409b4b4a932efb8a70656cde2f5b3f3dc47098/mart/attack/adversary.py#L126