fastapi / sqlmodel

SQL databases in Python, designed for simplicity, compatibility, and robustness.
https://sqlmodel.tiangolo.com/
MIT License
14.07k stars 626 forks source link

How do you define polymorphic models similar to the sqlalchemy ones? #36

Open ludokriss opened 3 years ago

ludokriss commented 3 years ago

First Check

Commit to Help

Example Code

from sqlmodel import Relationship, SQLModel, Field, create_engine
from typing import Optional
import uuid

class Principal(SQLModel, table=True):
    __tablename__ = "users"
    id: Optional[uuid.UUID] = Field(primary_key=True,nullable=False,default_factory=uuid.uuid4)
    is_active:bool = Field(default=True)
    type: str = Field(default="principal")
    __mapper_args__ = {
        'polymorphic_on':'type',
        'polymorphic_identity':'principal'
    }

class User(Principal,table=True):
    email: str
    __mapper_args__ = {
        'polymorphic_identity':'user'
    }

class ServiceUser(Principal,table=True):
    name: str
    owner_id: Optional[uuid.UUID] = Field(default=None, foreign_key=('users.id'))
    owner: "User" = Relationship()
    __mapper_args__ = {
        'polymorphic_identity':'serviceuser'
    }

sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"

engine = create_engine(sqlite_url, echo=True)

SQLModel.metadata.create_all(engine)

Description

Operating System

Windows

Operating System Details

No response

SQLModel Version

0.0.4

Python Version

3.8.10

Additional Context

I think I can fall back to sqlalchemy in this case without any problems, but maybe I am at a loss and it should be done in another way. Removing the "table=True" from the inherited classes makes no difference. Maybe this is also an edge case that should not be supported, but anyway it would be nice to see how this should be handled by people smarter than me. I am currently evaluating rewriting a backend to sqlmodel as it is already implemented in FastApi (which is amazing), and although I know it's early days for this project, I like what it tries to achieve :)

movabo commented 2 years ago

I found a workaround/hack on how to allow Single Table Inheritance as mentioned by you (Joined Table Inheritance does not work):

My example refers to the SQLAlchemy example but should work with your code as well:

class CustomMetaclass(SQLModelMetaclass):
    def __init__(
            cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
    ) -> None:
        patched = set()
        for i, base in enumerate(bases):
            config = getattr(base, "__config__")
            if config and getattr(config, "table", False):
                config.table = False
                patched.add(i)

        super().__init__(classname, bases, dict_, **kw)
        for i in patched:
            getattr(bases[i], "__config__").table = True

class Employee(SQLModel, table=True, metaclass=CustomMetaclass):
    __tablename__ = "employee"
    id: int = Field(primary_key=True)
    name: str = Field(nullable=True)
    type: str

    __mapper_args__ = {
        "polymorphic_identity": "employee",
        "polymorphic_on": "type"
    }

class Engineer(Employee, table=True):
    __table_args__ = {'extend_existing': True}
    engineer_info: str = Field(nullable=True)

    __mapper_args__ = {
        "polymorphic_identity": "engineer"
    }

class Manager(Employee, table=True):
    __table_args__ = {'extend_existing': True}
    manager_data: str = Field(nullable=True)

    __mapper_args__ = {
        "polymorphic_identity": "manager"
    }

# Edit: Column Conflicts are also resolvable:
# The equivalent of

@declared_attr
def start_date(cls):
    "Start date column, if not present already."
    return Employee.__table__.c.get('start_date', Column(DateTime))

# would be:

start_date: datetime = Field(sa_column=declared_attr(
    lambda cls: Employee.__table__.get('start_date')
))

The CustomMetaclass basically disables this check and thus allows the inheritance.

It might break some FastAPI functionality as the comment in the code of the check suggests. Because this is a hack it would be great if we could add an officially supported way of disabling this check to support (at least) single table inheritance. It seems to work well outside of FastAPI (I did not test it with FastAPI).

I see two ways in doing this:

  1. Explicit: Adding a keyword argument to disable the check (e.g. class Employee(SQLModel, table=True, allow_derived_tables=True):)
  2. Implicit: Disabling the check if __mapper_args__.["polymorphic_identity"] is set in all bases which are also tables and only if also the __tablename__ is the same across the bases (to only allow single table inheritance).

@tiangolo, would you accept a PR with one of the two ideas? If yes, which one would you prefer? Then I would prepare a PR. :)

shatteringlass commented 2 years ago

I found a workaround/hack on how to allow Single Table Inheritance as mentioned by you (Joined Table Inheritance does not work):

I tried this but it didn't seem to work for me on sqlmodel 0.0.6. Anyway, my project aims to use FastAPI so I am more interested in a proper solution rather than a workaround. Is this by any chance on the roadmap @tiangolo? Thanks a bunch!

BigTMiami commented 2 years ago

I would find this very useful as well.

jkehrbaum commented 2 years ago

Dear all, is there any workaround available for having polymorphic models? The code from movabo doesn’t seem to work within the latest version of SQLModel. Any ideas or plans to have it in future releases?

thx a lot

shatteringlass commented 2 years ago

I had posted my solution/workaround here:

https://matrix.to/#/!MYdZNkXpgAXqXRdstX:gitter.im/$cqBN6h_4QSiFRgjivn6rh3lsHAZIdDwvYLlwL8otP8s?via=gitter.im&via=matrix.org&via=averyan.ru

It's not super generic, but it could be useful to other people here as well, so here's the copypasta:

Hi all! i'd like to POST/GET a doubly-nested model (market_result -> result_curve (1:n) -> curve_point (1:n)) but I am struggling with the creation of the ORM object (object has no attribute '_sa_instance_state') so I'm probably doing something wrong. I should add that I'm also trying to use single-table-inheritance as i have subclassed my result_curve model into supply_result_curve and demand_result_curve, polymorphic on attribute "side". What's the easiest way one could achieve this?

I did manage to achieve what I wanted (albeit with more struggle than I would've expected), so I'll testify here in case someone has a similar problem. I ended up using Wouterkoorn/sqlalchemy-pydantic-orm and this is the code I wrote for models/schemas:

from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer,
                        Numeric, String)
from sqlalchemy.orm import relationship, declared_attr

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_pydantic_orm import ORMBaseSchema

from typing import Any, List, Optional
from decimal import Decimal
from pydantic import PrivateAttr
import datetime

# SQLAlchemy models

Base = declarative_base()

class ETSSResult(Base):
    __tablename__ = "T_ETSS_RESULTS"

    result_id = Column(Integer, primary_key=True)

    deliveryStartDT = Column('delivery', DateTime(timezone=True), index=True)
    MTUid = Column(String)
    deliveryStartDT = Column(DateTime(timezone=True), index=True)
    deliveryDuration = Column(Integer)
    targetMarket = Column(String, index=True)
    biddingZone = Column(String, index=True)
    netPosition = Column(Numeric)
    clearingPrice = Column(Numeric)
    calculatedPrice = Column(Numeric)
    currency = Column(String)
    matchedSupplyQuantityPPT = Column(Numeric)
    matchedSupplyQuantityBLK = Column(Numeric)
    matchedSupplyQuantityHBD = Column(Numeric)
    matchedDemandQuantityPPT = Column(Numeric)
    matchedDemandQuantityBLK = Column(Numeric)
    matchedDemandQuantityHBD = Column(Numeric)
    lastUpdateDT = Column(DateTime(timezone=True), index=True)

    supplyCurve = relationship(
        "ETSSMarketSupplyCurve", back_populates="result", uselist=False)
    demandCurve = relationship(
        "ETSSMarketDemandCurve", back_populates="result", uselist=False)

class ETSSMarketCurve(Base):
    __tablename__ = "T_ETSS_RESULTS_CURVES"
    curve_id = Column(Integer, primary_key=True)
    result_id = Column(ForeignKey('T_ETSS_RESULTS.result_id'))
    side = Column(String)
    points = relationship("ETSSMarketCurvePoint", back_populates="curve")

    __mapper_args__ = {
        'polymorphic_on': "side",
        'polymorphic_identity': "curve"
    }

class ETSSMarketSupplyCurve(ETSSMarketCurve):
    __mapper_args__ = {'polymorphic_identity': 'supply'}
    result = relationship(
        "ETSSResult", back_populates="supplyCurve", uselist=False)

    @declared_attr
    def side(cls):
        return ETSSMarketCurve.__table__.c.get('side', Column(String, default="supply"))

class ETSSMarketDemandCurve(ETSSMarketCurve):
    __mapper_args__ = {'polymorphic_identity': 'demand'}
    result = relationship(
        "ETSSResult", back_populates="demandCurve", uselist=False)

    @declared_attr
    def side(cls):
        return ETSSMarketCurve.__table__.c.get('side', Column(String, default="demand"))

class ETSSMarketCurvePoint(Base):
    __tablename__ = "T_ETSS_RESULTS_CURVE_POINTS"
    point_id = Column(Integer, primary_key=True)
    curve_id = Column(ForeignKey('T_ETSS_RESULTS_CURVES.curve_id'))
    price = Column(Numeric)
    volume = Column(Numeric)

    curve = relationship("ETSSMarketCurve", back_populates="points", uselist=False)

# Pydantic schemas

class ETSSMarketCurvePointBase(ORMBaseSchema):
    volume: Decimal
    price: Decimal
    _orm_model = PrivateAttr(model.ETSSMarketCurvePoint)

class ETSSMarketCurveBase(ORMBaseSchema):
    side: Optional[str]
    points: List[ETSSMarketCurvePointBase]
    _orm_model = PrivateAttr(model.ETSSMarketCurve)

class ETSSMarketSupplyCurveBase(ORMBaseSchema):
    side: Optional[str]
    points: List[ETSSMarketCurvePointBase]
    _orm_model = PrivateAttr(model.ETSSMarketSupplyCurve)

class ETSSMarketDemandCurveBase(ORMBaseSchema):
    side: Optional[str]
    points: List[ETSSMarketCurvePointBase]
    _orm_model = PrivateAttr(model.ETSSMarketDemandCurve)

class ETSSResultBase(ORMBaseSchema):
    MTUid: str
    deliveryStartDT: datetime.datetime
    deliveryDuration: int
    targetMarket: str
    biddingZone: str
    netPosition: Decimal
    clearingPrice: Decimal
    calculatedPrice: Decimal
    currency: str
    matchedSupplyQuantityPPT: Decimal
    matchedSupplyQuantityBLK: Decimal
    matchedSupplyQuantityHBD: Decimal
    matchedDemandQuantityPPT: Decimal
    matchedDemandQuantityBLK: Decimal
    matchedDemandQuantityHBD: Decimal
    lastUpdateDT: datetime.datetime
    _orm_model = PrivateAttr(model.ETSSResult)

# GET schemas

class ETSSMarketCurvePoint(ETSSMarketCurvePointBase):
    point_id: Optional[int]
    curve_id: Optional[int]

class ETSSMarketCurve(ETSSMarketCurveBase):
    curve_id: Optional[int]
    result_id: Optional[int]
    points: List[ETSSMarketCurvePoint]

class ETSSMarketSupplyCurve(ETSSMarketSupplyCurveBase):
    curve_id: Optional[int]
    result_id: Optional[int]
    points: List[ETSSMarketCurvePoint]

class ETSSMarketDemandCurve(ETSSMarketDemandCurveBase):
    curve_id: Optional[int]
    result_id: Optional[int]
    points: List[ETSSMarketCurvePoint]

class ETSSResult(ETSSResultBase):
    result_id: Optional[int]
    supplyCurve: ETSSMarketSupplyCurve
    demandCurve: ETSSMarketDemandCurve

class ETSSResultCreate(ETSSResultBase):
    supplyCurve: List[ETSSMarketCurvePointBase]
    demandCurve: List[ETSSMarketCurvePointBase]
LaQuay commented 1 year ago

Hi, I am struggling with this problem as well. The proposals provided here don't work either for me.

fastapi==0.86.0
pydantic==1.10.2
uvicorn==0.19.0
sqlmodel==0.0.8
psycopg2-binary==2.9.5

@shatteringlass Thanks but your "solution" is for FastAPI + SQLAlchemy, with no SQLModel involved in your code. This issue is for FastAPI + SQLModel.

Any ideas? Thanks

DrOncogene commented 1 year ago

Hello This is still not working. From my own end, I noticed the problem is not really the polymorphic inheritance. It is that SQLModel does not support model inheritance that is more than one layer deep, regardless of whether it is polymorphic or not.

I really need this to work for a project I'm currently working on as I want to use it with FastAPI. Thank you

andruli commented 1 year ago

I have a very ugly but working workaround for 0.0.8 for single table and joined inheritance.

I started from @movabo's answer, sadly that no longer works.

@DrOncogene was also correct, the problem is SQLModel doesn't allow inheritance of tables.

The things the workaround does:

  1. Declares a new metaclass for SQLModel that allows inheritance
  2. When inheriting from a table, craft a specific used_dict that doesn't include the properties of the part models (to avoid the duplicate field errors)
  3. Declares a modified DeclarativeMeta that repects the fields we sent since by default it automatically overrides that

All in all, this is the workaround

"""Workaround to make single class / joined table inheritance work with SQLModel.

https://github.com/tiangolo/sqlmodel/issues/36
"""

from typing import Any

from sqlalchemy import exc
from sqlalchemy.orm import registry
from sqlalchemy.orm.decl_api import _as_declarative  # type: ignore
from sqlmodel.main import (
    BaseConfig,  # type: ignore
    DeclarativeMeta,  # type: ignore
    ForwardRef,  # type: ignore
    ModelField,  # type: ignore
    ModelMetaclass,  # type: ignore
    RelationshipProperty,  # type: ignore
    SQLModelMetaclass,
    get_column_from_field,
    inspect,  # type: ignore
    relationship,  # type: ignore
)

class SQLModelPolymorphicAwareMetaClass(SQLModelMetaclass):
    """Workaround to make single table inheritance work with SQLModel."""

    def __init__(  # noqa: C901, PLR0912
        cls, classname: str, bases: tuple[type, ...], dict_: dict[str, Any], **kw: Any  # noqa: ANN401, N805, ANN101
    ) -> None:
        # Only one of the base classes (or the current one) should be a table model
        # this allows FastAPI cloning a SQLModel for the response_model without
        # trying to create a new SQLAlchemy, for a new table, with the same name, that
        # triggers an error
        base_table: type | None = None
        is_polymorphic = False
        for base in bases:
            config = getattr(base, "__config__")  # noqa: B009
            if config and getattr(config, "table", False):
                base_table = base
                is_polymorphic = bool(getattr(base, "__mapper_args__", {}).get("polymorphic_on"))
                break
        is_polymorphic &= bool(getattr(cls, "__mapper_args__", {}).get("polymorphic_identity"))
        if getattr(cls.__config__, "table", False) and (not base_table or is_polymorphic):
            dict_used = dict_.copy()
            for field_name, field_value in cls.__fields__.items():
                # Do not include fields from the parent table if we are using inheritance
                if base_table and field_name in getattr(base_table, "__fields__", {}):
                    continue
                dict_used[field_name] = get_column_from_field(field_value)
            for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
                # Do not include fields from the parent table if we are using inheritance
                if base_table and rel_name in getattr(base_table, "__sqlmodel_relationships__", {}):
                    continue
                if rel_info.sa_relationship:
                    # There's a SQLAlchemy relationship declared, that takes precedence
                    # over anything else, use that and continue with the next attribute
                    dict_used[rel_name] = rel_info.sa_relationship
                    continue
                ann = cls.__annotations__[rel_name]
                temp_field = ModelField.infer(
                    name=rel_name,
                    value=rel_info,
                    annotation=ann,
                    class_validators=None,
                    config=BaseConfig,
                )
                relationship_to = temp_field.type_
                if isinstance(temp_field.type_, ForwardRef):
                    relationship_to = temp_field.type_.__forward_arg__
                rel_kwargs: dict[str, Any] = {}
                if rel_info.back_populates:
                    rel_kwargs["back_populates"] = rel_info.back_populates
                if rel_info.link_model:
                    ins = inspect(rel_info.link_model)
                    local_table = getattr(ins, "local_table")  # noqa: B009
                    if local_table is None:
                        msg = f"Couldn't find the secondary table for model {rel_info.link_model}"
                        raise RuntimeError(msg)
                    rel_kwargs["secondary"] = local_table
                rel_args: list[Any] = []
                if rel_info.sa_relationship_args:
                    rel_args.extend(rel_info.sa_relationship_args)
                if rel_info.sa_relationship_kwargs:
                    rel_kwargs.update(rel_info.sa_relationship_kwargs)
                rel_value: RelationshipProperty = relationship(  # type: ignore
                    relationship_to, *rel_args, **rel_kwargs
                )
                dict_used[rel_name] = rel_value
                setattr(cls, rel_name, rel_value)  # Fix #315
            PatchedDeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw)  # type: ignore
        else:
            ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)

class PatchedDeclarativeMeta(DeclarativeMeta):  # noqa: D101
    def __init__(
        cls, classname: str, bases: tuple[type, ...], dict_, **kw  # noqa: N805, ANN001, ARG002, ANN101, ANN003
    ) -> None:
        # early-consume registry from the initial declarative base,
        # assign privately to not conflict with subclass attributes named
        # "registry"
        reg = getattr(cls, "_sa_registry", None)
        if reg is None:
            reg = dict_.get("registry", None)
            if not isinstance(reg, registry):
                msg = "Declarative base class has no 'registry' attribute, or registry is not a sqlalchemy.orm.registry() object"
                raise exc.InvalidRequestError(msg)
            cls._sa_registry = reg

        if not cls.__dict__.get("__abstract__", False):
            _as_declarative(reg, cls, dict_)
        type.__init__(cls, classname, bases, dict_)

And you'd use it like this for joined inheritance

class Employee(SQLModel, table=True, metaclass=SQLModelPolymorphicAwareMetaClass):
    __tablename__ = "employee"
    id: int = Field(primary_key=True)
    name: str = Field(nullable=True)
    type: str

    __mapper_args__ = {
        "polymorphic_identity": "employee",
        "polymorphic_on": "type"
    }

class Engineer(Employee, table=True):
    __tablename__ = "employee"

    employee_id: int = Field(primary_key=True, foreign_key="employee.id")
    engineer_info: str = Field(nullable=True)

    __mapper_args__ = {
        "polymorphic_identity": "engineer"
    }

And this for single table inheritance And you'd use it like this for joined inheritance

class Employee(SQLModel, table=True, metaclass=SQLModelPolymorphicAwareMetaClass):
    __tablename__ = "employee"
    id: int = Field(primary_key=True)
    name: str = Field(nullable=True)
    type: str

    __mapper_args__ = {
        "polymorphic_identity": "employee",
        "polymorphic_on": "type"
    }

class Engineer(Employee, table=True):
    __tablename__ = None  # Putting and explicit None here is important

    employee_id: int = Field(primary_key=True, foreign_key="employee.id")
    engineer_info: str = Field(nullable=True)

    __mapper_args__ = {
        "polymorphic_identity": "engineer"
    }

It looks like a lot of code, but it's mostly because I needed to copy the original metaclasses and just do a few modifications here and there. The actual changes are very little (<10 lines).

PS. This is a workaround and code WILL MOST LIKELY BREAK when either SQLModel or SQLAlchemy update.

bosukas commented 9 months ago

Any updates on this one?

hslira commented 5 months ago

just checking

fsackur commented 3 months ago

I can't really use metaclasses for my app, so here's an ugly workaround for joined inheritance. I prefer the code from @andruli .

v0.0.18.

from enum import Enum

from db import engine
from sqlalchemy import ForeignKey
from sqlmodel import Field, SQLModel
from sqlmodel.main import default_registry

PKey = Field(primary_key=True)
FKey = ForeignKey("barcodes.barcode")

class BarcodedType(str, Enum):
    Tool = "Tool"
    User = "User"

class Barcoded(SQLModel, table=True):
    __tablename__ = "barcodes"  # type: ignore
    barcode: str = Field(primary_key=True, max_length=10)
    type: BarcodedType
    __mapper_args__ = {
        "polymorphic_identity": "Barcoded",
        "polymorphic_abstract": True,
        "polymorphic_on": "type",
    }

# @default_registry.mapped  # does not work when there are multiple child classes
class Tool(Barcoded, table=True):
    __tablename__ = "tools"  # type: ignore
    barcode: str = Field(FKey, primary_key=True)
    type: BarcodedType = BarcodedType.Tool
    name: str = Field()
    __mapper_args__ = {
        "polymorphic_identity": BarcodedType.Tool,
        # can't set this here, or future queries will raise
        # _mysql_connector.MySQLInterfaceError: Python type FieldInfo cannot be converted
        # this gets memoised in sqlalchemy, you can't let it get resolved early
        # "inherit_condition": barcode == Barcoded.barcode,
        "inherit_condition": None,  # masks type error
    }

# updates Tool.barcode to replace FieldInfo with InstrumentedAttribute
SQLModel.metadata.create_all(bind=engine)

# set the join condition _after_ converting the Tool.barcode attribute to sqlalchemy native
# otherwise, you get _mysql_connector.MySQLInterfaceError: Python type FieldInfo cannot be converted
Tool.__mapper_args__["inherit_condition"] = Barcoded.barcode

# raises: dictionary changed size during iteration
# Tool = default_registry.mapped(Tool)

# @default_registry.mapped  # does not work when there are multiple child classes
class User(Barcoded, table=True):
    __tablename__ = "users"  # type: ignore
    barcode: str = Field(FKey, primary_key=True)
    name: str = Field()
    __mapper_args__ = {
        "polymorphic_identity": BarcodedType.User,
        "inherit_condition": None,  # masks type error
    }

SQLModel.metadata.create_all(bind=engine)
User.__mapper_args__["inherit_condition"] = Barcoded.barcode

# if you depend on these classes in other models, e.g. for a relationship,
# then you need to decorate the model with @mapped. That fails when any child
# class of the same parent has already been mapped. So, declare all the classes,
# then map them after declaration
Tool = default_registry.mapped(Tool)
User = default_registry.mapped(User)
schema ```sql CREATE TABLE `barcodes` ( `barcode` varchar(255) NOT NULL, `type` enum('Tool','User','Team') NOT NULL, PRIMARY KEY (`barcode`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci CREATE TABLE `tools` ( `barcode` varchar(255) NOT NULL, `type` enum('Tool','User','Team') NOT NULL, `name` varchar(255) NOT NULL, PRIMARY KEY (`barcode`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci ```
emitted warnings ```log site-packages/pydantic/_internal/_fields.py:200: UserWarning: Field name "barcode" in "Tool" shadows an attribute in parent "Barcoded" site-packages/pydantic/_internal/_fields.py:200: UserWarning: Field name "type" in "Tool" shadows an attribute in parent "Barcoded" site-packages/pydantic/_internal/_fields.py:200: UserWarning: Field name "barcode" in "User" shadows an attribute in parent "Barcoded" model.py:72: SAWarning: Implicitly combining column barcodes.barcode with column tools.barcode under attribute 'barcode'. Please configure one or more attributes for these same-named columns explicitly. model.py:72: SAWarning: Implicitly combining column barcodes.type with column tools.type under attribute 'type'. Please configure one or more attributes for these same-named columns explicitly. model.py:73: SAWarning: Implicitly combining column barcodes.barcode with column users.barcode under attribute 'barcode'. Please configure one or more attributes for these same-named columns explicitly. model.py:73: SAWarning: Implicitly combining column barcodes.type with column users.type under attribute 'type'. Please configure one or more attributes for these same-named columns explicitly. ```
ixycoexzckwpmlcu commented 2 weeks ago

I couldn't find a solution for single table inheritance for SQLModel 0.0.22, so I implemented my own solution here.

Maybe related: #438, #488