Merging upstream changes to allow compiling python3.13 wheels

This commit is contained in:
laggykiller 2025-02-09 22:52:57 +08:00
parent 70f0118baf
commit 8ea48ee2ad
No known key found for this signature in database
19 changed files with 639 additions and 478 deletions

0
tests/__init__.py Normal file
View file

39
tests/__main__.py Normal file
View file

@ -0,0 +1,39 @@
import optparse
import sys
import unittest
from tests.backup import suite as backup_suite
from tests.dbapi import suite as dbapi_suite
from tests.factory import suite as factory_suite
from tests.hooks import suite as hooks_suite
from tests.regression import suite as regression_suite
from tests.transactions import suite as transactions_suite
from tests.ttypes import suite as types_suite
from tests.userfunctions import suite as userfunctions_suite
def test(verbosity=1, failfast=False):
runner = unittest.TextTestRunner(verbosity=verbosity, failfast=failfast)
all_tests = unittest.TestSuite((
backup_suite(),
dbapi_suite(),
factory_suite(),
hooks_suite(),
regression_suite(),
transactions_suite(),
types_suite(),
userfunctions_suite()))
results = runner.run(all_tests)
return results.failures, results.errors
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option('-v', '--verbosity', default=1, dest='verbosity',
type='int', help='output verbosity, default=1')
parser.add_option('-f', '--failfast', action='store_true', dest='failfast')
options, args = parser.parse_args()
failures, errors = test(options.verbosity, options.failfast)
if failures or errors:
sys.exit(1)

183
tests/backup.py Normal file
View file

@ -0,0 +1,183 @@
from sqlcipher3 import dbapi2 as sqlite
import unittest
@unittest.skipIf(sqlite.sqlite_version_info < (3, 6, 11), "Backup API not supported")
class BackupTests(unittest.TestCase):
def setUp(self):
cx = self.cx = sqlite.connect(":memory:")
cx.execute('CREATE TABLE foo (key INTEGER)')
cx.executemany('INSERT INTO foo (key) VALUES (?)', [(3,), (4,)])
cx.commit()
def tearDown(self):
self.cx.close()
def verify_backup(self, bckcx):
result = bckcx.execute("SELECT key FROM foo ORDER BY key").fetchall()
self.assertEqual(result[0][0], 3)
self.assertEqual(result[1][0], 4)
def test_bad_target_none(self):
with self.assertRaises(TypeError):
self.cx.backup(None)
def test_bad_target_filename(self):
with self.assertRaises(TypeError):
self.cx.backup('some_file_name.db')
def test_bad_target_same_connection(self):
with self.assertRaises(ValueError):
self.cx.backup(self.cx)
def test_bad_target_closed_connection(self):
bck = sqlite.connect(':memory:')
bck.close()
with self.assertRaises(sqlite.ProgrammingError):
self.cx.backup(bck)
def test_bad_target_in_transaction(self):
bck = sqlite.connect(':memory:')
bck.execute('CREATE TABLE bar (key INTEGER)')
bck.executemany('INSERT INTO bar (key) VALUES (?)', [(3,), (4,)])
with self.assertRaises(sqlite.OperationalError) as cm:
self.cx.backup(bck)
if sqlite.sqlite_version_info < (3, 8, 8):
self.assertEqual(str(cm.exception), 'target is in transaction')
def test_keyword_only_args(self):
with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, 1)
def test_simple(self):
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck)
self.verify_backup(bck)
def test_progress(self):
journal = []
def progress(status, remaining, total):
journal.append(status)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
self.assertEqual(len(journal), 2)
self.assertEqual(journal[0], sqlite.SQLITE_OK)
self.assertEqual(journal[1], sqlite.SQLITE_DONE)
def test_progress_all_pages_at_once_1(self):
journal = []
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, progress=progress)
self.verify_backup(bck)
self.assertEqual(len(journal), 1)
self.assertEqual(journal[0], 0)
def test_progress_all_pages_at_once_2(self):
journal = []
def progress(status, remaining, total):
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, pages=-1, progress=progress)
self.verify_backup(bck)
self.assertEqual(len(journal), 1)
self.assertEqual(journal[0], 0)
def test_sleep(self):
with self.assertRaises(ValueError) as bm:
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=-1)
self.assertEqual(str(bm.exception), 'sleep must be greater-than or equal to zero')
with self.assertRaises(TypeError):
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=None)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, sleep=10)
self.verify_backup(bck)
def test_non_callable_progress(self):
with self.assertRaises(TypeError) as cm:
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, pages=1, progress='bar')
self.assertEqual(str(cm.exception), 'progress argument must be a callable')
def test_modifying_progress(self):
journal = []
def progress(status, remaining, total):
if not journal:
self.cx.execute('INSERT INTO foo (key) VALUES (?)', (remaining+1000,))
self.cx.commit()
journal.append(remaining)
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, pages=1, progress=progress)
self.verify_backup(bck)
result = bck.execute("SELECT key FROM foo"
" WHERE key >= 1000"
" ORDER BY key").fetchall()
self.assertEqual(result[0][0], 1001)
self.assertEqual(len(journal), 3)
self.assertEqual(journal[0], 1)
self.assertEqual(journal[1], 1)
self.assertEqual(journal[2], 0)
def test_failing_progress(self):
def progress(status, remaining, total):
raise SystemError('nearly out of space')
with self.assertRaises(SystemError) as err:
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, progress=progress)
self.assertEqual(str(err.exception), 'nearly out of space')
def test_database_source_name(self):
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, name='main')
#with sqlite.connect(':memory:') as bck:
# self.cx.backup(bck, name='temp')
with self.assertRaises(sqlite.OperationalError) as cm:
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, name='non-existing')
self.assertIn(
str(cm.exception),
['SQL logic error', 'SQL logic error or missing database']
)
self.cx.execute("ATTACH DATABASE ':memory:' AS attached_db")
self.cx.execute('CREATE TABLE attached_db.foo (key INTEGER)')
self.cx.executemany('INSERT INTO attached_db.foo (key) VALUES (?)', [(3,), (4,)])
self.cx.commit()
with sqlite.connect(':memory:') as bck:
self.cx.backup(bck, name='attached_db')
self.verify_backup(bck)
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
BackupTests,)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

1224
tests/dbapi.py Normal file

File diff suppressed because it is too large Load diff

320
tests/factory.py Normal file
View file

@ -0,0 +1,320 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/factory.py: tests for the various factories in pysqlite
#
# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import unittest
from sqlcipher3 import dbapi2 as sqlite
from collections.abc import Sequence
class MyConnection(sqlite.Connection):
def __init__(self, *args, **kwargs):
sqlite.Connection.__init__(self, *args, **kwargs)
def dict_factory(cursor, row):
d = {}
for idx, col in enumerate(cursor.description):
d[col[0]] = row[idx]
return d
class MyCursor(sqlite.Cursor):
def __init__(self, *args, **kwargs):
sqlite.Cursor.__init__(self, *args, **kwargs)
self.row_factory = dict_factory
class ConnectionFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:", factory=MyConnection)
def tearDown(self):
self.con.close()
def test_IsInstance(self):
self.assertIsInstance(self.con, MyConnection)
class CursorFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
def test_IsInstance(self):
cur = self.con.cursor()
self.assertIsInstance(cur, sqlite.Cursor)
cur = self.con.cursor(MyCursor)
self.assertIsInstance(cur, MyCursor)
cur = self.con.cursor(factory=lambda con: MyCursor(con))
self.assertIsInstance(cur, MyCursor)
def test_InvalidFactory(self):
# not a callable at all
self.assertRaises(TypeError, self.con.cursor, None)
# invalid callable with not exact one argument
self.assertRaises(TypeError, self.con.cursor, lambda: None)
# invalid callable returning non-cursor
self.assertRaises(TypeError, self.con.cursor, lambda con: None)
class RowFactoryTestsBackwardsCompat(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_IsProducedByFactory(self):
cur = self.con.cursor(factory=MyCursor)
cur.execute("select 4+5 as foo")
row = cur.fetchone()
self.assertIsInstance(row, dict)
cur.close()
def tearDown(self):
self.con.close()
class RowFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_CustomFactory(self):
self.con.row_factory = lambda cur, row: list(row)
row = self.con.execute("select 1, 2").fetchone()
self.assertIsInstance(row, list)
def test_SqliteRowIndex(self):
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a_1, 2 as b").fetchone()
self.assertIsInstance(row, sqlite.Row)
self.assertEqual(row["a_1"], 1, "by name: wrong result for column 'a_1'")
self.assertEqual(row["b"], 2, "by name: wrong result for column 'b'")
self.assertEqual(row["A_1"], 1, "by name: wrong result for column 'A_1'")
self.assertEqual(row["B"], 2, "by name: wrong result for column 'B'")
self.assertEqual(row[0], 1, "by index: wrong result for column 0")
self.assertEqual(row[1], 2, "by index: wrong result for column 1")
self.assertEqual(row[-1], 2, "by index: wrong result for column -1")
self.assertEqual(row[-2], 1, "by index: wrong result for column -2")
with self.assertRaises(IndexError):
row['c']
with self.assertRaises(IndexError):
row['a_\x11']
with self.assertRaises(IndexError):
row['a_\x7f1']
with self.assertRaises(IndexError):
row[2]
with self.assertRaises(IndexError):
row[-3]
with self.assertRaises(IndexError):
row[2**1000]
def test_SqliteRowIndexUnicode(self):
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as \xff").fetchone()
self.assertEqual(row["\xff"], 1)
with self.assertRaises(IndexError):
row['\u0178']
with self.assertRaises(IndexError):
row['\xdf']
def test_SqliteRowSlice(self):
# A sqlite.Row can be sliced like a list.
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1, 2, 3, 4").fetchone()
self.assertEqual(row[0:0], ())
self.assertEqual(row[0:1], (1,))
self.assertEqual(row[1:3], (2, 3))
self.assertEqual(row[3:1], ())
# Explicit bounds are optional.
self.assertEqual(row[1:], (2, 3, 4))
self.assertEqual(row[:3], (1, 2, 3))
# Slices can use negative indices.
self.assertEqual(row[-2:-1], (3,))
self.assertEqual(row[-2:], (3, 4))
# Slicing supports steps.
self.assertEqual(row[0:4:2], (1, 3))
self.assertEqual(row[3:0:-2], (4, 2))
def test_SqliteRowIter(self):
"""Checks if the row object is iterable"""
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone()
for col in row:
pass
def test_SqliteRowAsTuple(self):
"""Checks if the row object can be converted to a tuple"""
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone()
t = tuple(row)
self.assertEqual(t, (row['a'], row['b']))
def test_SqliteRowAsDict(self):
"""Checks if the row object can be correctly converted to a dictionary"""
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone()
d = dict(row)
self.assertEqual(d["a"], row["a"])
self.assertEqual(d["b"], row["b"])
def test_SqliteRowHashCmp(self):
"""Checks if the row object compares and hashes correctly"""
self.con.row_factory = sqlite.Row
row_1 = self.con.execute("select 1 as a, 2 as b").fetchone()
row_2 = self.con.execute("select 1 as a, 2 as b").fetchone()
row_3 = self.con.execute("select 1 as a, 3 as b").fetchone()
row_4 = self.con.execute("select 1 as b, 2 as a").fetchone()
row_5 = self.con.execute("select 2 as b, 1 as a").fetchone()
self.assertTrue(row_1 == row_1)
self.assertTrue(row_1 == row_2)
self.assertFalse(row_1 == row_3)
self.assertFalse(row_1 == row_4)
self.assertFalse(row_1 == row_5)
self.assertFalse(row_1 == object())
self.assertFalse(row_1 != row_1)
self.assertFalse(row_1 != row_2)
self.assertTrue(row_1 != row_3)
self.assertTrue(row_1 != row_4)
self.assertTrue(row_1 != row_5)
self.assertTrue(row_1 != object())
with self.assertRaises(TypeError):
row_1 > row_2
with self.assertRaises(TypeError):
row_1 < row_2
with self.assertRaises(TypeError):
row_1 >= row_2
with self.assertRaises(TypeError):
row_1 <= row_2
self.assertEqual(hash(row_1), hash(row_2))
def test_SqliteRowAsSequence(self):
""" Checks if the row object can act like a sequence """
self.con.row_factory = sqlite.Row
row = self.con.execute("select 1 as a, 2 as b").fetchone()
as_tuple = tuple(row)
self.assertEqual(list(reversed(row)), list(reversed(as_tuple)))
self.assertIsInstance(row, Sequence)
def test_FakeCursorClass(self):
# Issue #24257: Incorrect use of PyObject_IsInstance() caused
# segmentation fault.
# Issue #27861: Also applies for cursor factory.
class FakeCursor(str):
__class__ = sqlite.Cursor
self.con.row_factory = sqlite.Row
self.assertRaises(TypeError, self.con.cursor, FakeCursor)
self.assertRaises(TypeError, sqlite.Row, FakeCursor(), ())
def tearDown(self):
self.con.close()
class TextFactoryTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_Unicode(self):
austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
def test_String(self):
self.con.text_factory = bytes
austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), bytes, "type of row[0] must be bytes")
self.assertEqual(row[0], austria.encode("utf-8"), "column must equal original data in UTF-8")
def test_Custom(self):
self.con.text_factory = lambda x: str(x, "utf-8", "ignore")
austria = "Österreich"
row = self.con.execute("select ?", (austria,)).fetchone()
self.assertEqual(type(row[0]), str, "type of row[0] must be unicode")
self.assertTrue(row[0].endswith("reich"), "column must contain original data")
def test_OptimizedUnicode(self):
# In py3k, str objects are always returned when text_factory
# is OptimizedUnicode
self.con.text_factory = sqlite.OptimizedUnicode
austria = "Österreich"
germany = "Deutchland"
a_row = self.con.execute("select ?", (austria,)).fetchone()
d_row = self.con.execute("select ?", (germany,)).fetchone()
self.assertEqual(type(a_row[0]), str, "type of non-ASCII row must be str")
self.assertEqual(type(d_row[0]), str, "type of ASCII-only row must be str")
def tearDown(self):
self.con.close()
class TextFactoryTestsWithEmbeddedZeroBytes(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.execute("create table test (value text)")
self.con.execute("insert into test (value) values (?)", ("a\x00b",))
def test_String(self):
# text_factory defaults to str
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), str)
self.assertEqual(row[0], "a\x00b")
def test_Bytes(self):
self.con.text_factory = bytes
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def test_Bytearray(self):
self.con.text_factory = bytearray
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytearray)
self.assertEqual(row[0], b"a\x00b")
def test_Custom(self):
# A custom factory should receive a bytes argument
self.con.text_factory = lambda x: x
row = self.con.execute("select value from test").fetchone()
self.assertIs(type(row[0]), bytes)
self.assertEqual(row[0], b"a\x00b")
def tearDown(self):
self.con.close()
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
ConnectionFactoryTests,
CursorFactoryTests,
RowFactoryTestsBackwardsCompat,
RowFactoryTests,
TextFactoryTests,
TextFactoryTestsWithEmbeddedZeroBytes)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

330
tests/hooks.py Normal file
View file

@ -0,0 +1,330 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
#
# Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import os
import unittest
from sqlcipher3 import dbapi2 as sqlite
class CollationTests(unittest.TestCase):
def test_CreateCollationNotString(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError):
con.create_collation(None, lambda x, y: (x > y) - (x < y))
def test_CreateCollationNotCallable(self):
con = sqlite.connect(":memory:")
with self.assertRaises(TypeError) as cm:
con.create_collation("X", 42)
self.assertEqual(str(cm.exception), 'parameter must be callable')
def test_CreateCollationNotAscii(self):
con = sqlite.connect(":memory:")
con.create_collation("collä", lambda x, y: (x > y) - (x < y))
def test_CreateCollationBadUpper(self):
class BadUpperStr(str):
def upper(self):
return None
con = sqlite.connect(":memory:")
mycoll = lambda x, y: -((x > y) - (x < y))
con.create_collation(BadUpperStr("mycoll"), mycoll)
result = con.execute("""
select x from (
select 'a' as x
union
select 'b' as x
) order by x collate mycoll
""").fetchall()
self.assertEqual(result[0][0], 'b')
self.assertEqual(result[1][0], 'a')
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1),
'old SQLite versions crash on this test')
def test_CollationIsUsed(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y))
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
union
select 'b' as x
union
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg='the expected order was not returned')
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
result = con.execute(sql).fetchall()
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
def test_CollationReturnsLargeInteger(self):
def mycoll(x, y):
# reverse order
return -((x > y) - (x < y)) * 2**32
con = sqlite.connect(":memory:")
con.create_collation("mycoll", mycoll)
sql = """
select x from (
select 'a' as x
union
select 'b' as x
union
select 'c' as x
) order by x collate mycoll
"""
result = con.execute(sql).fetchall()
self.assertEqual(result, [('c',), ('b',), ('a',)],
msg="the expected order was not returned")
def test_CollationRegisterTwice(self):
"""
Register two different collation functions under the same name.
Verify that the last one is actually used.
"""
con = sqlite.connect(":memory:")
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
result = con.execute("""
select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
""").fetchall()
self.assertEqual(result[0][0], 'b')
self.assertEqual(result[1][0], 'a')
def test_DeregisterCollation(self):
"""
Register a collation, then deregister it. Make sure an error is raised if we try
to use it.
"""
con = sqlite.connect(":memory:")
con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
con.create_collation("mycoll", None)
with self.assertRaises(sqlite.OperationalError) as cm:
con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
class ProgressTests(unittest.TestCase):
def test_ProgressHandlerUsed(self):
"""
Test that the progress handler is invoked once it is set.
"""
con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
con.set_progress_handler(progress, 1)
con.execute("""
create table foo(a, b)
""")
self.assertTrue(progress_calls)
def test_OpcodeCount(self):
"""
Test that the opcode argument is respected.
"""
con = sqlite.connect(":memory:")
progress_calls = []
def progress():
progress_calls.append(None)
return 0
con.set_progress_handler(progress, 1)
curs = con.cursor()
curs.execute("""
create table foo (a, b)
""")
first_count = len(progress_calls)
progress_calls = []
con.set_progress_handler(progress, 2)
curs.execute("""
create table bar (a, b)
""")
second_count = len(progress_calls)
self.assertGreaterEqual(first_count, second_count)
def test_CancelOperation(self):
"""
Test that returning a non-zero value stops the operation in progress.
"""
con = sqlite.connect(":memory:")
def progress():
return 1
con.set_progress_handler(progress, 1)
curs = con.cursor()
self.assertRaises(
sqlite.OperationalError,
curs.execute,
"create table bar (a, b)")
def test_ClearHandler(self):
"""
Test that setting the progress handler to None clears the previously set handler.
"""
con = sqlite.connect(":memory:")
action = 0
def progress():
nonlocal action
action = 1
return 0
con.set_progress_handler(progress, 1)
con.set_progress_handler(None, 1)
con.execute("select 1 union select 2 union select 3").fetchall()
self.assertEqual(action, 0, "progress handler was not cleared")
class TraceCallbackTests(unittest.TestCase):
def test_TraceCallbackUsed(self):
"""
Test that the trace callback is invoked once it is set.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
def test_TraceCallbackError(self):
"""
Test behavior when exception raised in trace callback.
"""
con = sqlite.connect(":memory:")
def trace(statement):
raise Exception('uh-oh')
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
con.set_trace_callback(None)
def test_ClearTraceCallback(self):
"""
Test that setting the trace callback to None clears the previously set callback.
"""
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.set_trace_callback(None)
con.execute("create table foo(a, b)")
self.assertFalse(traced_statements, "trace callback was not cleared")
def test_UnicodeContent(self):
"""
Test that the statement can contain unicode literals.
"""
unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
con = sqlite.connect(":memory:")
traced_statements = []
def trace(statement):
traced_statements.append(statement)
con.set_trace_callback(trace)
con.execute("create table foo(x)")
# Can't execute bound parameters as their values don't appear
# in traced statements before SQLite 3.6.21
# (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html)
con.execute('insert into foo(x) values ("%s")' % unicode_value)
con.commit()
self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
"Unicode data %s garbled in trace callback: %s"
% (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
def test_TraceCallbackContent(self):
# set_trace_callback() shouldn't produce duplicate content (bpo-26187)
traced_statements = []
def trace(statement):
traced_statements.append(statement)
queries = ["create table foo(x)",
"insert into foo(x) values(1)"]
self.addCleanup(os.unlink, 'tracecb.db')
con1 = sqlite.connect('tracecb.db', isolation_level=None)
con2 = sqlite.connect('tracecb.db')
con1.set_trace_callback(trace)
cur = con1.cursor()
cur.execute(queries[0])
con2.execute("create table bar(x)")
cur.execute(queries[1])
self.assertEqual(traced_statements, queries)
class TestBusyHandlerTimeout(unittest.TestCase):
def test_busy_handler(self):
accum = []
def custom_handler(n):
accum.append(n)
return 0 if n == 3 else 1
self.addCleanup(os.unlink, 'busy.db')
conn1 = sqlite.connect('busy.db')
conn2 = sqlite.connect('busy.db')
conn2.set_busy_handler(custom_handler)
conn1.execute('begin exclusive')
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [0, 1, 2, 3])
accum.clear()
conn2.set_busy_handler(None)
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [])
conn2.set_busy_handler(custom_handler)
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [0, 1, 2, 3])
accum.clear()
conn2.set_busy_timeout(0.01) # Clears busy handler.
with self.assertRaises(sqlite.OperationalError):
conn2.execute('create table test(id)')
self.assertEqual(accum, [])
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
CollationTests,
ProgressTests,
TraceCallbackTests,
TestBusyHandlerTimeout)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

413
tests/regression.py Normal file
View file

@ -0,0 +1,413 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/regression.py: pysqlite regression tests
#
# Copyright (C) 2006-2010 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import datetime
import functools
import unittest
from sqlcipher3 import dbapi2 as sqlite
import weakref
#from test import support
class RegressionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def tearDown(self):
self.con.close()
def test_PragmaUserVersion(self):
# This used to crash pysqlite because this pragma command returns NULL for the column name
cur = self.con.cursor()
cur.execute("pragma user_version")
def test_PragmaSchemaVersion(self):
# This still crashed pysqlite <= 2.2.1
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
try:
cur = self.con.cursor()
cur.execute("pragma schema_version")
finally:
cur.close()
con.close()
def test_StatementReset(self):
# pysqlite 2.1.0 to 2.2.0 have the problem that not all statements are
# reset before a rollback, but only those that are still in the
# statement cache. The others are not accessible from the connection object.
con = sqlite.connect(":memory:", cached_statements=5)
cursors = [con.cursor() for x in range(5)]
cursors[0].execute("create table test(x)")
for i in range(10):
cursors[0].executemany("insert into test(x) values (?)", [(x,) for x in range(10)])
for i in range(5):
cursors[i].execute(" " * i + "select x from test")
con.rollback()
def test_ColumnNameWithSpaces(self):
cur = self.con.cursor()
cur.execute('select 1 as "foo bar [datetime]"')
self.assertEqual(cur.description[0][0], "foo bar [datetime]")
cur.execute('select 1 as "foo baz"')
self.assertEqual(cur.description[0][0], "foo baz")
def test_StatementFinalizationOnCloseDb(self):
# pysqlite versions <= 2.3.3 only finalized statements in the statement
# cache when closing the database. statements that were still
# referenced in cursors weren't closed and could provoke "
# "OperationalError: Unable to close due to unfinalised statements".
con = sqlite.connect(":memory:")
cursors = []
# default statement cache size is 100
for i in range(105):
cur = con.cursor()
cursors.append(cur)
cur.execute("select 1 x union select " + str(i))
con.close()
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2), 'needs sqlite 3.2.2 or newer')
def test_OnConflictRollback(self):
con = sqlite.connect(":memory:")
con.execute("create table foo(x, unique(x) on conflict rollback)")
con.execute("insert into foo(x) values (1)")
try:
con.execute("insert into foo(x) values (1)")
except sqlite.DatabaseError:
pass
con.execute("insert into foo(x) values (2)")
try:
con.commit()
except sqlite.OperationalError:
self.fail("pysqlite knew nothing about the implicit ROLLBACK")
def test_WorkaroundForBuggySqliteTransferBindings(self):
"""
pysqlite would crash with older SQLite versions unless
a workaround is implemented.
"""
self.con.execute("create table foo(bar)")
self.con.execute("drop table foo")
self.con.execute("create table foo(bar)")
def test_EmptyStatement(self):
"""
pysqlite used to segfault with SQLite versions 3.5.x. These return NULL
for "no-operation" statements
"""
self.con.execute("")
def test_TypeMapUsage(self):
"""
pysqlite until 2.4.1 did not rebuild the row_cast_map when recompiling
a statement. This test exhibits the problem.
"""
SELECT = "select * from foo"
con = sqlite.connect(":memory:",detect_types=sqlite.PARSE_DECLTYPES)
con.execute("create table foo(bar timestamp)")
con.execute("insert into foo(bar) values (?)", (datetime.datetime.now(),))
con.execute(SELECT)
con.execute("drop table foo")
con.execute("create table foo(bar integer)")
con.execute("insert into foo(bar) values (5)")
con.execute(SELECT)
def test_ErrorMsgDecodeError(self):
# When porting the module to Python 3.0, the error message about
# decoding errors disappeared. This verifies they're back again.
with self.assertRaises(sqlite.OperationalError) as cm:
self.con.execute("select 'xxx' || ? || 'yyy' colname",
(bytes(bytearray([250])),)).fetchone()
msg = "Could not decode to UTF-8 column 'colname' with text 'xxx"
self.assertIn(msg, str(cm.exception))
def test_RegisterAdapter(self):
"""
See issue 3312.
"""
self.assertRaises(TypeError, sqlite.register_adapter, {}, None)
def test_SetIsolationLevel(self):
# See issue 27881.
class CustomStr(str):
def upper(self):
return None
def __del__(self):
con.isolation_level = ""
con = sqlite.connect(":memory:")
con.isolation_level = None
for level in "", "DEFERRED", "IMMEDIATE", "EXCLUSIVE":
with self.subTest(level=level):
con.isolation_level = level
con.isolation_level = level.lower()
con.isolation_level = level.capitalize()
con.isolation_level = CustomStr(level)
# setting isolation_level failure should not alter previous state
con.isolation_level = None
con.isolation_level = "DEFERRED"
pairs = [
(1, TypeError), (b'', TypeError), ("abc", ValueError),
("\xe9", ValueError),
]
for value, exc in pairs:
with self.subTest(level=value):
with self.assertRaises(exc):
con.isolation_level = value
self.assertEqual(con.isolation_level, "DEFERRED")
def test_CursorConstructorCallCheck(self):
"""
Verifies that cursor methods check whether base class __init__ was
called.
"""
class Cursor(sqlite.Cursor):
def __init__(self, con):
pass
con = sqlite.connect(":memory:")
cur = Cursor(con)
with self.assertRaises(sqlite.ProgrammingError):
cur.execute("select 4+5").fetchall()
with self.assertRaisesRegex(sqlite.ProgrammingError,
r'^Base Cursor\.__init__ not called\.$'):
cur.close()
def test_StrSubclass(self):
"""
The Python 3.0 port of the module didn't cope with values of subclasses of str.
"""
class MyStr(str): pass
self.con.execute("select ?", (MyStr("abc"),))
def test_ConnectionConstructorCallCheck(self):
"""
Verifies that connection methods check whether base class __init__ was
called.
"""
class Connection(sqlite.Connection):
def __init__(self, name):
pass
con = Connection(":memory:")
with self.assertRaises(sqlite.ProgrammingError):
cur = con.cursor()
def test_CursorRegistration(self):
"""
Verifies that subclassed cursor classes are correctly registered with
the connection object, too. (fetch-across-rollback problem)
"""
class Connection(sqlite.Connection):
def cursor(self):
return Cursor(self)
class Cursor(sqlite.Cursor):
def __init__(self, con):
sqlite.Cursor.__init__(self, con)
con = Connection(":memory:")
cur = con.cursor()
cur.execute("create table foo(x)")
cur.executemany("insert into foo(x) values (?)", [(3,), (4,), (5,)])
cur.execute("select x from foo")
con.rollback()
with self.assertRaises(sqlite.InterfaceError):
cur.fetchall()
def test_AutoCommit(self):
"""
Verifies that creating a connection in autocommit mode works.
2.5.3 introduced a regression so that these could no longer
be created.
"""
con = sqlite.connect(":memory:", isolation_level=None)
def test_PragmaAutocommit(self):
"""
Verifies that running a PRAGMA statement that does an autocommit does
work. This did not work in 2.5.3/2.5.4.
"""
cur = self.con.cursor()
cur.execute("create table foo(bar)")
cur.execute("insert into foo(bar) values (5)")
cur.execute("pragma page_size")
row = cur.fetchone()
def test_ConnectionCall(self):
"""
Call a connection with a non-string SQL request: check error handling
of the statement constructor.
"""
self.assertRaises(TypeError, self.con, 1)
def test_Collation(self):
def collation_cb(a, b):
return 1
self.assertRaises(UnicodeEncodeError, self.con.create_collation,
# Lone surrogate cannot be encoded to the default encoding (utf8)
"\uDC80", collation_cb)
def test_RecursiveCursorUse(self):
"""
http://bugs.python.org/issue10811
Recursively using a cursor, such as when reusing it from a generator led to segfaults.
Now we catch recursive cursor usage and raise a ProgrammingError.
"""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur.execute("create table a (bar)")
cur.execute("create table b (baz)")
def foo():
cur.execute("insert into a (bar) values (?)", (1,))
yield 1
with self.assertRaises(sqlite.ProgrammingError):
cur.executemany("insert into b (baz) values (?)",
((i,) for i in foo()))
def test_ConvertTimestampMicrosecondPadding(self):
"""
http://bugs.python.org/issue14720
The microsecond parsing of convert_timestamp() should pad with zeros,
since the microsecond string "456" actually represents "456000".
"""
con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
cur = con.cursor()
cur.execute("CREATE TABLE t (x TIMESTAMP)")
# Microseconds should be 456000
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.456')")
# Microseconds should be truncated to 123456
cur.execute("INSERT INTO t (x) VALUES ('2012-04-04 15:06:00.123456789')")
cur.execute("SELECT * FROM t")
values = [x[0] for x in cur.fetchall()]
self.assertEqual(values, [
datetime.datetime(2012, 4, 4, 15, 6, 0, 456000),
datetime.datetime(2012, 4, 4, 15, 6, 0, 123456),
])
def test_InvalidIsolationLevelType(self):
# isolation level is a string, not an integer
self.assertRaises(TypeError,
sqlite.connect, ":memory:", isolation_level=123)
def test_NullCharacter(self):
# Issue #21147
con = sqlite.connect(":memory:")
self.assertRaises(ValueError, con, "\0select 1")
self.assertRaises(ValueError, con, "select 1\0")
cur = con.cursor()
self.assertRaises(ValueError, cur.execute, " \0select 2")
self.assertRaises(ValueError, cur.execute, "select 2\0")
def test_CommitCursorReset(self):
"""
Connection.commit() did reset cursors, which made sqlite3
to return rows multiple times when fetched from cursors
after commit. See issues 10513 and 23129 for details.
"""
con = sqlite.connect(":memory:")
con.executescript("""
create table t(c);
create table t2(c);
insert into t values(0);
insert into t values(1);
insert into t values(2);
""")
self.assertEqual(con.isolation_level, "")
counter = 0
for i, row in enumerate(con.execute("select c from t")):
with self.subTest(i=i, row=row):
con.execute("insert into t2(c) values (?)", (i,))
con.commit()
if counter == 0:
self.assertEqual(row[0], 0)
elif counter == 1:
self.assertEqual(row[0], 1)
elif counter == 2:
self.assertEqual(row[0], 2)
counter += 1
self.assertEqual(counter, 3, "should have returned exactly three rows")
def test_Bpo31770(self):
"""
The interpreter shouldn't crash in case Cursor.__init__() is called
more than once.
"""
def callback(*args):
pass
con = sqlite.connect(":memory:")
cur = sqlite.Cursor(con)
ref = weakref.ref(cur, callback)
cur.__init__(con)
del cur
# The interpreter shouldn't crash when ref is collected.
del ref
#support.gc_collect()
def test_DelIsolation_levelSegfault(self):
with self.assertRaises(AttributeError):
del self.con.isolation_level
def test_Bpo37347(self):
class Printer:
def log(self, *args):
return sqlite.SQLITE_OK
for method in [self.con.set_trace_callback,
functools.partial(self.con.set_progress_handler, n=1),
self.con.set_authorizer]:
printer_instance = Printer()
method(printer_instance.log)
method(printer_instance.log) # Register twice, incref twice.
self.con.execute('select 1') # Triggers segfault.
method(None)
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
RegressionTests,)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

301
tests/transactions.py Normal file
View file

@ -0,0 +1,301 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/transactions.py: tests transactions
#
# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import glob, os, unittest
from sqlcipher3 import dbapi2 as sqlite
def get_db_path():
return "sqlite_testdb"
class TransactionTests(unittest.TestCase):
def setUp(self):
try:
os.remove(get_db_path())
except OSError:
pass
self.con1 = sqlite.connect(get_db_path(), timeout=0.1)
self.cur1 = self.con1.cursor()
self.con2 = sqlite.connect(get_db_path(), timeout=0.1)
self.cur2 = self.con2.cursor()
def tearDown(self):
self.cur1.close()
self.con1.close()
self.cur2.close()
self.con2.close()
try:
os.unlink(get_db_path())
except OSError:
pass
def test_DMLDoesNotAutoCommitBefore(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.cur1.execute("create table test2(j)")
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 0)
def test_InsertStartsTransaction(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 0)
def test_UpdateStartsTransaction(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.con1.commit()
self.cur1.execute("update test set i=6")
self.cur2.execute("select i from test")
res = self.cur2.fetchone()[0]
self.assertEqual(res, 5)
def test_DeleteStartsTransaction(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.con1.commit()
self.cur1.execute("delete from test")
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 1)
def test_ReplaceStartsTransaction(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.con1.commit()
self.cur1.execute("replace into test(i) values (6)")
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 1)
self.assertEqual(res[0][0], 5)
def test_ToggleAutoCommit(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
self.con1.isolation_level = None
self.assertEqual(self.con1.isolation_level, None)
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 1)
self.con1.isolation_level = "DEFERRED"
self.assertEqual(self.con1.isolation_level , "DEFERRED")
self.cur1.execute("insert into test(i) values (5)")
self.cur2.execute("select i from test")
res = self.cur2.fetchall()
self.assertEqual(len(res), 1)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
'test hangs on sqlite versions older than 3.2.2')
def test_RaiseTimeout(self):
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
with self.assertRaises(sqlite.OperationalError):
self.cur2.execute("insert into test(i) values (5)")
@unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 2),
'test hangs on sqlite versions older than 3.2.2')
def test_Locking(self):
"""
This tests the improved concurrency with pysqlite 2.3.4. You needed
to roll back con2 before you could commit con1.
"""
self.cur1.execute("create table test(i)")
self.cur1.execute("insert into test(i) values (5)")
with self.assertRaises(sqlite.OperationalError):
self.cur2.execute("insert into test(i) values (5)")
# NO self.con2.rollback() HERE!!!
self.con1.commit()
def test_RollbackCursorConsistency(self):
"""
Checks if cursors on the connection are set into a "reset" state
when a rollback is done on the connection.
"""
con = sqlite.connect(":memory:")
cur = con.cursor()
cur.execute("create table test(x)")
cur.execute("insert into test(x) values (5)")
cur.execute("select 1 union select 2 union select 3")
con.rollback()
with self.assertRaises(sqlite.InterfaceError):
cur.fetchall()
class SpecialCommandTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
def test_DropTable(self):
self.cur.execute("create table test(i)")
self.cur.execute("insert into test(i) values (5)")
self.cur.execute("drop table test")
def test_Pragma(self):
self.cur.execute("create table test(i)")
self.cur.execute("insert into test(i) values (5)")
self.cur.execute("pragma count_changes=1")
def tearDown(self):
self.cur.close()
self.con.close()
class TransactionalDDL(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
def test_DdlDoesNotAutostartTransaction(self):
# For backwards compatibility reasons, DDL statements should not
# implicitly start a transaction.
self.con.execute("create table test(i)")
self.con.rollback()
result = self.con.execute("select * from test").fetchall()
self.assertEqual(result, [])
self.con.execute("alter table test rename to test2")
self.con.rollback()
result = self.con.execute("select * from test2").fetchall()
self.assertEqual(result, [])
def test_ImmediateTransactionalDDL(self):
# You can achieve transactional DDL by issuing a BEGIN
# statement manually.
self.con.execute("begin immediate")
self.con.execute("create table test(i)")
self.con.rollback()
with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test")
def test_TransactionalDDL(self):
# You can achieve transactional DDL by issuing a BEGIN
# statement manually.
self.con.execute("begin")
self.con.execute("create table test(i)")
self.con.rollback()
with self.assertRaises(sqlite.OperationalError):
self.con.execute("select * from test")
def tearDown(self):
self.con.close()
class DMLStatementDetectionTestCase(unittest.TestCase):
"""
https://bugs.python.org/issue36859
Use sqlite3_stmt_readonly to determine if the statement is DML or not.
"""
def setUp(self):
for f in glob.glob(get_db_path() + '*'):
try:
os.unlink(f)
except OSError:
pass
tearDown = setUp
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3),
'needs sqlite 3.8.3 or newer')
def test_dml_detection_cte(self):
conn = sqlite.connect(':memory:')
conn.execute('create table kv ("key" text, "val" integer)')
self.assertFalse(conn.in_transaction)
conn.execute('insert into kv (key, val) values (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
self.assertTrue(conn.in_transaction)
conn.commit()
self.assertFalse(conn.in_transaction)
rc = conn.execute('update kv set val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)
conn.commit()
self.assertFalse(conn.in_transaction)
rc = conn.execute('with c(k, v) as (select key, val + ? from kv) '
'update kv set val=(select v from c where k=kv.key)',
(100,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)
curs = conn.execute('select key, val from kv order by key')
self.assertEqual(curs.fetchall(), [('k1', 111), ('k2', 112)])
def test_dml_detection_sql_comment(self):
conn = sqlite.connect(':memory:')
conn.execute('create table kv ("key" text, "val" integer)')
conn.execute('insert into kv (key, val) values (?, ?), (?, ?)',
('k1', 1, 'k2', 2))
conn.commit()
self.assertFalse(conn.in_transaction)
rc = conn.execute('-- a comment\nupdate kv set val=val + ?', (10,))
self.assertEqual(rc.rowcount, 2)
self.assertTrue(conn.in_transaction)
curs = conn.execute('select key, val from kv order by key')
self.assertEqual(curs.fetchall(), [('k1', 11), ('k2', 12)])
conn.rollback()
def test_dml_detection_begin_exclusive(self):
conn = sqlite.connect(':memory:')
conn.execute('begin exclusive')
self.assertTrue(conn.in_transaction)
conn.execute('rollback')
self.assertFalse(conn.in_transaction)
def test_dml_detection_vacuum(self):
conn = sqlite.connect(':memory:')
conn.execute('vacuum')
self.assertFalse(conn.in_transaction)
def test_dml_detection_pragma(self):
conn = sqlite.connect(get_db_path())
conn.execute('pragma journal_mode=\'wal\'')
jmode, = conn.execute('pragma journal_mode').fetchone()
self.assertEqual(jmode, 'wal')
self.assertFalse(conn.in_transaction)
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
TransactionTests,
SpecialCommandTests,
TransactionalDDL,
DMLStatementDetectionTestCase)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

445
tests/ttypes.py Normal file
View file

@ -0,0 +1,445 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/types.py: tests for type conversion and detection
#
# Copyright (C) 2005 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import datetime
import unittest
from sqlcipher3 import dbapi2 as sqlite
try:
import zlib
except ImportError:
zlib = None
class SqliteTypeTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
self.cur.execute("create table test(i integer, s varchar, f number, b blob)")
def tearDown(self):
self.cur.close()
self.con.close()
def test_String(self):
self.cur.execute("insert into test(s) values (?)", ("Österreich",))
self.cur.execute("select s from test")
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")
def test_SmallInt(self):
self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], 42)
def test_LargeInt(self):
num = 2**40
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], num)
def test_Float(self):
val = 3.14
self.cur.execute("insert into test(f) values (?)", (val,))
self.cur.execute("select f from test")
row = self.cur.fetchone()
self.assertEqual(row[0], val)
def test_Blob(self):
sample = b"Guglhupf"
val = memoryview(sample)
self.cur.execute("insert into test(b) values (?)", (val,))
self.cur.execute("select b from test")
row = self.cur.fetchone()
self.assertEqual(row[0], sample)
def test_UnicodeExecute(self):
self.cur.execute("select 'Österreich'")
row = self.cur.fetchone()
self.assertEqual(row[0], "Österreich")
class DeclTypesTests(unittest.TestCase):
class Foo:
def __init__(self, _val):
if isinstance(_val, bytes):
# sqlite3 always calls __init__ with a bytes created from a
# UTF-8 string when __conform__ was used to store the object.
_val = _val.decode('utf-8')
self.val = _val
def __eq__(self, other):
if not isinstance(other, DeclTypesTests.Foo):
return NotImplemented
return self.val == other.val
def __conform__(self, protocol):
if protocol is sqlite.PrepareProtocol:
return self.val
else:
return None
def __str__(self):
return "<%s>" % self.val
class BadConform:
def __init__(self, exc):
self.exc = exc
def __conform__(self, protocol):
raise self.exc
def setUp(self):
self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
self.cur = self.con.cursor()
self.cur.execute("create table test(i int, s str, f float, b bool, u unicode, foo foo, bin blob, n1 number, n2 number(5), bad bad)")
# override float, make them always return the same number
sqlite.converters["FLOAT"] = lambda x: 47.2
# and implement two custom ones
sqlite.converters["BOOL"] = lambda x: bool(int(x))
sqlite.converters["FOO"] = DeclTypesTests.Foo
sqlite.converters["BAD"] = DeclTypesTests.BadConform
sqlite.converters["WRONG"] = lambda x: "WRONG"
sqlite.converters["NUMBER"] = float
def tearDown(self):
del sqlite.converters["FLOAT"]
del sqlite.converters["BOOL"]
del sqlite.converters["FOO"]
del sqlite.converters["BAD"]
del sqlite.converters["WRONG"]
del sqlite.converters["NUMBER"]
self.cur.close()
self.con.close()
def test_String(self):
# default
self.cur.execute("insert into test(s) values (?)", ("foo",))
self.cur.execute('select s as "s [WRONG]" from test')
row = self.cur.fetchone()
self.assertEqual(row[0], "foo")
def test_SmallInt(self):
# default
self.cur.execute("insert into test(i) values (?)", (42,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], 42)
def test_LargeInt(self):
# default
num = 2**40
self.cur.execute("insert into test(i) values (?)", (num,))
self.cur.execute("select i from test")
row = self.cur.fetchone()
self.assertEqual(row[0], num)
def test_Float(self):
# custom
val = 3.14
self.cur.execute("insert into test(f) values (?)", (val,))
self.cur.execute("select f from test")
row = self.cur.fetchone()
self.assertEqual(row[0], 47.2)
def test_Bool(self):
# custom
self.cur.execute("insert into test(b) values (?)", (False,))
self.cur.execute("select b from test")
row = self.cur.fetchone()
self.assertIs(row[0], False)
self.cur.execute("delete from test")
self.cur.execute("insert into test(b) values (?)", (True,))
self.cur.execute("select b from test")
row = self.cur.fetchone()
self.assertIs(row[0], True)
def test_Unicode(self):
# default
val = "\xd6sterreich"
self.cur.execute("insert into test(u) values (?)", (val,))
self.cur.execute("select u from test")
row = self.cur.fetchone()
self.assertEqual(row[0], val)
def test_Foo(self):
val = DeclTypesTests.Foo("bla")
self.cur.execute("insert into test(foo) values (?)", (val,))
self.cur.execute("select foo from test")
row = self.cur.fetchone()
self.assertEqual(row[0], val)
def test_ErrorInConform(self):
val = DeclTypesTests.BadConform(TypeError)
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(bad) values (?)", (val,))
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(bad) values (:val)", {"val": val})
val = DeclTypesTests.BadConform(KeyboardInterrupt)
with self.assertRaises(KeyboardInterrupt):
self.cur.execute("insert into test(bad) values (?)", (val,))
with self.assertRaises(KeyboardInterrupt):
self.cur.execute("insert into test(bad) values (:val)", {"val": val})
def test_UnsupportedSeq(self):
class Bar: pass
val = Bar()
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(f) values (?)", (val,))
def test_UnsupportedDict(self):
class Bar: pass
val = Bar()
with self.assertRaises(sqlite.InterfaceError):
self.cur.execute("insert into test(f) values (:val)", {"val": val})
def test_Blob(self):
# default
sample = b"Guglhupf"
val = memoryview(sample)
self.cur.execute("insert into test(bin) values (?)", (val,))
self.cur.execute("select bin from test")
row = self.cur.fetchone()
self.assertEqual(row[0], sample)
def test_Number1(self):
self.cur.execute("insert into test(n1) values (5)")
value = self.cur.execute("select n1 from test").fetchone()[0]
# if the converter is not used, it's an int instead of a float
self.assertEqual(type(value), float)
def test_Number2(self):
"""Checks whether converter names are cut off at '(' characters"""
self.cur.execute("insert into test(n2) values (5)")
value = self.cur.execute("select n2 from test").fetchone()[0]
# if the converter is not used, it's an int instead of a float
self.assertEqual(type(value), float)
class ColNamesTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
self.cur = self.con.cursor()
self.cur.execute("create table test(x foo)")
sqlite.converters["FOO"] = lambda x: "[%s]" % x.decode("ascii")
sqlite.converters["BAR"] = lambda x: "<%s>" % x.decode("ascii")
sqlite.converters["EXC"] = lambda x: 5/0
sqlite.converters["B1B1"] = lambda x: "MARKER"
def tearDown(self):
del sqlite.converters["FOO"]
del sqlite.converters["BAR"]
del sqlite.converters["EXC"]
del sqlite.converters["B1B1"]
self.cur.close()
self.con.close()
def test_DeclTypeNotUsed(self):
"""
Assures that the declared type is not used when PARSE_DECLTYPES
is not set.
"""
self.cur.execute("insert into test(x) values (?)", ("xxx",))
self.cur.execute("select x from test")
val = self.cur.fetchone()[0]
self.assertEqual(val, "xxx")
def test_None(self):
self.cur.execute("insert into test(x) values (?)", (None,))
self.cur.execute("select x from test")
val = self.cur.fetchone()[0]
self.assertEqual(val, None)
def test_ColName(self):
self.cur.execute("insert into test(x) values (?)", ("xxx",))
self.cur.execute('select x as "x y [bar]" from test')
val = self.cur.fetchone()[0]
self.assertEqual(val, "<xxx>")
# Check if the stripping of colnames works. Everything after the first
# whitespace should be stripped.
self.assertEqual(self.cur.description[0][0], "x y")
def test_CaseInConverterName(self):
self.cur.execute("select 'other' as \"x [b1b1]\"")
val = self.cur.fetchone()[0]
self.assertEqual(val, "MARKER")
def test_CursorDescriptionNoRow(self):
"""
cursor.description should at least provide the column name(s), even if
no row returned.
"""
self.cur.execute("select * from test where 0 = 1")
self.assertEqual(self.cur.description[0][0], "x")
def test_CursorDescriptionInsert(self):
self.cur.execute("insert into test values (1)")
self.assertIsNone(self.cur.description)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "CTEs not supported")
class CommonTableExpressionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.cur = self.con.cursor()
self.cur.execute("create table test(x foo)")
def tearDown(self):
self.cur.close()
self.con.close()
def test_CursorDescriptionCTESimple(self):
self.cur.execute("with one as (select 1) select * from one")
self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "1")
def test_CursorDescriptionCTESMultipleColumns(self):
self.cur.execute("insert into test values(1)")
self.cur.execute("insert into test values(2)")
self.cur.execute("with testCTE as (select * from test) select * from testCTE")
self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "x")
def test_CursorDescriptionCTE(self):
self.cur.execute("insert into test values (1)")
self.cur.execute("with bar as (select * from test) select * from test where x = 1")
self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "x")
self.cur.execute("with bar as (select * from test) select * from test where x = 2")
self.assertIsNotNone(self.cur.description)
self.assertEqual(self.cur.description[0][0], "x")
class ObjectAdaptationTests(unittest.TestCase):
def cast(obj):
return float(obj)
cast = staticmethod(cast)
def setUp(self):
self.con = sqlite.connect(":memory:")
try:
del sqlite.adapters[int]
except:
pass
sqlite.register_adapter(int, ObjectAdaptationTests.cast)
self.cur = self.con.cursor()
def tearDown(self):
del sqlite.adapters[(int, sqlite.PrepareProtocol)]
self.cur.close()
self.con.close()
def test_CasterIsUsed(self):
self.cur.execute("select ?", (4,))
val = self.cur.fetchone()[0]
self.assertEqual(type(val), float)
@unittest.skipUnless(zlib, "requires zlib")
class BinaryConverterTests(unittest.TestCase):
def convert(s):
return zlib.decompress(s)
convert = staticmethod(convert)
def setUp(self):
self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_COLNAMES)
sqlite.register_converter("bin", BinaryConverterTests.convert)
def tearDown(self):
self.con.close()
def test_BinaryInputForConverter(self):
testdata = b"abcdefg" * 10
result = self.con.execute('select ? as "x [bin]"', (memoryview(zlib.compress(testdata)),)).fetchone()[0]
self.assertEqual(testdata, result)
class DateTimeTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:", detect_types=sqlite.PARSE_DECLTYPES)
self.cur = self.con.cursor()
self.cur.execute("create table test(d date, ts timestamp)")
def tearDown(self):
self.cur.close()
self.con.close()
def test_SqliteDate(self):
d = sqlite.Date(2004, 2, 14)
self.cur.execute("insert into test(d) values (?)", (d,))
self.cur.execute("select d from test")
d2 = self.cur.fetchone()[0]
self.assertEqual(d, d2)
def test_SqliteTimestamp(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0)
self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test")
ts2 = self.cur.fetchone()[0]
self.assertEqual(ts, ts2)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 1),
'the date functions are available on 3.1 or later')
def test_SqlTimestamp(self):
now = datetime.datetime.utcnow()
self.cur.execute("insert into test(ts) values (current_timestamp)")
self.cur.execute("select ts from test")
ts = self.cur.fetchone()[0]
self.assertEqual(type(ts), datetime.datetime)
self.assertEqual(ts.year, now.year)
def test_DateTimeSubSeconds(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 500000)
self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test")
ts2 = self.cur.fetchone()[0]
self.assertEqual(ts, ts2)
def test_DateTimeSubSecondsFloatingPoint(self):
ts = sqlite.Timestamp(2004, 2, 14, 7, 15, 0, 510241)
self.cur.execute("insert into test(ts) values (?)", (ts,))
self.cur.execute("select ts from test")
ts2 = self.cur.fetchone()[0]
self.assertEqual(ts, ts2)
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
SqliteTypeTests,
DeclTypesTests,
ColNamesTests,
ObjectAdaptationTests,
BinaryConverterTests,
DateTimeTests,
CommonTableExpressionTests)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()

535
tests/userfunctions.py Normal file
View file

@ -0,0 +1,535 @@
#-*- coding: iso-8859-1 -*-
# pysqlite2/test/userfunctions.py: tests for user-defined functions and
# aggregates.
#
# Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
#
# This file is part of pysqlite.
#
# This software is provided 'as-is', without any express or implied
# warranty. In no event will the authors be held liable for any damages
# arising from the use of this software.
#
# Permission is granted to anyone to use this software for any purpose,
# including commercial applications, and to alter it and redistribute it
# freely, subject to the following restrictions:
#
# 1. The origin of this software must not be misrepresented; you must not
# claim that you wrote the original software. If you use this software
# in a product, an acknowledgment in the product documentation would be
# appreciated but is not required.
# 2. Altered source versions must be plainly marked as such, and must not be
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.
import unittest
import unittest.mock
from sqlcipher3 import dbapi2 as sqlite
def func_returntext():
return "foo"
def func_returnunicode():
return "bar"
def func_returnint():
return 42
def func_returnfloat():
return 3.14
def func_returnnull():
return None
def func_returnblob():
return b"blob"
def func_returnlonglong():
return 1<<31
def func_raiseexception():
5/0
def func_isstring(v):
return type(v) is str
def func_isint(v):
return type(v) is int
def func_isfloat(v):
return type(v) is float
def func_isnone(v):
return type(v) is type(None)
def func_isblob(v):
return isinstance(v, (bytes, memoryview))
def func_islonglong(v):
return isinstance(v, int) and v >= 1<<31
def func(*args):
return len(args)
class AggrNoStep:
def __init__(self):
pass
def finalize(self):
return 1
class AggrNoFinalize:
def __init__(self):
pass
def step(self, x):
pass
class AggrExceptionInInit:
def __init__(self):
5/0
def step(self, x):
pass
def finalize(self):
pass
class AggrExceptionInStep:
def __init__(self):
pass
def step(self, x):
5/0
def finalize(self):
return 42
class AggrExceptionInFinalize:
def __init__(self):
pass
def step(self, x):
pass
def finalize(self):
5/0
class AggrCheckType:
def __init__(self):
self.val = None
def step(self, whichType, val):
theType = {"str": str, "int": int, "float": float, "None": type(None),
"blob": bytes}
self.val = int(theType[whichType] is type(val))
def finalize(self):
return self.val
class AggrCheckTypes:
def __init__(self):
self.val = 0
def step(self, whichType, *vals):
theType = {"str": str, "int": int, "float": float, "None": type(None),
"blob": bytes}
for val in vals:
self.val += int(theType[whichType] is type(val))
def finalize(self):
return self.val
class AggrSum:
def __init__(self):
self.val = 0.0
def step(self, val):
self.val += val
def finalize(self):
return self.val
class FunctionTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.create_function("returntext", 0, func_returntext)
self.con.create_function("returnunicode", 0, func_returnunicode)
self.con.create_function("returnint", 0, func_returnint)
self.con.create_function("returnfloat", 0, func_returnfloat)
self.con.create_function("returnnull", 0, func_returnnull)
self.con.create_function("returnblob", 0, func_returnblob)
self.con.create_function("returnlonglong", 0, func_returnlonglong)
self.con.create_function("raiseexception", 0, func_raiseexception)
self.con.create_function("isstring", 1, func_isstring)
self.con.create_function("isint", 1, func_isint)
self.con.create_function("isfloat", 1, func_isfloat)
self.con.create_function("isnone", 1, func_isnone)
self.con.create_function("isblob", 1, func_isblob)
self.con.create_function("islonglong", 1, func_islonglong)
self.con.create_function("spam", -1, func)
def tearDown(self):
self.con.close()
def test_FuncErrorOnCreate(self):
with self.assertRaises(sqlite.OperationalError):
self.con.create_function("bla", -100, lambda x: 2*x)
def test_FuncRefCount(self):
def getfunc():
def f():
return 1
return f
f = getfunc()
globals()["foo"] = f
# self.con.create_function("reftest", 0, getfunc())
self.con.create_function("reftest", 0, f)
cur = self.con.cursor()
cur.execute("select reftest()")
def test_FuncReturnText(self):
cur = self.con.cursor()
cur.execute("select returntext()")
val = cur.fetchone()[0]
self.assertEqual(type(val), str)
self.assertEqual(val, "foo")
def test_FuncReturnUnicode(self):
cur = self.con.cursor()
cur.execute("select returnunicode()")
val = cur.fetchone()[0]
self.assertEqual(type(val), str)
self.assertEqual(val, "bar")
def test_FuncReturnInt(self):
cur = self.con.cursor()
cur.execute("select returnint()")
val = cur.fetchone()[0]
self.assertEqual(type(val), int)
self.assertEqual(val, 42)
def test_FuncReturnFloat(self):
cur = self.con.cursor()
cur.execute("select returnfloat()")
val = cur.fetchone()[0]
self.assertEqual(type(val), float)
if val < 3.139 or val > 3.141:
self.fail("wrong value")
def test_FuncReturnNull(self):
cur = self.con.cursor()
cur.execute("select returnnull()")
val = cur.fetchone()[0]
self.assertEqual(type(val), type(None))
self.assertEqual(val, None)
def test_FuncReturnBlob(self):
cur = self.con.cursor()
cur.execute("select returnblob()")
val = cur.fetchone()[0]
self.assertEqual(type(val), bytes)
self.assertEqual(val, b"blob")
def test_FuncReturnLongLong(self):
cur = self.con.cursor()
cur.execute("select returnlonglong()")
val = cur.fetchone()[0]
self.assertEqual(val, 1<<31)
def test_FuncException(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select raiseexception()")
cur.fetchone()
self.assertEqual(str(cm.exception), 'user-defined function raised exception')
def test_ParamString(self):
cur = self.con.cursor()
cur.execute("select isstring(?)", ("foo",))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_ParamInt(self):
cur = self.con.cursor()
cur.execute("select isint(?)", (42,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_ParamFloat(self):
cur = self.con.cursor()
cur.execute("select isfloat(?)", (3.14,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_ParamNone(self):
cur = self.con.cursor()
cur.execute("select isnone(?)", (None,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_ParamBlob(self):
cur = self.con.cursor()
cur.execute("select isblob(?)", (memoryview(b"blob"),))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_ParamLongLong(self):
cur = self.con.cursor()
cur.execute("select islonglong(?)", (1<<42,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AnyArguments(self):
cur = self.con.cursor()
cur.execute("select spam(?, ?)", (1, 2))
val = cur.fetchone()[0]
self.assertEqual(val, 2)
def test_FuncNonDeterministic(self):
mock = unittest.mock.Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, deterministic=False)
self.con.execute("select deterministic() = deterministic()")
self.assertEqual(mock.call_count, 2)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "deterministic parameter not supported")
def test_FuncDeterministic(self):
mock = unittest.mock.Mock(return_value=None)
self.con.create_function("deterministic", 0, mock, True)
self.con.execute("select 1 where deterministic() AND deterministic()")
self.assertEqual(mock.call_count, 1)
@unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
def test_FuncDeterministicNotSupported(self):
with self.assertRaises(sqlite.NotSupportedError):
self.con.create_function("deterministic", 0, int, deterministic=True)
class AggregateTests(unittest.TestCase):
def setUp(self):
self.con = sqlite.connect(":memory:")
cur = self.con.cursor()
cur.execute("""
create table test(
t text,
i integer,
f float,
n,
b blob
)
""")
cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
("foo", 5, 3.14, None, memoryview(b"blob"),))
self.con.create_aggregate("nostep", 1, AggrNoStep)
self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
self.con.create_aggregate("checkType", 2, AggrCheckType)
self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
self.con.create_aggregate("mysum", 1, AggrSum)
def tearDown(self):
#self.cur.close()
#self.con.close()
pass
def test_AggrErrorOnCreate(self):
with self.assertRaises(sqlite.OperationalError):
self.con.create_function("bla", -100, AggrSum)
def test_AggrNoStep(self):
cur = self.con.cursor()
with self.assertRaises(AttributeError) as cm:
cur.execute("select nostep(t) from test")
self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
def test_AggrNoFinalize(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select nofinalize(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
def test_AggrExceptionInInit(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excInit(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
def test_AggrExceptionInStep(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excStep(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
def test_AggrExceptionInFinalize(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excFinalize(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
def test_AggrCheckParamStr(self):
cur = self.con.cursor()
cur.execute("select checkType('str', ?)", ("foo",))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AggrCheckParamInt(self):
cur = self.con.cursor()
cur.execute("select checkType('int', ?)", (42,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AggrCheckParamsInt(self):
cur = self.con.cursor()
cur.execute("select checkTypes('int', ?, ?)", (42, 24))
val = cur.fetchone()[0]
self.assertEqual(val, 2)
def test_AggrCheckParamFloat(self):
cur = self.con.cursor()
cur.execute("select checkType('float', ?)", (3.14,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AggrCheckParamNone(self):
cur = self.con.cursor()
cur.execute("select checkType('None', ?)", (None,))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AggrCheckParamBlob(self):
cur = self.con.cursor()
cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
val = cur.fetchone()[0]
self.assertEqual(val, 1)
def test_AggrCheckAggrSum(self):
cur = self.con.cursor()
cur.execute("delete from test")
cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
cur.execute("select mysum(i) from test")
val = cur.fetchone()[0]
self.assertEqual(val, 60)
def test_AggrNoMatch(self):
cur = self.con.execute('select mysum(i) from (select 1 as i) where i == 0')
val = cur.fetchone()[0]
self.assertIsNone(val)
@unittest.skipIf(sqlite.sqlite_version_info < (3, 25, 0),
'requires sqlite with window-function support')
class WindowFunctionTests(unittest.TestCase):
def setUp(self):
self.conn = sqlite.connect(":memory:")
self.conn.execute('create table sample (id integer primary key, '
'counter integer, value real);')
self.conn.execute('insert into sample (counter, value) values '
'(?,?),(?,?),(?,?),(?,?),(?,?)', (
1, 10.,
1, 20.,
2, 1.,
2, 2.,
3, 100.))
def test_user_defined_window_function(self):
class MySum(object):
def __init__(self): self._value = 0
def step(self, value): self._value += value
def inverse(self, value): self._value -= value
def value(self): return self._value
def finalize(self): return self._value
self.conn.create_window_function('mysum', -1, MySum)
q = ('select counter, value, mysum(value) over (partition by counter) '
'from sample order by id')
self.assertEqual(self.conn.execute(q).fetchall(), [
(1, 10., 30.), (1, 20., 30.),
(2, 1., 3.), (2, 2., 3.),
(3, 100., 100.)])
class AuthorizerTests(unittest.TestCase):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
if action != sqlite.SQLITE_SELECT:
return sqlite.SQLITE_DENY
if arg2 == 'c2' or arg1 == 't2':
return sqlite.SQLITE_DENY
return sqlite.SQLITE_OK
def setUp(self):
self.con = sqlite.connect(":memory:")
self.con.executescript("""
create table t1 (c1, c2);
create table t2 (c1, c2);
insert into t1 (c1, c2) values (1, 2);
insert into t2 (c1, c2) values (4, 5);
""")
# For our security test:
self.con.execute("select c2 from t2")
self.con.set_authorizer(self.authorizer_cb)
def tearDown(self):
pass
def test_table_access(self):
with self.assertRaises(sqlite.DatabaseError) as cm:
self.con.execute("select * from t2")
self.assertIn('prohibited', str(cm.exception))
def test_column_access(self):
with self.assertRaises(sqlite.DatabaseError) as cm:
self.con.execute("select c2 from t1")
self.assertIn('prohibited', str(cm.exception))
def test_clear_authorizer(self):
self.con.set_authorizer(None)
self.con.execute('select * from t2')
self.con.execute('select c2 from t1')
class AuthorizerRaiseExceptionTests(AuthorizerTests):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
if action != sqlite.SQLITE_SELECT:
raise ValueError
if arg2 == 'c2' or arg1 == 't2':
raise ValueError
return sqlite.SQLITE_OK
class AuthorizerIllegalTypeTests(AuthorizerTests):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
if action != sqlite.SQLITE_SELECT:
return 0.0
if arg2 == 'c2' or arg1 == 't2':
return 0.0
return sqlite.SQLITE_OK
class AuthorizerLargeIntegerTests(AuthorizerTests):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
if action != sqlite.SQLITE_SELECT:
return 2**32
if arg2 == 'c2' or arg1 == 't2':
return 2**32
return sqlite.SQLITE_OK
def suite():
loader = unittest.TestLoader()
tests = [loader.loadTestsFromTestCase(t) for t in (
FunctionTests,
AggregateTests,
WindowFunctionTests,
AuthorizerTests,
AuthorizerRaiseExceptionTests,
AuthorizerIllegalTypeTests,
AuthorizerLargeIntegerTests)]
return unittest.TestSuite(tests)
def test():
runner = unittest.TextTestRunner()
runner.run(suite())
if __name__ == "__main__":
test()