Open ludokriss opened 3 years ago
I found a workaround/hack on how to allow Single Table Inheritance as mentioned by you (Joined Table Inheritance does not work):
My example refers to the SQLAlchemy example but should work with your code as well:
class CustomMetaclass(SQLModelMetaclass):
def __init__(
cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any
) -> None:
patched = set()
for i, base in enumerate(bases):
config = getattr(base, "__config__")
if config and getattr(config, "table", False):
config.table = False
patched.add(i)
super().__init__(classname, bases, dict_, **kw)
for i in patched:
getattr(bases[i], "__config__").table = True
class Employee(SQLModel, table=True, metaclass=CustomMetaclass):
__tablename__ = "employee"
id: int = Field(primary_key=True)
name: str = Field(nullable=True)
type: str
__mapper_args__ = {
"polymorphic_identity": "employee",
"polymorphic_on": "type"
}
class Engineer(Employee, table=True):
__table_args__ = {'extend_existing': True}
engineer_info: str = Field(nullable=True)
__mapper_args__ = {
"polymorphic_identity": "engineer"
}
class Manager(Employee, table=True):
__table_args__ = {'extend_existing': True}
manager_data: str = Field(nullable=True)
__mapper_args__ = {
"polymorphic_identity": "manager"
}
# Edit: Column Conflicts are also resolvable:
# The equivalent of
@declared_attr
def start_date(cls):
"Start date column, if not present already."
return Employee.__table__.c.get('start_date', Column(DateTime))
# would be:
start_date: datetime = Field(sa_column=declared_attr(
lambda cls: Employee.__table__.get('start_date')
))
The CustomMetaclass
basically disables this check and thus allows the inheritance.
It might break some FastAPI functionality as the comment in the code of the check suggests. Because this is a hack it would be great if we could add an officially supported way of disabling this check to support (at least) single table inheritance. It seems to work well outside of FastAPI (I did not test it with FastAPI).
I see two ways in doing this:
class Employee(SQLModel, table=True, allow_derived_tables=True):
)__mapper_args__.["polymorphic_identity"]
is set in all bases which are also tables and only if also the __tablename__
is the same across the bases (to only allow single table inheritance).@tiangolo, would you accept a PR with one of the two ideas? If yes, which one would you prefer? Then I would prepare a PR. :)
I found a workaround/hack on how to allow Single Table Inheritance as mentioned by you (Joined Table Inheritance does not work):
I tried this but it didn't seem to work for me on sqlmodel 0.0.6. Anyway, my project aims to use FastAPI so I am more interested in a proper solution rather than a workaround. Is this by any chance on the roadmap @tiangolo? Thanks a bunch!
I would find this very useful as well.
Dear all, is there any workaround available for having polymorphic models? The code from movabo doesn’t seem to work within the latest version of SQLModel. Any ideas or plans to have it in future releases?
thx a lot
I had posted my solution/workaround here:
It's not super generic, but it could be useful to other people here as well, so here's the copypasta:
Hi all! i'd like to POST/GET a doubly-nested model (market_result -> result_curve (1:n) -> curve_point (1:n)) but I am struggling with the creation of the ORM object (object has no attribute '_sa_instance_state') so I'm probably doing something wrong. I should add that I'm also trying to use single-table-inheritance as i have subclassed my result_curve model into supply_result_curve and demand_result_curve, polymorphic on attribute "side". What's the easiest way one could achieve this?
I did manage to achieve what I wanted (albeit with more struggle than I would've expected), so I'll testify here in case someone has a similar problem. I ended up using Wouterkoorn/sqlalchemy-pydantic-orm and this is the code I wrote for models/schemas:
from sqlalchemy import (Boolean, Column, DateTime, ForeignKey, Integer, Numeric, String) from sqlalchemy.orm import relationship, declared_attr from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_pydantic_orm import ORMBaseSchema from typing import Any, List, Optional from decimal import Decimal from pydantic import PrivateAttr import datetime # SQLAlchemy models Base = declarative_base() class ETSSResult(Base): __tablename__ = "T_ETSS_RESULTS" result_id = Column(Integer, primary_key=True) deliveryStartDT = Column('delivery', DateTime(timezone=True), index=True) MTUid = Column(String) deliveryStartDT = Column(DateTime(timezone=True), index=True) deliveryDuration = Column(Integer) targetMarket = Column(String, index=True) biddingZone = Column(String, index=True) netPosition = Column(Numeric) clearingPrice = Column(Numeric) calculatedPrice = Column(Numeric) currency = Column(String) matchedSupplyQuantityPPT = Column(Numeric) matchedSupplyQuantityBLK = Column(Numeric) matchedSupplyQuantityHBD = Column(Numeric) matchedDemandQuantityPPT = Column(Numeric) matchedDemandQuantityBLK = Column(Numeric) matchedDemandQuantityHBD = Column(Numeric) lastUpdateDT = Column(DateTime(timezone=True), index=True) supplyCurve = relationship( "ETSSMarketSupplyCurve", back_populates="result", uselist=False) demandCurve = relationship( "ETSSMarketDemandCurve", back_populates="result", uselist=False) class ETSSMarketCurve(Base): __tablename__ = "T_ETSS_RESULTS_CURVES" curve_id = Column(Integer, primary_key=True) result_id = Column(ForeignKey('T_ETSS_RESULTS.result_id')) side = Column(String) points = relationship("ETSSMarketCurvePoint", back_populates="curve") __mapper_args__ = { 'polymorphic_on': "side", 'polymorphic_identity': "curve" } class ETSSMarketSupplyCurve(ETSSMarketCurve): __mapper_args__ = {'polymorphic_identity': 'supply'} result = relationship( "ETSSResult", back_populates="supplyCurve", uselist=False) @declared_attr def side(cls): return ETSSMarketCurve.__table__.c.get('side', Column(String, default="supply")) class ETSSMarketDemandCurve(ETSSMarketCurve): __mapper_args__ = {'polymorphic_identity': 'demand'} result = relationship( "ETSSResult", back_populates="demandCurve", uselist=False) @declared_attr def side(cls): return ETSSMarketCurve.__table__.c.get('side', Column(String, default="demand")) class ETSSMarketCurvePoint(Base): __tablename__ = "T_ETSS_RESULTS_CURVE_POINTS" point_id = Column(Integer, primary_key=True) curve_id = Column(ForeignKey('T_ETSS_RESULTS_CURVES.curve_id')) price = Column(Numeric) volume = Column(Numeric) curve = relationship("ETSSMarketCurve", back_populates="points", uselist=False) # Pydantic schemas class ETSSMarketCurvePointBase(ORMBaseSchema): volume: Decimal price: Decimal _orm_model = PrivateAttr(model.ETSSMarketCurvePoint) class ETSSMarketCurveBase(ORMBaseSchema): side: Optional[str] points: List[ETSSMarketCurvePointBase] _orm_model = PrivateAttr(model.ETSSMarketCurve) class ETSSMarketSupplyCurveBase(ORMBaseSchema): side: Optional[str] points: List[ETSSMarketCurvePointBase] _orm_model = PrivateAttr(model.ETSSMarketSupplyCurve) class ETSSMarketDemandCurveBase(ORMBaseSchema): side: Optional[str] points: List[ETSSMarketCurvePointBase] _orm_model = PrivateAttr(model.ETSSMarketDemandCurve) class ETSSResultBase(ORMBaseSchema): MTUid: str deliveryStartDT: datetime.datetime deliveryDuration: int targetMarket: str biddingZone: str netPosition: Decimal clearingPrice: Decimal calculatedPrice: Decimal currency: str matchedSupplyQuantityPPT: Decimal matchedSupplyQuantityBLK: Decimal matchedSupplyQuantityHBD: Decimal matchedDemandQuantityPPT: Decimal matchedDemandQuantityBLK: Decimal matchedDemandQuantityHBD: Decimal lastUpdateDT: datetime.datetime _orm_model = PrivateAttr(model.ETSSResult) # GET schemas class ETSSMarketCurvePoint(ETSSMarketCurvePointBase): point_id: Optional[int] curve_id: Optional[int] class ETSSMarketCurve(ETSSMarketCurveBase): curve_id: Optional[int] result_id: Optional[int] points: List[ETSSMarketCurvePoint] class ETSSMarketSupplyCurve(ETSSMarketSupplyCurveBase): curve_id: Optional[int] result_id: Optional[int] points: List[ETSSMarketCurvePoint] class ETSSMarketDemandCurve(ETSSMarketDemandCurveBase): curve_id: Optional[int] result_id: Optional[int] points: List[ETSSMarketCurvePoint] class ETSSResult(ETSSResultBase): result_id: Optional[int] supplyCurve: ETSSMarketSupplyCurve demandCurve: ETSSMarketDemandCurve class ETSSResultCreate(ETSSResultBase): supplyCurve: List[ETSSMarketCurvePointBase] demandCurve: List[ETSSMarketCurvePointBase]
Hi, I am struggling with this problem as well. The proposals provided here don't work either for me.
fastapi==0.86.0
pydantic==1.10.2
uvicorn==0.19.0
sqlmodel==0.0.8
psycopg2-binary==2.9.5
@shatteringlass Thanks but your "solution" is for FastAPI + SQLAlchemy, with no SQLModel involved in your code. This issue is for FastAPI + SQLModel.
Any ideas? Thanks
Hello This is still not working. From my own end, I noticed the problem is not really the polymorphic inheritance. It is that SQLModel does not support model inheritance that is more than one layer deep, regardless of whether it is polymorphic or not.
I really need this to work for a project I'm currently working on as I want to use it with FastAPI. Thank you
I have a very ugly but working workaround for 0.0.8
for single table and joined inheritance.
I started from @movabo's answer, sadly that no longer works.
@DrOncogene was also correct, the problem is SQLModel doesn't allow inheritance of tables.
The things the workaround does:
used_dict
that doesn't include the properties of the part models (to avoid the duplicate field errors)DeclarativeMeta
that repects the fields we sent since by default it automatically overrides thatAll in all, this is the workaround
"""Workaround to make single class / joined table inheritance work with SQLModel.
https://github.com/tiangolo/sqlmodel/issues/36
"""
from typing import Any
from sqlalchemy import exc
from sqlalchemy.orm import registry
from sqlalchemy.orm.decl_api import _as_declarative # type: ignore
from sqlmodel.main import (
BaseConfig, # type: ignore
DeclarativeMeta, # type: ignore
ForwardRef, # type: ignore
ModelField, # type: ignore
ModelMetaclass, # type: ignore
RelationshipProperty, # type: ignore
SQLModelMetaclass,
get_column_from_field,
inspect, # type: ignore
relationship, # type: ignore
)
class SQLModelPolymorphicAwareMetaClass(SQLModelMetaclass):
"""Workaround to make single table inheritance work with SQLModel."""
def __init__( # noqa: C901, PLR0912
cls, classname: str, bases: tuple[type, ...], dict_: dict[str, Any], **kw: Any # noqa: ANN401, N805, ANN101
) -> None:
# Only one of the base classes (or the current one) should be a table model
# this allows FastAPI cloning a SQLModel for the response_model without
# trying to create a new SQLAlchemy, for a new table, with the same name, that
# triggers an error
base_table: type | None = None
is_polymorphic = False
for base in bases:
config = getattr(base, "__config__") # noqa: B009
if config and getattr(config, "table", False):
base_table = base
is_polymorphic = bool(getattr(base, "__mapper_args__", {}).get("polymorphic_on"))
break
is_polymorphic &= bool(getattr(cls, "__mapper_args__", {}).get("polymorphic_identity"))
if getattr(cls.__config__, "table", False) and (not base_table or is_polymorphic):
dict_used = dict_.copy()
for field_name, field_value in cls.__fields__.items():
# Do not include fields from the parent table if we are using inheritance
if base_table and field_name in getattr(base_table, "__fields__", {}):
continue
dict_used[field_name] = get_column_from_field(field_value)
for rel_name, rel_info in cls.__sqlmodel_relationships__.items():
# Do not include fields from the parent table if we are using inheritance
if base_table and rel_name in getattr(base_table, "__sqlmodel_relationships__", {}):
continue
if rel_info.sa_relationship:
# There's a SQLAlchemy relationship declared, that takes precedence
# over anything else, use that and continue with the next attribute
dict_used[rel_name] = rel_info.sa_relationship
continue
ann = cls.__annotations__[rel_name]
temp_field = ModelField.infer(
name=rel_name,
value=rel_info,
annotation=ann,
class_validators=None,
config=BaseConfig,
)
relationship_to = temp_field.type_
if isinstance(temp_field.type_, ForwardRef):
relationship_to = temp_field.type_.__forward_arg__
rel_kwargs: dict[str, Any] = {}
if rel_info.back_populates:
rel_kwargs["back_populates"] = rel_info.back_populates
if rel_info.link_model:
ins = inspect(rel_info.link_model)
local_table = getattr(ins, "local_table") # noqa: B009
if local_table is None:
msg = f"Couldn't find the secondary table for model {rel_info.link_model}"
raise RuntimeError(msg)
rel_kwargs["secondary"] = local_table
rel_args: list[Any] = []
if rel_info.sa_relationship_args:
rel_args.extend(rel_info.sa_relationship_args)
if rel_info.sa_relationship_kwargs:
rel_kwargs.update(rel_info.sa_relationship_kwargs)
rel_value: RelationshipProperty = relationship( # type: ignore
relationship_to, *rel_args, **rel_kwargs
)
dict_used[rel_name] = rel_value
setattr(cls, rel_name, rel_value) # Fix #315
PatchedDeclarativeMeta.__init__(cls, classname, bases, dict_used, **kw) # type: ignore
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)
class PatchedDeclarativeMeta(DeclarativeMeta): # noqa: D101
def __init__(
cls, classname: str, bases: tuple[type, ...], dict_, **kw # noqa: N805, ANN001, ARG002, ANN101, ANN003
) -> None:
# early-consume registry from the initial declarative base,
# assign privately to not conflict with subclass attributes named
# "registry"
reg = getattr(cls, "_sa_registry", None)
if reg is None:
reg = dict_.get("registry", None)
if not isinstance(reg, registry):
msg = "Declarative base class has no 'registry' attribute, or registry is not a sqlalchemy.orm.registry() object"
raise exc.InvalidRequestError(msg)
cls._sa_registry = reg
if not cls.__dict__.get("__abstract__", False):
_as_declarative(reg, cls, dict_)
type.__init__(cls, classname, bases, dict_)
And you'd use it like this for joined inheritance
class Employee(SQLModel, table=True, metaclass=SQLModelPolymorphicAwareMetaClass):
__tablename__ = "employee"
id: int = Field(primary_key=True)
name: str = Field(nullable=True)
type: str
__mapper_args__ = {
"polymorphic_identity": "employee",
"polymorphic_on": "type"
}
class Engineer(Employee, table=True):
__tablename__ = "employee"
employee_id: int = Field(primary_key=True, foreign_key="employee.id")
engineer_info: str = Field(nullable=True)
__mapper_args__ = {
"polymorphic_identity": "engineer"
}
And this for single table inheritance And you'd use it like this for joined inheritance
class Employee(SQLModel, table=True, metaclass=SQLModelPolymorphicAwareMetaClass):
__tablename__ = "employee"
id: int = Field(primary_key=True)
name: str = Field(nullable=True)
type: str
__mapper_args__ = {
"polymorphic_identity": "employee",
"polymorphic_on": "type"
}
class Engineer(Employee, table=True):
__tablename__ = None # Putting and explicit None here is important
employee_id: int = Field(primary_key=True, foreign_key="employee.id")
engineer_info: str = Field(nullable=True)
__mapper_args__ = {
"polymorphic_identity": "engineer"
}
It looks like a lot of code, but it's mostly because I needed to copy the original metaclasses and just do a few modifications here and there. The actual changes are very little (<10 lines).
PS. This is a workaround and code WILL MOST LIKELY BREAK when either SQLModel or SQLAlchemy update.
Any updates on this one?
just checking
I can't really use metaclasses for my app, so here's an ugly workaround for joined inheritance. I prefer the code from @andruli .
v0.0.18.
from enum import Enum
from db import engine
from sqlalchemy import ForeignKey
from sqlmodel import Field, SQLModel
from sqlmodel.main import default_registry
PKey = Field(primary_key=True)
FKey = ForeignKey("barcodes.barcode")
class BarcodedType(str, Enum):
Tool = "Tool"
User = "User"
class Barcoded(SQLModel, table=True):
__tablename__ = "barcodes" # type: ignore
barcode: str = Field(primary_key=True, max_length=10)
type: BarcodedType
__mapper_args__ = {
"polymorphic_identity": "Barcoded",
"polymorphic_abstract": True,
"polymorphic_on": "type",
}
# @default_registry.mapped # does not work when there are multiple child classes
class Tool(Barcoded, table=True):
__tablename__ = "tools" # type: ignore
barcode: str = Field(FKey, primary_key=True)
type: BarcodedType = BarcodedType.Tool
name: str = Field()
__mapper_args__ = {
"polymorphic_identity": BarcodedType.Tool,
# can't set this here, or future queries will raise
# _mysql_connector.MySQLInterfaceError: Python type FieldInfo cannot be converted
# this gets memoised in sqlalchemy, you can't let it get resolved early
# "inherit_condition": barcode == Barcoded.barcode,
"inherit_condition": None, # masks type error
}
# updates Tool.barcode to replace FieldInfo with InstrumentedAttribute
SQLModel.metadata.create_all(bind=engine)
# set the join condition _after_ converting the Tool.barcode attribute to sqlalchemy native
# otherwise, you get _mysql_connector.MySQLInterfaceError: Python type FieldInfo cannot be converted
Tool.__mapper_args__["inherit_condition"] = Barcoded.barcode
# raises: dictionary changed size during iteration
# Tool = default_registry.mapped(Tool)
# @default_registry.mapped # does not work when there are multiple child classes
class User(Barcoded, table=True):
__tablename__ = "users" # type: ignore
barcode: str = Field(FKey, primary_key=True)
name: str = Field()
__mapper_args__ = {
"polymorphic_identity": BarcodedType.User,
"inherit_condition": None, # masks type error
}
SQLModel.metadata.create_all(bind=engine)
User.__mapper_args__["inherit_condition"] = Barcoded.barcode
# if you depend on these classes in other models, e.g. for a relationship,
# then you need to decorate the model with @mapped. That fails when any child
# class of the same parent has already been mapped. So, declare all the classes,
# then map them after declaration
Tool = default_registry.mapped(Tool)
User = default_registry.mapped(User)
I couldn't find a solution for single table inheritance for SQLModel 0.0.22, so I implemented my own solution here.
Maybe related: #438, #488
I have implement this feature in #1226 .
Can any maintainer review that? @tiangolo
First Check
Commit to Help
Example Code
Description
Operating System
Windows
Operating System Details
No response
SQLModel Version
0.0.4
Python Version
3.8.10
Additional Context
I think I can fall back to sqlalchemy in this case without any problems, but maybe I am at a loss and it should be done in another way. Removing the "table=True" from the inherited classes makes no difference. Maybe this is also an edge case that should not be supported, but anyway it would be nice to see how this should be handled by people smarter than me. I am currently evaluating rewriting a backend to sqlmodel as it is already implemented in FastApi (which is amazing), and although I know it's early days for this project, I like what it tries to achieve :)