Fix case-insensitive string comparison in sqlite3.Row indexing.
This commit is contained in:
parent
102ec435ff
commit
a6b7b19bc9
2 changed files with 49 additions and 37 deletions
61
src/row.c
61
src/row.c
|
|
@ -76,16 +76,38 @@ PyObject* pysqlite_row_item(pysqlite_Row* self, Py_ssize_t idx)
|
||||||
return item;
|
return item;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int
|
||||||
|
equal_ignore_case(PyObject *left, PyObject *right)
|
||||||
|
{
|
||||||
|
int eq = PyObject_RichCompareBool(left, right, Py_EQ);
|
||||||
|
if (eq) {
|
||||||
|
return eq;
|
||||||
|
}
|
||||||
|
if (!PyUnicode_Check(left) || !PyUnicode_Check(right)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
if (!PyUnicode_IS_ASCII(left) || !PyUnicode_IS_ASCII(right)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Py_ssize_t len = PyUnicode_GET_LENGTH(left);
|
||||||
|
if (PyUnicode_GET_LENGTH(right) != len) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
const Py_UCS1 *p1 = PyUnicode_1BYTE_DATA(left);
|
||||||
|
const Py_UCS1 *p2 = PyUnicode_1BYTE_DATA(right);
|
||||||
|
for (; len; len--, p1++, p2++) {
|
||||||
|
if (Py_TOLOWER(*p1) != Py_TOLOWER(*p2)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
PyObject* pysqlite_row_subscript(pysqlite_Row* self, PyObject* idx)
|
PyObject* pysqlite_row_subscript(pysqlite_Row* self, PyObject* idx)
|
||||||
{
|
{
|
||||||
Py_ssize_t _idx;
|
Py_ssize_t _idx;
|
||||||
const char *key;
|
|
||||||
Py_ssize_t nitems, i;
|
Py_ssize_t nitems, i;
|
||||||
const char *compare_key;
|
|
||||||
|
|
||||||
const char *p1;
|
|
||||||
const char *p2;
|
|
||||||
|
|
||||||
PyObject* item;
|
PyObject* item;
|
||||||
|
|
||||||
if (PyLong_Check(idx)) {
|
if (PyLong_Check(idx)) {
|
||||||
|
|
@ -98,44 +120,23 @@ PyObject* pysqlite_row_subscript(pysqlite_Row* self, PyObject* idx)
|
||||||
Py_XINCREF(item);
|
Py_XINCREF(item);
|
||||||
return item;
|
return item;
|
||||||
} else if (PyUnicode_Check(idx)) {
|
} else if (PyUnicode_Check(idx)) {
|
||||||
key = PyUnicode_AsUTF8(idx);
|
|
||||||
if (key == NULL)
|
|
||||||
return NULL;
|
|
||||||
|
|
||||||
nitems = PyTuple_Size(self->description);
|
nitems = PyTuple_Size(self->description);
|
||||||
|
|
||||||
for (i = 0; i < nitems; i++) {
|
for (i = 0; i < nitems; i++) {
|
||||||
PyObject *obj;
|
PyObject *obj;
|
||||||
obj = PyTuple_GET_ITEM(self->description, i);
|
obj = PyTuple_GET_ITEM(self->description, i);
|
||||||
obj = PyTuple_GET_ITEM(obj, 0);
|
obj = PyTuple_GET_ITEM(obj, 0);
|
||||||
compare_key = PyUnicode_AsUTF8(obj);
|
int eq = equal_ignore_case(idx, obj);
|
||||||
if (!compare_key) {
|
if (eq < 0) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
p1 = key;
|
if (eq) {
|
||||||
p2 = compare_key;
|
|
||||||
|
|
||||||
while (1) {
|
|
||||||
if ((*p1 == (char)0) || (*p2 == (char)0)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((*p1 | 0x20) != (*p2 | 0x20)) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
p1++;
|
|
||||||
p2++;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((*p1 == (char)0) && (*p2 == (char)0)) {
|
|
||||||
/* found item */
|
/* found item */
|
||||||
item = PyTuple_GetItem(self->data, i);
|
item = PyTuple_GetItem(self->data, i);
|
||||||
Py_INCREF(item);
|
Py_INCREF(item);
|
||||||
return item;
|
return item;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyErr_SetString(PyExc_IndexError, "No item with that key");
|
PyErr_SetString(PyExc_IndexError, "No item with that key");
|
||||||
|
|
|
||||||
|
|
@ -98,16 +98,14 @@ class RowFactoryTests(unittest.TestCase):
|
||||||
|
|
||||||
def CheckSqliteRowIndex(self):
|
def CheckSqliteRowIndex(self):
|
||||||
self.con.row_factory = sqlite.Row
|
self.con.row_factory = sqlite.Row
|
||||||
row = self.con.execute("select 1 as a, 2 as b").fetchone()
|
row = self.con.execute("select 1 as a_1, 2 as b").fetchone()
|
||||||
self.assertIsInstance(row, sqlite.Row)
|
self.assertIsInstance(row, sqlite.Row)
|
||||||
|
|
||||||
col1, col2 = row["a"], row["b"]
|
self.assertEqual(row["a_1"], 1, "by name: wrong result for column 'a_1'")
|
||||||
self.assertEqual(col1, 1, "by name: wrong result for column 'a'")
|
self.assertEqual(row["b"], 2, "by name: wrong result for column 'b'")
|
||||||
self.assertEqual(col2, 2, "by name: wrong result for column 'a'")
|
|
||||||
|
|
||||||
col1, col2 = row["A"], row["B"]
|
self.assertEqual(row["A_1"], 1, "by name: wrong result for column 'A_1'")
|
||||||
self.assertEqual(col1, 1, "by name: wrong result for column 'A'")
|
self.assertEqual(row["B"], 2, "by name: wrong result for column 'B'")
|
||||||
self.assertEqual(col2, 2, "by name: wrong result for column 'B'")
|
|
||||||
|
|
||||||
self.assertEqual(row[0], 1, "by index: wrong result for column 0")
|
self.assertEqual(row[0], 1, "by index: wrong result for column 0")
|
||||||
self.assertEqual(row[1], 2, "by index: wrong result for column 1")
|
self.assertEqual(row[1], 2, "by index: wrong result for column 1")
|
||||||
|
|
@ -116,6 +114,10 @@ class RowFactoryTests(unittest.TestCase):
|
||||||
|
|
||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
row['c']
|
row['c']
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
row['a_\x11']
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
row['a_\x7f1']
|
||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
row[2]
|
row[2]
|
||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
|
|
@ -123,6 +125,15 @@ class RowFactoryTests(unittest.TestCase):
|
||||||
with self.assertRaises(IndexError):
|
with self.assertRaises(IndexError):
|
||||||
row[2**1000]
|
row[2**1000]
|
||||||
|
|
||||||
|
def CheckSqliteRowIndexUnicode(self):
|
||||||
|
self.con.row_factory = sqlite.Row
|
||||||
|
row = self.con.execute("select 1 as \xff").fetchone()
|
||||||
|
self.assertEqual(row["\xff"], 1)
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
row['\u0178']
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
row['\xdf']
|
||||||
|
|
||||||
def CheckSqliteRowSlice(self):
|
def CheckSqliteRowSlice(self):
|
||||||
# A sqlite.Row can be sliced like a list.
|
# A sqlite.Row can be sliced like a list.
|
||||||
self.con.row_factory = sqlite.Row
|
self.con.row_factory = sqlite.Row
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue