Support sqlite3_trace_v2() when Sqlite support is available.

This commit is contained in:
Charles Leifer 2023-05-02 11:39:58 -05:00
parent e1bc4d9669
commit 63c0189248

View file

@ -21,6 +21,7 @@
# misrepresented as being the original software. # misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import os
import unittest import unittest
from sqlcipher3 import dbapi2 as sqlite from sqlcipher3 import dbapi2 as sqlite
@ -214,6 +215,17 @@ class TraceCallbackTests(unittest.TestCase):
self.assertTrue(traced_statements) self.assertTrue(traced_statements)
self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
def CheckTraceCallbackError(self):
"""
Test behavior when exception raised in trace callback.
"""
con = sqlite.connect(":memory:")
def trace(statement):
raise Exception('uh-oh')
con.set_trace_callback(trace)
con.execute("create table foo(a, b)")
con.set_trace_callback(None)
def CheckClearTraceCallback(self): def CheckClearTraceCallback(self):
""" """
Test that setting the trace callback to None clears the previously set callback. Test that setting the trace callback to None clears the previously set callback.
@ -247,24 +259,23 @@ class TraceCallbackTests(unittest.TestCase):
"Unicode data %s garbled in trace callback: %s" "Unicode data %s garbled in trace callback: %s"
% (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
#@unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available") def CheckTraceCallbackContent(self):
#def CheckTraceCallbackContent(self): # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
# # set_trace_callback() shouldn't produce duplicate content (bpo-26187) traced_statements = []
# traced_statements = [] def trace(statement):
# def trace(statement): traced_statements.append(statement)
# traced_statements.append(statement)
# queries = ["create table foo(x)", queries = ["create table foo(x)",
# "insert into foo(x) values(1)"] "insert into foo(x) values(1)"]
# self.addCleanup(unlink, TESTFN) self.addCleanup(os.unlink, 'tracecb.db')
# con1 = sqlite.connect(TESTFN, isolation_level=None) con1 = sqlite.connect('tracecb.db', isolation_level=None)
# con2 = sqlite.connect(TESTFN) con2 = sqlite.connect('tracecb.db')
# con1.set_trace_callback(trace) con1.set_trace_callback(trace)
# cur = con1.cursor() cur = con1.cursor()
# cur.execute(queries[0]) cur.execute(queries[0])
# con2.execute("create table bar(x)") con2.execute("create table bar(x)")
# cur.execute(queries[1]) cur.execute(queries[1])
# self.assertEqual(traced_statements, queries) self.assertEqual(traced_statements, queries)
def suite(): def suite():