tortoise / tortoise-orm

Familiar asyncio ORM for python, built with relations in mind
https://tortoise.github.io
Apache License 2.0
4.5k stars 369 forks source link

m2m_changed signal django analog #1406

Open Primobolancode opened 1 year ago

Primobolancode commented 1 year ago

I'm adapting my project from Django to fastapi + tortoise orm. All the logic is in the signals. My task is to catch the event of adding an m2m field. I'm looking for an m2m_changed analog, but I can't find it in tortoise orm signals. Are there any workarounds?

bdaene commented 7 months ago

Same problem for me, I want to update a list of players in a web page when a player is added to the tournament. And there is multiple page seeing the same tournament at the same time.

I would also need a way to remove a listener. Else we would have memory leaks or calls to removed objects.

Here is a workaround I build for my needs. Other signals should be managed the same way.

import asyncio
from typing import Callable, Optional, Tuple, Type, Dict, List

from tortoise import Tortoise, fields, run_async, ConfigurationError, BaseDBAsyncClient, models
from tortoise.models import MODEL
from tortoise.signals import Signals

# Should be in tortoise.fields.relational ##############################################################################

class ManyToManyRelation(fields.ManyToManyRelation[MODEL]):
    _listeners: Dict[Signals, Dict[Tuple[Type[MODEL], str], List[Callable]]] = {  # type: ignore
        Signals.pre_save: {},
        Signals.post_save: {},
        Signals.pre_delete: {},
        Signals.post_delete: {},
    }

    @classmethod
    def register_listener(cls, sender: Tuple[Type[MODEL], str], signal: Signals, listener: Callable):
        """
        Register listener to current model class for special Signal.

        :param sender: ManyToManyRelation model class and field name
        :param signal: one of tortoise.signals.Signals
        :param listener: callable listener

        :raises ConfigurationError: When listener is not callable
        """
        if not callable(listener):
            raise ConfigurationError("Signal listener must be callable!")
        cls_listeners = cls._listeners[signal].setdefault(sender, [])
        if listener not in cls_listeners:
            cls_listeners.append(listener)

    @classmethod
    def unregister_listener(cls, sender: Tuple[Type[MODEL], str], signal: Signals, listener: Callable):
        """
        Register listener to current model class for special Signal.

        :param sender: ManyToManyRelation model class and field name
        :param signal: one of tortoise.signals.Signals
        :param listener: callable listener
        """
        cls_listeners = cls._listeners[signal][sender]
        cls_listeners.remove(listener)

    async def _post_save(
            self,
            instances,
            using_db: Optional[BaseDBAsyncClient] = None,
    ) -> None:
        listeners = []
        sender = (self.instance.__class__, self.field.model_field_name)
        using_db = using_db or self.remote_model._meta.db
        cls_listeners = self._listeners.get(Signals.post_save, {}).get(sender, [])
        for listener in cls_listeners:
            listeners.append(listener(sender, self, instances, using_db))
        await asyncio.gather(*listeners)

    async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None) -> None:
        await super().add(*instances, using_db=using_db)
        await self._post_save(instances, using_db=using_db)

# End of tortoise.fields.relational workaround #########################################################################

# Should be in tortoise.models #########################################################################################

models.ManyToManyRelation = ManyToManyRelation

class Model(models.Model):
    @classmethod
    def unregister_listener(cls, signal: Signals, listener: Callable):
        """
        Register listener to current model class for special Signal.

        :param signal: one of tortoise.signals.Signals
        :param listener: callable listener
        """
        cls_listeners = cls._listeners[signal][cls]
        cls_listeners.remove(listener)

# End of tortoise.models workaround ####################################################################################

# Should be in tortoise.signals ########################################################################################

def post_save(*senders) -> Callable:
    """
    Register given models post_save signal.

    :param senders: Model class
    """

    def decorator(f):
        for sender in senders:
            if isinstance(sender, tuple):
                ManyToManyRelation.register_listener(sender, Signals.post_save, f)
            else:
                sender.register_listener(Signals.post_save, f)
        return f

    return decorator

# End of in tortoise.signals workaround ################################################################################

class Event(Model):
    name = fields.TextField()
    participants: fields.ManyToManyRelation["Team"] = fields.ManyToManyField("models.Team", related_name="events")

class Team(Model):
    name = fields.TextField()
    events: fields.ManyToManyRelation[Event]

@post_save(Event, Team, (Event, 'participants'))
async def save_signal_received(sender, instance, *args):
    print('Saved:', sender, instance, *args)

async def run():
    await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
    await Tortoise.generate_schemas()

    event = await Event.create(name="Event#1")  # Saved: Event event True database None
    team = await Team.create(name="Team#1")  # Saved: Team team True database None
    await event.participants.add(team)  # Saved: (Team, 'participants') event.participants (team,) database

    Team.unregister_listener(Signals.post_save, save_signal_received)
    ManyToManyRelation.unregister_listener((Event, 'participants'), Signals.post_save, save_signal_received)

    event = await Event.create(name="Event#2")  # Saved: Event event True database None
    team = await Team.create(name="Team#2")  # Nothing
    await event.participants.add(team)  # Nothing

if __name__ == "__main__":
    run_async(run())