snowflakedb / snowpark-python

Snowflake Snowpark Python API
Apache License 2.0
273 stars 112 forks source link

SNOW-1650888: Add missing transform function for snowpark dataframe #2231

Open sfc-gh-gmahadevan opened 2 months ago

sfc-gh-gmahadevan commented 2 months ago

What is the current behavior?

transform function is not available for snowpark dataframe whereas its available in spark dataframe. Customers are using that function a lot and it would be better to add this method to this library.

What is the desired behavior?

Add transform funciton to snowpark dataframe class so its available when we migrate customer code to snowpark.

How would this improve snowflake-snowpark-python?

By adding this function, it will allow migrating customer code to snowpark directly without additional rewrite. Increases SMA code compatibility.

References, Other Background

Sample code :

def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
    result = func(self, *args, **kwargs)
    return result

Equivalent link from spark lib - https://spark.apache.org/docs/latest/api/python/_modules/pyspark/sql/dataframe.html#DataFrame.transform

sfc-gh-gmahadevan commented 2 months ago

Created jira story here - https://snowflakecomputing.atlassian.net/browse/SNOW-1649742

sfc-gh-sghosh commented 2 months ago

Hello @sfc-gh-gmahadevan ,

Thanks for raising the issue. Yes, at present Snowpark Dataframe APIs doesnt have direct transform API, will look into it. In the meantime you can achieve the same result using this way

`

from snowflake.snowpark import Session
from snowflake.snowpark.functions import col

df = session.create_dataframe([[1, 1.0], [2, 2.0]], schema=["int_col", "float_col"])

def cast_all_to_int(input_df):
    return input_df.select([col(col_name).cast("INTEGER") for col_name in input_df.columns])

def sort_columns_asc(input_df):
    return input_df.select(*sorted(input_df.columns))

def transform(df, func, *args, **kwargs):
    return func(df, *args, **kwargs)

transformed_df = transform(df, cast_all_to_int)
sorted_transformed_df = transformed_df.select(*sorted(transformed_df.columns))

result = sorted_transformed_df.collect()

for row in result:
    print(row)

Row(CAST ("FLOAT_COL" AS INT)=1, CAST ("INT_COL" AS INT)=1)
Row(CAST ("FLOAT_COL" AS INT)=2, CAST ("INT_COL" AS INT)=2)

`

Regards, Sujan

sfc-gh-gmahadevan commented 2 months ago

thanks @sfc-gh-sghosh for the workaround. Please let me know once its available.