From 18551b870721204ead5cc528548518e8ce5e38d5 Mon Sep 17 00:00:00 2001 From: Charles Leifer Date: Wed, 4 Aug 2021 14:10:22 -0500 Subject: [PATCH] Allow authorizer callback to be cleared --- src/connection.c | 16 +++++++++++----- test/userfunctions.py | 5 +++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/connection.c b/src/connection.c index 7291a4b..7b1bd12 100644 --- a/src/connection.c +++ b/src/connection.c @@ -1225,7 +1225,6 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P PyObject* authorizer_cb; static char *kwlist[] = { "authorizer_callback", NULL }; - int rc; if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { return NULL; @@ -1236,14 +1235,21 @@ static PyObject* pysqlite_connection_set_authorizer(pysqlite_Connection* self, P return NULL; } - rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); + int rc; + if (authorizer_cb == Py_None) { + rc = sqlite3_set_authorizer(self->db, NULL, NULL); + Py_XSETREF(self->function_pinboard_authorizer_cb, NULL); + } + else { + Py_INCREF(authorizer_cb); + Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb); + rc = sqlite3_set_authorizer(self->db, _authorizer_callback, (void*)authorizer_cb); + } + if (rc != SQLITE_OK) { PyErr_SetString(pysqlite_OperationalError, "Error setting authorizer callback"); Py_XSETREF(self->function_pinboard_authorizer_cb, NULL); return NULL; - } else { - Py_INCREF(authorizer_cb); - Py_XSETREF(self->function_pinboard_authorizer_cb, authorizer_cb); } Py_RETURN_NONE; } diff --git a/test/userfunctions.py b/test/userfunctions.py index ffdc619..c5f53c0 100644 --- a/test/userfunctions.py +++ b/test/userfunctions.py @@ -482,6 +482,11 @@ class AuthorizerTests(unittest.TestCase): self.con.execute("select c2 from t1") self.assertIn('prohibited', str(cm.exception)) + def test_clear_authorizer(self): + self.con.set_authorizer(None) + self.con.execute('select * from t2') + self.con.execute('select c2 from t1') + class AuthorizerRaiseExceptionTests(AuthorizerTests): @staticmethod def authorizer_cb(action, arg1, arg2, dbname, source):