import logging
import os
from typing import Any, Dict, List, Optional, Union
import numpy as np
import pandas as pd
import snowflake.connector
from box import Box
from pandas.api.types import is_datetime64_any_dtype
from snowflake.connector import SnowflakeConnection
from snowflake.connector.pandas_tools import write_pandas
from gamma_anp_core.config.config import Config
from ...data_filters import RdbmsFilterManager
from ...schemas.response_model.input import GeoMasterSchema
from ..base.base_dbms_manager import BaseDBMSManager
logger = logging.getLogger(name)
gs = GeoMasterSchema()
class SnowflakeManager(BaseDBMSManager):
connection_type = "snowflake"
def __init__(
self,
database: str,
schema: str,
pre_filters: Dict[str, Any] = Box(),
tables: Optional[Dict[str, str]] = None,
read_only: bool = True,
run_info: Optional[Dict] = None,
**kwargs,
):
super().__init__(
database=database,
schema=schema,
read_only=read_only,
tables=tables,
pre_filters=pre_filters,
run_info=run_info,
)
try:
snowflake_account = os.environ["SNOWFLAKE_ACCOUNT"]
snowflake_password = os.environ["SNOWFLAKE_PASSWORD"]
snowflake_role = os.environ["SNOWFLAKE_ROLE"]
snowflake_user = os.environ["SNOWFLAKE_USER"]
snowflake_warehouse = os.environ["SNOWFLAKE_WAREHOUSE"]
except KeyError:
logger.error(
"Please set Snowflake credentials through environment variables: "
"SNOWFLAKE_ACCOUNT, SNOWFLAKE_PASSWORD, SNOWFLAKE_ROLE, SNOWFLAKE_DATABASE"
"SNOWFLAKE_USER and SNOWFLAKE_WAREHOUSE"
)
raise
logger.info(
f"[SNOWFLAKE] database = `{self._database}` | schema = `{self._schema}`"
)
self.connection_args = {
"account": snowflake_account,
"user": snowflake_user,
"password": snowflake_password,
"database": self._database,
"schema": self._schema,
"warehouse": snowflake_warehouse,
"role": snowflake_role,
}
def create_connection(self) -> SnowflakeConnection:
return snowflake.connector.connect(**self.connection_args)
def load_from_db(
self,
table_name: str,
filter_manager: RdbmsFilterManager,
query_sql: Optional[str] = None,
**kwargs,
) -> pd.DataFrame:
"""
Args:
- table_name: SQL query to be executed or a table name.
- query_sql: Full SQL query to be executed (Optional)
"""
logger.info("Loading from Snowflake: Table name: '%s'", table_name)
if query_sql:
return self._execute_fetch_table_query(
connection=self.create_connection(),
query_sql=query_sql,
filter_manager=filter_manager,
table_name=table_name,
)
return self._fetch_table(
connection=self.create_connection(),
table_name=table_name,
filter_manager=filter_manager,
**kwargs,
)
def save_to_db(self, db_table: str, data: pd.DataFrame, **kwargs):
logger.info("Saving to Snowflake: Table name: '%s'", db_table)
# write_pandas function is case-sensitive and since Snowflake objects are by default in uppercase,
# we have to do the conversion manually
db_table = db_table.upper()
data.columns = [column.upper() for column in data.columns]
# convert datetime columns to varchar as snowflake is not abble to do the implicit conversion
for col in data.select_dtypes(include=[np.datetime64]).columns:
data[col] = data[col].dt.strftime("%Y-%m-%d %H:%M:%S")
success, num_chunks, num_rows, output = write_pandas(
conn=self.create_connection(), df=data, table_name=db_table, **kwargs
)
logger.info(
"Saved table to Snowflake. Success: %s. Num chunks: %s. Num rows: %s. Output: '%s'",
success,
num_chunks,
num_rows,
output,
)
def _load_data(
self,
table_name: str,
filter_manager: RdbmsFilterManager,
query_sql: Optional[str] = None,
**kwargs,
) -> pd.DataFrame:
loaded_data = self.load_from_db(
table_name=table_name,
query_sql=query_sql,
filter_manager=filter_manager,
**kwargs,
)
loaded_data.columns = [col.lower() for col in loaded_data.columns]
return loaded_data
def _delete_version(self, table_name: str, run_version: str, **kwargs):
cursor = self.create_connection().cursor()
query_sql = f"DELETE FROM {table_name.upper()} WHERE VERSION_CODE=%s"
cursor.execute(query_sql, run_version)
def _save_data(
self,
data: pd.DataFrame,
table_name: str,
tech_fields: Dict[str, Any],
is_old_version: bool,
**kwargs,
):
table_name = self.tables[table_name]
if is_old_version:
self._delete_version(table_name, tech_fields["run_version"])
data = self._add_tech_fields(data, tech_fields)
self.save_to_db(db_table=table_name, data=data, **kwargs)
# FIXME: At the moment, we transform all TIMESTAMP_NTZ columns. A deeper check needs to be
# done to verify the behavior of other Snowflake's date types.
@staticmethod
def _get_table_columns(
connection,
table_name: str,
requested_columns: Optional[List[str]] = None,
) -> List[str]:
"""Generates SQL statetement for Snowflake.
This function transforms NTZ timestamps to a format understandable by pyarrow. Without
this, the data is loaded but not interpreted correctly.
"""
cursor = connection.cursor()
columns = cursor.execute(f"DESCRIBE {table_name}")
columns = [
c for c in columns if not requested_columns or c[0] in requested_columns
]
cursor.close()
select_columns: List[str] = []
for column_name, column_type, *_ in columns:
select_columns.append(
column_name
if not column_type.startswith("TIMESTAMP_NTZ")
else f"TO_TIMESTAMP_NTZ({column_name})::TIMESTAMP_NTZ(3) as {column_name}"
)
return select_columns
def _generate_select_statement(
self, connection, table_name: str, requested_columns: Optional[List[str]] = None
) -> str:
"""Generates SQL statement for Snowflake.
This function transforms NTZ timestamps to a format understandable by pyarrow. Without
this, the data is loaded but not interpreted correctly.
"""
select_columns = self._get_table_columns(
connection=connection,
table_name=table_name,
requested_columns=requested_columns,
)
select_columns_str = ", ".join(select_columns)
return f"SELECT {select_columns_str} FROM {table_name}"
def _execute_fetch_table_query(
self,
connection: SnowflakeConnection,
query_sql: str,
filter_manager: RdbmsFilterManager,
table_name: str,
) -> pd.DataFrame:
"""
Fetches a table from Snowflake using the data unloading capability
"""
logger.info(f"Query from Snowflake : {query_sql}")
cursor = connection.cursor()
if filter_manager:
entity_columns = self._get_table_columns(
connection=connection,
table_name=table_name,
)
query_sql, condition_values = self._filter_data_rdbms(
table_name, query_sql, entity_columns, filter_manager
)
# Convert to Snowflake default uppercase except for placeholder character %s
query_sql = query_sql.upper().replace("%S", "%s")
cursor.execute(query_sql, condition_values)
else:
cursor.execute(query_sql)
df = cursor.fetch_pandas_all()
cursor.close()
# Convert columns names to lowercase as Snowflake objects are in uppercase by default
df.columns = [column.lower() for column in df.columns]
return df
def _fetch_table(
self,
connection: SnowflakeConnection,
table_name: str,
filter_manager: RdbmsFilterManager,
columns: Optional[List[str]] = None,
) -> pd.DataFrame:
"""
Fetches a table from Snowflake using the data unloading capability
Args:
- table_name: Name of table in the database or query to execute
"""
select = self._generate_select_statement(connection, table_name, columns)
return self._execute_fetch_table_query(
connection=connection,
query_sql=select,
filter_manager=filter_manager,
table_name=table_name,
)
@property
def latest_input_version(self) -> Union[str, None]:
connection = self.create_connection()
cursor = connection.cursor()
# First check if any table is versioned and has geo dimension
if self.pre_filters.get("internal_geo_code") is not None:
query_sql = """
SELECT column_name, table_name
FROM information_schema.columns a
WHERE lower(table_schema) = %s
AND lower(table_name) IN (%s)
AND lower(column_name) = 'version_code'
AND EXISTS (SELECT 1 from information_schema.columns b
WHERE b.table_schema = a.table_schema
AND b.table_name = a.table_name
AND lower(b.column_name) = 'internal_geo_code')
"""
else:
query_sql = """
SELECT column_name, table_name
FROM information_schema.columns a
WHERE lower(table_schema) = %s
AND lower(table_name) IN (%s)
AND lower(column_name) = 'version_code'
"""
condition_values = [self._schema.lower(), list(self.tables.values())]
cursor.execute(query_sql, condition_values)
versioned_tables_df = cursor.fetch_pandas_all()
# No table is versioned
if versioned_tables_df.empty:
return None
if self.pre_filters.get("internal_geo_code") is not None:
# Get the latest version_code available for the specific geo code
table_name = versioned_tables_df["TABLE_NAME"].iloc[0].lower()
query_sql = f"""
SELECT MAX(version_code) as version_code
FROM {self._schema.lower()}.{table_name}
WHERE version_code REGEXP '[0-9]{{8}}_[0-9]{{6}}'
AND {gs.internal_geo_code} = '{self.pre_filters.internal_geo_code}'
"""
else:
# Get the latest version_code available
table_name = versioned_tables_df["TABLE_NAME"].iloc[0].lower()
query_sql = f"""
SELECT MAX(version_code) as version_code
FROM {self._schema.lower()}.{table_name}
WHERE version_code REGEXP '[0-9]{{8}}_[0-9]{{6}}'
"""
cursor.execute(query_sql)
latest_input_version = cursor.fetch_pandas_all().loc[0, "VERSION_CODE"]
connection.close()
return latest_input_version
def apply_input_version(
self, input_version: str, config: Config, **kwargs
) -> Config:
config._config.merge_update(
{
"scope": {
"filters": {"shared": {"version_code": [{"equal": input_version}]}}
}
}
)
return config
What did you expect to see?
it is used by a pytest module to retrieve data from snowflake db
Can you set logging to DEBUG and collect the logs?
Hi @wassimrkik, could you paste the full error message, and briefly talk about what are you doing with snowpark and where is the error? Just from you code it's very hard to tell.
Please answer these questions before submitting your issue. Thanks!
What version of Python are you using?
3.8.16
What operating system and processor architecture are you using?
Linux-5.4.231-137.341.amzn2.x86_64-x86_64-with-glibc2.29
What are the component versions in the environment (
pip freeze
)?adal==1.2.7 alembic==1.4.1 argon2-cffi==21.1.0 arviz==0.11.2 asgiref==3.4.1 asn1crypto==1.4.0 attrs==21.2.0 azure-common==1.1.27 azure-core==1.20.1 azure-graphrbac==0.61.1 azure-identity==1.2.0 azure-mgmt-authorization==0.61.0 azure-mgmt-containerregistry==8.2.0 azure-mgmt-core==1.3.0 azure-mgmt-keyvault==2.2.0 azure-mgmt-resource==13.0.0 azure-mgmt-storage==11.2.0 azure-storage-blob==12.1.0 azure-storage-file-datalake==12.0.0 azureml-core==1.22.0 backcall==0.2.0 backports.tempfile==1.0 backports.weakref==1.0.post1 bleach==4.1.0 boto3==1.14.1 botocore==1.17.1 cachetools==4.2.4 certifi==2020.12.5 cffi==1.15.0 cftime==1.5.1.1 chardet==3.0.4 charset-normalizer==2.0.7 click==8.0.3 cloudpickle==2.0.0 cmdstanpy==0.9.76 contextlib2==0.5.5 convertdate==2.3.2 cryptography==3.4.8 cycler==0.11.0 databricks-cli==0.16.2 dataclasses==0.6 debugpy==1.5.1 decorator==5.1.0 defusedxml==0.7.1 dill==0.3.4 docker==5.0.3 docutils==0.15.2 entrypoints==0.3 fastapi==0.68.1 fastapi-utils==0.2.1 Flask==2.0.2 fsspec==2021.11.0 -e git+https://github.com/Sanofi-GitHub/AnP_Spend_Allocation_ML@45aa6a414dac4726414cb59b75249ccb3a8ef661#egg=gamma_anp_core&subdirectory=libs/gamma-anp-core -e git+https://github.com/Sanofi-GitHub/AnP_Spend_Allocation_ML@45aa6a414dac4726414cb59b75249ccb3a8ef661#egg=gamma_data_manager&subdirectory=libs/gamma-data-manager gitdb==4.0.9 GitPython==3.1.24 greenlet==1.1.2 gunicorn==20.1.0 gurobipy==9.1.0 h11==0.12.0 hijri-converter==2.2.2 holidays==0.11.1 idna==2.10 importlib-metadata==4.8.2 importlib-resources==5.4.0 iniconfig==1.1.1 ipykernel==6.5.1 ipython==7.29.0 ipython-genutils==0.2.0 ipywidgets==7.6.5 isodate==0.6.0 itsdangerous==2.0.1 jedi==0.18.1 jeepney==0.7.1 Jinja2==3.0.3 jmespath==0.10.0 joblib==1.1.0 jsonpickle==2.0.0 jsonschema==4.2.1 jupyter==1.0.0 jupyter-client==7.1.0 jupyter-console==6.4.0 jupyter-core==4.9.1 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.2 kiwisolver==1.3.2 korean-lunar-calendar==0.2.1 Mako==1.1.6 MarkupSafe==2.0.1 matplotlib==3.1.2 matplotlib-inline==0.1.3 mistune==0.8.4 mlflow==1.20.1 msal==1.16.0 msal-extensions==0.1.3 msrest==0.6.21 msrestazure==0.6.4 multiprocess==0.70.12.2 mypy-extensions==0.4.3 nbclient==0.5.9 nbconvert==6.3.0 nbformat==5.1.3 ndg-httpsclient==0.5.1 nest-asyncio==1.5.1 netCDF4==1.5.8 networkx==2.8.4 nose==1.3.7 notebook==6.4.6 numpy==1.18.5 oauthlib==3.1.1 oscrypto==1.2.1 packaging==21.3 pandas==0.25.2 pandera==0.6.1 pandocfilters==1.5.0 parso==0.8.2 pathos==0.2.8 pathspec==0.9.0 patsy==0.5.2 pexpect==4.8.0 pickleshare==0.7.5 plotly==5.10.0 pluggy==1.0.0 ply==3.11 portalocker==1.7.1 pox==0.3.0 ppft==1.6.6.4 prometheus-client==0.12.0 prometheus-flask-exporter==0.18.6 prompt-toolkit==3.0.22 protobuf==3.19.1 ptyprocess==0.7.0 py==1.11.0 pyarrow==0.17.0 pyasn1==0.4.8 pycparser==2.21 pycryptodomex==3.11.0 pydantic==1.8.2 Pygments==2.10.0 PyJWT==1.7.1 PyMeeus==0.5.11 Pympler==0.9 Pyomo==6.0.1 pyOpenSSL==19.1.0 pyparsing==3.0.6 pyrsistent==0.18.0 pytest==6.2.5 python-box==4.2.3 python-dateutil==2.8.2 python-dotenv==0.13.0 python-editor==1.0.4 pytz==2020.5 PyUtilib==6.0.0 PyYAML==5.1 pyzmq==22.3.0 qtconsole==5.2.0 QtPy==1.11.2 querystring-parser==1.2.4 requests==2.26.0 requests-oauthlib==1.3.0 ruamel.yaml==0.17.17 ruamel.yaml.clib==0.2.6 s3fs==0.4.2 s3transfer==0.3.7 scikit-learn==0.24.1 scipy==1.4.1 seaborn==0.10.0 SecretStorage==3.3.1 Send2Trash==1.8.0 six==1.16.0 smmap==5.0.0 snowflake-connector-python==2.3.8 snowflake-sqlalchemy==1.3.2 SQLAlchemy==1.4.41 sqlparse==0.4.2 starlette==0.14.2 statsmodels==0.11.1 tabulate==0.8.9 tenacity==8.0.1 terminado==0.12.1 testpath==0.5.0 threadpoolctl==3.0.0 toml==0.10.2 tornado==6.1 tqdm==4.62.3 traitlets==5.1.1 typing-extensions==3.10.0.2 typing-inspect==0.7.1 ujson==4.3.0 urllib3==1.25.11 uvicorn==0.15.0 wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.2.1 Werkzeug==2.0.2 widgetsnbextension==3.5.2 wrapt==1.13.3 xarray==0.17.0 xlrd==1.2.0 XlsxWriter==1.1.8 zipp==3.6.0
import logging import os from typing import Any, Dict, List, Optional, Union
import numpy as np import pandas as pd import snowflake.connector from box import Box from pandas.api.types import is_datetime64_any_dtype from snowflake.connector import SnowflakeConnection from snowflake.connector.pandas_tools import write_pandas
from gamma_anp_core.config.config import Config
from ...data_filters import RdbmsFilterManager from ...schemas.response_model.input import GeoMasterSchema from ..base.base_dbms_manager import BaseDBMSManager
logger = logging.getLogger(name)
gs = GeoMasterSchema()
class SnowflakeManager(BaseDBMSManager): connection_type = "snowflake"
it is used by a pytest module to retrieve data from snowflake db
Can you set logging to DEBUG and collect the logs?