IntelLabs / MART

Modular Adversarial Robustness Toolkit
BSD 3-Clause "New" or "Revised" License
17 stars 0 forks source link

Fix a Callable type annotation. #66

Closed mzweilin closed 1 year ago

mzweilin commented 1 year ago

What does this PR do?

Fixes #65

Type of change

Please check all relevant options.

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

Before submitting

Did you have fun?

Make sure you had fun coding 🙃

mzweilin commented 1 year ago

Running pytest -v locally (Python3.9.0) does catch the bug that is intentional in a01f346 (to be reverted later after we fix the test).

Python 3.9.15 (on other workstation) and 3.9.16 (on CI) tolerate the syntax error, that's why CI failed to catch it.

$ pytest -v
=========================================================================== test session starts ===========================================================================
platform linux -- Python 3.9.0, pytest-7.2.0, pluggy-1.0.0 -- /home/weilinxu/coder/MART/.venv/bin/python3.9
cachedir: .pytest_cache
rootdir: /home/weilinxu/coder/MART, configfile: pyproject.toml, testpaths: tests/
plugins: hydra-core-1.2.0, cov-4.0.0
collected 17 items / 16 errors                                                                                                                                            

================================================================================= ERRORS ==================================================================================
________________________________________________________________ ERROR collecting tests/test_adversary.py _________________________________________________________________
tests/test_adversary.py:13: in <module>
    import mart
mart/__init__.py:3: in <module>
    from mart import attack as attack
mart/attack/__init__.py:1: in <module>
    from .adversary import *
mart/attack/adversary.py:15: in <module>
    from .objective import Objective
mart/attack/objective/__init__.py:2: in <module>
    from .classification import *
mart/attack/objective/classification.py:25: in <module>
    class RandomTarget(Objective):
mart/attack/objective/classification.py:26: in RandomTarget
    def __init__(self, nb_classes: int, gain_fn: Callable[torch.Tensor, torch.Tensor]) -> None:
../../.pyenv/versions/3.9.0/lib/python3.9/typing.py:826: in __getitem__
    raise TypeError(f"Callable[args, result]: args must be a list."
E   TypeError: Callable[args, result]: args must be a list. Got <class 'torch.Tensor'>
lumurillo commented 1 year ago

I haven't been able to reproduce the issue. Can you provide more information about which conditions you got the issue? Which OS, Python version, etc.?

mzweilin commented 1 year ago

@lumurillo
Let's fix the syntax error first and try to improve testing later.