FactoryBoy / factory_boy

A test fixtures replacement for Python
https://factoryboy.readthedocs.io/
MIT License
3.52k stars 397 forks source link

How to attach RelatedFactoryList result to instance? #1092

Open albertalexandrov opened 1 month ago

albertalexandrov commented 1 month ago

Hi!

I have a question about using RelatedFactoryList in async SQLAlchemy. RelatedFactoryList creates instances but they are not attached to instance.

overridden for async base factory (from discussions in this repository):

import inspect

from factory.alchemy import SESSION_PERSISTENCE_COMMIT, SESSION_PERSISTENCE_FLUSH, SQLAlchemyModelFactory
from factory.base import FactoryOptions
from factory.builder import StepBuilder, BuildStep, parse_declarations
from factory import FactoryError, RelatedFactoryList, CREATE_STRATEGY
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, NoResultFound

def use_postgeneration_results(self, step, instance, results):
    return self.factory._after_postgeneration(
        instance,
        create=step.builder.strategy == CREATE_STRATEGY,
        results=results,
    )

FactoryOptions.use_postgeneration_results = use_postgeneration_results

class SQLAlchemyFactory(SQLAlchemyModelFactory):
    @classmethod
    async def _generate(cls, strategy, params):
        if cls._meta.abstract:
            raise FactoryError(
                "Cannot generate instances of abstract factory %(f)s; "
                "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
                "is either not set or False." % dict(f=cls.__name__)
            )

        step = AsyncStepBuilder(cls._meta, params, strategy)
        return await step.build()

    @classmethod
    async def _create(cls, model_class, *args, **kwargs):
        for key, value in kwargs.items():
            if inspect.isawaitable(value):
                kwargs[key] = await value
        return await super()._create(model_class, *args, **kwargs)

    @classmethod
    async def create_batch(cls, size, **kwargs):
        return [await cls.create(**kwargs) for _ in range(size)]

    @classmethod
    async def _save(cls, model_class, session, args, kwargs):
        session_persistence = cls._meta.sqlalchemy_session_persistence
        obj = model_class(*args, **kwargs)
        session.add(obj)
        if session_persistence == SESSION_PERSISTENCE_FLUSH:
            await session.flush()
        elif session_persistence == SESSION_PERSISTENCE_COMMIT:
            await session.commit()
        return obj

    @classmethod
    async def _get_or_create(cls, model_class, session, args, kwargs):
        key_fields = {}
        for field in cls._meta.sqlalchemy_get_or_create:
            if field not in kwargs:
                raise FactoryError(
                    "sqlalchemy_get_or_create - "
                    "Unable to find initialization value for '%s' in factory %s" % (field, cls.__name__)
                )
            key_fields[field] = kwargs.pop(field)

        obj = (await session.execute(select(model_class).filter_by(*args, **key_fields))).scalars().one_or_none()

        if not obj:
            try:
                obj = await cls._save(model_class, session, args, {**key_fields, **kwargs})
            except IntegrityError as e:
                session.rollback()

                if cls._original_params is None:
                    raise e

                get_or_create_params = {
                    lookup: value
                    for lookup, value in cls._original_params.items()
                    if lookup in cls._meta.sqlalchemy_get_or_create
                }
                if get_or_create_params:
                    try:
                        obj = (
                            (await session.execute(select(model_class).filter_by(**get_or_create_params)))
                            .scalars()
                            .one()
                        )
                    except NoResultFound:
                        # Original params are not a valid lookup and triggered a create(),
                        # that resulted in an IntegrityError.
                        raise e
                else:
                    raise e

        return obj

class AsyncStepBuilder(StepBuilder):
    # Redefine build function that await for instance creation and awaitable postgenerations
    async def build(self, parent_step=None, force_sequence=None):
        """Build a factory instance."""
        # TODO: Handle "batch build" natively
        pre, post = parse_declarations(
            self.extras,
            base_pre=self.factory_meta.pre_declarations,
            base_post=self.factory_meta.post_declarations,
        )

        if force_sequence is not None:
            sequence = force_sequence
        elif self.force_init_sequence is not None:
            sequence = self.force_init_sequence
        else:
            sequence = self.factory_meta.next_sequence()

        step = BuildStep(
            builder=self,
            sequence=sequence,
            parent_step=parent_step,
        )
        step.resolve(pre)

        args, kwargs = self.factory_meta.prepare_arguments(step.attributes)

        instance = await self.factory_meta.instantiate(
            step=step,
            args=args,
            kwargs=kwargs,
        )
        postgen_results = {}
        for declaration_name in post.sorted():
            declaration = post[declaration_name]
            declaration_result = declaration.declaration.evaluate_post(
                instance=instance,
                step=step,
                overrides=declaration.context,
            )
            if inspect.isawaitable(declaration_result):
                declaration_result = await declaration_result
            if isinstance(declaration.declaration, RelatedFactoryList):
                for idx, item in enumerate(declaration_result):
                    if inspect.isawaitable(item):
                        declaration_result[idx] = await item
            postgen_results[declaration_name] = declaration_result
        postgen = self.factory_meta.use_postgeneration_results(
            instance=instance,
            step=step,
            results=postgen_results,
        )
        if inspect.isawaitable(postgen):
            await postgen
        return instance

models.py

class TtzFile(Base):
    """Модель файла ТТЗ."""

    __tablename__ = "ttz_files"
    __mapper_args__ = {"eager_defaults": True}

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    ttz_id: Mapped[int] = mapped_column(ForeignKey("ttz.id"))
    attachment_id: Mapped[UUID] = mapped_column()
    ttz: Mapped["Ttz"] = relationship(back_populates="files")

class Ttz(Base):
    __tablename__ = "ttz"

    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    name: Mapped[str] = mapped_column(String(250))
    files: Mapped[list["TtzFile"]] = relationship(cascade="all, delete-orphan", back_populates="ttz")

factories.py

class TtzFactory(SQLAlchemyFactory):
    name = Sequence(lambda n: f"ТТЗ {n + 1}")
    start_date = FuzzyDate(parse_date("2024-02-23"))
    is_deleted = False
    output_message = None
    input_message = None
    error_output_message = None
    files = RelatedFactoryList("tests.factories.ttz.TtzFileFactory", 'ttz', 2)

    class Meta:
        model = Ttz
        sqlalchemy_get_or_create = ["name"]
        sqlalchemy_session_factory = Session
        sqlalchemy_session_persistence = SESSION_PERSISTENCE_FLUSH

    @classmethod
    def _after_postgeneration(cls, instance, create, results=None):
        session = cls._meta.sqlalchemy_session_factory()
        return session.refresh(instance, attribute_names=["files"])

class TtzFileFactory(SQLAlchemyFactory):
    ttz = SubFactory(TtzFactory)
    file_name = Faker("file_name")
    attachment_id = FuzzyUuid()

    class Meta:
        model = TtzFile
        sqlalchemy_get_or_create = ["attachment_id"]
        sqlalchemy_session_factory = Session
        sqlalchemy_session_persistence = SESSION_PERSISTENCE_FLUSH

To make it available to get Ttz.files I have do refresh:

@classmethod
def _after_postgeneration(cls, instance, create, results=None):
    session = cls._meta.sqlalchemy_session_factory()
    return session.refresh(instance, attribute_names=["files"])

My question is it is the only way to get Ttz.files? I mean do I have to write _after_postgeneration method in each factory where I need to get related list?

rbarrois commented 1 month ago

Thanks for providing the full code example.

It is, however, quite complex to read without prior knowledge of your project.

By default, with a RelatedFactoryList, the behaviour is akin to:

ttz = Ttz(name="TTZ 1")
session.add(ttz)
for i in range(2):
  session.add(TtzFile(ttz=ttz, file_name="some_file_name", attachment_id=SomeUUID()))

How would you write that piece of code without factories in order to get the files attribute populated?

albertalexandrov commented 1 month ago

Hi, @rbarrois !

I would write like this:

files = []

for i in range(2):
    file = TtzFile(file_name="some_file_name", attachment_id=SomeUUID())
    files.append(file) 

ttz = Ttz(name="TTZ 1", files=files)
session.add(ttz)

As far as I now factory boy first creates main object and then related list.

rbarrois commented 1 month ago

Your snippet wouldn't work, the ttz is not created beforehand!

However, if that's the way you'd write it, I suggest using a factory.List and a factory.SubFactory:

class FileFactory:
  ...

class TtzFactory:
  files = factory.List([
    factory.SubFactory(FileFactory),
    factory.SubFactory(FileFactory),
  ])

This might work, instantiating the two File objects before attaching them.

albertalexandrov commented 1 month ago

There was a mistake (copy paste). I fixed.

Does SubFactory(FileFictory) return a stub object? As you can see TtzFile cannot be created without Ttz.

Sorry, I can't check it because I don't have access to my computer. Well I ll try in a week.

rbarrois commented 1 month ago

Thanks! Can you try the approach I suggested above, i.e a list of subfactories instead of a RelatedFactoryList?

albertalexandrov commented 1 month ago

I'll try later in a week when I reach my computer. Thanks.