Skip to content

Commit

Permalink
Merge pull request #2769 from nomme/reconnect-transaction-begin
Browse files Browse the repository at this point in the history
Cover transaction begin in ReconnectMixin
  • Loading branch information
coleifer committed Aug 18, 2023
2 parents 59435f4 + ecf7812 commit 5bc1bf2
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
10 changes: 8 additions & 2 deletions playhouse/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,14 @@ def __init__(self, *args, **kwargs):
def execute_sql(self, sql, params=None, commit=None):
if commit is not None:
__deprecated__('"commit" has been deprecated and is a no-op.')
return self._reconnect(super(ReconnectMixin, self).execute_sql, sql, params)

def begin(self):
return self._reconnect(super(ReconnectMixin, self).begin)

def _reconnect(self, func, *args, **kwargs):
try:
return super(ReconnectMixin, self).execute_sql(sql, params)
return func(*args, **kwargs)
except Exception as exc:
# If we are in a transaction, do not reconnect silently as
# any changes could be lost.
Expand All @@ -275,7 +281,7 @@ def execute_sql(self, sql, params=None, commit=None):
self.close()
self.connect()

return super(ReconnectMixin, self).execute_sql(sql, params)
return func(*args, **kwargs)


def resolve_multimodel_query(query, key='_model_identifier'):
Expand Down
38 changes: 36 additions & 2 deletions tests/shortcuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def cursor(self, named_cursor=None):

# The first (0th) query fails, as do all queries after the 2nd (1st).
if self._query_counter != 1:
def _fake_execute(self, _):
def _fake_execute(self, *args):
raise OperationalError('2006')
cursor.execute = _fake_execute
self._query_counter += 1
Expand All @@ -601,7 +601,7 @@ def _reset_mock(self):
class TestReconnectMixin(DatabaseTestCase):
database = db_loader('mysql', db_class=ReconnectMySQLDatabase)

def test_reconnect_mixin(self):
def test_reconnect_mixin_execute_sql(self):
# Verify initial state.
self.database._reset_mock()
self.assertEqual(self.database._close_counter, 0)
Expand All @@ -625,6 +625,40 @@ def test_reconnect_mixin(self):
self.assertEqual(self.database._close_counter, 1)


def test_reconnect_mixin_begin(self):
# Verify initial state.
self.database._reset_mock()
self.assertEqual(self.database._close_counter, 0)

with self.database.atomic():
self.assertTrue(self.database.in_transaction())
self.assertEqual(self.database._close_counter, 1)
# Prepare mock for commit call
self.database._query_counter = 1

# Due to how we configured our mock, our queries are now failing and we
# can verify a reconnect is occuring *AND* the exception is propagated.
self.assertRaises(OperationalError, self.database.atomic().__enter__)
self.assertEqual(self.database._close_counter, 2)
self.assertFalse(self.database.in_transaction())

# We reset the mock counters. The first query we execute will fail. The
# second query will succeed (which happens automatically, thanks to the
# retry logic).
self.database._reset_mock()
with self.database.atomic():
self.assertTrue(self.database.in_transaction())
self.assertEqual(self.database._close_counter, 1)

# Do not reconnect when nesting transactions
self.assertRaises(OperationalError, self.database.atomic().__enter__)
self.assertEqual(self.database._close_counter, 1)

# Prepare mock for commit call
self.database._query_counter = 1
self.assertFalse(self.database.in_transaction())


class MMA(TestModel):
key = TextField()
value = IntegerField()
Expand Down

0 comments on commit 5bc1bf2

Please sign in to comment.