diff --git a/test/hooks.py b/test/hooks.py index c35f044..154a4b7 100644 --- a/test/hooks.py +++ b/test/hooks.py @@ -21,6 +21,7 @@ # misrepresented as being the original software. # 3. This notice may not be removed or altered from any source distribution. +import os import unittest from sqlcipher3 import dbapi2 as sqlite @@ -214,6 +215,17 @@ class TraceCallbackTests(unittest.TestCase): self.assertTrue(traced_statements) self.assertTrue(any("create table foo" in stmt for stmt in traced_statements)) + def CheckTraceCallbackError(self): + """ + Test behavior when exception raised in trace callback. + """ + con = sqlite.connect(":memory:") + def trace(statement): + raise Exception('uh-oh') + con.set_trace_callback(trace) + con.execute("create table foo(a, b)") + con.set_trace_callback(None) + def CheckClearTraceCallback(self): """ Test that setting the trace callback to None clears the previously set callback. @@ -247,24 +259,23 @@ class TraceCallbackTests(unittest.TestCase): "Unicode data %s garbled in trace callback: %s" % (ascii(unicode_value), ', '.join(map(ascii, traced_statements)))) - #@unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available") - #def CheckTraceCallbackContent(self): - # # set_trace_callback() shouldn't produce duplicate content (bpo-26187) - # traced_statements = [] - # def trace(statement): - # traced_statements.append(statement) + def CheckTraceCallbackContent(self): + # set_trace_callback() shouldn't produce duplicate content (bpo-26187) + traced_statements = [] + def trace(statement): + traced_statements.append(statement) - # queries = ["create table foo(x)", - # "insert into foo(x) values(1)"] - # self.addCleanup(unlink, TESTFN) - # con1 = sqlite.connect(TESTFN, isolation_level=None) - # con2 = sqlite.connect(TESTFN) - # con1.set_trace_callback(trace) - # cur = con1.cursor() - # cur.execute(queries[0]) - # con2.execute("create table bar(x)") - # cur.execute(queries[1]) - # self.assertEqual(traced_statements, queries) + queries = ["create table foo(x)", + "insert into foo(x) values(1)"] + self.addCleanup(os.unlink, 'tracecb.db') + con1 = sqlite.connect('tracecb.db', isolation_level=None) + con2 = sqlite.connect('tracecb.db') + con1.set_trace_callback(trace) + cur = con1.cursor() + cur.execute(queries[0]) + con2.execute("create table bar(x)") + cur.execute(queries[1]) + self.assertEqual(traced_statements, queries) def suite():