I created a TypeTransformer for a class (ParseableDocument) so that I could manually set its hash function (instead of relying on the default dataclass hash). In doing so, I observed that my task which returns ParseableDocuments was always running (i.e., always had a cache miss) despite its inputs not changing.
In testing this out I discovered that my issue was an incorrect get_literal_type() function on my TypeTransformer. After fixing that, caching works properly.
Code included below can be run multiple times with the same inputs, and will rerun download_document() every time despite saying it is writing to the cache.
Expected behavior
Should provide a warning or error stating that get_literal_type() output does not match the actual Literal returned, or otherwise note that cache could not be properly written to.
Additional context to reproduce
import datetime
from flytekit import Secret, current_context, task, workflow, map_task
from flytekit.types.file import FlyteFile
from flytekit.extend import TypeTransformer, TypeEngine
from flytekit.core.context_manager import FlyteContext
from flytekit.models.types import LiteralType, SimpleType
from flytekit.models.literals import Literal, LiteralMap, Primitive, Scalar
from flytekit.types.file.file import FlyteFilePathTransformer
from typing import Type
from dataclasses import dataclass
from dataclasses_json import DataClassJsonMixin, dataclass_json
## Define DownloadableDocument with @dataclass
@dataclass_json
@dataclass
class DownloadableDocument(DataClassJsonMixin):
document_id: int
url: str
def __init__(self, document_id: int, url: str):
self.document_id = document_id
self.url = url
## Define ParseableDocument without @dataclass, set up a Transformer
## so we can manually modify how the hash is computed (though we are
## not doing so in this example)
class ParseableDocument:
document_id: int
url: str
file: FlyteFile
def __init__(self, document_id: int, url: str, file: FlyteFile):
self.document_id = document_id
self.url = url
self.file = file
class ParseableDocumentTransformer(TypeTransformer[ParseableDocument]): # type: ignore
def __init__(self) -> None:
super(ParseableDocumentTransformer, self).__init__(
name="parseable-document-transform", t=ParseableDocument
)
def get_literal_type(self, t: Type[ParseableDocument]) -> LiteralType:
# Correct return value commented out to demonstrate strange caching behavior
# variants: list[LiteralType] = [
# LiteralType(simple=SimpleType.STRING),
# FlyteFilePathTransformer().get_literal_type(os.PathLike)
# ]
# return LiteralType(map_value_type=LiteralType(union_type=UnionType(variants=variants)))
return LiteralType(map_value_type=LiteralType(simple=SimpleType.STRING))
def to_literal(
self,
ctx: FlyteContext,
python_val: ParseableDocument,
python_type: Type[ParseableDocument],
expected: LiteralType,
) -> Literal:
"""
This method is used to convert from the given python type object ``ParseableDocument`` to the Literal representation.
"""
lm = LiteralMap(
{
"document_id": Literal(
scalar=Scalar(
primitive=Primitive(string_value=str(python_val.document_id))
)
),
"url": Literal(
scalar=Scalar(primitive=Primitive(string_value=python_val.url))
),
"file": FlyteFilePathTransformer().to_literal(
ctx, python_val.file, FlyteFile, expected
),
}
)
return Literal(map=lm)
def to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[ParseableDocument]
) -> ParseableDocument:
"""
In this method, we want to be able to re-hydrate the custom object from Flyte Literal value.
"""
doc = ParseableDocument(
document_id=int(lv.map.literals["document_id"].scalar.primitive.value),
url=lv.map.literals["url"].scalar.primitive.value,
file=FlyteFilePathTransformer().to_python_value(
ctx, lv.map.literals["file"], FlyteFile
),
)
return doc
TypeEngine.register(ParseableDocumentTransformer())
@workflow
def cache_bug_example_transformer_minimal(
kickoff_time: datetime.datetime,
documents_to_create: int,
) -> list[ParseableDocument]:
documents: list[DownloadableDocument] = get_documents_to_download(documents_to_create=documents_to_create)
parseable_documents: list[ParseableDocument] = map_task(download_document)(document=documents)
return parseable_documents
@task
def get_documents_to_download(documents_to_create: int) -> list[DownloadableDocument]:
assert documents_to_create < 100
docs: list[DownloadableDocument] = []
for i in range(documents_to_create):
docs.append(DownloadableDocument(i, "dummy/location"))
return docs
@task(
cache=True,
cache_version="1.0",
)
def download_document(document: DownloadableDocument) -> ParseableDocument:
print("Running download for document with id {}".format(document.document_id))
parseable_document = ParseableDocument(document.document_id, document.url, FlyteFile("http://www.google.com"))
return parseable_document
Screenshots
No response
Are you sure this issue hasn't been raised already?
Describe the bug
Slack thread where I provided a bit of context on this issue, though in digging in I found the root cause of my issue: https://flyte-org.slack.com/archives/CP2HDHKE1/p1699461676401929
I created a
TypeTransformer
for a class (ParseableDocument
) so that I could manually set its hash function (instead of relying on the defaultdataclass
hash). In doing so, I observed that my task which returnsParseableDocument
s was always running (i.e., always had a cache miss) despite its inputs not changing.In testing this out I discovered that my issue was an incorrect
get_literal_type()
function on myTypeTransformer
. After fixing that, caching works properly.Code included below can be run multiple times with the same inputs, and will rerun
download_document()
every time despite saying it is writing to the cache.Expected behavior
Should provide a warning or error stating that
get_literal_type()
output does not match the actualLiteral
returned, or otherwise note that cache could not be properly written to.Additional context to reproduce
Screenshots
No response
Are you sure this issue hasn't been raised already?
Have you read the Code of Conduct?