Merging upstream changes to allow compiling python3.13 wheels

This commit is contained in:
laggykiller 2025-02-09 22:52:57 +08:00
parent 70f0118baf
commit 8ea48ee2ad
No known key found for this signature in database
19 changed files with 639 additions and 478 deletions

View file

@ -1,6 +1,6 @@
[project] [project]
name = "sqlcipher3-wheels" name = "sqlcipher3-wheels"
version = "0.5.2.post1" version = "0.5.4"
description = "DB-API 2.0 interface for SQLCipher 3.x" description = "DB-API 2.0 interface for SQLCipher 3.x"
readme = { content-type = "text/markdown", file = "README.md" } readme = { content-type = "text/markdown", file = "README.md" }
authors = [{ name = "Charles Leifer", email = "coleifer@gmail.com" }] authors = [{ name = "Charles Leifer", email = "coleifer@gmail.com" }]

View file

@ -226,7 +226,7 @@ if __name__ == "__main__":
# With pyproject.toml, all are not necessary except ext_modules # With pyproject.toml, all are not necessary except ext_modules
# However, they are kept for building python 3.6 wheels # However, they are kept for building python 3.6 wheels
name="sqlcipher3-wheels", name="sqlcipher3-wheels",
version="0.5.2.post1", version="0.5.4",
package_dir={"sqlcipher3": "sqlcipher3"}, package_dir={"sqlcipher3": "sqlcipher3"},
packages=["sqlcipher3"], packages=["sqlcipher3"],
ext_modules=[module], ext_modules=[module],

View file

@ -24,16 +24,18 @@ int pysqlite_blob_init(pysqlite_Blob *self, pysqlite_Connection* connection,
static void remove_blob_from_connection_blob_list(pysqlite_Blob *self) static void remove_blob_from_connection_blob_list(pysqlite_Blob *self)
{ {
Py_ssize_t i; Py_ssize_t i;
PyObject *item; PyObject *item, *ref;
for (i = 0; i < PyList_GET_SIZE(self->connection->blobs); i++) { for (i = 0; i < PyList_GET_SIZE(self->connection->blobs); i++) {
item = PyList_GET_ITEM(self->connection->blobs, i); item = PyList_GET_ITEM(self->connection->blobs, i);
if (PyWeakref_GetObject(item) == (PyObject *)self) { if (PyWeakref_GetRef(item, &ref) == 1) {
if (ref == (PyObject *)self) {
PyList_SetSlice(self->connection->blobs, i, i+1, NULL); PyList_SetSlice(self->connection->blobs, i, i+1, NULL);
break; break;
} }
} }
} }
}
static void _close_blob_inner(pysqlite_Blob* self) static void _close_blob_inner(pysqlite_Blob* self)
{ {

View file

@ -56,6 +56,10 @@
#define HAVE_ENCRYPTION #define HAVE_ENCRYPTION
#endif #endif
#if PY_VERSION_HEX < 0x030D0000
#define PyLong_AsInt _PyLong_AsInt
#endif
_Py_IDENTIFIER(cursor); _Py_IDENTIFIER(cursor);
static const char * const begin_statements[] = { static const char * const begin_statements[] = {
@ -199,6 +203,7 @@ int pysqlite_connection_init(pysqlite_Connection* self, PyObject* args, PyObject
self->function_pinboard_trace_callback = NULL; self->function_pinboard_trace_callback = NULL;
self->function_pinboard_progress_handler = NULL; self->function_pinboard_progress_handler = NULL;
self->function_pinboard_authorizer_cb = NULL; self->function_pinboard_authorizer_cb = NULL;
self->function_pinboard_busy_handler_cb = NULL;
Py_XSETREF(self->collations, PyDict_New()); Py_XSETREF(self->collations, PyDict_New());
if (!self->collations) { if (!self->collations) {
@ -229,9 +234,7 @@ void pysqlite_do_all_statements(pysqlite_Connection* self, int action, int reset
for (i = 0; i < PyList_Size(self->statements); i++) { for (i = 0; i < PyList_Size(self->statements); i++) {
weakref = PyList_GetItem(self->statements, i); weakref = PyList_GetItem(self->statements, i);
statement = PyWeakref_GetObject(weakref); if (PyWeakref_GetRef(weakref, &statement) == 1) {
if (statement != Py_None) {
Py_INCREF(statement);
if (action == ACTION_RESET) { if (action == ACTION_RESET) {
(void)pysqlite_statement_reset((pysqlite_Statement*)statement); (void)pysqlite_statement_reset((pysqlite_Statement*)statement);
} else { } else {
@ -244,9 +247,9 @@ void pysqlite_do_all_statements(pysqlite_Connection* self, int action, int reset
if (reset_cursors) { if (reset_cursors) {
for (i = 0; i < PyList_Size(self->cursors); i++) { for (i = 0; i < PyList_Size(self->cursors); i++) {
weakref = PyList_GetItem(self->cursors, i); weakref = PyList_GetItem(self->cursors, i);
cursor = (pysqlite_Cursor*)PyWeakref_GetObject(weakref); if (PyWeakref_GetRef(weakref, (PyObject**)&cursor) == 1) {
if ((PyObject*)cursor != Py_None) {
cursor->reset = 1; cursor->reset = 1;
Py_DECREF(cursor);
} }
} }
} }
@ -265,6 +268,7 @@ void pysqlite_connection_dealloc(pysqlite_Connection* self)
Py_XDECREF(self->function_pinboard_trace_callback); Py_XDECREF(self->function_pinboard_trace_callback);
Py_XDECREF(self->function_pinboard_progress_handler); Py_XDECREF(self->function_pinboard_progress_handler);
Py_XDECREF(self->function_pinboard_authorizer_cb); Py_XDECREF(self->function_pinboard_authorizer_cb);
Py_XDECREF(self->function_pinboard_busy_handler_cb);
Py_XDECREF(self->row_factory); Py_XDECREF(self->row_factory);
Py_XDECREF(self->text_factory); Py_XDECREF(self->text_factory);
Py_XDECREF(self->collations); Py_XDECREF(self->collations);
@ -412,9 +416,9 @@ static void pysqlite_close_all_blobs(pysqlite_Connection *self)
for (i = 0; i < PyList_GET_SIZE(self->blobs); i++) { for (i = 0; i < PyList_GET_SIZE(self->blobs); i++) {
weakref = PyList_GET_ITEM(self->blobs, i); weakref = PyList_GET_ITEM(self->blobs, i);
blob = PyWeakref_GetObject(weakref); if (PyWeakref_GetRef(weakref, &blob) == 1) {
if (blob != Py_None) {
pysqlite_blob_close((pysqlite_Blob*)blob); pysqlite_blob_close((pysqlite_Blob*)blob);
Py_DECREF(blob);
} }
} }
} }
@ -936,6 +940,7 @@ static void _pysqlite_drop_unused_statement_references(pysqlite_Connection* self
{ {
PyObject* new_list; PyObject* new_list;
PyObject* weakref; PyObject* weakref;
PyObject* ref;
int i; int i;
/* we only need to do this once in a while */ /* we only need to do this once in a while */
@ -952,7 +957,8 @@ static void _pysqlite_drop_unused_statement_references(pysqlite_Connection* self
for (i = 0; i < PyList_Size(self->statements); i++) { for (i = 0; i < PyList_Size(self->statements); i++) {
weakref = PyList_GetItem(self->statements, i); weakref = PyList_GetItem(self->statements, i);
if (PyWeakref_GetObject(weakref) != Py_None) { if (PyWeakref_GetRef(weakref, &ref) == 1) {
Py_DECREF(ref);
if (PyList_Append(new_list, weakref) != 0) { if (PyList_Append(new_list, weakref) != 0) {
Py_DECREF(new_list); Py_DECREF(new_list);
return; return;
@ -967,6 +973,7 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
{ {
PyObject* new_list; PyObject* new_list;
PyObject* weakref; PyObject* weakref;
PyObject* ref;
int i; int i;
/* we only need to do this once in a while */ /* we only need to do this once in a while */
@ -983,7 +990,8 @@ static void _pysqlite_drop_unused_cursor_references(pysqlite_Connection* self)
for (i = 0; i < PyList_Size(self->cursors); i++) { for (i = 0; i < PyList_Size(self->cursors); i++) {
weakref = PyList_GetItem(self->cursors, i); weakref = PyList_GetItem(self->cursors, i);
if (PyWeakref_GetObject(weakref) != Py_None) { if (PyWeakref_GetRef(weakref, &ref) == 1) {
Py_DECREF(ref);
if (PyList_Append(new_list, weakref) != 0) { if (PyList_Append(new_list, weakref) != 0) {
Py_DECREF(new_list); Py_DECREF(new_list);
return; return;
@ -1154,7 +1162,7 @@ static int _authorizer_callback(void* user_arg, int action, const char* arg1, co
} }
else { else {
if (PyLong_Check(ret)) { if (PyLong_Check(ret)) {
rc = _PyLong_AsInt(ret); rc = PyLong_AsInt(ret);
if (rc == -1 && PyErr_Occurred()) { if (rc == -1 && PyErr_Occurred()) {
if (_pysqlite_enable_callback_tracebacks) if (_pysqlite_enable_callback_tracebacks)
PyErr_Print(); PyErr_Print();
@ -1200,6 +1208,36 @@ static int _progress_handler(void* user_arg)
return rc; return rc;
} }
static int _busy_handler(void* user_arg, int n)
{
int rc;
PyObject *ret;
PyGILState_STATE gilstate;
gilstate = PyGILState_Ensure();
ret = PyObject_CallFunction((PyObject*)user_arg, "i", n);
if (ret == NULL) {
if (_pysqlite_enable_callback_tracebacks)
PyErr_Print();
else
PyErr_Clear();
rc = 0;
}
else {
if (PyLong_Check(ret))
rc = PyLong_AsInt(ret);
else
rc = 0;
Py_DECREF(ret);
}
PyGILState_Release(gilstate);
return rc;
}
#ifdef HAVE_TRACE_V2 #ifdef HAVE_TRACE_V2
static int _trace_callback(unsigned int type, void *ctx, void *stmt, void *sql) static int _trace_callback(unsigned int type, void *ctx, void *stmt, void *sql)
{ {
@ -1336,6 +1374,68 @@ static PyObject* pysqlite_connection_set_progress_handler(pysqlite_Connection* s
Py_RETURN_NONE; Py_RETURN_NONE;
} }
static PyObject* pysqlite_connection_set_busy_handler(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{
PyObject* busy_handler;
static char *kwlist[] = { "busy_handler", NULL };
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O:set_busy_handler",
kwlist, &busy_handler)) {
return NULL;
}
int rc;
if (busy_handler == Py_None) {
rc = sqlite3_busy_handler(self->db, NULL, NULL);
Py_XSETREF(self->function_pinboard_busy_handler_cb, NULL);
}
else {
Py_INCREF(busy_handler);
Py_XSETREF(self->function_pinboard_busy_handler_cb, busy_handler);
rc = sqlite3_busy_handler(self->db, _busy_handler, (void*)busy_handler);
}
if (rc != SQLITE_OK) {
PyErr_SetString(pysqlite_OperationalError, "Error setting busy handler");
Py_XSETREF(self->function_pinboard_busy_handler_cb, NULL);
return NULL;
}
Py_RETURN_NONE;
}
static PyObject* pysqlite_connection_set_busy_timeout(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{
double busy_timeout;
static char *kwlist[] = { "timeout", NULL };
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
return NULL;
}
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "d:set_busy_timeout",
kwlist, &busy_timeout)) {
return NULL;
}
int rc;
rc = sqlite3_busy_timeout(self->db, (int)busy_timeout * 1000);
if (rc != SQLITE_OK) {
PyErr_SetString(pysqlite_OperationalError, "Error setting busy timeout");
return NULL;
}
else {
Py_XDECREF(self->function_pinboard_busy_handler_cb);
}
Py_RETURN_NONE;
}
static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs) static PyObject* pysqlite_connection_set_trace_callback(pysqlite_Connection* self, PyObject* args, PyObject* kwargs)
{ {
PyObject* trace_callback; PyObject* trace_callback;
@ -1478,8 +1578,6 @@ pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* iso
self->begin_statement = NULL; self->begin_statement = NULL;
} else { } else {
const char * const *candidate; const char * const *candidate;
PyObject *uppercase_level;
_Py_IDENTIFIER(upper);
if (!PyUnicode_Check(isolation_level)) { if (!PyUnicode_Check(isolation_level)) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
@ -1488,17 +1586,14 @@ pysqlite_connection_set_isolation_level(pysqlite_Connection* self, PyObject* iso
return -1; return -1;
} }
uppercase_level = _PyObject_CallMethodIdObjArgs( const char *level = PyUnicode_AsUTF8(isolation_level);
(PyObject *)&PyUnicode_Type, &PyId_upper, if (level == NULL) {
isolation_level, NULL);
if (!uppercase_level) {
return -1; return -1;
} }
for (candidate = begin_statements; *candidate; candidate++) { for (candidate = begin_statements; *candidate; candidate++) {
if (_PyUnicode_EqualToASCIIString(uppercase_level, *candidate + 6)) if (sqlite3_stricmp(level, *candidate + 6) == 0)
break; break;
} }
Py_DECREF(uppercase_level);
if (!*candidate) { if (!*candidate) {
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"invalid value for isolation_level"); "invalid value for isolation_level");
@ -1523,9 +1618,6 @@ PyObject* pysqlite_connection_call(pysqlite_Connection* self, PyObject* args, Py
return NULL; return NULL;
} }
if (!_PyArg_NoKeywords(MODULE_NAME ".Connection", kwargs))
return NULL;
if (!PyArg_ParseTuple(args, "U", &sql)) if (!PyArg_ParseTuple(args, "U", &sql))
return NULL; return NULL;
@ -1742,35 +1834,18 @@ pysqlite_connection_backup(pysqlite_Connection *self, PyObject *args, PyObject *
const char *name = "main"; const char *name = "main";
int rc; int rc;
int callback_error = 0; int callback_error = 0;
PyObject *sleep_obj = NULL; double sleep_s = 0.25;
int sleep_ms = 250; int sleep_ms = 0;
sqlite3 *bck_conn; sqlite3 *bck_conn;
sqlite3_backup *bck_handle; sqlite3_backup *bck_handle;
static char *keywords[] = {"target", "pages", "progress", "name", "sleep", NULL}; static char *keywords[] = {"target", "pages", "progress", "name", "sleep", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|$iOsO:backup", keywords, if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!|$iOsd:backup", keywords,
&pysqlite_ConnectionType, &target, &pysqlite_ConnectionType, &target,
&pages, &progress, &name, &sleep_obj)) { &pages, &progress, &name, &sleep_s)) {
return NULL; return NULL;
} }
// XXX: We use _PyTime_ROUND_CEILING to support 3.6.x, but it should
// use _PyTime_ROUND_TIMEOUT instead.
if (sleep_obj != NULL) {
_PyTime_t sleep_secs;
if (_PyTime_FromSecondsObject(&sleep_secs, sleep_obj,
_PyTime_ROUND_CEILING)) {
return NULL;
}
_PyTime_t ms = _PyTime_AsMilliseconds(sleep_secs,
_PyTime_ROUND_CEILING);
if (ms < INT_MIN || ms > INT_MAX) {
PyErr_SetString(PyExc_OverflowError, "sleep is too large");
return NULL;
}
sleep_ms = (int)ms;
}
if (!pysqlite_check_connection((pysqlite_Connection *)target)) { if (!pysqlite_check_connection((pysqlite_Connection *)target)) {
return NULL; return NULL;
} }
@ -1779,6 +1854,11 @@ pysqlite_connection_backup(pysqlite_Connection *self, PyObject *args, PyObject *
PyErr_SetString(PyExc_ValueError, "target cannot be the same connection instance"); PyErr_SetString(PyExc_ValueError, "target cannot be the same connection instance");
return NULL; return NULL;
} }
if (sleep_s < 0) {
PyErr_SetString(PyExc_ValueError, "sleep must be greater-than or equal to zero");
return NULL;
}
sleep_ms = (int)(sleep_s * 1000.0);
#if SQLITE_VERSION_NUMBER < 3008008 #if SQLITE_VERSION_NUMBER < 3008008
/* Since 3.8.8 this is already done, per commit /* Since 3.8.8 this is already done, per commit
@ -1865,15 +1945,10 @@ static PyObject *
pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args) pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
{ {
PyObject* callable; PyObject* callable;
PyObject* uppercase_name = 0; PyObject* name = NULL;
PyObject* name;
PyObject* retval; PyObject* retval;
Py_ssize_t i, len; const char *name_str;
_Py_IDENTIFIER(upper);
const char *uppercase_name_str;
int rc; int rc;
unsigned int kind;
const void *data;
if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) { if (!pysqlite_check_thread(self) || !pysqlite_check_connection(self)) {
goto finally; goto finally;
@ -1884,32 +1959,8 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
goto finally; goto finally;
} }
uppercase_name = _PyObject_CallMethodIdObjArgs((PyObject *)&PyUnicode_Type, name_str = PyUnicode_AsUTF8(name);
&PyId_upper, name, NULL); if (!name_str)
if (!uppercase_name) {
goto finally;
}
if (PyUnicode_READY(uppercase_name))
goto finally;
len = PyUnicode_GET_LENGTH(uppercase_name);
kind = PyUnicode_KIND(uppercase_name);
data = PyUnicode_DATA(uppercase_name);
for (i=0; i<len; i++) {
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
if ((ch >= '0' && ch <= '9')
|| (ch >= 'A' && ch <= 'Z')
|| (ch == '_'))
{
continue;
} else {
PyErr_SetString(pysqlite_ProgrammingError, "invalid character in collation name");
goto finally;
}
}
uppercase_name_str = PyUnicode_AsUTF8(uppercase_name);
if (!uppercase_name_str)
goto finally; goto finally;
if (callable != Py_None && !PyCallable_Check(callable)) { if (callable != Py_None && !PyCallable_Check(callable)) {
@ -1918,27 +1969,25 @@ pysqlite_connection_create_collation(pysqlite_Connection* self, PyObject* args)
} }
if (callable != Py_None) { if (callable != Py_None) {
if (PyDict_SetItem(self->collations, uppercase_name, callable) == -1) if (PyDict_SetItem(self->collations, name, callable) == -1)
goto finally; goto finally;
} else { } else {
if (PyDict_DelItem(self->collations, uppercase_name) == -1) if (PyDict_DelItem(self->collations, name) == -1)
goto finally; goto finally;
} }
rc = sqlite3_create_collation(self->db, rc = sqlite3_create_collation(self->db,
uppercase_name_str, name_str,
SQLITE_UTF8, SQLITE_UTF8,
(callable != Py_None) ? callable : NULL, (callable != Py_None) ? callable : NULL,
(callable != Py_None) ? pysqlite_collation_callback : NULL); (callable != Py_None) ? pysqlite_collation_callback : NULL);
if (rc != SQLITE_OK) { if (rc != SQLITE_OK) {
PyDict_DelItem(self->collations, uppercase_name); PyDict_DelItem(self->collations, name);
_pysqlite_seterror(self->db); _pysqlite_seterror(self->db);
goto finally; goto finally;
} }
finally: finally:
Py_XDECREF(uppercase_name);
if (PyErr_Occurred()) { if (PyErr_Occurred()) {
retval = NULL; retval = NULL;
} else { } else {
@ -2063,6 +2112,10 @@ static PyMethodDef connection_methods[] = {
#endif #endif
{"set_authorizer", (PyCFunction)(void(*)(void))pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS, {"set_authorizer", (PyCFunction)(void(*)(void))pysqlite_connection_set_authorizer, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Sets authorizer callback. Non-standard.")}, PyDoc_STR("Sets authorizer callback. Non-standard.")},
{"set_busy_handler", (PyCFunction)(void(*)(void))pysqlite_connection_set_busy_handler, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Sets busy handler. Non-standard.")},
{"set_busy_timeout", (PyCFunction)(void(*)(void))pysqlite_connection_set_busy_timeout, METH_VARARGS|METH_KEYWORDS,
PyDoc_STR("Sets busy timeout. Non-standard.")},
#ifdef HAVE_LOAD_EXTENSION #ifdef HAVE_LOAD_EXTENSION
{"enable_load_extension", (PyCFunction)pysqlite_enable_load_extension, METH_VARARGS, {"enable_load_extension", (PyCFunction)pysqlite_enable_load_extension, METH_VARARGS,
PyDoc_STR("Enable dynamic loading of SQLite extension modules. Non-standard.")}, PyDoc_STR("Enable dynamic loading of SQLite extension modules. Non-standard.")},

View file

@ -90,6 +90,7 @@ typedef struct
PyObject* function_pinboard_trace_callback; PyObject* function_pinboard_trace_callback;
PyObject* function_pinboard_progress_handler; PyObject* function_pinboard_progress_handler;
PyObject* function_pinboard_authorizer_cb; PyObject* function_pinboard_authorizer_cb;
PyObject* function_pinboard_busy_handler_cb;
/* a dictionary of registered collation name => collation callable mappings */ /* a dictionary of registered collation name => collation callable mappings */
PyObject* collations; PyObject* collations;

View file

@ -536,7 +536,7 @@ _pysqlite_query_execute(pysqlite_Cursor* self, int multiple, PyObject* args)
} }
if (pysqlite_build_row_cast_map(self) != 0) { if (pysqlite_build_row_cast_map(self) != 0) {
_PyErr_FormatFromCause(pysqlite_OperationalError, "Error while building row_cast_map"); PyErr_Format(pysqlite_OperationalError, "Error while building row_cast_map");
goto error; goto error;
} }

View file

@ -41,8 +41,6 @@ pysqlite_row_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
assert(type != NULL && type->tp_alloc != NULL); assert(type != NULL && type->tp_alloc != NULL);
if (!_PyArg_NoKeywords("Row", kwargs))
return NULL;
if (!PyArg_ParseTuple(args, "OO", &cursor, &data)) if (!PyArg_ParseTuple(args, "OO", &cursor, &data))
return NULL; return NULL;

View file

@ -157,9 +157,15 @@ _pysqlite_long_as_int64(PyObject * py_val)
} }
else if (sizeof(value) < sizeof(sqlite_int64)) { else if (sizeof(value) < sizeof(sqlite_int64)) {
sqlite_int64 int64val; sqlite_int64 int64val;
#if PY_VERSION_HEX < 0x030D0000
if (_PyLong_AsByteArray((PyLongObject *)py_val, if (_PyLong_AsByteArray((PyLongObject *)py_val,
(unsigned char *)&int64val, sizeof(int64val), (unsigned char *)&int64val, sizeof(int64val),
IS_LITTLE_ENDIAN, 1 /* signed */) >= 0) { IS_LITTLE_ENDIAN, 1 /* signed */) >= 0) {
#else
if (_PyLong_AsByteArray((PyLongObject *)py_val,
(unsigned char *)&int64val, sizeof(int64val),
IS_LITTLE_ENDIAN, 1 /* signed */, 1) >= 0) {
#endif
return int64val; return int64val;
} }
} }

View file

@ -39,4 +39,46 @@ int _pysqlite_seterror(sqlite3* db);
sqlite_int64 _pysqlite_long_as_int64(PyObject * value); sqlite_int64 _pysqlite_long_as_int64(PyObject * value);
#ifndef _Py_CAST
# define _Py_CAST(type, expr) ((type)(expr))
#endif
// Cast argument to PyObject* type.
#ifndef _PyObject_CAST
# define _PyObject_CAST(op) _Py_CAST(PyObject*, op)
#endif
#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef)
static inline PyObject* _Py_NewRef(PyObject *obj)
{
Py_INCREF(obj);
return obj;
}
#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
#endif
#if PY_VERSION_HEX < 0x030D0000
static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj)
{
PyObject *obj;
if (ref != NULL && !PyWeakref_Check(ref)) {
*pobj = NULL;
PyErr_SetString(PyExc_TypeError, "expected a weakref");
return -1;
}
obj = PyWeakref_GetObject(ref);
if (obj == NULL) {
// SystemError if ref is NULL
*pobj = NULL;
return -1;
}
if (obj == Py_None) {
*pobj = NULL;
return 0;
}
*pobj = Py_NewRef(obj);
return (*pobj != NULL);
}
#endif
#endif #endif

View file

@ -2,14 +2,14 @@ import optparse
import sys import sys
import unittest import unittest
from test.test_backup import suite as backup_suite from tests.backup import suite as backup_suite
from test.test_dbapi import suite as dbapi_suite from tests.dbapi import suite as dbapi_suite
from test.test_factory import suite as factory_suite from tests.factory import suite as factory_suite
from test.test_hooks import suite as hooks_suite from tests.hooks import suite as hooks_suite
from test.test_regression import suite as regression_suite from tests.regression import suite as regression_suite
from test.test_transactions import suite as transactions_suite from tests.transactions import suite as transactions_suite
from test.test_ttypes import suite as types_suite from tests.ttypes import suite as types_suite
from test.test_userfunctions import suite as userfunctions_suite from tests.userfunctions import suite as userfunctions_suite
def test(verbosity=1, failfast=False): def test(verbosity=1, failfast=False):

View file

@ -95,6 +95,20 @@ class BackupTests(unittest.TestCase):
self.assertEqual(len(journal), 1) self.assertEqual(len(journal), 1)
self.assertEqual(journal[0], 0) self.assertEqual(journal[0], 0)
def test_sleep(self):
with self.assertRaises(ValueError) as bm:
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=-1)
self.assertEqual(str(bm.exception), 'sleep must be greater-than or equal to zero')
with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=None)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=10)
self.verify_backup(bck)
def test_non_callable_progress(self): def test_non_callable_progress(self):
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
with sqlite.connect(':memory:') as bck: with sqlite.connect(':memory:') as bck:
@ -156,7 +170,14 @@ class BackupTests(unittest.TestCase):
def suite(): def suite():
return unittest.makeSuite(BackupTests) loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
BackupTests,)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() test()

File diff suppressed because it is too large Load diff

View file

@ -47,7 +47,7 @@ class ConnectionFactoryTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.con.close() self.con.close()
def CheckIsInstance(self): def test_IsInstance(self):
self.assertIsInstance(self.con, MyConnection) self.assertIsInstance(self.con, MyConnection)
class CursorFactoryTests(unittest.TestCase): class CursorFactoryTests(unittest.TestCase):
@ -57,7 +57,7 @@ class CursorFactoryTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.con.close() self.con.close()
def CheckIsInstance(self): def test_IsInstance(self):
cur = self.con.cursor() cur = self.con.cursor()
self.assertIsInstance(cur, sqlite.Cursor) self.assertIsInstance(cur, sqlite.Cursor)
cur = self.con.cursor(MyCursor) cur = self.con.cursor(MyCursor)
@ -65,7 +65,7 @@ class CursorFactoryTests(unittest.TestCase):
cur = self.con.cursor(factory=lambda con: MyCursor(con)) cur = self.con.cursor(factory=lambda con: MyCursor(con))
self.assertIsInstance(cur, MyCursor) self.assertIsInstance(cur, MyCursor)
def CheckInvalidFactory(self): def test_InvalidFactory(self):
# not a callable at all # not a callable at all
self.assertRaises(TypeError, self.con.cursor, None) self.assertRaises(TypeError, self.con.cursor, None)
# invalid callable with not exact one argument # invalid callable with not exact one argument
@ -77,7 +77,7 @@ class RowFactoryTestsBackwardsCompat(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
def CheckIsProducedByFactory(self): def test_IsProducedByFactory(self):
cur = self.con.cursor(factory=MyCursor) cur = self.con.cursor(factory=MyCursor)
cur.execute("select 4+5 as foo") cur.execute("select 4+5 as foo")
row = cur.fetchone() row = cur.fetchone()
@ -91,12 +91,12 @@ class RowFactoryTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
def CheckCustomFactory(self): def test_CustomFactory(self):
self.con.row_factory = lambda cur, row: list(row) self.con.row_factory = lambda cur, row: list(row)
row = self.con.execute("select 1, 2").fetchone() row = self.con.execute("select 1, 2").fetchone()
self.assertIsInstance(row, list) self.assertIsInstance(row, list)
def CheckSqliteRowIndex(self): def test_SqliteRowIndex(self):
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a_1, 2 as b").fetchone() row = self.con.execute("select 1 as a_1, 2 as b").fetchone()
self.assertIsInstance(row, sqlite.Row) self.assertIsInstance(row, sqlite.Row)
@ -125,7 +125,7 @@ class RowFactoryTests(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
row[2**1000] row[2**1000]
def CheckSqliteRowIndexUnicode(self): def test_SqliteRowIndexUnicode(self):
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as \xff").fetchone() row = self.con.execute("select 1 as \xff").fetchone()
self.assertEqual(row["\xff"], 1) self.assertEqual(row["\xff"], 1)
@ -134,7 +134,7 @@ class RowFactoryTests(unittest.TestCase):
with self.assertRaises(IndexError): with self.assertRaises(IndexError):
row['\xdf'] row['\xdf']
def CheckSqliteRowSlice(self): def test_SqliteRowSlice(self):
# A sqlite.Row can be sliced like a list. # A sqlite.Row can be sliced like a list.
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1, 2, 3, 4").fetchone() row = self.con.execute("select 1, 2, 3, 4").fetchone()
@ -152,21 +152,21 @@ class RowFactoryTests(unittest.TestCase):
self.assertEqual(row[0:4:2], (1, 3)) self.assertEqual(row[0:4:2], (1, 3))
self.assertEqual(row[3:0:-2], (4, 2)) self.assertEqual(row[3:0:-2], (4, 2))
def CheckSqliteRowIter(self): def test_SqliteRowIter(self):
"""Checks if the row object is iterable""" """Checks if the row object is iterable"""
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone() row = self.con.execute("select 1 as a, 2 as b").fetchone()
for col in row: for col in row:
pass pass
def CheckSqliteRowAsTuple(self): def test_SqliteRowAsTuple(self):
"""Checks if the row object can be converted to a tuple""" """Checks if the row object can be converted to a tuple"""
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone() row = self.con.execute("select 1 as a, 2 as b").fetchone()
t = tuple(row) t = tuple(row)
self.assertEqual(t, (row['a'], row['b'])) self.assertEqual(t, (row['a'], row['b']))
def CheckSqliteRowAsDict(self): def test_SqliteRowAsDict(self):
"""Checks if the row object can be correctly converted to a dictionary""" """Checks if the row object can be correctly converted to a dictionary"""
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone() row = self.con.execute("select 1 as a, 2 as b").fetchone()
@ -174,7 +174,7 @@ class RowFactoryTests(unittest.TestCase):
self.assertEqual(d["a"], row["a"]) self.assertEqual(d["a"], row["a"])
self.assertEqual(d["b"], row["b"]) self.assertEqual(d["b"], row["b"])
def CheckSqliteRowHashCmp(self): def test_SqliteRowHashCmp(self):
"""Checks if the row object compares and hashes correctly""" """Checks if the row object compares and hashes correctly"""
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row_1 = self.con.execute("select 1 as a, 2 as b").fetchone() row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
@ -208,7 +208,7 @@ class RowFactoryTests(unittest.TestCase):
self.assertEqual(hash(row_1), hash(row_2)) self.assertEqual(hash(row_1), hash(row_2))
def CheckSqliteRowAsSequence(self): def test_SqliteRowAsSequence(self):
""" Checks if the row object can act like a sequence """ """ Checks if the row object can act like a sequence """
self.con.row_factory = sqlite.Row self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone() row = self.con.execute("select 1 as a, 2 as b").fetchone()
@ -217,7 +217,7 @@ class RowFactoryTests(unittest.TestCase):
self.assertEqual(list(reversed(row)), list(reversed(as_tuple))) self.assertEqual(list(reversed(row)), list(reversed(as_tuple)))
self.assertIsInstance(row, Sequence) self.assertIsInstance(row, Sequence)
def CheckFakeCursorClass(self): def test_FakeCursorClass(self):
# Issue #24257: Incorrect use of PyObject_IsInstance() caused # Issue #24257: Incorrect use of PyObject_IsInstance() caused
# segmentation fault. # segmentation fault.
# Issue #27861: Also applies for cursor factory. # Issue #27861: Also applies for cursor factory.
@ -234,26 +234,26 @@ class TextFactoryTests(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
def CheckUnicode(self): def test_Unicode(self):
austria = "Österreich" austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone() row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
def CheckString(self): def test_String(self):
self.con.text_factory = bytes self.con.text_factory = bytes
austria = "Österreich" austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone() row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes") self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes")
self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8") self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8")
def CheckCustom(self): def test_Custom(self):
self.con.text_factory = lambda x: str(x, "utf-8", "ignore") self.con.text_factory = lambda x: str(x, "utf-8", "ignore")
austria = "Österreich" austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone() row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode") self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
self.assertTrue(row[0].endswith("reich"), "column must contain original data") self.assertTrue(row[0].endswith("reich"), "column must contain original data")
def CheckOptimizedUnicode(self): def test_OptimizedUnicode(self):
# In py3k, str objects are always returned when text_factory # In py3k, str objects are always returned when text_factory
# is OptimizedUnicode # is OptimizedUnicode
self.con.text_factory = sqlite.OptimizedUnicode self.con.text_factory = sqlite.OptimizedUnicode
@ -273,25 +273,25 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
self.con.execute("create table test (value text)") self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",)) self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def CheckString(self): def test_String(self):
# text_factory defaults to str # text_factory defaults to str
row = self.con.execute("select value from test").fetchone() row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), str) self.assertIs(type(row[0]), str)
self.assertEqual(row[0], "a\x00b") self.assertEqual(row[0], "a\x00b")
def CheckBytes(self): def test_Bytes(self):
self.con.text_factory = bytes self.con.text_factory = bytes
row = self.con.execute("select value from test").fetchone() row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytes) self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b") self.assertEqual(row[0], b"a\x00b")
def CheckBytearray(self): def test_Bytearray(self):
self.con.text_factory = bytearray self.con.text_factory = bytearray
row = self.con.execute("select value from test").fetchone() row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytearray) self.assertIs(type(row[0]), bytearray)
self.assertEqual(row[0], b"a\x00b") self.assertEqual(row[0], b"a\x00b")
def CheckCustom(self): def test_Custom(self):
# A custom factory should receive a bytes argument # A custom factory should receive a bytes argument
self.con.text_factory = lambda x: x self.con.text_factory = lambda x: x
row = self.con.execute("select value from test").fetchone() row = self.con.execute("select value from test").fetchone()
@ -302,13 +302,15 @@ class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
self.con.close() self.con.close()
def suite(): def suite():
connection_suite = unittest.makeSuite(ConnectionFactoryTests, "Check") loader = unittest.TestLoader()
cursor_suite = unittest.makeSuite(CursorFactoryTests, "Check") tests = [loader.loadTestsFromTestCase(t) for t in (
row_suite_compat = unittest.makeSuite(RowFactoryTestsBackwardsCompat, "Check") ConnectionFactoryTests,
row_suite = unittest.makeSuite(RowFactoryTests, "Check") CursorFactoryTests,
text_suite = unittest.makeSuite(TextFactoryTests, "Check") RowFactoryTestsBackwardsCompat,
text_zero_bytes_suite = unittest.makeSuite(TextFactoryTestsWithEmbeddedZeroBytes, "Check") RowFactoryTests,
return unittest.TestSuite((connection_suite, cursor_suite, row_suite_compat, row_suite, text_suite, text_zero_bytes_suite)) TextFactoryTests,
TextFactoryTestsWithEmbeddedZeroBytes)]
return unittest.TestSuite(tests)
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()

View file

@ -23,27 +23,27 @@
import os import os
import unittest import unittest
from sqlcipher3 import dbapi2 as sqlite from sqlcipher3 import dbapi2 as sqlite
class CollationTests(unittest.TestCase): class CollationTests(unittest.TestCase):
def CheckCreateCollationNotString(self): def test_CreateCollationNotString(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y)) con.create_collation(None, lambda x, y: (x > y) - (x < y))
def CheckCreateCollationNotCallable(self): def test_CreateCollationNotCallable(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm: with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42) con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable') self.assertEqual(str(cm.exception), 'parameter must be callable')
def CheckCreateCollationNotAscii(self): def test_CreateCollationNotAscii(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
with self.assertRaises(sqlite.ProgrammingError):
con.create_collation("collä", lambda x, y: (x > y) - (x < y)) con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def CheckCreateCollationBadUpper(self): def test_CreateCollationBadUpper(self):
class BadUpperStr(str): class BadUpperStr(str):
def upper(self): def upper(self):
return None return None
@ -62,7 +62,7 @@ class CollationTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1), @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1),
'old SQLite versions crash on this test') 'old SQLite versions crash on this test')
def CheckCollationIsUsed(self): def test_CollationIsUsed(self):
def mycoll(x, y): def mycoll(x, y):
# reverse order # reverse order
return -((x > y) - (x < y)) return -((x > y) - (x < y))
@ -87,7 +87,7 @@ class CollationTests(unittest.TestCase):
result = con.execute(sql).fetchall() result = con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def CheckCollationReturnsLargeInteger(self): def test_CollationReturnsLargeInteger(self):
def mycoll(x, y): def mycoll(x, y):
# reverse order # reverse order
return -((x > y) - (x < y)) * 2**32 return -((x > y) - (x < y)) * 2**32
@ -106,7 +106,7 @@ class CollationTests(unittest.TestCase):
self.assertEqual(result, [('c',), ('b',), ('a',)], self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned") msg="the expected order was not returned")
def CheckCollationRegisterTwice(self): def test_CollationRegisterTwice(self):
""" """
Register two different collation functions under the same name. Register two different collation functions under the same name.
Verify that the last one is actually used. Verify that the last one is actually used.
@ -120,7 +120,7 @@ class CollationTests(unittest.TestCase):
self.assertEqual(result[0][0], 'b') self.assertEqual(result[0][0], 'b')
self.assertEqual(result[1][0], 'a') self.assertEqual(result[1][0], 'a')
def CheckDeregisterCollation(self): def test_DeregisterCollation(self):
""" """
Register a collation, then deregister it. Make sure an error is raised if we try Register a collation, then deregister it. Make sure an error is raised if we try
to use it. to use it.
@ -133,7 +133,7 @@ class CollationTests(unittest.TestCase):
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll') self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase): class ProgressTests(unittest.TestCase):
def CheckProgressHandlerUsed(self): def test_ProgressHandlerUsed(self):
""" """
Test that the progress handler is invoked once it is set. Test that the progress handler is invoked once it is set.
""" """
@ -149,7 +149,7 @@ class ProgressTests(unittest.TestCase):
self.assertTrue(progress_calls) self.assertTrue(progress_calls)
def CheckOpcodeCount(self): def test_OpcodeCount(self):
""" """
Test that the opcode argument is respected. Test that the opcode argument is respected.
""" """
@ -172,7 +172,7 @@ class ProgressTests(unittest.TestCase):
second_count = len(progress_calls) second_count = len(progress_calls)
self.assertGreaterEqual(first_count, second_count) self.assertGreaterEqual(first_count, second_count)
def CheckCancelOperation(self): def test_CancelOperation(self):
""" """
Test that returning a non-zero value stops the operation in progress. Test that returning a non-zero value stops the operation in progress.
""" """
@ -186,7 +186,7 @@ class ProgressTests(unittest.TestCase):
curs.execute, curs.execute,
"create table bar (a, b)") "create table bar (a, b)")
def CheckClearHandler(self): def test_ClearHandler(self):
""" """
Test that setting the progress handler to None clears the previously set handler. Test that setting the progress handler to None clears the previously set handler.
""" """
@ -202,7 +202,7 @@ class ProgressTests(unittest.TestCase):
self.assertEqual(action, 0, "progress handler was not cleared") self.assertEqual(action, 0, "progress handler was not cleared")
class TraceCallbackTests(unittest.TestCase): class TraceCallbackTests(unittest.TestCase):
def CheckTraceCallbackUsed(self): def test_TraceCallbackUsed(self):
""" """
Test that the trace callback is invoked once it is set. Test that the trace callback is invoked once it is set.
""" """
@ -215,7 +215,7 @@ class TraceCallbackTests(unittest.TestCase):
self.assertTrue(traced_statements) self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
def CheckTraceCallbackError(self): def test_TraceCallbackError(self):
""" """
Test behavior when exception raised in trace callback. Test behavior when exception raised in trace callback.
""" """
@ -226,7 +226,7 @@ class TraceCallbackTests(unittest.TestCase):
con.execute("create table foo(a, b)") con.execute("create table foo(a, b)")
con.set_trace_callback(None) con.set_trace_callback(None)
def CheckClearTraceCallback(self): def test_ClearTraceCallback(self):
""" """
Test that setting the trace callback to None clears the previously set callback. Test that setting the trace callback to None clears the previously set callback.
""" """
@ -239,7 +239,7 @@ class TraceCallbackTests(unittest.TestCase):
con.execute("create table foo(a, b)") con.execute("create table foo(a, b)")
self.assertFalse(traced_statements, "trace callback was not cleared") self.assertFalse(traced_statements, "trace callback was not cleared")
def CheckUnicodeContent(self): def test_UnicodeContent(self):
""" """
Test that the statement can contain unicode literals. Test that the statement can contain unicode literals.
""" """
@ -259,7 +259,7 @@ class TraceCallbackTests(unittest.TestCase):
"Unicode data %s garbled in trace callback: %s" "Unicode data %s garbled in trace callback: %s"
% (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
def CheckTraceCallbackContent(self): def test_TraceCallbackContent(self):
# set_trace_callback() shouldn't produce duplicate content (bpo-26187) # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
traced_statements = [] traced_statements = []
def trace(statement): def trace(statement):
@ -278,11 +278,49 @@ class TraceCallbackTests(unittest.TestCase):
self.assertEqual(traced_statements, queries) self.assertEqual(traced_statements, queries)
class TestBusyHandlerTimeout(unittest.TestCase):
def test_busy_handler(self):
accum = []
def custom_handler(n):
accum.append(n)
return 0 if n == 3 else 1
self.addCleanup(os.unlink, 'busy.db')
conn1 = sqlite.connect('busy.db')
conn2 = sqlite.connect('busy.db')
conn2.set_busy_handler(custom_handler)
conn1.execute('begin exclusive')
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [0, 1, 2, 3])
accum.clear()
conn2.set_busy_handler(None)
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [])
conn2.set_busy_handler(custom_handler)
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [0, 1, 2, 3])
accum.clear()
conn2.set_busy_timeout(0.01) # Clears busy handler.
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [])
def suite(): def suite():
collation_suite = unittest.makeSuite(CollationTests, "Check") loader = unittest.TestLoader()
progress_suite = unittest.makeSuite(ProgressTests, "Check") tests = [loader.loadTestsFromTestCase(t) for t in (
trace_suite = unittest.makeSuite(TraceCallbackTests, "Check") CollationTests,
return unittest.TestSuite((collation_suite, progress_suite, trace_suite)) ProgressTests,
TraceCallbackTests,
TestBusyHandlerTimeout)]
return unittest.TestSuite(tests)
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()

View file

@ -35,12 +35,12 @@ class RegressionTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.con.close() self.con.close()
def CheckPragmaUserVersion(self): def test_PragmaUserVersion(self):
# This used to crash pysqlite because this pragma command returns NULL for the column name # This used to crash pysqlite because this pragma command returns NULL for the column name
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("pragma user_version") cur.execute("pragma user_version")
def CheckPragmaSchemaVersion(self): def test_PragmaSchemaVersion(self):
# This still crashed pysqlite <= 2.2.1 # This still crashed pysqlite <= 2.2.1
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES) con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
try: try:
@ -50,7 +50,7 @@ class RegressionTests(unittest.TestCase):
cur.close() cur.close()
con.close() con.close()
def CheckStatementReset(self): def test_StatementReset(self):
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are # pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
# reset before a rollback, but only those that are still in the # reset before a rollback, but only those that are still in the
# statement cache. The others are not accessible from the connection object. # statement cache. The others are not accessible from the connection object.
@ -65,7 +65,7 @@ class RegressionTests(unittest.TestCase):
con.rollback() con.rollback()
def CheckColumnNameWithSpaces(self): def test_ColumnNameWithSpaces(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute('select 1 as "foo bar [datetime]"') cur.execute('select 1 as "foo bar [datetime]"')
self.assertEqual(cur.description[0][0], "foo bar [datetime]") self.assertEqual(cur.description[0][0], "foo bar [datetime]")
@ -73,7 +73,7 @@ class RegressionTests(unittest.TestCase):
cur.execute('select 1 as "foo baz"') cur.execute('select 1 as "foo baz"')
self.assertEqual(cur.description[0][0], "foo baz") self.assertEqual(cur.description[0][0], "foo baz")
def CheckStatementFinalizationOnCloseDb(self): def test_StatementFinalizationOnCloseDb(self):
# pysqlite versions <= 2.3.3 only finalized statements in the statement # pysqlite versions <= 2.3.3 only finalized statements in the statement
# cache when closing the database. statements that were still # cache when closing the database. statements that were still
# referenced in cursors weren't closed and could provoke " # referenced in cursors weren't closed and could provoke "
@ -88,7 +88,7 @@ class RegressionTests(unittest.TestCase):
con.close() con.close()
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2), 'needs sqlite 3.2.2 or newer') @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2), 'needs sqlite 3.2.2 or newer')
def CheckOnConflictRollback(self): def test_OnConflictRollback(self):
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
con.execute("create table foo(x, unique(x) on conflict rollback)") con.execute("create table foo(x, unique(x) on conflict rollback)")
con.execute("insert into foo(x) values (1)") con.execute("insert into foo(x) values (1)")
@ -102,7 +102,7 @@ class RegressionTests(unittest.TestCase):
except sqlite.OperationalError: except sqlite.OperationalError:
self.fail("pysqlite knew nothing about the implicit ROLLBACK") self.fail("pysqlite knew nothing about the implicit ROLLBACK")
def CheckWorkaroundForBuggySqliteTransferBindings(self): def test_WorkaroundForBuggySqliteTransferBindings(self):
""" """
pysqlite would crash with older SQLite versions unless pysqlite would crash with older SQLite versions unless
a workaround is implemented. a workaround is implemented.
@ -111,14 +111,14 @@ class RegressionTests(unittest.TestCase):
self.con.execute("drop table foo") self.con.execute("drop table foo")
self.con.execute("create table foo(bar)") self.con.execute("create table foo(bar)")
def CheckEmptyStatement(self): def test_EmptyStatement(self):
""" """
pysqlite used to segfault with SQLite versions 3.5.x. These return NULL pysqlite used to segfault with SQLite versions 3.5.x. These return NULL
for "no-operation" statements for "no-operation" statements
""" """
self.con.execute("") self.con.execute("")
def CheckTypeMapUsage(self): def test_TypeMapUsage(self):
""" """
pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling
a statement. This test exhibits the problem. a statement. This test exhibits the problem.
@ -133,7 +133,7 @@ class RegressionTests(unittest.TestCase):
con.execute("insert into foo(bar) values (5)") con.execute("insert into foo(bar) values (5)")
con.execute(SELECT) con.execute(SELECT)
def CheckErrorMsgDecodeError(self): def test_ErrorMsgDecodeError(self):
# When porting the module to Python 3.0, the error message about # When porting the module to Python 3.0, the error message about
# decoding errors disappeared. This verifies they're back again. # decoding errors disappeared. This verifies they're back again.
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
@ -142,13 +142,13 @@ class RegressionTests(unittest.TestCase):
msg = "Could not decode to UTF-8 column 'colname' with text 'xxx" msg = "Could not decode to UTF-8 column 'colname' with text 'xxx"
self.assertIn(msg, str(cm.exception)) self.assertIn(msg, str(cm.exception))
def CheckRegisterAdapter(self): def test_RegisterAdapter(self):
""" """
See issue 3312. See issue 3312.
""" """
self.assertRaises(TypeError, sqlite.register_adapter, {}, None) self.assertRaises(TypeError, sqlite.register_adapter, {}, None)
def CheckSetIsolationLevel(self): def test_SetIsolationLevel(self):
# See issue 27881. # See issue 27881.
class CustomStr(str): class CustomStr(str):
def upper(self): def upper(self):
@ -170,7 +170,7 @@ class RegressionTests(unittest.TestCase):
con.isolation_level = "DEFERRED" con.isolation_level = "DEFERRED"
pairs = [ pairs = [
(1, TypeError), (b'', TypeError), ("abc", ValueError), (1, TypeError), (b'', TypeError), ("abc", ValueError),
("IMMEDIATE\0EXCLUSIVE", ValueError), ("\xe9", ValueError), ("\xe9", ValueError),
] ]
for value, exc in pairs: for value, exc in pairs:
with self.subTest(level=value): with self.subTest(level=value):
@ -178,7 +178,7 @@ class RegressionTests(unittest.TestCase):
con.isolation_level = value con.isolation_level = value
self.assertEqual(con.isolation_level, "DEFERRED") self.assertEqual(con.isolation_level, "DEFERRED")
def CheckCursorConstructorCallCheck(self): def test_CursorConstructorCallCheck(self):
""" """
Verifies that cursor methods check whether base class __init__ was Verifies that cursor methods check whether base class __init__ was
called. called.
@ -195,14 +195,14 @@ class RegressionTests(unittest.TestCase):
r'^Base Cursor\.__init__ not called\.$'): r'^Base Cursor\.__init__ not called\.$'):
cur.close() cur.close()
def CheckStrSubclass(self): def test_StrSubclass(self):
""" """
The Python 3.0 port of the module didn't cope with values of subclasses of str. The Python 3.0 port of the module didn't cope with values of subclasses of str.
""" """
class MyStr(str): pass class MyStr(str): pass
self.con.execute("select ?", (MyStr("abc"),)) self.con.execute("select ?", (MyStr("abc"),))
def CheckConnectionConstructorCallCheck(self): def test_ConnectionConstructorCallCheck(self):
""" """
Verifies that connection methods check whether base class __init__ was Verifies that connection methods check whether base class __init__ was
called. called.
@ -215,7 +215,7 @@ class RegressionTests(unittest.TestCase):
with self.assertRaises(sqlite.ProgrammingError): with self.assertRaises(sqlite.ProgrammingError):
cur = con.cursor() cur = con.cursor()
def CheckCursorRegistration(self): def test_CursorRegistration(self):
""" """
Verifies that subclassed cursor classes are correctly registered with Verifies that subclassed cursor classes are correctly registered with
the connection object, too. (fetch-across-rollback problem) the connection object, too. (fetch-across-rollback problem)
@ -237,7 +237,7 @@ class RegressionTests(unittest.TestCase):
with self.assertRaises(sqlite.InterfaceError): with self.assertRaises(sqlite.InterfaceError):
cur.fetchall() cur.fetchall()
def CheckAutoCommit(self): def test_AutoCommit(self):
""" """
Verifies that creating a connection in autocommit mode works. Verifies that creating a connection in autocommit mode works.
2.5.3 introduced a regression so that these could no longer 2.5.3 introduced a regression so that these could no longer
@ -245,7 +245,7 @@ class RegressionTests(unittest.TestCase):
""" """
con = sqlite.connect(":memory:", isolation_level=None) con = sqlite.connect(":memory:", isolation_level=None)
def CheckPragmaAutocommit(self): def test_PragmaAutocommit(self):
""" """
Verifies that running a PRAGMA statement that does an autocommit does Verifies that running a PRAGMA statement that does an autocommit does
work. This did not work in 2.5.3/2.5.4. work. This did not work in 2.5.3/2.5.4.
@ -257,21 +257,21 @@ class RegressionTests(unittest.TestCase):
cur.execute("pragma page_size") cur.execute("pragma page_size")
row = cur.fetchone() row = cur.fetchone()
def CheckConnectionCall(self): def test_ConnectionCall(self):
""" """
Call a connection with a non-string SQL request: check error handling Call a connection with a non-string SQL request: check error handling
of the statement constructor. of the statement constructor.
""" """
self.assertRaises(TypeError, self.con, 1) self.assertRaises(TypeError, self.con, 1)
def CheckCollation(self): def test_Collation(self):
def collation_cb(a, b): def collation_cb(a, b):
return 1 return 1
self.assertRaises(sqlite.ProgrammingError, self.con.create_collation, self.assertRaises(UnicodeEncodeError, self.con.create_collation,
# Lone surrogate cannot be encoded to the default encoding (utf8) # Lone surrogate cannot be encoded to the default encoding (utf8)
"\uDC80", collation_cb) "\uDC80", collation_cb)
def CheckRecursiveCursorUse(self): def test_RecursiveCursorUse(self):
""" """
http://bugs.python.org/issue10811 http://bugs.python.org/issue10811
@ -292,7 +292,7 @@ class RegressionTests(unittest.TestCase):
cur.executemany("insert into b (baz) values (?)", cur.executemany("insert into b (baz) values (?)",
((i,) for i in foo())) ((i,) for i in foo()))
def CheckConvertTimestampMicrosecondPadding(self): def test_ConvertTimestampMicrosecondPadding(self):
""" """
http://bugs.python.org/issue14720 http://bugs.python.org/issue14720
@ -318,13 +318,13 @@ class RegressionTests(unittest.TestCase):
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456), datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
]) ])
def CheckInvalidIsolationLevelType(self): def test_InvalidIsolationLevelType(self):
# isolation level is a string, not an integer # isolation level is a string, not an integer
self.assertRaises(TypeError, self.assertRaises(TypeError,
sqlite.connect, ":memory:", isolation_level=123) sqlite.connect, ":memory:", isolation_level=123)
def CheckNullCharacter(self): def test_NullCharacter(self):
# Issue #21147 # Issue #21147
con = sqlite.connect(":memory:") con = sqlite.connect(":memory:")
self.assertRaises(ValueError, con, "\0select 1") self.assertRaises(ValueError, con, "\0select 1")
@ -333,7 +333,7 @@ class RegressionTests(unittest.TestCase):
self.assertRaises(ValueError, cur.execute, " \0select 2") self.assertRaises(ValueError, cur.execute, " \0select 2")
self.assertRaises(ValueError, cur.execute, "select 2\0") self.assertRaises(ValueError, cur.execute, "select 2\0")
def CheckCommitCursorReset(self): def test_CommitCursorReset(self):
""" """
Connection.commit() did reset cursors, which made sqlite3 Connection.commit() did reset cursors, which made sqlite3
to return rows multiple times when fetched from cursors to return rows multiple times when fetched from cursors
@ -364,7 +364,7 @@ class RegressionTests(unittest.TestCase):
counter += 1 counter += 1
self.assertEqual(counter, 3, "should have returned exactly three rows") self.assertEqual(counter, 3, "should have returned exactly three rows")
def CheckBpo31770(self): def test_Bpo31770(self):
""" """
The interpreter shouldn't crash in case Cursor.__init__() is called The interpreter shouldn't crash in case Cursor.__init__() is called
more than once. more than once.
@ -380,11 +380,11 @@ class RegressionTests(unittest.TestCase):
del ref del ref
#support.gc_collect() #support.gc_collect()
def CheckDelIsolation_levelSegfault(self): def test_DelIsolation_levelSegfault(self):
with self.assertRaises(AttributeError): with self.assertRaises(AttributeError):
del self.con.isolation_level del self.con.isolation_level
def CheckBpo37347(self): def test_Bpo37347(self):
class Printer: class Printer:
def log(self, *args): def log(self, *args):
return sqlite.SQLITE_OK return sqlite.SQLITE_OK
@ -400,10 +400,10 @@ class RegressionTests(unittest.TestCase):
def suite(): def suite():
regression_suite = unittest.makeSuite(RegressionTests, "Check") loader = unittest.TestLoader()
return unittest.TestSuite(( tests = [loader.loadTestsFromTestCase(t) for t in (
regression_suite, RegressionTests,)]
)) return unittest.TestSuite(tests)
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()

View file

@ -52,7 +52,7 @@ class TransactionTests(unittest.TestCase):
except OSError: except OSError:
pass pass
def CheckDMLDoesNotAutoCommitBefore(self): def test_DMLDoesNotAutoCommitBefore(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.cur1.execute("create table test2(j)") self.cur1.execute("create table test2(j)")
@ -60,14 +60,14 @@ class TransactionTests(unittest.TestCase):
res = self.cur2.fetchall() res = self.cur2.fetchall()
self.assertEqual(len(res), 0) self.assertEqual(len(res), 0)
def CheckInsertStartsTransaction(self): def test_InsertStartsTransaction(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.cur2.execute("select i from test") self.cur2.execute("select i from test")
res = self.cur2.fetchall() res = self.cur2.fetchall()
self.assertEqual(len(res), 0) self.assertEqual(len(res), 0)
def CheckUpdateStartsTransaction(self): def test_UpdateStartsTransaction(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.con1.commit() self.con1.commit()
@ -76,7 +76,7 @@ class TransactionTests(unittest.TestCase):
res = self.cur2.fetchone()[0] res = self.cur2.fetchone()[0]
self.assertEqual(res, 5) self.assertEqual(res, 5)
def CheckDeleteStartsTransaction(self): def test_DeleteStartsTransaction(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.con1.commit() self.con1.commit()
@ -85,7 +85,7 @@ class TransactionTests(unittest.TestCase):
res = self.cur2.fetchall() res = self.cur2.fetchall()
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
def CheckReplaceStartsTransaction(self): def test_ReplaceStartsTransaction(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.con1.commit() self.con1.commit()
@ -95,7 +95,7 @@ class TransactionTests(unittest.TestCase):
self.assertEqual(len(res), 1) self.assertEqual(len(res), 1)
self.assertEqual(res[0][0], 5) self.assertEqual(res[0][0], 5)
def CheckToggleAutoCommit(self): def test_ToggleAutoCommit(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
self.con1.isolation_level = None self.con1.isolation_level = None
@ -113,7 +113,7 @@ class TransactionTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2), @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
'test hangs on sqlite versions older than 3.2.2') 'test hangs on sqlite versions older than 3.2.2')
def CheckRaiseTimeout(self): def test_RaiseTimeout(self):
self.cur1.execute("create table test(i)") self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)") self.cur1.execute("insert into test(i) values (5)")
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
@ -121,7 +121,7 @@ class TransactionTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2), @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
'test hangs on sqlite versions older than 3.2.2') 'test hangs on sqlite versions older than 3.2.2')
def CheckLocking(self): def test_Locking(self):
""" """
This tests the improved concurrency with pysqlite 2.3.4. You needed This tests the improved concurrency with pysqlite 2.3.4. You needed
to roll back con2 before you could commit con1. to roll back con2 before you could commit con1.
@ -133,7 +133,7 @@ class TransactionTests(unittest.TestCase):
# NO self.con2.rollback() HERE!!! # NO self.con2.rollback() HERE!!!
self.con1.commit() self.con1.commit()
def CheckRollbackCursorConsistency(self): def test_RollbackCursorConsistency(self):
""" """
Checks if cursors on the connection are set into a "reset" state Checks if cursors on the connection are set into a "reset" state
when a rollback is done on the connection. when a rollback is done on the connection.
@ -153,12 +153,12 @@ class SpecialCommandTests(unittest.TestCase):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor() self.cur = self.con.cursor()
def CheckDropTable(self): def test_DropTable(self):
self.cur.execute("create table test(i)") self.cur.execute("create table test(i)")
self.cur.execute("insert into test(i) values (5)") self.cur.execute("insert into test(i) values (5)")
self.cur.execute("drop table test") self.cur.execute("drop table test")
def CheckPragma(self): def test_Pragma(self):
self.cur.execute("create table test(i)") self.cur.execute("create table test(i)")
self.cur.execute("insert into test(i) values (5)") self.cur.execute("insert into test(i) values (5)")
self.cur.execute("pragma count_changes=1") self.cur.execute("pragma count_changes=1")
@ -171,7 +171,7 @@ class TransactionalDDL(unittest.TestCase):
def setUp(self): def setUp(self):
self.con = sqlite.connect(":memory:") self.con = sqlite.connect(":memory:")
def CheckDdlDoesNotAutostartTransaction(self): def test_DdlDoesNotAutostartTransaction(self):
# For backwards compatibility reasons, DDL statements should not # For backwards compatibility reasons, DDL statements should not
# implicitly start a transaction. # implicitly start a transaction.
self.con.execute("create table test(i)") self.con.execute("create table test(i)")
@ -184,7 +184,7 @@ class TransactionalDDL(unittest.TestCase):
result = self.con.execute("select * from test2").fetchall() result = self.con.execute("select * from test2").fetchall()
self.assertEqual(result, []) self.assertEqual(result, [])
def CheckImmediateTransactionalDDL(self): def test_ImmediateTransactionalDDL(self):
# You can achieve transactional DDL by issuing a BEGIN # You can achieve transactional DDL by issuing a BEGIN
# statement manually. # statement manually.
self.con.execute("begin immediate") self.con.execute("begin immediate")
@ -193,7 +193,7 @@ class TransactionalDDL(unittest.TestCase):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test") self.con.execute("select * from test")
def CheckTransactionalDDL(self): def test_TransactionalDDL(self):
# You can achieve transactional DDL by issuing a BEGIN # You can achieve transactional DDL by issuing a BEGIN
# statement manually. # statement manually.
self.con.execute("begin") self.con.execute("begin")
@ -285,11 +285,13 @@ class DMLStatementDetectionTestCase(unittest.TestCase):
def suite(): def suite():
default_suite = unittest.makeSuite(TransactionTests, "Check") loader = unittest.TestLoader()
special_command_suite = unittest.makeSuite(SpecialCommandTests, "Check") tests = [loader.loadTestsFromTestCase(t) for t in (
ddl_suite = unittest.makeSuite(TransactionalDDL, "Check") TransactionTests,
dml_suite = unittest.makeSuite(DMLStatementDetectionTestCase) SpecialCommandTests,
return unittest.TestSuite((default_suite, special_command_suite, ddl_suite, dml_suite)) TransactionalDDL,
DMLStatementDetectionTestCase)]
return unittest.TestSuite(tests)
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()

View file

@ -40,33 +40,33 @@ class SqliteTypeTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckString(self): def test_String(self):
self.cur.execute("insert into test(s) values (?)", ("Österreich",)) self.cur.execute("insert into test(s) values (?)", ("Österreich",))
self.cur.execute("select s from test") self.cur.execute("select s from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich") self.assertEqual(row[0], "Österreich")
def CheckSmallInt(self): def test_SmallInt(self):
self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], 42) self.assertEqual(row[0], 42)
def CheckLargeInt(self): def test_LargeInt(self):
num = 2**40 num = 2**40
self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], num) self.assertEqual(row[0], num)
def CheckFloat(self): def test_Float(self):
val = 3.14 val = 3.14
self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("insert into test(f) values (?)", (val,))
self.cur.execute("select f from test") self.cur.execute("select f from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], val) self.assertEqual(row[0], val)
def CheckBlob(self): def test_Blob(self):
sample = b"Guglhupf" sample = b"Guglhupf"
val = memoryview(sample) val = memoryview(sample)
self.cur.execute("insert into test(b) values (?)", (val,)) self.cur.execute("insert into test(b) values (?)", (val,))
@ -74,7 +74,7 @@ class SqliteTypeTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], sample) self.assertEqual(row[0], sample)
def CheckUnicodeExecute(self): def test_UnicodeExecute(self):
self.cur.execute("select 'Österreich'") self.cur.execute("select 'Österreich'")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich") self.assertEqual(row[0], "Österreich")
@ -133,21 +133,21 @@ class DeclTypesTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckString(self): def test_String(self):
# default # default
self.cur.execute("insert into test(s) values (?)", ("foo",)) self.cur.execute("insert into test(s) values (?)", ("foo",))
self.cur.execute('select s as "s [WRONG]" from test') self.cur.execute('select s as "s [WRONG]" from test')
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], "foo") self.assertEqual(row[0], "foo")
def CheckSmallInt(self): def test_SmallInt(self):
# default # default
self.cur.execute("insert into test(i) values (?)", (42,)) self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test") self.cur.execute("select i from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], 42) self.assertEqual(row[0], 42)
def CheckLargeInt(self): def test_LargeInt(self):
# default # default
num = 2**40 num = 2**40
self.cur.execute("insert into test(i) values (?)", (num,)) self.cur.execute("insert into test(i) values (?)", (num,))
@ -155,7 +155,7 @@ class DeclTypesTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], num) self.assertEqual(row[0], num)
def CheckFloat(self): def test_Float(self):
# custom # custom
val = 3.14 val = 3.14
self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("insert into test(f) values (?)", (val,))
@ -163,7 +163,7 @@ class DeclTypesTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], 47.2) self.assertEqual(row[0], 47.2)
def CheckBool(self): def test_Bool(self):
# custom # custom
self.cur.execute("insert into test(b) values (?)", (False,)) self.cur.execute("insert into test(b) values (?)", (False,))
self.cur.execute("select b from test") self.cur.execute("select b from test")
@ -176,7 +176,7 @@ class DeclTypesTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertIs(row[0], True) self.assertIs(row[0], True)
def CheckUnicode(self): def test_Unicode(self):
# default # default
val = "\xd6sterreich" val = "\xd6sterreich"
self.cur.execute("insert into test(u) values (?)", (val,)) self.cur.execute("insert into test(u) values (?)", (val,))
@ -184,14 +184,14 @@ class DeclTypesTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], val) self.assertEqual(row[0], val)
def CheckFoo(self): def test_Foo(self):
val = DeclTypesTests.Foo("bla") val = DeclTypesTests.Foo("bla")
self.cur.execute("insert into test(foo) values (?)", (val,)) self.cur.execute("insert into test(foo) values (?)", (val,))
self.cur.execute("select foo from test") self.cur.execute("select foo from test")
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], val) self.assertEqual(row[0], val)
def CheckErrorInConform(self): def test_ErrorInConform(self):
val = DeclTypesTests.BadConform(TypeError) val = DeclTypesTests.BadConform(TypeError)
with self.assertRaises(sqlite.InterfaceError): with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(bad) values (?)", (val,)) self.cur.execute("insert into test(bad) values (?)", (val,))
@ -204,19 +204,19 @@ class DeclTypesTests(unittest.TestCase):
with self.assertRaises(KeyboardInterrupt): with self.assertRaises(KeyboardInterrupt):
self.cur.execute("insert into test(bad) values (:val)", {"val": val}) self.cur.execute("insert into test(bad) values (:val)", {"val": val})
def CheckUnsupportedSeq(self): def test_UnsupportedSeq(self):
class Bar: pass class Bar: pass
val = Bar() val = Bar()
with self.assertRaises(sqlite.InterfaceError): with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(f) values (?)", (val,)) self.cur.execute("insert into test(f) values (?)", (val,))
def CheckUnsupportedDict(self): def test_UnsupportedDict(self):
class Bar: pass class Bar: pass
val = Bar() val = Bar()
with self.assertRaises(sqlite.InterfaceError): with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(f) values (:val)", {"val": val}) self.cur.execute("insert into test(f) values (:val)", {"val": val})
def CheckBlob(self): def test_Blob(self):
# default # default
sample = b"Guglhupf" sample = b"Guglhupf"
val = memoryview(sample) val = memoryview(sample)
@ -225,13 +225,13 @@ class DeclTypesTests(unittest.TestCase):
row = self.cur.fetchone() row = self.cur.fetchone()
self.assertEqual(row[0], sample) self.assertEqual(row[0], sample)
def CheckNumber1(self): def test_Number1(self):
self.cur.execute("insert into test(n1) values (5)") self.cur.execute("insert into test(n1) values (5)")
value = self.cur.execute("select n1 from test").fetchone()[0] value = self.cur.execute("select n1 from test").fetchone()[0]
# if the converter is not used, it's an int instead of a float # if the converter is not used, it's an int instead of a float
self.assertEqual(type(value), float) self.assertEqual(type(value), float)
def CheckNumber2(self): def test_Number2(self):
"""Checks whether converter names are cut off at '(' characters""" """Checks whether converter names are cut off at '(' characters"""
self.cur.execute("insert into test(n2) values (5)") self.cur.execute("insert into test(n2) values (5)")
value = self.cur.execute("select n2 from test").fetchone()[0] value = self.cur.execute("select n2 from test").fetchone()[0]
@ -257,7 +257,7 @@ class ColNamesTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckDeclTypeNotUsed(self): def test_DeclTypeNotUsed(self):
""" """
Assures that the declared type is not used when PARSE_DECLTYPES Assures that the declared type is not used when PARSE_DECLTYPES
is not set. is not set.
@ -267,13 +267,13 @@ class ColNamesTests(unittest.TestCase):
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
self.assertEqual(val, "xxx") self.assertEqual(val, "xxx")
def CheckNone(self): def test_None(self):
self.cur.execute("insert into test(x) values (?)", (None,)) self.cur.execute("insert into test(x) values (?)", (None,))
self.cur.execute("select x from test") self.cur.execute("select x from test")
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
self.assertEqual(val, None) self.assertEqual(val, None)
def CheckColName(self): def test_ColName(self):
self.cur.execute("insert into test(x) values (?)", ("xxx",)) self.cur.execute("insert into test(x) values (?)", ("xxx",))
self.cur.execute('select x as "x y [bar]" from test') self.cur.execute('select x as "x y [bar]" from test')
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
@ -283,12 +283,12 @@ class ColNamesTests(unittest.TestCase):
# whitespace should be stripped. # whitespace should be stripped.
self.assertEqual(self.cur.description[0][0], "x y") self.assertEqual(self.cur.description[0][0], "x y")
def CheckCaseInConverterName(self): def test_CaseInConverterName(self):
self.cur.execute("select 'other' as \"x [b1b1]\"") self.cur.execute("select 'other' as \"x [b1b1]\"")
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
self.assertEqual(val, "MARKER") self.assertEqual(val, "MARKER")
def CheckCursorDescriptionNoRow(self): def test_CursorDescriptionNoRow(self):
""" """
cursor.description should at least provide the column name(s), even if cursor.description should at least provide the column name(s), even if
no row returned. no row returned.
@ -296,7 +296,7 @@ class ColNamesTests(unittest.TestCase):
self.cur.execute("select * from test where 0 = 1") self.cur.execute("select * from test where 0 = 1")
self.assertEqual(self.cur.description[0][0], "x") self.assertEqual(self.cur.description[0][0], "x")
def CheckCursorDescriptionInsert(self): def test_CursorDescriptionInsert(self):
self.cur.execute("insert into test values (1)") self.cur.execute("insert into test values (1)")
self.assertIsNone(self.cur.description) self.assertIsNone(self.cur.description)
@ -313,19 +313,19 @@ class CommonTableExpressionTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckCursorDescriptionCTESimple(self): def test_CursorDescriptionCTESimple(self):
self.cur.execute("with one as (select 1) select * from one") self.cur.execute("with one as (select 1) select * from one")
self.assertIsNotNone(self.cur.description) self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "1") self.assertEqual(self.cur.description[0][0], "1")
def CheckCursorDescriptionCTESMultipleColumns(self): def test_CursorDescriptionCTESMultipleColumns(self):
self.cur.execute("insert into test values(1)") self.cur.execute("insert into test values(1)")
self.cur.execute("insert into test values(2)") self.cur.execute("insert into test values(2)")
self.cur.execute("with testCTE as (select * from test) select * from testCTE") self.cur.execute("with testCTE as (select * from test) select * from testCTE")
self.assertIsNotNone(self.cur.description) self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "x") self.assertEqual(self.cur.description[0][0], "x")
def CheckCursorDescriptionCTE(self): def test_CursorDescriptionCTE(self):
self.cur.execute("insert into test values (1)") self.cur.execute("insert into test values (1)")
self.cur.execute("with bar as (select * from test) select * from test where x = 1") self.cur.execute("with bar as (select * from test) select * from test where x = 1")
self.assertIsNotNone(self.cur.description) self.assertIsNotNone(self.cur.description)
@ -354,7 +354,7 @@ class ObjectAdaptationTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckCasterIsUsed(self): def test_CasterIsUsed(self):
self.cur.execute("select ?", (4,)) self.cur.execute("select ?", (4,))
val = self.cur.fetchone()[0] val = self.cur.fetchone()[0]
self.assertEqual(type(val), float) self.assertEqual(type(val), float)
@ -372,7 +372,7 @@ class BinaryConverterTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.con.close() self.con.close()
def CheckBinaryInputForConverter(self): def test_BinaryInputForConverter(self):
testdata = b"abcdefg" * 10 testdata = b"abcdefg" * 10
result = self.con.execute('select ? as "x [bin]"', (memoryview(zlib.compress(testdata)),)).fetchone()[0] result = self.con.execute('select ? as "x [bin]"', (memoryview(zlib.compress(testdata)),)).fetchone()[0]
self.assertEqual(testdata, result) self.assertEqual(testdata, result)
@ -387,14 +387,14 @@ class DateTimeTests(unittest.TestCase):
self.cur.close() self.cur.close()
self.con.close() self.con.close()
def CheckSqliteDate(self): def test_SqliteDate(self):
d = sqlite.Date(2004, 2, 14) d = sqlite.Date(2004, 2, 14)
self.cur.execute("insert into test(d) values (?)", (d,)) self.cur.execute("insert into test(d) values (?)", (d,))
self.cur.execute("select d from test") self.cur.execute("select d from test")
d2 = self.cur.fetchone()[0] d2 = self.cur.fetchone()[0]
self.assertEqual(d, d2) self.assertEqual(d, d2)
def CheckSqliteTimestamp(self): def test_SqliteTimestamp(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0) ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0)
self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test") self.cur.execute("select ts from test")
@ -403,7 +403,7 @@ class DateTimeTests(unittest.TestCase):
@unittest.skipIf(sqlite.sqlite_version_info < (3, 1), @unittest.skipIf(sqlite.sqlite_version_info < (3, 1),
'the date functions are available on 3.1 or later') 'the date functions are available on 3.1 or later')
def CheckSqlTimestamp(self): def test_SqlTimestamp(self):
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
self.cur.execute("insert into test(ts) values (current_timestamp)") self.cur.execute("insert into test(ts) values (current_timestamp)")
self.cur.execute("select ts from test") self.cur.execute("select ts from test")
@ -411,14 +411,14 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(type(ts), datetime.datetime) self.assertEqual(type(ts), datetime.datetime)
self.assertEqual(ts.year, now.year) self.assertEqual(ts.year, now.year)
def CheckDateTimeSubSeconds(self): def test_DateTimeSubSeconds(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 500000) ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 500000)
self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test") self.cur.execute("select ts from test")
ts2 = self.cur.fetchone()[0] ts2 = self.cur.fetchone()[0]
self.assertEqual(ts, ts2) self.assertEqual(ts, ts2)
def CheckDateTimeSubSecondsFloatingPoint(self): def test_DateTimeSubSecondsFloatingPoint(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 510241) ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 510241)
self.cur.execute("insert into test(ts) values (?)", (ts,)) self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test") self.cur.execute("select ts from test")
@ -426,14 +426,16 @@ class DateTimeTests(unittest.TestCase):
self.assertEqual(ts, ts2) self.assertEqual(ts, ts2)
def suite(): def suite():
sqlite_type_suite = unittest.makeSuite(SqliteTypeTests, "Check") loader = unittest.TestLoader()
decltypes_type_suite = unittest.makeSuite(DeclTypesTests, "Check") tests = [loader.loadTestsFromTestCase(t) for t in (
colnames_type_suite = unittest.makeSuite(ColNamesTests, "Check") SqliteTypeTests,
adaptation_suite = unittest.makeSuite(ObjectAdaptationTests, "Check") DeclTypesTests,
bin_suite = unittest.makeSuite(BinaryConverterTests, "Check") ColNamesTests,
date_suite = unittest.makeSuite(DateTimeTests, "Check") ObjectAdaptationTests,
cte_suite = unittest.makeSuite(CommonTableExpressionTests, "Check") BinaryConverterTests,
return unittest.TestSuite((sqlite_type_suite, decltypes_type_suite, colnames_type_suite, adaptation_suite, bin_suite, date_suite, cte_suite)) DateTimeTests,
CommonTableExpressionTests)]
return unittest.TestSuite(tests)
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()

View file

@ -162,11 +162,11 @@ class FunctionTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
self.con.close() self.con.close()
def CheckFuncErrorOnCreate(self): def test_FuncErrorOnCreate(self):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
self.con.create_function("bla", -100, lambda x: 2*x) self.con.create_function("bla", -100, lambda x: 2*x)
def CheckFuncRefCount(self): def test_FuncRefCount(self):
def getfunc(): def getfunc():
def f(): def f():
return 1 return 1
@ -178,28 +178,28 @@ class FunctionTests(unittest.TestCase):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select reftest()") cur.execute("select reftest()")
def CheckFuncReturnText(self): def test_FuncReturnText(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returntext()") cur.execute("select returntext()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(type(val), str) self.assertEqual(type(val), str)
self.assertEqual(val, "foo") self.assertEqual(val, "foo")
def CheckFuncReturnUnicode(self): def test_FuncReturnUnicode(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnunicode()") cur.execute("select returnunicode()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(type(val), str) self.assertEqual(type(val), str)
self.assertEqual(val, "bar") self.assertEqual(val, "bar")
def CheckFuncReturnInt(self): def test_FuncReturnInt(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnint()") cur.execute("select returnint()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(type(val), int) self.assertEqual(type(val), int)
self.assertEqual(val, 42) self.assertEqual(val, 42)
def CheckFuncReturnFloat(self): def test_FuncReturnFloat(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnfloat()") cur.execute("select returnfloat()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
@ -207,90 +207,90 @@ class FunctionTests(unittest.TestCase):
if val < 3.139 or val > 3.141: if val < 3.139 or val > 3.141:
self.fail("wrong value") self.fail("wrong value")
def CheckFuncReturnNull(self): def test_FuncReturnNull(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnnull()") cur.execute("select returnnull()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(type(val), type(None)) self.assertEqual(type(val), type(None))
self.assertEqual(val, None) self.assertEqual(val, None)
def CheckFuncReturnBlob(self): def test_FuncReturnBlob(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnblob()") cur.execute("select returnblob()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(type(val), bytes) self.assertEqual(type(val), bytes)
self.assertEqual(val, b"blob") self.assertEqual(val, b"blob")
def CheckFuncReturnLongLong(self): def test_FuncReturnLongLong(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select returnlonglong()") cur.execute("select returnlonglong()")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1<<31) self.assertEqual(val, 1<<31)
def CheckFuncException(self): def test_FuncException(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select raiseexception()") cur.execute("select raiseexception()")
cur.fetchone() cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception') self.assertEqual(str(cm.exception), 'user-defined function raised exception')
def CheckParamString(self): def test_ParamString(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select isstring(?)", ("foo",)) cur.execute("select isstring(?)", ("foo",))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckParamInt(self): def test_ParamInt(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select isint(?)", (42,)) cur.execute("select isint(?)", (42,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckParamFloat(self): def test_ParamFloat(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select isfloat(?)", (3.14,)) cur.execute("select isfloat(?)", (3.14,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckParamNone(self): def test_ParamNone(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select isnone(?)", (None,)) cur.execute("select isnone(?)", (None,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckParamBlob(self): def test_ParamBlob(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select isblob(?)", (memoryview(b"blob"),)) cur.execute("select isblob(?)", (memoryview(b"blob"),))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckParamLongLong(self): def test_ParamLongLong(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select islonglong(?)", (1<<42,)) cur.execute("select islonglong(?)", (1<<42,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAnyArguments(self): def test_AnyArguments(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select spam(?, ?)", (1, 2)) cur.execute("select spam(?, ?)", (1, 2))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 2) self.assertEqual(val, 2)
def CheckFuncNonDeterministic(self): def test_FuncNonDeterministic(self):
mock = unittest.mock.Mock(return_value=None) mock = unittest.mock.Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, deterministic=False) self.con.create_function("deterministic", 0, mock, deterministic=False)
self.con.execute("select deterministic() = deterministic()") self.con.execute("select deterministic() = deterministic()")
self.assertEqual(mock.call_count, 2) self.assertEqual(mock.call_count, 2)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "deterministic parameter not supported") @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "deterministic parameter not supported")
def CheckFuncDeterministic(self): def test_FuncDeterministic(self):
mock = unittest.mock.Mock(return_value=None) mock = unittest.mock.Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, True) self.con.create_function("deterministic", 0, mock, True)
self.con.execute("select 1 where deterministic() AND deterministic()") self.con.execute("select 1 where deterministic() AND deterministic()")
self.assertEqual(mock.call_count, 1) self.assertEqual(mock.call_count, 1)
@unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed") @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
def CheckFuncDeterministicNotSupported(self): def test_FuncDeterministicNotSupported(self):
with self.assertRaises(sqlite.NotSupportedError): with self.assertRaises(sqlite.NotSupportedError):
self.con.create_function("deterministic", 0, int, deterministic=True) self.con.create_function("deterministic", 0, int, deterministic=True)
@ -325,81 +325,81 @@ class AggregateTests(unittest.TestCase):
#self.con.close() #self.con.close()
pass pass
def CheckAggrErrorOnCreate(self): def test_AggrErrorOnCreate(self):
with self.assertRaises(sqlite.OperationalError): with self.assertRaises(sqlite.OperationalError):
self.con.create_function("bla", -100, AggrSum) self.con.create_function("bla", -100, AggrSum)
def CheckAggrNoStep(self): def test_AggrNoStep(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(AttributeError) as cm: with self.assertRaises(AttributeError) as cm:
cur.execute("select nostep(t) from test") cur.execute("select nostep(t) from test")
self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'") self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
def CheckAggrNoFinalize(self): def test_AggrNoFinalize(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select nofinalize(t) from test") cur.execute("select nofinalize(t) from test")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
def CheckAggrExceptionInInit(self): def test_AggrExceptionInInit(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excInit(t) from test") cur.execute("select excInit(t) from test")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
def CheckAggrExceptionInStep(self): def test_AggrExceptionInStep(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excStep(t) from test") cur.execute("select excStep(t) from test")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
def CheckAggrExceptionInFinalize(self): def test_AggrExceptionInFinalize(self):
cur = self.con.cursor() cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm: with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excFinalize(t) from test") cur.execute("select excFinalize(t) from test")
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error") self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
def CheckAggrCheckParamStr(self): def test_AggrCheckParamStr(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkType('str', ?)", ("foo",)) cur.execute("select checkType('str', ?)", ("foo",))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAggrCheckParamInt(self): def test_AggrCheckParamInt(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkType('int', ?)", (42,)) cur.execute("select checkType('int', ?)", (42,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAggrCheckParamsInt(self): def test_AggrCheckParamsInt(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkTypes('int', ?, ?)", (42, 24)) cur.execute("select checkTypes('int', ?, ?)", (42, 24))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 2) self.assertEqual(val, 2)
def CheckAggrCheckParamFloat(self): def test_AggrCheckParamFloat(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkType('float', ?)", (3.14,)) cur.execute("select checkType('float', ?)", (3.14,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAggrCheckParamNone(self): def test_AggrCheckParamNone(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkType('None', ?)", (None,)) cur.execute("select checkType('None', ?)", (None,))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAggrCheckParamBlob(self): def test_AggrCheckParamBlob(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),)) cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 1) self.assertEqual(val, 1)
def CheckAggrCheckAggrSum(self): def test_AggrCheckAggrSum(self):
cur = self.con.cursor() cur = self.con.cursor()
cur.execute("delete from test") cur.execute("delete from test")
cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)]) cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
@ -407,7 +407,7 @@ class AggregateTests(unittest.TestCase):
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertEqual(val, 60) self.assertEqual(val, 60)
def CheckAggrNoMatch(self): def test_AggrNoMatch(self):
cur = self.con.execute('select mysum(i) from (select 1 as i) where i == 0') cur = self.con.execute('select mysum(i) from (select 1 as i) where i == 0')
val = cur.fetchone()[0] val = cur.fetchone()[0]
self.assertIsNone(val) self.assertIsNone(val)
@ -516,19 +516,16 @@ class AuthorizerLargeIntegerTests(AuthorizerTests):
def suite(): def suite():
function_suite = unittest.makeSuite(FunctionTests, "Check") loader = unittest.TestLoader()
aggregate_suite = unittest.makeSuite(AggregateTests, "Check") tests = [loader.loadTestsFromTestCase(t) for t in (
window_suite = unittest.makeSuite(WindowFunctionTests) FunctionTests,
authorizer_suite = unittest.makeSuite(AuthorizerTests) AggregateTests,
return unittest.TestSuite(( WindowFunctionTests,
function_suite, AuthorizerTests,
aggregate_suite, AuthorizerRaiseExceptionTests,
window_suite, AuthorizerIllegalTypeTests,
authorizer_suite, AuthorizerLargeIntegerTests)]
unittest.makeSuite(AuthorizerRaiseExceptionTests), return unittest.TestSuite(tests)
unittest.makeSuite(AuthorizerIllegalTypeTests),
unittest.makeSuite(AuthorizerLargeIntegerTests),
))
def test(): def test():
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()