snowflakedb / snowpark-python

Snowflake Snowpark Python API
Apache License 2.0
256 stars 106 forks source link

SNOW-763621: error_value = {'done_format_msg': False, 'errno': 255005, 'msg': "Failed to read next arrow batch: b''"} #737

Open wassimrkik opened 1 year ago

wassimrkik commented 1 year ago

Please answer these questions before submitting your issue. Thanks!

  1. What version of Python are you using?

    3.8.16

  2. What operating system and processor architecture are you using?

    Linux-5.4.231-137.341.amzn2.x86_64-x86_64-with-glibc2.29

  3. What are the component versions in the environment (pip freeze)?

adal==1.2.7 alembic==1.4.1 argon2-cffi==21.1.0 arviz==0.11.2 asgiref==3.4.1 asn1crypto==1.4.0 attrs==21.2.0 azure-common==1.1.27 azure-core==1.20.1 azure-graphrbac==0.61.1 azure-identity==1.2.0 azure-mgmt-authorization==0.61.0 azure-mgmt-containerregistry==8.2.0 azure-mgmt-core==1.3.0 azure-mgmt-keyvault==2.2.0 azure-mgmt-resource==13.0.0 azure-mgmt-storage==11.2.0 azure-storage-blob==12.1.0 azure-storage-file-datalake==12.0.0 azureml-core==1.22.0 backcall==0.2.0 backports.tempfile==1.0 backports.weakref==1.0.post1 bleach==4.1.0 boto3==1.14.1 botocore==1.17.1 cachetools==4.2.4 certifi==2020.12.5 cffi==1.15.0 cftime==1.5.1.1 chardet==3.0.4 charset-normalizer==2.0.7 click==8.0.3 cloudpickle==2.0.0 cmdstanpy==0.9.76 contextlib2==0.5.5 convertdate==2.3.2 cryptography==3.4.8 cycler==0.11.0 databricks-cli==0.16.2 dataclasses==0.6 debugpy==1.5.1 decorator==5.1.0 defusedxml==0.7.1 dill==0.3.4 docker==5.0.3 docutils==0.15.2 entrypoints==0.3 fastapi==0.68.1 fastapi-utils==0.2.1 Flask==2.0.2 fsspec==2021.11.0 -e git+https://github.com/Sanofi-GitHub/AnP_Spend_Allocation_ML@45aa6a414dac4726414cb59b75249ccb3a8ef661#egg=gamma_anp_core&subdirectory=libs/gamma-anp-core -e git+https://github.com/Sanofi-GitHub/AnP_Spend_Allocation_ML@45aa6a414dac4726414cb59b75249ccb3a8ef661#egg=gamma_data_manager&subdirectory=libs/gamma-data-manager gitdb==4.0.9 GitPython==3.1.24 greenlet==1.1.2 gunicorn==20.1.0 gurobipy==9.1.0 h11==0.12.0 hijri-converter==2.2.2 holidays==0.11.1 idna==2.10 importlib-metadata==4.8.2 importlib-resources==5.4.0 iniconfig==1.1.1 ipykernel==6.5.1 ipython==7.29.0 ipython-genutils==0.2.0 ipywidgets==7.6.5 isodate==0.6.0 itsdangerous==2.0.1 jedi==0.18.1 jeepney==0.7.1 Jinja2==3.0.3 jmespath==0.10.0 joblib==1.1.0 jsonpickle==2.0.0 jsonschema==4.2.1 jupyter==1.0.0 jupyter-client==7.1.0 jupyter-console==6.4.0 jupyter-core==4.9.1 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.2 kiwisolver==1.3.2 korean-lunar-calendar==0.2.1 Mako==1.1.6 MarkupSafe==2.0.1 matplotlib==3.1.2 matplotlib-inline==0.1.3 mistune==0.8.4 mlflow==1.20.1 msal==1.16.0 msal-extensions==0.1.3 msrest==0.6.21 msrestazure==0.6.4 multiprocess==0.70.12.2 mypy-extensions==0.4.3 nbclient==0.5.9 nbconvert==6.3.0 nbformat==5.1.3 ndg-httpsclient==0.5.1 nest-asyncio==1.5.1 netCDF4==1.5.8 networkx==2.8.4 nose==1.3.7 notebook==6.4.6 numpy==1.18.5 oauthlib==3.1.1 oscrypto==1.2.1 packaging==21.3 pandas==0.25.2 pandera==0.6.1 pandocfilters==1.5.0 parso==0.8.2 pathos==0.2.8 pathspec==0.9.0 patsy==0.5.2 pexpect==4.8.0 pickleshare==0.7.5 plotly==5.10.0 pluggy==1.0.0 ply==3.11 portalocker==1.7.1 pox==0.3.0 ppft==1.6.6.4 prometheus-client==0.12.0 prometheus-flask-exporter==0.18.6 prompt-toolkit==3.0.22 protobuf==3.19.1 ptyprocess==0.7.0 py==1.11.0 pyarrow==0.17.0 pyasn1==0.4.8 pycparser==2.21 pycryptodomex==3.11.0 pydantic==1.8.2 Pygments==2.10.0 PyJWT==1.7.1 PyMeeus==0.5.11 Pympler==0.9 Pyomo==6.0.1 pyOpenSSL==19.1.0 pyparsing==3.0.6 pyrsistent==0.18.0 pytest==6.2.5 python-box==4.2.3 python-dateutil==2.8.2 python-dotenv==0.13.0 python-editor==1.0.4 pytz==2020.5 PyUtilib==6.0.0 PyYAML==5.1 pyzmq==22.3.0 qtconsole==5.2.0 QtPy==1.11.2 querystring-parser==1.2.4 requests==2.26.0 requests-oauthlib==1.3.0 ruamel.yaml==0.17.17 ruamel.yaml.clib==0.2.6 s3fs==0.4.2 s3transfer==0.3.7 scikit-learn==0.24.1 scipy==1.4.1 seaborn==0.10.0 SecretStorage==3.3.1 Send2Trash==1.8.0 six==1.16.0 smmap==5.0.0 snowflake-connector-python==2.3.8 snowflake-sqlalchemy==1.3.2 SQLAlchemy==1.4.41 sqlparse==0.4.2 starlette==0.14.2 statsmodels==0.11.1 tabulate==0.8.9 tenacity==8.0.1 terminado==0.12.1 testpath==0.5.0 threadpoolctl==3.0.0 toml==0.10.2 tornado==6.1 tqdm==4.62.3 traitlets==5.1.1 typing-extensions==3.10.0.2 typing-inspect==0.7.1 ujson==4.3.0 urllib3==1.25.11 uvicorn==0.15.0 wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.2.1 Werkzeug==2.0.2 widgetsnbextension==3.5.2 wrapt==1.13.3 xarray==0.17.0 xlrd==1.2.0 XlsxWriter==1.1.8 zipp==3.6.0

  1. What did you do?

import logging import os from typing import Any, Dict, List, Optional, Union

import numpy as np import pandas as pd import snowflake.connector from box import Box from pandas.api.types import is_datetime64_any_dtype from snowflake.connector import SnowflakeConnection from snowflake.connector.pandas_tools import write_pandas

from gamma_anp_core.config.config import Config

from ...data_filters import RdbmsFilterManager from ...schemas.response_model.input import GeoMasterSchema from ..base.base_dbms_manager import BaseDBMSManager

logger = logging.getLogger(name)

gs = GeoMasterSchema()

class SnowflakeManager(BaseDBMSManager): connection_type = "snowflake"

def __init__(
    self,
    database: str,
    schema: str,
    pre_filters: Dict[str, Any] = Box(),
    tables: Optional[Dict[str, str]] = None,
    read_only: bool = True,
    run_info: Optional[Dict] = None,
    **kwargs,
):
    super().__init__(
        database=database,
        schema=schema,
        read_only=read_only,
        tables=tables,
        pre_filters=pre_filters,
        run_info=run_info,
    )

    try:
        snowflake_account = os.environ["SNOWFLAKE_ACCOUNT"]
        snowflake_password = os.environ["SNOWFLAKE_PASSWORD"]
        snowflake_role = os.environ["SNOWFLAKE_ROLE"]
        snowflake_user = os.environ["SNOWFLAKE_USER"]
        snowflake_warehouse = os.environ["SNOWFLAKE_WAREHOUSE"]
    except KeyError:
        logger.error(
            "Please set Snowflake credentials through environment variables: "
            "SNOWFLAKE_ACCOUNT, SNOWFLAKE_PASSWORD, SNOWFLAKE_ROLE, SNOWFLAKE_DATABASE"
            "SNOWFLAKE_USER and SNOWFLAKE_WAREHOUSE"
        )
        raise

    logger.info(
        f"[SNOWFLAKE] database = `{self._database}` | schema = `{self._schema}`"
    )

    self.connection_args = {
        "account": snowflake_account,
        "user": snowflake_user,
        "password": snowflake_password,
        "database": self._database,
        "schema": self._schema,
        "warehouse": snowflake_warehouse,
        "role": snowflake_role,
    }

def create_connection(self) -> SnowflakeConnection:
    return snowflake.connector.connect(**self.connection_args)

def load_from_db(
    self,
    table_name: str,
    filter_manager: RdbmsFilterManager,
    query_sql: Optional[str] = None,
    **kwargs,
) -> pd.DataFrame:
    """
    Args:
        - table_name: SQL query to be executed or a table name.
        - query_sql: Full SQL query to be executed (Optional)
    """
    logger.info("Loading from Snowflake: Table name:  '%s'", table_name)

    if query_sql:
        return self._execute_fetch_table_query(
            connection=self.create_connection(),
            query_sql=query_sql,
            filter_manager=filter_manager,
            table_name=table_name,
        )

    return self._fetch_table(
        connection=self.create_connection(),
        table_name=table_name,
        filter_manager=filter_manager,
        **kwargs,
    )

def save_to_db(self, db_table: str, data: pd.DataFrame, **kwargs):
    logger.info("Saving to Snowflake: Table name:  '%s'", db_table)

    # write_pandas function is case-sensitive and since Snowflake objects are by default in uppercase,
    # we have to do the conversion manually
    db_table = db_table.upper()
    data.columns = [column.upper() for column in data.columns]

    # convert datetime columns to varchar as snowflake is not abble to do the implicit conversion
    for col in data.select_dtypes(include=[np.datetime64]).columns:
        data[col] = data[col].dt.strftime("%Y-%m-%d %H:%M:%S")

    success, num_chunks, num_rows, output = write_pandas(
        conn=self.create_connection(), df=data, table_name=db_table, **kwargs
    )
    logger.info(
        "Saved table to Snowflake. Success: %s. Num chunks: %s. Num rows: %s. Output: '%s'",
        success,
        num_chunks,
        num_rows,
        output,
    )

def _load_data(
    self,
    table_name: str,
    filter_manager: RdbmsFilterManager,
    query_sql: Optional[str] = None,
    **kwargs,
) -> pd.DataFrame:
    loaded_data = self.load_from_db(
        table_name=table_name,
        query_sql=query_sql,
        filter_manager=filter_manager,
        **kwargs,
    )

    loaded_data.columns = [col.lower() for col in loaded_data.columns]
    return loaded_data

def _delete_version(self, table_name: str, run_version: str, **kwargs):
    cursor = self.create_connection().cursor()
    query_sql = f"DELETE FROM {table_name.upper()} WHERE VERSION_CODE=%s"
    cursor.execute(query_sql, run_version)

def _save_data(
    self,
    data: pd.DataFrame,
    table_name: str,
    tech_fields: Dict[str, Any],
    is_old_version: bool,
    **kwargs,
):
    table_name = self.tables[table_name]

    if is_old_version:
        self._delete_version(table_name, tech_fields["run_version"])

    data = self._add_tech_fields(data, tech_fields)
    self.save_to_db(db_table=table_name, data=data, **kwargs)

# FIXME: At the moment, we transform all TIMESTAMP_NTZ columns. A deeper check needs to be
#       done to verify the behavior of other Snowflake's date types.
@staticmethod
def _get_table_columns(
    connection,
    table_name: str,
    requested_columns: Optional[List[str]] = None,
) -> List[str]:
    """Generates SQL statetement for Snowflake.

    This function transforms NTZ timestamps to a format understandable by pyarrow. Without
    this, the data is loaded but not interpreted correctly.
    """
    cursor = connection.cursor()
    columns = cursor.execute(f"DESCRIBE {table_name}")
    columns = [
        c for c in columns if not requested_columns or c[0] in requested_columns
    ]
    cursor.close()

    select_columns: List[str] = []
    for column_name, column_type, *_ in columns:
        select_columns.append(
            column_name
            if not column_type.startswith("TIMESTAMP_NTZ")
            else f"TO_TIMESTAMP_NTZ({column_name})::TIMESTAMP_NTZ(3) as {column_name}"
        )

    return select_columns

def _generate_select_statement(
    self, connection, table_name: str, requested_columns: Optional[List[str]] = None
) -> str:
    """Generates SQL statement for Snowflake.

    This function transforms NTZ timestamps to a format understandable by pyarrow. Without
    this, the data is loaded but not interpreted correctly.
    """
    select_columns = self._get_table_columns(
        connection=connection,
        table_name=table_name,
        requested_columns=requested_columns,
    )
    select_columns_str = ", ".join(select_columns)
    return f"SELECT {select_columns_str} FROM {table_name}"

def _execute_fetch_table_query(
    self,
    connection: SnowflakeConnection,
    query_sql: str,
    filter_manager: RdbmsFilterManager,
    table_name: str,
) -> pd.DataFrame:
    """
    Fetches a table from Snowflake using the data unloading capability
    """
    logger.info(f"Query from Snowflake : {query_sql}")
    cursor = connection.cursor()

    if filter_manager:
        entity_columns = self._get_table_columns(
            connection=connection,
            table_name=table_name,
        )

        query_sql, condition_values = self._filter_data_rdbms(
            table_name, query_sql, entity_columns, filter_manager
        )

        # Convert to Snowflake default uppercase except for placeholder character %s
        query_sql = query_sql.upper().replace("%S", "%s")
        cursor.execute(query_sql, condition_values)
    else:
        cursor.execute(query_sql)

    df = cursor.fetch_pandas_all()
    cursor.close()

    # Convert columns names to lowercase as Snowflake objects are in uppercase by default
    df.columns = [column.lower() for column in df.columns]
    return df

def _fetch_table(
    self,
    connection: SnowflakeConnection,
    table_name: str,
    filter_manager: RdbmsFilterManager,
    columns: Optional[List[str]] = None,
) -> pd.DataFrame:
    """
    Fetches a table from Snowflake using the data unloading capability
    Args:
        - table_name: Name of table in the database or query to execute
    """
    select = self._generate_select_statement(connection, table_name, columns)

    return self._execute_fetch_table_query(
        connection=connection,
        query_sql=select,
        filter_manager=filter_manager,
        table_name=table_name,
    )

@property
def latest_input_version(self) -> Union[str, None]:
    connection = self.create_connection()
    cursor = connection.cursor()

    # First check if any table is versioned and has geo dimension
    if self.pre_filters.get("internal_geo_code") is not None:
        query_sql = """
            SELECT column_name, table_name
            FROM information_schema.columns a
            WHERE lower(table_schema) = %s
            AND lower(table_name) IN (%s)
            AND lower(column_name) = 'version_code'
            AND EXISTS (SELECT 1 from information_schema.columns b
                        WHERE b.table_schema = a.table_schema
                        AND b.table_name = a.table_name
                        AND lower(b.column_name) = 'internal_geo_code')
        """
    else:
        query_sql = """
        SELECT column_name, table_name
        FROM information_schema.columns a
        WHERE lower(table_schema) = %s
        AND lower(table_name) IN (%s)
        AND lower(column_name) = 'version_code'
        """

    condition_values = [self._schema.lower(), list(self.tables.values())]
    cursor.execute(query_sql, condition_values)
    versioned_tables_df = cursor.fetch_pandas_all()

    # No table is versioned
    if versioned_tables_df.empty:
        return None

    if self.pre_filters.get("internal_geo_code") is not None:
        # Get the latest version_code available for the specific geo code
        table_name = versioned_tables_df["TABLE_NAME"].iloc[0].lower()
        query_sql = f"""
            SELECT MAX(version_code) as version_code
            FROM {self._schema.lower()}.{table_name}
            WHERE version_code REGEXP '[0-9]{{8}}_[0-9]{{6}}'
            AND {gs.internal_geo_code} = '{self.pre_filters.internal_geo_code}'
        """
    else:
        # Get the latest version_code available
        table_name = versioned_tables_df["TABLE_NAME"].iloc[0].lower()
        query_sql = f"""
            SELECT MAX(version_code) as version_code
            FROM {self._schema.lower()}.{table_name}
            WHERE version_code REGEXP '[0-9]{{8}}_[0-9]{{6}}'
        """

    cursor.execute(query_sql)
    latest_input_version = cursor.fetch_pandas_all().loc[0, "VERSION_CODE"]
    connection.close()

    return latest_input_version

def apply_input_version(
    self, input_version: str, config: Config, **kwargs
) -> Config:
    config._config.merge_update(
        {
            "scope": {
                "filters": {"shared": {"version_code": [{"equal": input_version}]}}
            }
        }
    )
    return config
  1. What did you expect to see?

it is used by a pytest module to retrieve data from snowflake db

  1. Can you set logging to DEBUG and collect the logs?

    import logging
    
    for logger_name in ('snowflake.snowpark', 'snowflake.connector'):
       logger = logging.getLogger(logger_name)
       logger.setLevel(logging.DEBUG)
       ch = logging.StreamHandler()
       ch.setLevel(logging.DEBUG)
       ch.setFormatter(logging.Formatter('%(asctime)s - %(threadName)s %(filename)s:%(lineno)d - %(funcName)s() - %(levelname)s - %(message)s'))
       logger.addHandler(ch)
sfc-gh-jdu commented 1 year ago

Hi @wassimrkik, could you paste the full error message, and briefly talk about what are you doing with snowpark and where is the error? Just from you code it's very hard to tell.