starsimhub / starsim

Starsim disease modeling framework
http://starsim.org
MIT License
15 stars 9 forks source link

Fix gotcha of incorrectly named method #600

Open cliffckerr opened 2 months ago

cliffckerr commented 2 months ago

Not sure if there's a Pythonic way to do this, but one thing that's really annoying is that a slightly mistyped class method will be silently ignored, e.g.:

class NastyCough(ss.Disease):

  def update(self):
    recovered = (self.infected & (self.ti_recovered <= self.sim.ti)).uids
    ...

But oops, it's called update_pre(). What about a decorator that raises an exception if the method isn't run at the expected time (e.g. init, run, post)?

class NastyCough(ss.Disease):

  @ss.expect_when('run')
  def update(self):
    recovered = (self.infected & (self.ti_recovered <= self.sim.ti)).uids
    ...

Then we just have to worry about people mistyping the method name and forgetting to add the decorator 🙃 Is this just too annoying / too much overhead?

cliffckerr commented 1 month ago

I wonder if with a decorator we could also get around the need to call super().every_method_name() -- I don't really want to add an extra line of boilerplate, but if we replace one line of boilerplate (super) with another (a decorator), it might be ok.

cliffckerr commented 1 month ago

Something like this? (Thanks ChatGPT!)

# Dictionary to store the call counts for methods
call_counts = {}

def call_counter(func, callclass='noclass'):
    def wrapper(*args, **kwargs):
        method_name = callclass + '_' + func.__name__
        if method_name not in call_counts:
            call_counts[method_name] = 0
        call_counts[method_name] += 1
        print(f"{method_name} has been called {call_counts[method_name]} times")
        return func(*args, **kwargs)

    return wrapper

class AutoDecorated:
    def __init_subclass__(cls, **kwargs):
        # Iterate over all class attributes
        for attr_name, attr_value in cls.__dict__.items():
            # If the attribute is a function, wrap it with call_counter
            if callable(attr_value) and not attr_name.startswith("__"):
                setattr(cls, attr_name, call_counter(attr_value, cls.__name__))
        super().__init_subclass__(**kwargs)

# Example class that will automatically have its methods decorated
class MyClass(AutoDecorated):
    def method_one(self):
        print("Executing method_one")

    def method_two(self):
        print("Executing method_two")

# Example usage
obj = MyClass()
obj.method_one()  # Output: "method_one has been called 1 times"
obj.method_one()  # Output: "method_one has been called 2 times"
obj.method_two()  # Output: "method_two has been called 1 times"

# Inspect the call_counts dictionary
print(call_counts)  # Output: {'method_one': 2, 'method_two': 1}
cliffckerr commented 1 month ago

Implemented here: https://github.com/starsimhub/starsim/tree/add-calldebug

cliffckerr commented 1 month ago

Re-implemented (sort of) here: https://github.com/sciris/sciris/issues/606

cliffckerr commented 1 month ago

Yes, I think a decorator is a good idea. I'm leaning towards ss.required_on(), which can also catch (or replace?) super() not being called. Maybe?

cliffckerr commented 1 month ago

Update: should work (thanks again ChatGPT!)

import inspect
import pytest

def immutable(method):
    method._immutable = True
    return method

def check_immutable_methods(cls):
    # Get all methods from the base classes that are marked as immutable
    immutable_methods = {}
    for base in inspect.getmro(cls)[1:]:  # Iterate through base classes, excluding cls itself
        for name, method in base.__dict__.items():
            if callable(method) and hasattr(method, '_immutable'):
                immutable_methods[name] = method

    # Check if any method in the derived class overrides an immutable method
    for name, method in cls.__dict__.items():
        if name in immutable_methods and callable(method):
            raise Exception(f"Method '{name}' in class '{cls.__name__}' is immutable and cannot be redefined")

    return cls

@check_immutable_methods
class MyClass:
    def foo(self):
        return 'foo'

    @immutable
    def bar(self):
        return 'bar'

@check_immutable_methods
class DerivedClass1(MyClass):
    def foo(self):
        return 'ok'

@check_immutable_methods
class DerivedClass2(MyClass):
    def bar(self):
        return 'not ok'

# Test cases
dc1 = DerivedClass1()
print(dc1.foo())  # Valid

# Test that DerivedClass2 raises an exception
with pytest.raises(Exception):
    dc2 = DerivedClass2()
    print(dc2.bar())
devclinton commented 1 month ago

I personally love decorators. I find theme a great way to add a lot of power quickly to code. I think this is the best approach without going something heavier like pluggy.

daniel-klein commented 1 month ago

very cool, this will be helpful for users like myself who often forget to rename a method.