zillow / ctds

Python DB-API 2.0 library for MS SQL Server
MIT License
83 stars 12 forks source link

Fix bulk_copy exceptions when attempting to copy no rows #37

Closed kadrach closed 5 years ago

kadrach commented 5 years ago

See #36, seemed a simple enough fix :)

ctds source explicitly specifies /* Always call bcp_done() regardless of previous errors. */, so I may be completely off-track here.

joshuahlang commented 5 years ago

Nice work finding that ancient FreeTDS thread. At first glance this seems fine, but it looks like valgrind is reporting memory leaks which I assume are due to not freeing memory via a final bcp_done call.

How about adding this change in:

diff --git a/src/ctds/connection.c b/src/ctds/connection.c
index 91a16e7..74d63cf 100644
--- a/src/ctds/connection.c
+++ b/src/ctds/connection.c
@@ -1591,6 +1591,7 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject

             RETCODE retcode;
             size_t sent = 0;
+            size_t rows = 0;

             DBINT processed;

@@ -1635,6 +1636,8 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject
                 char msg[ARRAYSIZE(INVALID_SEQUENCE_FMT) + ARRAYSIZE(STRINGIFY(UINT64_MAX))];
                 (void)sprintf(msg, INVALID_SEQUENCE_FMT, sent);

+                rows++;
+
                 sequence = PySequence_Fast(row, msg);
                 if (sequence)
                 {
@@ -1656,7 +1659,11 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject
                 sent++;
             }

-            /* Always call bcp_done() regardless of previous errors. */
+            /*
+                Always call bcp_done() regardless of previous errors.
+                This is required to free memory allocated by FreeTDS in
+                `bcp_init`.
+            */
             Py_BEGIN_ALLOW_THREADS

                 processed = bcp_done(connection->dbproc);
@@ -1667,7 +1674,11 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject
             {
                 saved += processed;
             }
-            else
+            /*
+                `bcp_done` will return -1 if called without sending rows.
+                Ignore the error in this case.
+            */
+            else if (0 != rows)
             {
                 /* Don't overwrite a previous error if bcp_done fails. */
                 if (!PyErr_Occurred())
kadrach commented 5 years ago

Was just taking a look at _bcp_free_storage which is internally used by bcp_done. Let me take a look at your solution :)

kadrach commented 5 years ago

You are supressing all errors from bcp_done regardless of cause when zero rows were sent. I don't have a deep enough understanding of TDS to know whether that's safe to do (can there be other internal failures that should be raised?).

If you think that's not a concern, I'm happy to update the PR. Valgrind is happy with your patch.

joshuahlang commented 5 years ago

That's a good observation. It probably won't matter in 99.99% of the use cases, but perhaps a better solution is to only call bcp_init if there are rows to send:

diff --git a/src/ctds/connection.c b/src/ctds/connection.c
index 91a16e7..221a1f5 100644
--- a/src/ctds/connection.c
+++ b/src/ctds/connection.c
@@ -1592,47 +1592,56 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject
             RETCODE retcode;
             size_t sent = 0;

-            DBINT processed;
+            DBINT processed = 0;

-            Py_BEGIN_ALLOW_THREADS
+            bool initialized = false;

-                do
+            while (NULL != (row = PyIter_Next(irows)))
+            {
+#define INVALID_SEQUENCE_FMT "invalid sequence for row %zd"
+                PyObject* sequence;
+
+                char msg[ARRAYSIZE(INVALID_SEQUENCE_FMT) + ARRAYSIZE(STRINGIFY(UINT64_MAX))];
+
+                /* Initialize only if there are rows to send. */
+                if (!initialized)
                 {
-                    retcode = bcp_init(connection->dbproc, table, NULL, NULL, DB_IN);
-                    if (FAIL == retcode)
-                    {
-                        break;
-                    }
+                    Py_BEGIN_ALLOW_THREADS

-                    if (Py_True == tablock)
-                    {
-                        static const char s_TABLOCK[] = "TABLOCK";
-                        retcode = bcp_options(connection->dbproc,
-                                              BCPHINTS,
-                                              (BYTE*)s_TABLOCK,
-                                              ARRAYSIZE(s_TABLOCK));
-                        if (FAIL == retcode)
+                        do
                         {
-                            break;
+                            retcode = bcp_init(connection->dbproc, table, NULL, NULL, DB_IN);
+                            if (FAIL == retcode)
+                            {
+                                break;
+                            }
+
+                            if (Py_True == tablock)
+                            {
+                                static const char s_TABLOCK[] = "TABLOCK";
+                                retcode = bcp_options(connection->dbproc,
+                                                      BCPHINTS,
+                                                      (BYTE*)s_TABLOCK,
+                                                      ARRAYSIZE(s_TABLOCK));
+                                if (FAIL == retcode)
+                                {
+                                    break;
+                                }
+                            }
                         }
-                    }
-                }
-                while (0);
+                        while (0);

-            Py_END_ALLOW_THREADS
+                    Py_END_ALLOW_THREADS

-            if (FAIL == retcode)
-            {
-                Connection_raise_lasterror(connection);
-                break;
-            }
+                    if (FAIL == retcode)
+                    {
+                        Connection_raise_lasterror(connection);
+                        break;
+                    }

-            while (NULL != (row = PyIter_Next(irows)))
-            {
-#define INVALID_SEQUENCE_FMT "invalid sequence for row %zd"
-                PyObject* sequence;
+                    initialized = true;
+                }

-                char msg[ARRAYSIZE(INVALID_SEQUENCE_FMT) + ARRAYSIZE(STRINGIFY(UINT64_MAX))];
                 (void)sprintf(msg, INVALID_SEQUENCE_FMT, sent);

                 sequence = PySequence_Fast(row, msg);
@@ -1656,12 +1665,15 @@ static PyObject* Connection_bulk_insert(PyObject* self, PyObject* args, PyObject
                 sent++;
             }

-            /* Always call bcp_done() regardless of previous errors. */
-            Py_BEGIN_ALLOW_THREADS
+            if (initialized)
+            {
+                /* Always call bcp_done() regardless of previous errors. */
+                Py_BEGIN_ALLOW_THREADS

-                processed = bcp_done(connection->dbproc);
+                    processed = bcp_done(connection->dbproc);

-            Py_END_ALLOW_THREADS
+                Py_END_ALLOW_THREADS
+            }

             if (-1 != processed)
             {