NVIDIA / spark-rapids

Spark RAPIDS plugin - accelerate Apache Spark with GPUs
https://nvidia.github.io/spark-rapids
Apache License 2.0
782 stars 228 forks source link

[BUG] While running a query, GpuSortOrder throws an Exception #187

Closed razajafri closed 4 years ago

razajafri commented 4 years ago

Describe the bug While running a query on Jupyter notebook. The GpuSortOrder is throwing the following exception.

Py4JJavaError: An error occurred while calling o2391.collectToPython.
: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree: count(species)#10979L DESC NULLS LAST
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:56)
    at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:458)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:431)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:314)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDown$3(TreeNode.scala:314)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChild$2(TreeNode.scala:368)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$4(TreeNode.scala:427)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
    at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
    at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
    at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
    at scala.collection.TraversableLike.map(TraversableLike.scala:238)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
    at scala.collection.AbstractTraversable.map(Traversable.scala:108)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$mapChildren$1(TreeNode.scala:427)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:237)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:397)
    at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:350)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:314)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:298)
    at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.updateAttribute(InMemoryTableScanExec.scala:170)
    at org.apache.spark.sql.execution.columnar.InMemoryTableScanExec.outputPartitioning(InMemoryTableScanExec.scala:179)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements.$anonfun$ensureDistributionAndOrdering$1(EnsureRequirements.scala:54)
    at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:238)
    at scala.collection.immutable.List.foreach(List.scala:392)
    at scala.collection.TraversableLike.map(TraversableLike.scala:238)
    at scala.collection.TraversableLike.map$(TraversableLike.scala:231)
    at scala.collection.immutable.List.map(List.scala:298)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements.org$apache$spark$sql$execution$exchange$EnsureRequirements$$ensureDistributionAndOrdering(EnsureRequirements.scala:53)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$apply$1.applyOrElse(EnsureRequirements.scala:226)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements$$anonfun$apply$1.applyOrElse(EnsureRequirements.scala:218)
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformUp$2(TreeNode.scala:333)
    at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:72)
    at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:333)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements.apply(EnsureRequirements.scala:218)
    at org.apache.spark.sql.execution.exchange.EnsureRequirements.apply(EnsureRequirements.scala:37)
    at org.apache.spark.sql.execution.QueryExecution$.$anonfun$prepareForExecution$1(QueryExecution.scala:316)
    at scala.collection.LinearSeqOptimized.foldLeft(LinearSeqOptimized.scala:126)
    at scala.collection.LinearSeqOptimized.foldLeft$(LinearSeqOptimized.scala:122)
    at scala.collection.immutable.List.foldLeft(List.scala:89)
    at org.apache.spark.sql.execution.QueryExecution$.prepareForExecution(QueryExecution.scala:316)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$executedPlan$1(QueryExecution.scala:107)
    at org.apache.spark.sql.catalyst.QueryPlanningTracker.measurePhase(QueryPlanningTracker.scala:111)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$executePhase$1(QueryExecution.scala:133)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:764)
    at org.apache.spark.sql.execution.QueryExecution.executePhase(QueryExecution.scala:133)
    at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:107)
    at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:100)
    at org.apache.spark.sql.execution.QueryExecution.$anonfun$writePlans$5(QueryExecution.scala:199)
    at org.apache.spark.sql.catalyst.plans.QueryPlan$.append(QueryPlan.scala:381)
    at org.apache.spark.sql.execution.QueryExecution.org$apache$spark$sql$execution$QueryExecution$$writePlans(QueryExecution.scala:199)
    at org.apache.spark.sql.execution.QueryExecution.toString(QueryExecution.scala:207)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:95)
    at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:160)
    at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:87)
    at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:764)
    at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
    at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3614)
    at org.apache.spark.sql.Dataset.collectToPython(Dataset.scala:3445)
    at jdk.internal.reflect.GeneratedMethodAccessor163.invoke(Unknown Source)
    at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
    at java.base/java.lang.reflect.Method.invoke(Method.java:566)
    at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
    at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
    at py4j.Gateway.invoke(Gateway.java:282)
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
    at py4j.commands.CallCommand.execute(CallCommand.java:79)
    at py4j.GatewayConnection.run(GatewayConnection.java:238)
    at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: 
Failed to copy node.
Is otherCopyArgs specified correctly for GpuSortOrder.
Exception message: argument type mismatch
ctor: public ai.rapids.spark.GpuSortOrder(ai.rapids.spark.GpuExpression,org.apache.spark.sql.catalyst.expressions.SortDirection,org.apache.spark.sql.catalyst.expressions.NullOrdering,scala.collection.immutable.Set,org.apache.spark.sql.catalyst.expressions.Expression)?
types: class org.apache.spark.sql.catalyst.expressions.AttributeReference, class org.apache.spark.sql.catalyst.expressions.Descending$, class org.apache.spark.sql.catalyst.expressions.NullsLast$, class scala.collection.immutable.Set$EmptySet$, class org.apache.spark.sql.catalyst.expressions.AttributeReference
args: count(species)#10979L, Descending, NullsLast, Set(), count(species)#10979L
           , tree: count(species)#10979L DESC NULLS LAST
    at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$makeCopy$1(TreeNode.scala:505)
    at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:52)
    ... 69 more

Steps/Code to reproduce bug Load the attached notebook and run it

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Base.py + Spark Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Base.py imports  \n",
    "\n",
    "import sys\n",
    "\n",
    "try:\n",
    "    from StringIO import BytesIO\n",
    "except ImportError:\n",
    "    from io import BytesIO\n",
    "\n",
    "try:\n",
    "    from urllib import quote\n",
    "except ImportError:\n",
    "    from urllib.parse import quote\n",
    "\n",
    "import base64\n",
    "from itertools import combinations\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.use('Agg')\n",
    "\n",
    "import numpy as np\n",
    "import json\n",
    "import pandas as pd\n",
    "#import spark_df_profiling.formatters as formatters, spark_df_profiling.templates as templates\n",
    "import formatters\n",
    "#import templates\n",
    "from matplotlib import pyplot as plt\n",
    "from pkg_resources import resource_filename\n",
    "import six\n",
    "\n",
    "from pyspark.sql import DataFrame as SparkDataFrame\n",
    "from pyspark.sql.functions import (abs as df_abs, col, count, countDistinct,\n",
    "                                   max as df_max, mean, min as df_min,\n",
    "                                   sum as df_sum, when\n",
    "                                   )\n",
    "\n",
    "# Backwards compatibility with Spark 1.5:\n",
    "try:\n",
    "    from pyspark.sql.functions import variance, stddev, kurtosis, skewness\n",
    "    spark_version = \"1.6+\"\n",
    "except ImportError:\n",
    "    from pyspark.sql.functions import pow as df_pow, sqrt\n",
    "    def variance_custom(column, mean, count):\n",
    "        return df_sum(df_pow(column - mean, int(2))) / float(count-1)\n",
    "    def skewness_custom(column, mean, count):\n",
    "        return ((np.sqrt(count) * df_sum(df_pow(column - mean, int(3)))) / df_pow(sqrt(df_sum(df_pow(column - mean, int(2)))),3))\n",
    "    def kurtosis_custom(column, mean, count):\n",
    "        return ((count*df_sum(df_pow(column - mean, int(4)))) / df_pow(df_sum(df_pow(column - mean, int(2))),2)) -3\n",
    "    spark_version = \"<1.6\"\n",
    "    \n",
    "# Spark imports\n",
    "\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark.conf import SparkConf\n",
    "from pyspark.sql import SQLContext \n",
    "from pyspark import SparkContext\n",
    "import pyspark\n",
    "\n",
    "from pandas_profiling import ProfileReport\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading iris.csv as DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Initially, load the dataframe as a pandas .csv file \n",
    "iris_d = pd.read_csv('iris.csv')\n",
    "copy_iris_d = iris_d[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading iris.csv as Spark DataFrame "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataFrame[sepal_length: double, sepal_width: double, petal_length: double, petal_width: double]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iris_sparkdf = spark.createDataFrame(copy_iris_d)\n",
    "iris_sparkdf = iris_sparkdf.cache()\n",
    "iris_sparkdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Timer Class for Benchmarking - Source: https://realpython.com/python-timer/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "class TimerError(Exception):\n",
    "    \"\"\"A custom exception used to report errors in use of Timer class\"\"\"\n",
    "\n",
    "class Timer:\n",
    "    def __init__(self):\n",
    "        self._start_time = None\n",
    "\n",
    "    def start(self):\n",
    "        \"\"\"Start a new timer\"\"\"\n",
    "        if self._start_time is not None:\n",
    "            raise TimerError(f\"Timer is running. Use .stop() to stop it\")\n",
    "\n",
    "        self._start_time = time.perf_counter()\n",
    "\n",
    "    def stop(self):\n",
    "        \"\"\"Stop the timer, and report the elapsed time\"\"\"\n",
    "        if self._start_time is None:\n",
    "            raise TimerError(f\"Timer is not running. Use .start() to start it\")\n",
    "\n",
    "        elapsed_time = time.perf_counter() - self._start_time\n",
    "        self._start_time = None\n",
    "        print(f\"Elapsed time: {elapsed_time:0.4f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Describe Function (from base.py) - Returns all of the profiling statistics that pandas-profiling would normally output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def describe(df, bins, corr_reject, config, **kwargs):\n",
    "    if not isinstance(df, SparkDataFrame):\n",
    "        raise TypeError(\"df must be of type pyspark.sql.DataFrame\")\n",
    "\n",
    "    # Number of rows:\n",
    "    table_stats = {\"n\": df.count()}\n",
    "    if table_stats[\"n\"] == 0:\n",
    "        raise ValueError(\"df cannot be empty\")\n",
    "\n",
    "    try:\n",
    "        # reset matplotlib style before use\n",
    "        # Fails in matplotlib 1.4.x so plot might look bad\n",
    "        matplotlib.style.use(\"default\")\n",
    "    except:\n",
    "        pass\n",
    "\n",
    "    matplotlib.style.use(resource_filename(__name__, \"spark_df_profiling.mplstyle\"))\n",
    "\n",
    "    # Function to \"pretty name\" floats:\n",
    "    def pretty_name(x):\n",
    "        x *= 100\n",
    "        if x == int(x):\n",
    "            return '%.0f%%' % x\n",
    "        else:\n",
    "            return '%.1f%%' % x\n",
    "\n",
    "    # Function to compute the correlation matrix:\n",
    "    def corr_matrix(df, columns=None):\n",
    "        if columns is None:\n",
    "            columns = df.columns\n",
    "        col_combinations = combinations(columns, 2)\n",
    "\n",
    "        df_cleaned = df.select(*columns).na.drop(how=\"any\")\n",
    "\n",
    "        corr_result = pd.DataFrame(np.eye(len(columns)))\n",
    "        corr_result.columns = columns\n",
    "        corr_result.index = columns\n",
    "\n",
    "        for i, j in col_combinations:\n",
    "            corr_result[i][j] = corr_result[j][i] = df_cleaned.corr(str(i), str(j))\n",
    "\n",
    "        return corr_result\n",
    "\n",
    "    # Compute histogram (is not as easy as it looks):\n",
    "    def create_hist_data(df, column, minim, maxim, bins=10):\n",
    "\n",
    "        def create_all_conditions(current_col, column, left_edges, count=1):\n",
    "            \"\"\"\n",
    "            Recursive function that exploits the\n",
    "            ability to call the Spark SQL Column method\n",
    "            .when() in a recursive way.\n",
    "            \"\"\"\n",
    "            left_edges = left_edges[:]\n",
    "            if len(left_edges) == 0:\n",
    "                return current_col\n",
    "            if len(left_edges) == 1:\n",
    "                next_col = current_col.when(col(column) >= float(left_edges[0]), count)\n",
    "                left_edges.pop(0)\n",
    "                return create_all_conditions(next_col, column, left_edges[:], count+1)\n",
    "            next_col = current_col.when((float(left_edges[0]) <= col(column))\n",
    "                                        & (col(column) < float(left_edges[1])), count)\n",
    "            left_edges.pop(0)\n",
    "            return create_all_conditions(next_col, column, left_edges[:], count+1)\n",
    "\n",
    "        num_range = maxim - minim\n",
    "        bin_width = num_range / float(bins)\n",
    "        left_edges = [minim]\n",
    "        for _bin in range(bins):\n",
    "            left_edges = left_edges + [left_edges[-1] + bin_width]\n",
    "        left_edges.pop()\n",
    "        expression_col = when((float(left_edges[0]) <= col(column))\n",
    "                              & (col(column) < float(left_edges[1])), 0)\n",
    "        left_edges_copy = left_edges[:]\n",
    "        left_edges_copy.pop(0)\n",
    "        bin_data = (df.select(col(column))\n",
    "                    .na.drop()\n",
    "                    .select(col(column),\n",
    "                            create_all_conditions(expression_col,\n",
    "                                                  column,\n",
    "                                                  left_edges_copy\n",
    "                                                 ).alias(\"bin_id\")\n",
    "                           )\n",
    "                    .groupBy(\"bin_id\").count()\n",
    "                   ).toPandas()\n",
    "\n",
    "        # If no data goes into one bin, it won't \n",
    "        # appear in bin_data; so we should fill\n",
    "        # in the blanks:\n",
    "        bin_data.index = bin_data[\"bin_id\"]\n",
    "        new_index = list(range(bins))\n",
    "        bin_data = bin_data.reindex(new_index)\n",
    "        bin_data[\"bin_id\"] = bin_data.index\n",
    "        bin_data = bin_data.fillna(0)\n",
    "\n",
    "        # We add the left edges and bin width:\n",
    "        bin_data[\"left_edge\"] = left_edges\n",
    "        bin_data[\"width\"] = bin_width\n",
    "\n",
    "        return bin_data\n",
    "\n",
    "    def mini_histogram(histogram_data):\n",
    "        # Small histogram\n",
    "        imgdata = BytesIO()\n",
    "        hist_data = histogram_data\n",
    "        figure = plt.figure(figsize=(2, 0.75))\n",
    "        plot = plt.subplot()\n",
    "        plt.bar(hist_data[\"left_edge\"],\n",
    "                hist_data[\"count\"],\n",
    "                width=hist_data[\"width\"],\n",
    "                facecolor='#337ab7')\n",
    "        plot.axes.get_yaxis().set_visible(False)\n",
    "        plot.set_facecolor(\"w\")\n",
    "        xticks = plot.xaxis.get_major_ticks()\n",
    "        for tick in xticks[1:-1]:\n",
    "            tick.set_visible(False)\n",
    "            tick.label.set_visible(False)\n",
    "        for tick in (xticks[0], xticks[-1]):\n",
    "            tick.label.set_fontsize(8)\n",
    "        plot.figure.subplots_adjust(left=0.15, right=0.85, top=1, bottom=0.35, wspace=0, hspace=0)\n",
    "        plot.figure.savefig(imgdata)\n",
    "        imgdata.seek(0)\n",
    "        result_string = 'data:image/png;base64,' + quote(base64.b64encode(imgdata.getvalue()))\n",
    "        plt.close(plot.figure)\n",
    "        return result_string\n",
    "\n",
    "\n",
    "    def describe_integer_1d(df, column, current_result, nrows):\n",
    "        if spark_version == \"1.6+\":\n",
    "            stats_df = df.select(column).na.drop().agg(mean(col(column)).alias(\"mean\"),\n",
    "                                                       df_min(col(column)).alias(\"min\"),\n",
    "                                                       df_max(col(column)).alias(\"max\"),\n",
    "                                                       variance(col(column)).alias(\"variance\"),\n",
    "                                                       kurtosis(col(column)).alias(\"kurtosis\"),\n",
    "                                                       stddev(col(column)).alias(\"std\"),\n",
    "                                                       skewness(col(column)).alias(\"skewness\"),\n",
    "                                                       df_sum(col(column)).alias(\"sum\"),\n",
    "                                                       count(col(column) == 0.0).alias('n_zeros')\n",
    "                                                       ).toPandas()\n",
    "        else:\n",
    "            stats_df = df.select(column).na.drop().agg(mean(col(column)).alias(\"mean\"),\n",
    "                                                       df_min(col(column)).alias(\"min\"),\n",
    "                                                       df_max(col(column)).alias(\"max\"),\n",
    "                                                       df_sum(col(column)).alias(\"sum\"),\n",
    "                                                       count(col(column) == 0.0).alias('n_zeros')\n",
    "                                                       ).toPandas()\n",
    "            stats_df[\"variance\"] = df.select(column).na.drop().agg(variance_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "            stats_df[\"std\"] = np.sqrt(stats_df[\"variance\"])\n",
    "            stats_df[\"skewness\"] = df.select(column).na.drop().agg(skewness_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "            stats_df[\"kurtosis\"] = df.select(column).na.drop().agg(kurtosis_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "\n",
    "        for x in [0.05, 0.25, 0.5, 0.75, 0.95]:\n",
    "            stats_df[pretty_name(x)] = (df.select(column)\n",
    "                                        .na.drop()\n",
    "                                        .selectExpr(\"percentile(`{col}`,CAST({n} AS DOUBLE))\"\n",
    "                                                    .format(col=column, n=x)).toPandas().iloc[:,0]\n",
    "                                        )\n",
    "        stats = stats_df.iloc[0].copy()\n",
    "        stats.name = column\n",
    "        stats[\"range\"] = stats[\"max\"] - stats[\"min\"]\n",
    "        stats[\"iqr\"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]\n",
    "        stats[\"cv\"] = stats[\"std\"] / float(stats[\"mean\"])\n",
    "        stats[\"mad\"] = (df.select(column)\n",
    "                        .na.drop()\n",
    "                        .select(df_abs(col(column)-stats[\"mean\"]).alias(\"delta\"))\n",
    "                        .agg(df_sum(col(\"delta\"))).toPandas().iloc[0,0] / float(current_result[\"count\"]))\n",
    "        stats[\"type\"] = \"NUM\"\n",
    "        stats['p_zeros'] = stats['n_zeros'] / float(nrows)\n",
    "\n",
    "        # Large histogram\n",
    "        imgdata = BytesIO()\n",
    "        hist_data = create_hist_data(df, column, stats[\"min\"], stats[\"max\"], bins)\n",
    "        figure = plt.figure(figsize=(6, 4))\n",
    "        plot = plt.subplot()\n",
    "        plt.bar(hist_data[\"left_edge\"],\n",
    "                hist_data[\"count\"],\n",
    "                width=hist_data[\"width\"],\n",
    "                facecolor='#337ab7')\n",
    "        plot.set_ylabel(\"Frequency\")\n",
    "        plot.figure.subplots_adjust(left=0.15, right=0.95, top=0.9, bottom=0.1, wspace=0, hspace=0)\n",
    "        plot.figure.savefig(imgdata)\n",
    "        imgdata.seek(0)\n",
    "        stats['histogram'] = 'data:image/png;base64,' + quote(base64.b64encode(imgdata.getvalue()))\n",
    "        #TODO Think about writing this to disk instead of caching them in strings\n",
    "        plt.close(plot.figure)\n",
    "\n",
    "        stats['mini_histogram'] = mini_histogram(hist_data)\n",
    "\n",
    "        return stats\n",
    "\n",
    "    def describe_float_1d(df, column, current_result, nrows):\n",
    "        if spark_version == \"1.6+\":\n",
    "            stats_df = df.select(column).na.drop().agg(mean(col(column)).alias(\"mean\"),\n",
    "                                                       df_min(col(column)).alias(\"min\"),\n",
    "                                                       df_max(col(column)).alias(\"max\"),\n",
    "                                                       variance(col(column)).alias(\"variance\"),\n",
    "                                                       kurtosis(col(column)).alias(\"kurtosis\"),\n",
    "                                                       stddev(col(column)).alias(\"std\"),\n",
    "                                                       skewness(col(column)).alias(\"skewness\"),\n",
    "                                                       df_sum(col(column)).alias(\"sum\"),\n",
    "                                                       count(col(column) == 0.0).alias('n_zeros')\n",
    "                                                       ).toPandas()\n",
    "        else:\n",
    "            stats_df = df.select(column).na.drop().agg(mean(col(column)).alias(\"mean\"),\n",
    "                                                       df_min(col(column)).alias(\"min\"),\n",
    "                                                       df_max(col(column)).alias(\"max\"),\n",
    "                                                       df_sum(col(column)).alias(\"sum\"),\n",
    "                                                       count(col(column) == 0.0).alias('n_zeros')\n",
    "                                                       ).toPandas()\n",
    "            stats_df[\"variance\"] = df.select(column).na.drop().agg(variance_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "            stats_df[\"std\"] = np.sqrt(stats_df[\"variance\"])\n",
    "            stats_df[\"skewness\"] = df.select(column).na.drop().agg(skewness_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "            stats_df[\"kurtosis\"] = df.select(column).na.drop().agg(kurtosis_custom(col(column),\n",
    "                                                                                   stats_df[\"mean\"].iloc[0],\n",
    "                                                                                   current_result[\"count\"])).toPandas().iloc[0][0]\n",
    "\n",
    "        for x in [0.05, 0.25, 0.5, 0.75, 0.95]:\n",
    "            stats_df[pretty_name(x)] = (df.select(column)\n",
    "                                        .na.drop()\n",
    "                                        .selectExpr(\"percentile_approx(`{col}`,CAST({n} AS DOUBLE))\"\n",
    "                                                    .format(col=column, n=x)).toPandas().iloc[:,0]\n",
    "                                        )\n",
    "        stats = stats_df.iloc[0].copy()\n",
    "        stats.name = column\n",
    "        stats[\"range\"] = stats[\"max\"] - stats[\"min\"]\n",
    "        stats[\"iqr\"] = stats[pretty_name(0.75)] - stats[pretty_name(0.25)]\n",
    "        stats[\"cv\"] = stats[\"std\"] / float(stats[\"mean\"])\n",
    "        stats[\"mad\"] = (df.select(column)\n",
    "                        .na.drop()\n",
    "                        .select(df_abs(col(column)-stats[\"mean\"]).alias(\"delta\"))\n",
    "                        .agg(df_sum(col(\"delta\"))).toPandas().iloc[0,0] / float(current_result[\"count\"]))\n",
    "        stats[\"type\"] = \"NUM\"\n",
    "        stats['p_zeros'] = stats['n_zeros'] / float(nrows)\n",
    "\n",
    "        # Large histogram\n",
    "        imgdata = BytesIO()\n",
    "        hist_data = create_hist_data(df, column, stats[\"min\"], stats[\"max\"], bins)\n",
    "        figure = plt.figure(figsize=(6, 4))\n",
    "        plot = plt.subplot()\n",
    "        plt.bar(hist_data[\"left_edge\"],\n",
    "                hist_data[\"count\"],\n",
    "                width=hist_data[\"width\"],\n",
    "                facecolor='#337ab7')\n",
    "        plot.set_ylabel(\"Frequency\")\n",
    "        plot.figure.subplots_adjust(left=0.15, right=0.95, top=0.9, bottom=0.1, wspace=0, hspace=0)\n",
    "        plot.figure.savefig(imgdata)\n",
    "        imgdata.seek(0)\n",
    "        stats['histogram'] = 'data:image/png;base64,' + quote(base64.b64encode(imgdata.getvalue()))\n",
    "        #TODO Think about writing this to disk instead of caching them in strings\n",
    "        plt.close(plot.figure)\n",
    "\n",
    "        stats['mini_histogram'] = mini_histogram(hist_data)\n",
    "\n",
    "        return stats\n",
    "\n",
    "    def describe_date_1d(df, column):\n",
    "        stats_df = df.select(column).na.drop().agg(df_min(col(column)).alias(\"min\"),\n",
    "                                                   df_max(col(column)).alias(\"max\")\n",
    "                                                  ).toPandas()\n",
    "        stats = stats_df.iloc[0].copy()\n",
    "        stats.name = column\n",
    "\n",
    "        # Convert Pandas timestamp object to regular datetime:\n",
    "        if isinstance(stats[\"max\"], pd.Timestamp):\n",
    "            stats = stats.astype(object)\n",
    "            stats[\"max\"] = str(stats[\"max\"].to_pydatetime())\n",
    "            stats[\"min\"] = str(stats[\"min\"].to_pydatetime())\n",
    "        # Range only got when type is date\n",
    "        else:\n",
    "            stats[\"range\"] = stats[\"max\"] - stats[\"min\"]\n",
    "        stats[\"type\"] = \"DATE\"\n",
    "        return stats\n",
    "\n",
    "    def guess_json_type(string_value):\n",
    "        try:\n",
    "            obj = json.loads(string_value)\n",
    "        except:\n",
    "            return None\n",
    "\n",
    "        return type(obj)\n",
    "\n",
    "    def describe_categorical_1d(df, column):\n",
    "        count_column_name = \"count({c})\".format(c=column)\n",
    "\n",
    "        value_counts = (df.select(column).na.drop()\n",
    "                        .groupBy(column)\n",
    "                        .agg(count(col(column)))\n",
    "                        .orderBy(count_column_name, ascending=False)\n",
    "                       ).cache()\n",
    "\n",
    "        # Get the top 50 classes by value count,\n",
    "        # and put the rest of them grouped at the\n",
    "        # end of the Series:\n",
    "        top_50 = value_counts.limit(50).toPandas().sort_values(count_column_name,\n",
    "                                                               ascending=False)\n",
    "\n",
    "        stats = top_50.take([0]).rename(columns={column: 'top', count_column_name: 'freq'}).iloc[0]\n",
    "\n",
    "        others_count = 0\n",
    "        others_distinct_count = 0\n",
    "        unique_categories_count = value_counts.count()\n",
    "        if unique_categories_count > 50:\n",
    "            others_count = value_counts.select(df_sum(count_column_name)).toPandas().iloc[0, 0] - top_50[count_column_name].sum()\n",
    "            others_distinct_count = unique_categories_count - 50\n",
    "\n",
    "        value_counts.unpersist()\n",
    "        top = top_50.set_index(column)[count_column_name]\n",
    "        top[\"***Other Values***\"] = others_count\n",
    "        top[\"***Other Values Distinct Count***\"] = others_distinct_count\n",
    "        stats[\"value_counts\"] = top\n",
    "        stats[\"type\"] = \"CAT\"\n",
    "        unparsed_valid_jsons = df.select(column).na.drop().rdd.map(\n",
    "            lambda x: guess_json_type(x[column])).filter(\n",
    "            lambda x: x).distinct().collect()\n",
    "        stats[\"unparsed_json_types\"] = unparsed_valid_jsons\n",
    "        return stats\n",
    "\n",
    "    def describe_constant_1d(df, column):\n",
    "        stats = pd.Series(['CONST'], index=['type'], name=column)\n",
    "        stats[\"value_counts\"] = (df.select(column)\n",
    "                                 .na.drop()\n",
    "                                 .limit(1)).toPandas().iloc[:,0].value_counts()\n",
    "        return stats\n",
    "\n",
    "    def describe_unique_1d(df, column):\n",
    "        stats = pd.Series(['UNIQUE'], index=['type'], name=column)\n",
    "        stats[\"value_counts\"] = (df.select(column)\n",
    "                                 .na.drop()\n",
    "                                 .limit(50)).toPandas().iloc[:,0].value_counts()\n",
    "        return stats\n",
    "\n",
    "    def describe_1d(df, column, nrows, lookup_config=None):\n",
    "        column_type = df.select(column).dtypes[0][1]\n",
    "        # TODO: think about implementing analysis for complex\n",
    "        # data types:\n",
    "        if (\"array\" in column_type) or (\"stuct\" in column_type) or (\"map\" in column_type):\n",
    "            raise NotImplementedError(\"Column {c} is of type {t} and cannot be analyzed\".format(c=column, t=column_type))\n",
    "\n",
    "        results_data = df.select(countDistinct(col(column)).alias(\"distinct_count\"),\n",
    "                                 count(col(column).isNotNull()).alias('count')).toPandas()\n",
    "        results_data[\"p_unique\"] = results_data[\"distinct_count\"] / float(results_data[\"count\"])\n",
    "        results_data[\"is_unique\"] = results_data[\"distinct_count\"] == nrows\n",
    "        results_data[\"n_missing\"] = nrows - results_data[\"count\"]\n",
    "        results_data[\"p_missing\"] = results_data[\"n_missing\"] / float(nrows)\n",
    "        results_data[\"p_infinite\"] = 0\n",
    "        results_data[\"n_infinite\"] = 0\n",
    "        result = results_data.iloc[0].copy()\n",
    "        result[\"memorysize\"] = 0\n",
    "        result.name = column\n",
    "\n",
    "        if result[\"distinct_count\"] <= 1:\n",
    "            result = result.append(describe_constant_1d(df, column))\n",
    "        elif column_type in {\"tinyint\", \"smallint\", \"int\", \"bigint\"}:\n",
    "            result = result.append(describe_integer_1d(df, column, result, nrows))\n",
    "        elif column_type in {\"float\", \"double\", \"decimal\"}:\n",
    "            result = result.append(describe_float_1d(df, column, result, nrows))\n",
    "        elif column_type in {\"date\", \"timestamp\"}:\n",
    "            result = result.append(describe_date_1d(df, column))\n",
    "        elif result[\"is_unique\"] == True:\n",
    "            result = result.append(describe_unique_1d(df, column))\n",
    "        else:\n",
    "            result = result.append(describe_categorical_1d(df, column))\n",
    "            # Fix to also count MISSING value in the distict_count field:\n",
    "            if result[\"n_missing\"] > 0:\n",
    "                result[\"distinct_count\"] = result[\"distinct_count\"] + 1\n",
    "\n",
    "        # TODO: check whether it is worth it to\n",
    "        # implement the \"real\" mode:\n",
    "        if (result[\"count\"] > result[\"distinct_count\"] > 1):\n",
    "            try:\n",
    "                result[\"mode\"] = result[\"top\"]\n",
    "            except KeyError:\n",
    "                result[\"mode\"] = 0\n",
    "        else:\n",
    "            try:\n",
    "                result[\"mode\"] = result[\"value_counts\"].index[0]\n",
    "            except KeyError:\n",
    "                result[\"mode\"] = 0\n",
    "            # If and IndexError happens,\n",
    "            # it is because all column are NULLs:\n",
    "            except IndexError:\n",
    "                result[\"mode\"] = \"MISSING\"\n",
    "\n",
    "        if lookup_config:\n",
    "            lookup_object = lookup_config['object']\n",
    "            col_name_in_db = lookup_config['col_name_in_db'] if 'col_name_in_db' in lookup_config else None\n",
    "            try:\n",
    "                matched, unmatched = lookup_object.lookup(df.select(column), col_name_in_db)\n",
    "                result['lookedup_values'] = str(matched.count()) + \"/\" + str(df.select(column).count())\n",
    "            except:\n",
    "                result['lookedup_values'] = 'FAILED'\n",
    "        else:\n",
    "            result['lookedup_values'] = ''\n",
    "\n",
    "        return result\n",
    "\n",
    "\n",
    "    # Do the thing:\n",
    "    ldesc = {}\n",
    "    for colum in df.columns:\n",
    "        if colum in config:\n",
    "            if 'lookup' in config[colum]:\n",
    "                lookup_config = config[colum]['lookup']\n",
    "                desc = describe_1d(df, colum, table_stats[\"n\"], lookup_config=lookup_config)\n",
    "            else:\n",
    "                desc = describe_1d(df, colum, table_stats[\"n\"])\n",
    "        else:\n",
    "            desc = describe_1d(df, colum, table_stats[\"n\"])\n",
    "        ldesc.update({colum: desc})\n",
    "\n",
    "    # Compute correlation matrix\n",
    "    if corr_reject is not None:\n",
    "        computable_corrs = [colum for colum in ldesc if ldesc[colum][\"type\"] in {\"NUM\"}]\n",
    "\n",
    "        if len(computable_corrs) > 0:\n",
    "            corr = corr_matrix(df, columns=computable_corrs)\n",
    "            for x, corr_x in corr.iterrows():\n",
    "                for y, corr in corr_x.iteritems():\n",
    "                    if x == y:\n",
    "                        break\n",
    "\n",
    "                    if corr >= corr_reject:\n",
    "                        ldesc[x] = pd.Series(['CORR', y, corr], index=['type', 'correlation_var', 'correlation'], name=x)\n",
    "\n",
    "    # Convert ldesc to a DataFrame\n",
    "    variable_stats = pd.DataFrame(ldesc)\n",
    "\n",
    "    # General statistics\n",
    "    table_stats[\"nvar\"] = len(df.columns)\n",
    "    table_stats[\"total_missing\"] = float(variable_stats.loc[\"n_missing\"].sum()) / (table_stats[\"n\"] * table_stats[\"nvar\"])\n",
    "    memsize = 0\n",
    "    table_stats['memsize'] = formatters.fmt_bytesize(memsize)\n",
    "    table_stats['recordsize'] = formatters.fmt_bytesize(memsize / table_stats['n'])\n",
    "    table_stats.update({k: 0 for k in (\"NUM\", \"DATE\", \"CONST\", \"CAT\", \"UNIQUE\", \"CORR\")})\n",
    "    table_stats.update(dict(variable_stats.loc['type'].value_counts()))\n",
    "    table_stats['REJECTED'] = table_stats['CONST'] + table_stats['CORR']\n",
    "\n",
    "    freq_dict = {}\n",
    "    for var in variable_stats:\n",
    "        if \"value_counts\" not in variable_stats[var]:\n",
    "            pass\n",
    "        elif not(variable_stats[var][\"value_counts\"] is np.nan):\n",
    "            freq_dict[var] = variable_stats[var][\"value_counts\"]\n",
    "        else:\n",
    "            pass\n",
    "    try:\n",
    "        variable_stats = variable_stats.drop(\"value_counts\")\n",
    "    except (ValueError, KeyError):\n",
    "        pass\n",
    "\n",
    "    return {'table': table_stats, 'variables': variable_stats.T, 'freq': freq_dict}\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Running the describe function on the Iris dataset (Returns all of the raw profiling statistics) + Recording Time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time: 11.5530 seconds\n"
     ]
    }
   ],
   "source": [
    "spark.conf.set('spark.rapids.sql.enabled', 'true')\n",
    "spark.conf.set('spark.rapids.sql.improvedFloatOps.enabled', 'true')\n",
    "iris_timer = Timer()\n",
    "iris_timer.start()\n",
    "iris_dataset_profile = describe(df=iris_sparkdf, bins=10, corr_reject=None, config={})\n",
    "iris_timer.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'table': {'n': 150,\n",
       "  'nvar': 4,\n",
       "  'total_missing': 0.0,\n",
       "  'memsize': '0.0 B',\n",
       "  'recordsize': '0.0 B',\n",
       "  'NUM': 4,\n",
       "  'DATE': 0,\n",
       "  'CONST': 0,\n",
       "  'CAT': 0,\n",
       "  'UNIQUE': 0,\n",
       "  'CORR': 0,\n",
       "  'REJECTED': 0},\n",
       " 'variables':              distinct_count count  p_unique is_unique n_missing p_missing  \\\n",
       " sepal_length             35   150  0.233333     False         0         0   \n",
       " sepal_width              23   150  0.153333     False         0         0   \n",
       " petal_length             43   150  0.286667     False         0         0   \n",
       " petal_width              22   150  0.146667     False         0         0   \n",
       " \n",
       "              p_infinite n_infinite memorysize     mean  ... range  iqr  \\\n",
       " sepal_length          0          0          0  5.84333  ...   3.6  1.3   \n",
       " sepal_width           0          0          0    3.054  ...   2.4  0.5   \n",
       " petal_length          0          0          0  3.75867  ...   5.9  3.5   \n",
       " petal_width           0          0          0  1.19867  ...   2.4  1.5   \n",
       " \n",
       "                     cv       mad type p_zeros  \\\n",
       " sepal_length  0.141711  0.687556  NUM       1   \n",
       " sepal_width   0.141976  0.333093  NUM       1   \n",
       " petal_length  0.469427   1.56192  NUM       1   \n",
       " petal_width   0.636675  0.658933  NUM       1   \n",
       " \n",
       "                                                       histogram  \\\n",
       " sepal_length  data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...   \n",
       " sepal_width   data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...   \n",
       " petal_length  data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...   \n",
       " petal_width   data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...   \n",
       " \n",
       "                                                  mini_histogram mode  \\\n",
       " sepal_length  data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...    0   \n",
       " sepal_width   data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...    0   \n",
       " petal_length  data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...    0   \n",
       " petal_width   data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA...    0   \n",
       " \n",
       "              lookedup_values  \n",
       " sepal_length                  \n",
       " sepal_width                   \n",
       " petal_length                  \n",
       " petal_width                   \n",
       " \n",
       " [4 rows x 33 columns],\n",
       " 'freq': {}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "iris_dataset_profile"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get all acquisition data between two given years stored into a pandas DataFrame\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Load data for each quarter --> Will write function later on to do this more efficiently\n",
    "\n",
    "\n",
    "#This generates a full acquisition dataframe for a desired start year and end year (with all quarters Q1-Q4 considered)\n",
    "\n",
    "def get_acq_df(start_year, end_year):\n",
    "    data = []\n",
    "    num_quarters = 4\n",
    "    for year in range(start_year, end_year+1):\n",
    "        for i in range(1, num_quarters+1):\n",
    "            filepath = 'acq/Acquisition_' + str(year) + 'Q' + str(i) + '.csv'\n",
    "            acq_quarter_df = pd.read_csv(filepath, sep='|')\n",
    "            acq_quarter_df.columns = ['loan_id', 'orig_channel', 'seller_name', 'orig_interest_rate', 'orig_upb', 'orig_loan_term', 'orig_date', 'first_payment_date', 'orig_loan_to_value', 'orig_combined_loan_to_value', 'num_borrowers', 'original_debt_income_ratio', 'borrower_credit_score_origination', 'first_time_home_buyer', 'loan_purpose', 'property_type', 'num_units', 'occupancy_type', 'property_state', 'zip_code_short', 'primary_mortgage_insurance_percent', 'product_type', 'co_borrower_credit_score_origination', 'mortgage_insurance_type', 'relocation_mortgage_indicator', 'const']\n",
    "            data.append(acq_quarter_df)\n",
    "\n",
    "    return pd.concat(data).reset_index(drop=True)\n",
    "\n",
    "acq_test_data = get_acq_df(2000, 2000)         "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Take Numerical Data (for Now)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "acq_numeric_test_data = acq_test_data[['loan_id', 'orig_interest_rate', 'orig_upb', 'orig_loan_term']]\n",
    "acq_numeric_test_data = acq_numeric_test_data.fillna(0)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get all performance data between two given years stored into a Pandas DataFrame"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_perf_df(start_year, end_year):\n",
    "    data = []\n",
    "    num_quarters = 4\n",
    "    for year in range(start_year, end_year+1):\n",
    "        for i in range(1, num_quarters+1):\n",
    "            filepath = 'perf/Performance_' + str(year) + 'Q' + str(i) + '.csv'\n",
    "            perf_quarter_df = pd.read_csv(filepath, sep='|')\n",
    "            perf_quarter_df.columns = [\n",
    "        \"loan_id\", \"monthly_reporting_period\", \"servicer\", \"interest_rate\", \"current_actual_upb\",\n",
    "        \"loan_age\", \"remaining_months_to_legal_maturity\", \"adj_remaining_months_to_maturity\",\n",
    "        \"maturity_date\", \"msa\", \"current_loan_delinquency_status\", \"mod_flag\", \"zero_balance_code\",\n",
    "        \"zero_balance_effective_date\", \"last_paid_installment_date\", \"foreclosed_after\",\n",
    "        \"disposition_date\", \"foreclosure_costs\", \"prop_preservation_and_repair_costs\",\n",
    "        \"asset_recovery_costs\", \"misc_holding_expenses\", \"holding_taxes\", \"net_sale_proceeds\",\n",
    "        \"credit_enhancement_proceeds\", \"repurchase_make_whole_proceeds\", \"other_foreclosure_proceeds\",\n",
    "        \"non_interest_bearing_upb\", \"principal_forgiveness_upb\", \"repurchase_make_whole_proceeds_flag\",\n",
    "        \"foreclosure_principal_write_off_amount\", \"servicing_activity_indicator\"]\n",
    "            data.append(perf_quarter_df)\n",
    "\n",
    "    return pd.concat(data).reset_index(drop=True)\n",
    "\n",
    "perf_test_data = get_perf_df(2000, 2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perf_test_data.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to Spark DataFrame - Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "perf_test_data_to_spark = spark.createDataFrame(perf_test_data)\n",
    "perf_test_data_to_spark"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Convert to Spark DataFrame - Acquisition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "acq_test_data_to_spark = spark.createDataFrame(acq_numeric_test_data)\n",
    "acq_test_data_to_spark = acq_test_data_to_spark.cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Benchmarking Time for Spark Profiling - Acquisition\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Elapsed time: 41.7574 seconds\n"
     ]
    }
   ],
   "source": [
    "spark.conf.set('spark.rapids.sql.enabled', 'false')\n",
    "spark.conf.set('spark.rapids.sql.improvedFloatOps.enabled', 'false')\n",
    "acq_2000_timer = Timer()\n",
    "acq_2000_timer.start()\n",
    "describe(df=acq_test_data_to_spark, bins=10, corr_reject=None, config={})\n",
    "acq_2000_timer.stop()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notes "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Acquisition Data (from 1 GB Mortgage E2E) - GPU Time (31.9239 seconds) with numerical data \n",
    "#### Acquisition Data (from 1 GB Mortgage E2E) - CPU Time (41.7574 seconds) with numerical data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

Expected behavior It should run without a problem on the GPU just like it does on the CPU

razajafri commented 4 years ago

@jlowe has hypothesized from a very brief look at the trace that the plugin might be passing an AttributeReference to the GpuSortExec instead of the GpuAttributeReference. I am looking into this further

razajafri commented 4 years ago

Upon further inspection I confirmed that it is indeed the case where spark tries to recreate the GpuSortOrder from outputs and tries to pass Attribute in place of GpuAttribute which causes a crash.

It was decided that we will have to get rid of the GpuAttribute and any objects that are currently taking in GpuExpression as a parameter will have to be changed to accept Expression. I have also added a unit test that tests this case as part of this PR