pytransitions / transitions

A lightweight, object-oriented finite state machine implementation in Python with many extensions
MIT License
5.49k stars 524 forks source link

AsyncMachine: Improve task management #474

Closed aleneum closed 3 years ago

aleneum commented 3 years ago

I created a branch dev-async-tasks to investigate how task (cancellation) management could be improved. By introducing AsyncMachine.protected_tasks (see commit 0f985e5) one could solve #465 by @AxelVoitier rather conveniently:

import asyncio
import logging
import time
from threading import Thread

from transitions.extensions.asyncio import AsyncMachine

logging.basicConfig(level=logging.DEBUG)
# logging.basicConfig(level=logging.INFO)
_logger = logging.getLogger(__name__)

class MyMachine(AsyncMachine):

    STATES = [
        'A',
        'B',
        'C',
        'D',
    ]

    TRANSITIONS = [
        dict(
            trigger='A_to_B',
            source='A',
            before='do_something',
            dest='B',
        ),
        dict(
            trigger='B_to_C',
            source='B',
            before='do_something',
            dest='C',
        ),
        dict(
            trigger='C_to_D',
            source='C',
            before='do_something',
            dest='D',
        ),
        dict(
            trigger='reset',
            source='*',
            before='do_something',
            dest='A',
        ),
    ]

    def __init__(self):
        super().__init__(
            states=self.STATES, transitions=self.TRANSITIONS,
            initial='A', auto_transitions=False,
        )

    async def do_something(self):
        _logger.info('Do something from state %s', self.state)
        await asyncio.sleep(0.2)
        _logger.info('Do something from state %s finished', self.state)

machine = MyMachine()
loop = None

async def aio_main():
    global loop
    loop = asyncio.get_event_loop()

    try:
        machine.protected_tasks.append(asyncio.current_task())
        while machine.state != 'D':
            if machine.state == 'A':
                await machine.A_to_B()
                continue

            if machine.state == 'C':
                await machine.C_to_D()
                continue

            await asyncio.sleep(0.1)

    except asyncio.CancelledError:
        _logger.info('Got cancelled')
        await machine.reset()

def external_trigger():
    time.sleep(1)

    async def call():
        _logger.info('Calling B_to_C')
        result = await machine.B_to_C()
        _logger.info('Trigger done (%s), state is %s', result, machine.state)
        assert result
        assert machine.state == 'C'

    fut = asyncio.run_coroutine_threadsafe(call(), loop)
    fut.result()

if __name__ == '__main__':
    external_trigger_thread = Thread(target=external_trigger)
    external_trigger_thread.start()
    asyncio.run(aio_main(), debug=True)
aleneum commented 3 years ago

closing this due to lack of feedback. protected_tasks has been merged into the release 0.8.4