Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions src/google/adk/sessions/database_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,13 +744,15 @@ async def append_event(self, session: Session, event: Event) -> Event:
storage_session.update_time = update_time
sql_session.add(schema.StorageEvent.from_event(session, event))

# Read revision fields before commit. Post-commit ORM attribute access
# can lazy-load expired columns and trigger MissingGreenlet with asyncpg
# when pool_pre_ping is enabled.
last_update_time = storage_session.get_update_timestamp(is_sqlite)
storage_update_marker = storage_session.get_update_marker()
await sql_session.commit()

# Update timestamp with commit time
session.last_update_time = storage_session.get_update_timestamp(
is_sqlite
)
session._storage_update_marker = storage_session.get_update_marker()
session.last_update_time = last_update_time
session._storage_update_marker = storage_update_marker

# Also update the in-memory session
await super().append_event(session=session, event=event)
Expand Down
79 changes: 79 additions & 0 deletions tests/unittests/sessions/test_session_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,29 @@ def __getattr__(self, name):
return getattr(self._real, name)


class _CommitOrderSpySession:
"""SQLAlchemy session spy that marks when commit() has completed."""

def __init__(self, real_session, on_committed):
self._real = real_session
self._on_committed = on_committed

async def __aenter__(self):
self._real = await self._real.__aenter__()
return self

async def __aexit__(self, *args):
return await self._real.__aexit__(*args)

async def commit(self):
result = await self._real.commit()
self._on_committed()
return result

def __getattr__(self, name):
return getattr(self._real, name)


@pytest.mark.asyncio
async def test_create_session_calls_rollback_on_commit_failure():
"""Verifies that a commit failure during create_session triggers an explicit
Expand Down Expand Up @@ -1246,6 +1269,62 @@ def _spy_factory():
await service.close()


@pytest.mark.asyncio
async def test_append_event_reads_storage_revision_before_commit():
"""append_event captures session revision before commit completes."""
service = DatabaseSessionService('sqlite+aiosqlite:///:memory:')
schema = service._get_schema_classes()
original_get_update_timestamp = schema.StorageSession.get_update_timestamp
original_get_update_marker = schema.StorageSession.get_update_marker
revision_read_state = {'committed': False, 'post_commit_reads': 0}

def _track_revision_read(original):
def wrapper(self, *args, **kwargs):
if revision_read_state['committed']:
revision_read_state['post_commit_reads'] += 1
return original(self, *args, **kwargs)

return wrapper

schema.StorageSession.get_update_timestamp = _track_revision_read(
original_get_update_timestamp
)
schema.StorageSession.get_update_marker = _track_revision_read(
original_get_update_marker
)

try:
session = await service.create_session(
app_name='app', user_id='user', session_id='s1'
)
event_timestamp = session.last_update_time + 10
event = Event(
invocation_id='inv1',
author='user',
timestamp=event_timestamp,
)

original_factory = service.database_session_factory

def _spy_factory():
return _CommitOrderSpySession(
original_factory(),
on_committed=lambda: revision_read_state.update({'committed': True}),
)

service.database_session_factory = _spy_factory

await service.append_event(session, event)

assert revision_read_state['post_commit_reads'] == 0
assert session.last_update_time == pytest.approx(event_timestamp, abs=1e-6)
assert session._storage_update_marker is not None
finally:
schema.StorageSession.get_update_timestamp = original_get_update_timestamp
schema.StorageSession.get_update_marker = original_get_update_marker
await service.close()


@pytest.mark.asyncio
async def test_delete_session_calls_rollback_on_commit_failure():
"""Verifies that a commit failure during delete_session triggers an explicit
Expand Down