Pogchamp-company / alembic-postgresql-enum

Alembic autogenerate support for creation, alteration and deletion of enums
MIT License
184 stars 10 forks source link

Enums are not created if the schema is created in the same migration #40

Closed jankatins closed 11 months ago

jankatins commented 11 months ago

Not sure if this is an alembic bug/problem or an alembic-postgresql-enum one:

I generate a file with sqlalchemy models which includes a new schema and some tables. I also have some code which automatically creates schemas. The problems is now that this code doesn't see the new schema yet:

@alembic.autogenerate.comparators.dispatch_for("schema")
def compare_enums(autogen_context: AutogenContext, upgrade_ops: UpgradeOps, schema_names: Iterable[Union[str, None]]):
    add_create_type_false(upgrade_ops) # Here the original enums are set to "create_type=False"

    log.info("All Schemas %r", schema_names) # Here you see that the new schema is not yet included

    for schema in schema_names:
        ...
        # And here would the new enums be created, but this is never reached as the schema is not included
        create_new_enums(definitions, declarations.enum_values, schema, upgrade_ops) 

After finding out, I now create the schema manually in one migration and then generate the real migration on top of it. But it seems that "something" here should actually be aware that there is another schema.

I've now some other code which also dispatches on"schema"and that uses this code snippet to get all used schemas:

    used_schemas = set()
    for operations_group in upgrade_ops.ops:
        # For me only create table is relevant, not sure if there could be others? 
        # For the enum case, it needs to at elast union in the current schema_names
        if isinstance(operations_group, CreateTableOp) and operations_group.schema:
            used_schemas.add(operations_group.schema)
RustyGuard commented 11 months ago

Do you have include_schemas parameter set to true in env.py? This is "something" that informs alembic about existance of another schemas

jankatins commented 11 months ago

Yes, in both offline and online:

[...]
    context.configure(
        [...]
        include_schemas=True,
        [...]
    )
[...]
RustyGuard commented 11 months ago

Please show the whole env.py

RustyGuard commented 11 months ago

Since schema creation can only be made by execution of raw sql separate migation is necessary. However when schema is created it should be added by include_schemas. So there is no need to check for additional schemas in upgrade_ops.ops.

jankatins commented 11 months ago

I actually have a workaround to create schemas automatically. Here is the env.py (removed most schemas, because $work):

import logging
import os
from collections.abc import Iterable
from logging.config import fileConfig
from typing import Any

import alembic

# For better enum handling
import alembic_postgresql_enum  # type: ignore[import-untyped]
import sqlalchemy.sql.base
from alembic import context
from alembic.autogenerate.api import AutogenContext
from alembic.operations.ops import (
    CreateTableOp,
    ExecuteSQLOp,
    UpgradeOps,
)
from alembic.script import ScriptDirectory
from alembic_utils.pg_grant_table import PGGrantTable
from alembic_utils.replaceable_entity import register_entities
from sqlalchemy import engine_from_config, pool

import elt_service.ingestion.sqlalchemy_models as ingestion_metadata_classes
from app.sqlalchemy import (
    a_sqla_classes,
    b_sqla_classes,
    # ... lot's of more ...
)
from elt_service.settings_run_migrations import SettingsRunMigrations

logging.basicConfig(level=logging.INFO)

# To not get it removed again on autoformat
assert alembic_postgresql_enum  # noqa: S101

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config

# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
    fileConfig(config.config_file_name)

# a local logger
_logger = logging.getLogger("alembic")

# Add here the SQL Alchemy metadata entries
target_metadata = [
    ingestion_metadata_classes.Base.metadata,
    a_sqla_classes.Base.metadata,
    b_sqla_classes.Base.metadata,
    # ... again lot's more...
]

# add functions
replaceable_entities = []
replaceable_entities.extend(ingestion_metadata_classes.replaceable_entities)

register_entities(replaceable_entities)

class ExecuteArbitraryDDLOp(ExecuteSQLOp):
    def __init__(
        self,
        ddl: sqlalchemy.sql.base.Executable | str,
        reverse_ddl: sqlalchemy.sql.base.Executable | str,
        *,
        execution_options: dict[str, Any] | None = None,
    ) -> None:
        """A DDL Operation with both upgrade and downgrade commands."""
        super().__init__(ddl, execution_options=execution_options)
        self.reverse_ddl = reverse_ddl

    def reverse(self) -> "ExecuteArbitraryDDLOp":
        """Return the reverse of this ArbitraryDDL operation (used for downgrades)."""
        return ExecuteArbitraryDDLOp(
            ddl=self.reverse_ddl, reverse_ddl=self.sqltext, execution_options=self.execution_options
        )

@alembic.autogenerate.comparators.dispatch_for("schema")
def create_missing_schemas(
    autogen_context: AutogenContext, upgrade_ops: UpgradeOps, schema_names: Iterable[str | None]
) -> None:
    """Creates missing schemas.

    This depends on sqla/alembic to give us all existing
    schemas in the schema_names argument.
    """
    used_schemas = set()
    for operations_group in upgrade_ops.ops:
        # We only care about Tables at the top level, so this is enough.
        if isinstance(operations_group, CreateTableOp) and operations_group.schema:
            used_schemas.add(operations_group.schema)

    existing_schemas = set(schema_names)
    missing_schemas = used_schemas - existing_schemas
    if missing_schemas:
        for schema in missing_schemas:
            _logger.info("Add migration ops for schema: %s", schema)
            upgrade_ops.ops.insert(
                0,
                ExecuteArbitraryDDLOp(
                    ddl=f"CREATE SCHEMA {schema}",
                    reverse_ddl=f"DROP SCHEMA {schema}",
                ),
            )

def process_revision_directives(context, revision, directives) -> None:  # noqa: ANN001, missing type annotations
    """Change the numbering of migrations to something sane."""
    # From: https://stackoverflow.com/a/67398484/1380673
    # extract Migration
    migration_script = directives[0]
    # extract current head revision
    head_revision = ScriptDirectory.from_config(context.config).get_current_head()

    if head_revision is None:
        # edge case with first migration
        new_rev_id = 1
    else:
        # default branch with incrementation
        last_rev_id = int(head_revision.lstrip("0"))
        new_rev_id = last_rev_id + 1
    # fill zeros up to 4 digits: 1 -> 0001
    migration_script.rev_id = f"{new_rev_id:04}"

settings = SettingsRunMigrations()

# settings.POSTGRES_URL must be a string
config.set_main_option(name="sqlalchemy.url", value=str(settings.POSTGRES_URL))

def include_object(obj, name, type_, reflected, compare_to) -> bool:  # noqa: ANN001 no annotation
    # These are actually correct, but for now we keep the grants we we get by default
    if isinstance(obj, PGGrantTable):
        return False
    return True

def run_migrations_offline() -> None:
    """Run migrations in 'offline' mode.

    This configures the context with just a URL
    and not an Engine, though an Engine is acceptable
    here as well.  By skipping the Engine creation
    we don't even need a DBAPI to be available.

    Calls to context.execute() here emit the given string to the
    script output.

    """
    url = config.get_main_option("sqlalchemy.url")
    context.configure(
        url=url,
        target_metadata=target_metadata,  # type: ignore[arg-type]
        literal_binds=True,
        dialect_opts={"paramstyle": "named"},
        process_revision_directives=process_revision_directives,
        include_schemas=True,
        include_object=include_object,
        compare_server_default=True,
        compare_type=True,
    )

    with context.begin_transaction():
        context.run_migrations()

def run_migrations_online() -> None:
    """Run migrations in 'online' mode.

    In this scenario we need to create an Engine
    and associate a connection with the context.

    """
    connectable = context.config.attributes.get("connection", None)

    if connectable is None:
        prd = process_revision_directives
        connectable = engine_from_config(
            config.get_section(config.config_ini_section) or {},  # type: ignore[arg-type]
            prefix="sqlalchemy.",
            poolclass=pool.NullPool,
        )
    else:
        # This branch is taken in alembic util tests where we do not want the migrations getting 
        # contaminated somehow which would spook the tests
        def prd(context, revision, directives) -> None:  # noqa: ANN001, missing type annotations
            pass

    with connectable.connect() as connection:
        context.configure(
            connection=connection,
            target_metadata=target_metadata,  # type: ignore[arg-type]
            process_revision_directives=prd,
            include_schemas=True,
            include_object=include_object,
            compare_server_default=True,
            compare_type=True,
        )

        with context.begin_transaction():
            context.run_migrations()

if context.is_offline_mode():
    run_migrations_offline()
else:
    run_migrations_online()
RustyGuard commented 11 months ago

As I can see this also affects users that add create schema manually

RustyGuard commented 11 months ago

So search for schemas in upgrade_ops.ops is justified

jankatins commented 11 months ago

Thanks a lot for the fix and the new release!