minnervva / torchdetscan

This is a tool for finding non-deterministic functions in your pytorch code.
https://github.com/minnervva/torchdetscan
MIT License
1 stars 0 forks source link

Implementing simple AST-based linter #21

Closed markcoletti closed 6 months ago

markcoletti commented 7 months ago

Using an AST-based approach to parse a python file and look for non-deterministic violating code. This will also entail exercising this on some simple benchmarks, at first.

markcoletti commented 7 months ago

These are the target applications:

Optional is mala.

markcoletti commented 7 months ago

I’d like to incorporate a crude scan for the forbidden pytorch functions for tonight for Ada to show Prasanna tomorrow.

Something to look out for:

[torch.nn.functional.interpolate()](https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html#torch.nn.functional.interpolate) when attempting to differentiate a CUDA tensor and one of the following modes is used:
linear
bilinear
bicubic
trilinear

ast will allow access for that, but that means I have to have in place special handling for these use cases.

Then there’s this:

torch.Tensor.__getitem__() when attempting to differentiate a CPU tensor and the index is a list of tensors

getitem maps to an operator, if I’m not mistaken, so we’ll have to see if ast can catch that.

There are many other functions that are context dependent. I.e., only for certain arguments and also some for other invocation contexts. This does add some complexity, but we can do a crude matching for a first pass.

Progress:

/Users/may/Projects/minnervva/minnervva/venv/bin/python /Users/may/Projects/minnervva/minnervva/minnervva/minnervva.py -v . 
Linting directory: .
Linting file: basic_pytorch.py
Violations found in basic_pytorch.py:
  7: kthvalue
Done

However, it’s not picking up the second violation in that file that I directly copied from the documentation. I.e., torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward().

Just for giggles, maybe look into that walk() function as it may do a better job of finding these function calls. But, it’s late, so going to turn in. Hopefully I can make progress on this tomorrow.

markcoletti commented 6 months ago

It's working:

minnervva.py -v . 
Linting directory: /Users/may/Projects/minnervva/minnervva/tests
Linting file: basic_pytorch.py
Non-deterministic function found in line 11: kthvalue
Non-deterministic function found in line 13: AvgPool3d
Done

Well, it just does the simple pattern matching for one of the forbidden function names, not anything sophisticated. There are more clever things we should be doing, too, such as checking that certain functions have arguments that lead to non-deterministic behavior.

However, this is a good first cut to close this branch. I'll create separate more focused tasks for each of the functions that need special attention.