strawberry-graphql / strawberry-sqlalchemy

A SQLAlchemy Integration for strawberry-graphql
MIT License
91 stars 26 forks source link

Relationship fails when attribute name != column name #21

Open shish opened 1 year ago

shish commented 1 year ago

My model has a Survey class with owner_id attribute, which is using a different column name (user_id) for historic reasons

class User(Base):
    __tablename__ = "user"
    user_id: Mapped[int] = mapped_column("id", primary_key=True)
    username: Mapped[str]

class Survey(Base):
    __tablename__ = "survey"
    survey_id: Mapped[int] = mapped_column("id", primary_key=True)
    name: Mapped[str]
    owner_id: Mapped[int] = mapped_column("user_id", ForeignKey("user.id"))
    owner: Mapped[User] = relationship("User", backref="surveys", lazy=True)
import models

@strawberry_sqlalchemy_mapper.type(models.User)
class User:
    pass

@strawberry_sqlalchemy_mapper.type(models.Survey)
class Survey:
    pass

@strawberry.type
class Query:
    @strawberry.field
    def survey(self, info: Info, survey_id: int) -> typing.Optional[Survey]:
        db = info.context["db"]
        return db.execute(select(models.Survey).where(models.Survey.survey_id == survey_id)).scalars().first()

In relationship_resolver_for, the code tries to access getattr(self, sql_column_name) instead of getattr(self, python_attr_name)

query MyQuery {
  survey(surveyId: 1) {
    name
    owner {
      username
    }
  }
}
  File ".../strawberry_sqlalchemy_mapper/mapper.py", line 409, in <listcomp>
    getattr(self, local.key)
AttributeError: 'Survey' object has no attribute 'user_id'

Upvote & Fund

Fund with Polar

cpsnowden commented 1 year ago

@TimDumol , we ran into this issue ourselves and see errors from two places where the relationship value is resolved on the respective row using the sql_column_name rather than the python_attr_name

StrawberrySQLAlchemyLoader#loader_for

 def group_by_remote_key(row: Any) -> Tuple:
                    return tuple(
                        [
                            getattr(row, remote.key) <- uses sql_column_name
                            for _, remote in relationship.local_remote_pairs
                        ]
                    )

StrawberrySQLAlchemyMapper#relationship_resolver_for

 relationship_key = tuple(
                    [
                        getattr(self, local.key) <- uses sql_column_name
                        for local, _ in relationship.local_remote_pairs
                    ]
                )

We have a temporary work around by overriding the respective methods and building a column name to attribute name map from the respective relationship mapper but keen to have a central fix for this.

I'm happy to contribute a fix if we can agree an approach.

Example fix:

def build_get_col(mapper):
    attr_names = mapper.attr.keys()
    col_to_attr = {
        mapper.c[attr_name].name: attr_name for attr_name in attr_names if attr_name in mapper.c
    }
    def get_col(row: Any, col: str):
        attr = col_to_attr[col]
        return getattr(row, attr)
    return get_col

##StrawberrySQLAlchemyLoader
def loader_for(self, relationship: RelationshipProperty) -> DataLoader:
        """
        Retrieve or create a DataLoader for the given relationship
        """
        try:
            return self._loaders[relationship]
        except KeyError:
            related_model = relationship.entity.entity

            get_col = build_get_col(related_model.mapper) #get_col created here
            async def load_fn(keys: List[Tuple]) -> List[Any]:
                query = select(related_model).filter(
                    tuple_(
                        *[remote for _, remote in relationship.local_remote_pairs]
                    ).in_(keys)
                )
                if relationship.order_by:
                    query = query.order_by(*relationship.order_by)
                rows = self.bind.scalars(query).all()

                def group_by_remote_key(row: Any) -> Tuple:
                    return tuple(
                        [
                            get_col(row, remote.key)
                            for _, remote in relationship.local_remote_pairs
                        ]
                    )

                grouped_keys: Mapping[Tuple, List[Any]] = defaultdict(list)
                for row in rows:
                    grouped_keys[group_by_remote_key(row)].append(row)
                if relationship.uselist:
                    return [grouped_keys[key] for key in keys]
                else:
                    return [
                        grouped_keys[key][0] if grouped_keys[key] else None
                        for key in keys
                    ]

            self._loaders[relationship] = DataLoader(load_fn=load_fn)
            return self._loaders[relationship]

##StrawberrySQLAlchemyMapper
def relationship_resolver_for(
        self, relationship: RelationshipProperty
    ) -> Callable[..., Awaitable[Any]]:
        """
        Return an async field resolver for the given relationship,
        so as to avoid n+1 query problem.
        """
        get_col = build_get_col(relationship.parent) #get_col created here
        async def resolve(self, info: Info):
            instance_state = cast(InstanceState, inspect(self))
            if relationship.key not in instance_state.unloaded:
                related_objects = getattr(self, relationship.key)
            else:
                relationship_key = tuple(
                    [
                        get_col(self, local.key)
                        for local, _ in relationship.local_remote_pairs
                    ]
                )
                if any(item is None for item in relationship_key):
                    if relationship.uselist:
                        return []
                    else:
                        return None
                if isinstance(info.context, dict):
                    loader = info.context["sqlalchemy_loader"]
                else:
                    loader = info.context.sqlalchemy_loader
                related_objects = await loader.loader_for(relationship).load(
                    relationship_key
                )
            return related_objects

        setattr(resolve, _IS_GENERATED_RESOLVER_KEY, True)

        return resolve
TimDumol commented 1 year ago

Hi @cpsnowden - sorry totally forgot I assigned myself to this. Your proposed fix looks good to me. Feel free to PR it!

cpsnowden commented 1 year ago

Thanks @TimDumol - see that @gravy-jones-locker is addressing this in https://github.com/strawberry-graphql/strawberry-sqlalchemy/pull/25