Open gigaverse-oz opened 1 week ago
Thanks for the bug report @gigaverse-oz! Can you provide some example code where you see this error? This simple flow works for me:
from prefect import flow, task
@task
def delay_task():
print("Hello")
@flow(log_prints=True)
def delay_flow():
delay_task.delay()
if __name__ == "__main__":
delay_flow()
so I suspect there are some additional variables at play causing this issue.
Hi @desertaxle, Thanks for the response.
The setup is as follows:
I have a pod with prefect server. I serve multiple flows in a different pods with PREFECT_URL_API="http://localhost:4200/api":
def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(None, description="URL of the recording, if available")
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}")
await topic_poll_task(flow_input=flow_input) # THIS WORKS, also topic_poll_task.submit(), but in the same worker.
await topic_poll_task.delay(flow_input=flow_input) # FAILS WITH MULTIPLE VARIATIONS (with await, without await) DOESNT WORK
return True
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())), name=topic_poll_flow.__name__
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10)
I copied the relevant code snippets and functions from multiple files. this is not a "working" code in a single file.
Thanks for the example @gigaverse-oz! Unfortunately, I wasn't able to reproduce the issue with your example.
Here's the code that I ran:
import asyncio
from datetime import datetime, timezone
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger
from pydantic import BaseModel, Field
logger = get_logger(__name__)
def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(
None, description="URL of the recording, if available"
)
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(
f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
)
topic_poll_task.delay(
flow_input=flow_input
)
return True
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(TopicPollFlowInput(**flow_input.model_dump())),
name=topic_poll_flow.__name__,
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10)
Maybe one of your other flows cannot be pickled? You can check to see if a flow is picklable like this:
import cloudpickle
print(cloudpickle.dumps(start_stream_flow))
Hi @desertaxle ,
The code you run is fine and work on my machine as well.
In practice my code is in multiple files, after long investigation I found the following file structure fails. Could you verify that on your machine as well?
BTW, using the cloudpickle.dumps()
on the flow in main always works. If it is used inside the first flow, it fails. It seems the serialization is very sensitive to the structure of the files and imports.
Many thanks
# pydantic_models.py
from datetime import datetime, timezone
from pydantic import BaseModel, Field
from typing import Optional
def get_current_iso_time():
"""
Returns the current time in ISO format with UTC timezone.
Returns:
datetime: The current datetime with timezone set to UTC.
"""
return datetime.now(timezone.utc)
class PrefectFlowInputBase(BaseModel):
"""
Base model for all Prefect flow input models.
Attributes:
timestamp (Optional[datetime]): The timestamp indicating when the
input was created. Defaults to the current UTC time.
"""
timestamp: Optional[datetime] = Field(default_factory=get_current_iso_time)
class EndStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
recording_url: Optional[str] = Field(
None, description="URL of the recording, if available"
)
class StartStreamFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class SnapshotFlowInput(PrefectFlowInputBase):
"""
Input model for a snapshot-related Prefect flow.
Attributes:
channel_id (str): The ID of the channel the snapshot flow should process.
"""
channel_id: str
class TopicPollFlowInput(PrefectFlowInputBase):
"""
Input model for a topic polling Prefect flow.
Attributes:
channel_id (str): The ID of the channel the topic poll flow should process.
"""
channel_id: str
# topic_poll_flow.py
from loguru import logger
from prefect import flow, task
from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput
@task(log_prints=False, persist_result=False)
async def topic_poll_task(flow_input: TopicPollFlowInput) -> bool:
"""
Orchestrates the topic poll workflow for a given channel, initializes the worker to
create a topic poll, and schedules a new task if a livestream is active.
Args:
channel_id (str): The ID of the channel for which the topic poll is being generated.
timestamp (datetime): The timestamp indicating when the topic poll flow started.
Returns:
bool: True if a new topic poll task was successfully scheduled; False otherwise.
"""
# SOME PROCESSING, WE NEVER REALLY GET HERE WHEN USING DELAY SO PUT WHAT EVER YOU WANT
return True
@flow(log_prints=False, persist_result=False)
async def topic_poll_flow(flow_input: TopicPollFlowInput) -> bool:
logger.info(
f"Starting topic poll flow: {flow_input.channel_id}. PID: {os.getpid()}"
)
topic_poll_task.delay(
flow_input=flow_input
)
return True
# main.py
import asyncio
import os
from typing import List, Optional
from prefect import Flow, flow, serve, task
from prefect.logging import get_logger
logger = get_logger(__name__)
from gv.ai.livestream_prefect_worker.pydantic_models import TopicPollFlowInput, StartStreamFlowInput
from gv.ai.livestream_prefect_worker.topic_poll_flow import topic_poll_task, topic_poll_flow
# from .pydantic_models import TopicPollFlowInput, StartStreamFlowInput # This import version also raises the error even if topic_poll_task, topic_poll_flow are in the main.py
@flow
async def start_stream_flow(flow_input: StartStreamFlowInput):
logger.info(f"Stream {flow_input.channel_id} started. {os.getpid()}")
list_of_flows = []
# for i in range(10):
list_of_flows.append(
asyncio.create_task(
topic_poll_flow(flow_input=TopicPollFlowInput(**flow_input.model_dump())),
name=topic_poll_flow.__name__,
)
)
# )
done, pending = await asyncio.wait(list_of_flows, timeout=600)
if pending:
raise Exception("Not all tasks are finished")
for task in done:
task: asyncio.Task = task
if task.exception():
logger.error(f"{task.get_name()} failed: {str(task.exception())}")
continue
logger.info(f"{task.get_name()} finished succesfully")
def serve_multiple_flows(list_of_flows: List[Flow], concurrent_limit: int = 10):
list_of_deployments = [flow.to_deployment(name=flow.name) for flow in list_of_flows]
serve(*list_of_deployments, limit=concurrent_limit)
if __name__ == "__main__":
list_of_served_flows = [start_stream_flow, topic_poll_flow]
serve_multiple_flows(list_of_served_flows, concurrent_limit=10)
Bug summary
I'm trying to launch a task using the delay() from a flow. When starting the task, the context to the tasks includes the flow which is not serializable object (by either pickle or json).
More debuging information in the additional context
Version info (
prefect version
output)Additional context
Debugged the error and found that the following field was failing the serialization:
context["flow_run_context"]["flow"]
which is of typeFlow
The real exception in
serialize_result()
:In order to double check that this is the only problematic field, changed the following in the EngineContext and it passed, but the task worker fails cause it needs the flow: