Exclude pragma from is_dml logic and implicit transaction.

This commit is contained in:
Charles Leifer 2023-06-25 20:48:42 -05:00
parent 63c0189248
commit 5a97877e14
2 changed files with 18 additions and 2 deletions

View file

@ -104,7 +104,8 @@ int pysqlite_statement_create(pysqlite_Statement* self, pysqlite_Connection* con
PyOS_strnicmp(p, "alter", 5) && PyOS_strnicmp(p, "alter", 5) &&
PyOS_strnicmp(p, "analyze", 7) && PyOS_strnicmp(p, "analyze", 7) &&
PyOS_strnicmp(p, "reindex", 7) && PyOS_strnicmp(p, "reindex", 7) &&
PyOS_strnicmp(p, "vacuum", 6)); PyOS_strnicmp(p, "vacuum", 6) &&
PyOS_strnicmp(p, "pragma", 6));
break; break;
} }
} }

View file

@ -21,7 +21,7 @@
# misrepresented as being the original software. # misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution. # 3. This notice may not be removed or altered from any source distribution.
import os, unittest import glob, os, unittest
from sqlcipher3 import dbapi2 as sqlite from sqlcipher3 import dbapi2 as sqlite
def get_db_path(): def get_db_path():
@ -212,6 +212,14 @@ class DMLStatementDetectionTestCase(unittest.TestCase):
Use sqlite3_stmt_readonly to determine if the statement is DML or not. Use sqlite3_stmt_readonly to determine if the statement is DML or not.
""" """
def setUp(self):
for f in glob.glob(get_db_path() + '*'):
try:
os.unlink(f)
except OSError:
pass
tearDown = setUp
@unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3),
'needs sqlite 3.8.3 or newer') 'needs sqlite 3.8.3 or newer')
def test_dml_detection_cte(self): def test_dml_detection_cte(self):
@ -268,6 +276,13 @@ class DMLStatementDetectionTestCase(unittest.TestCase):
conn.execute('vacuum') conn.execute('vacuum')
self.assertFalse(conn.in_transaction) self.assertFalse(conn.in_transaction)
def test_dml_detection_vacuum(self):
conn = sqlite.connect(get_db_path())
conn.execute('pragma journal_mode=\'wal\'')
jmode, = conn.execute('pragma journal_mode').fetchone()
self.assertEqual(jmode, 'wal')
self.assertFalse(conn.in_transaction)
def suite(): def suite():
default_suite = unittest.makeSuite(TransactionTests, "Check") default_suite = unittest.makeSuite(TransactionTests, "Check")