From 03b08708b1af1ca8c75d07e47277a320faa70675 Mon Sep 17 00:00:00 2001 From: cloudwebrtc Date: Wed, 29 Apr 2026 21:29:30 +0800 Subject: [PATCH 1/2] E2E Test: data channel send/receive test. --- .github/workflows/tests.yml | 4 +- livekit-rtc/tests/test_dc.py | 373 +++++++++++++++++++++++++++++++++++ 2 files changed, 375 insertions(+), 2 deletions(-) create mode 100644 livekit-rtc/tests/test_dc.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5834e804..3b495c01 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -128,7 +128,7 @@ jobs: LIVEKIT_API_SECRET: ${{ secrets.LIVEKIT_API_SECRET }} run: | source .test-venv/bin/activate - pytest tests/ + pytest tests/ livekit-rtc/tests/ - name: Run tests (Windows) if: runner.os == 'Windows' @@ -136,6 +136,6 @@ jobs: LIVEKIT_URL: ${{ secrets.LIVEKIT_URL }} LIVEKIT_API_KEY: ${{ secrets.LIVEKIT_API_KEY }} LIVEKIT_API_SECRET: ${{ secrets.LIVEKIT_API_SECRET }} - run: .test-venv\Scripts\python.exe -m pytest tests/ + run: .test-venv\Scripts\python.exe -m pytest tests/ livekit-rtc/tests/ shell: pwsh \ No newline at end of file diff --git a/livekit-rtc/tests/test_dc.py b/livekit-rtc/tests/test_dc.py new file mode 100644 index 00000000..418802c5 --- /dev/null +++ b/livekit-rtc/tests/test_dc.py @@ -0,0 +1,373 @@ +"""End-to-end Test for data-channel scenarios. + +Covers one-to-one delivery, broadcasting to all, topic filtering, and targeted +delivery via `destination_identities`. + +Requires the following environment variables to run: + LIVEKIT_URL + LIVEKIT_API_KEY + LIVEKIT_API_SECRET +""" + +from __future__ import annotations + +import asyncio +import os +import uuid +from typing import Callable, Optional + +import pytest + +from livekit import api, rtc +from livekit.rtc.room import EventTypes + + +WAIT_TIMEOUT = 20.0 +WAIT_INTERVAL = 0.1 + + +def skip_if_no_credentials(): + required_vars = ["LIVEKIT_URL", "LIVEKIT_API_KEY", "LIVEKIT_API_SECRET"] + missing = [var for var in required_vars if not os.getenv(var)] + return pytest.mark.skipif( + bool(missing), reason=f"Missing environment variables: {', '.join(missing)}" + ) + + +def create_token(identity: str, room_name: str) -> str: + return ( + api.AccessToken() + .with_identity(identity) + .with_name(identity) + .with_grants( + api.VideoGrants( + room_join=True, + room=room_name, + ) + ) + .to_jwt() + ) + + +def unique_room_name(base: str) -> str: + return f"{base}-{uuid.uuid4().hex[:8]}" + + +async def _wait_until( + predicate: Callable[[], bool], + *, + timeout: float = WAIT_TIMEOUT, + interval: float = WAIT_INTERVAL, + message: str = "condition not met", +) -> None: + loop = asyncio.get_event_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + if predicate(): + return + await asyncio.sleep(interval) + raise AssertionError(f"timeout waiting: {message}") + + +async def _connect(room: rtc.Room, identity: str, room_name: str) -> str: + token = create_token(identity, room_name) + url = os.environ["LIVEKIT_URL"] + await room.connect(url, token) + return token + + +async def _ensure_all_connected(rooms: list[rtc.Room]) -> None: + await _wait_until( + lambda: all(r.connection_state == rtc.ConnectionState.CONN_CONNECTED for r in rooms), + message="not all rooms reached CONN_CONNECTED", + ) + + +async def _ensure_visible(observer: rtc.Room, identities: list[str]) -> None: + """Wait until `observer` sees every identity in `identities` as a remote participant. + + Targeted publishes resolve identities at publish time, so we must let the + sender's room state catch up before sending.""" + + def _all_visible() -> bool: + seen = {p.identity for p in observer.remote_participants.values()} + return all(ident in seen for ident in identities) + + await _wait_until( + _all_visible, + message=f"not all identities visible to {observer.local_participant.identity}: {identities}", + ) + + +def _expect_event( + room: rtc.Room, + event: EventTypes, + predicate: Optional[Callable[..., bool]] = None, +) -> asyncio.Future: + loop = asyncio.get_event_loop() + fut: asyncio.Future = loop.create_future() + + def _on_event(*args, **kwargs) -> None: + if fut.done(): + return + if predicate is None or predicate(*args, **kwargs): + fut.set_result(args) + + room.on(event, _on_event) + return fut + + +async def _await_event(fut: asyncio.Future, timeout: float = WAIT_TIMEOUT) -> None: + try: + await asyncio.wait_for(fut, timeout=timeout) + except asyncio.TimeoutError as e: + raise AssertionError("timed out waiting for event") from e + + +class _DataCollector: + """Collects `data_received` packets matching `sender_identity` (when set).""" + + def __init__(self, room: rtc.Room, sender_identity: Optional[str] = None) -> None: + self.packets: list[rtc.DataPacket] = [] + self._sender_identity = sender_identity + + def _on_data(packet: rtc.DataPacket) -> None: + if self._sender_identity is not None and ( + packet.participant is None or packet.participant.identity != self._sender_identity + ): + return + self.packets.append(packet) + + room.on("data_received", _on_data) + + def payloads(self) -> list[bytes]: + return [p.data for p in self.packets] + + def topics(self) -> list[str | None]: + return [p.topic for p in self.packets] + + +async def _assert_no_data( + room: rtc.Room, collector: _DataCollector, *, settle: float = 1.0 +) -> None: + """Give the server time to deliver, then assert nothing arrived.""" + await asyncio.sleep(settle) + assert collector.packets == [], ( + f"{room.local_participant.identity} unexpectedly received " + f"{len(collector.packets)} packet(s): {collector.payloads()}" + ) + + +@skip_if_no_credentials() +@pytest.mark.asyncio +async def test_data_one_to_one() -> None: + """sender targets a single identity; only that identity receives.""" + room_name = unique_room_name("py-dc-1to1") + + sender = rtc.Room() + receiver = rtc.Room() + bystander = rtc.Room() + + await _connect(sender, "sender", room_name) + await _connect(receiver, "receiver", room_name) + await _connect(bystander, "bystander", room_name) + await _ensure_all_connected([sender, receiver, bystander]) + await _ensure_visible(sender, ["receiver", "bystander"]) + + receiver_collector = _DataCollector(receiver, sender_identity="sender") + bystander_collector = _DataCollector(bystander, sender_identity="sender") + + receiver_got = _expect_event( + receiver, + "data_received", + predicate=lambda packet: ( + packet.participant is not None and packet.participant.identity == "sender" + ), + ) + + payload = b"hello receiver" + await sender.local_participant.publish_data(payload, destination_identities=["receiver"]) + + await _await_event(receiver_got) + assert receiver_collector.payloads() == [payload] + await _assert_no_data(bystander, bystander_collector) + + await asyncio.gather(sender.disconnect(), receiver.disconnect(), bystander.disconnect()) + + +@skip_if_no_credentials() +@pytest.mark.asyncio +async def test_data_one_to_many_targeted() -> None: + """sender targets a subset of identities; only that subset receives.""" + room_name = unique_room_name("py-dc-1tomany") + + sender = rtc.Room() + r1 = rtc.Room() + r2 = rtc.Room() + excluded = rtc.Room() + + await _connect(sender, "sender", room_name) + await _connect(r1, "r1", room_name) + await _connect(r2, "r2", room_name) + await _connect(excluded, "excluded", room_name) + await _ensure_all_connected([sender, r1, r2, excluded]) + await _ensure_visible(sender, ["r1", "r2", "excluded"]) + + r1_collector = _DataCollector(r1, sender_identity="sender") + r2_collector = _DataCollector(r2, sender_identity="sender") + excluded_collector = _DataCollector(excluded, sender_identity="sender") + + r1_got = _expect_event( + r1, + "data_received", + predicate=lambda packet: ( + packet.participant is not None and packet.participant.identity == "sender" + ), + ) + r2_got = _expect_event( + r2, + "data_received", + predicate=lambda packet: ( + packet.participant is not None and packet.participant.identity == "sender" + ), + ) + + payload = b"hello selected" + await sender.local_participant.publish_data(payload, destination_identities=["r1", "r2"]) + + await asyncio.gather(_await_event(r1_got), _await_event(r2_got)) + assert r1_collector.payloads() == [payload] + assert r2_collector.payloads() == [payload] + await _assert_no_data(excluded, excluded_collector) + + await asyncio.gather( + sender.disconnect(), r1.disconnect(), r2.disconnect(), excluded.disconnect() + ) + + +@skip_if_no_credentials() +@pytest.mark.asyncio +async def test_data_broadcast() -> None: + """Empty `destination_identities` broadcasts to every other participant.""" + room_name = unique_room_name("py-dc-broadcast") + + sender = rtc.Room() + receivers = [rtc.Room() for _ in range(3)] + receiver_idents = [f"r{i}" for i in range(len(receivers))] + + await _connect(sender, "sender", room_name) + for room, ident in zip(receivers, receiver_idents): + await _connect(room, ident, room_name) + await _ensure_all_connected([sender, *receivers]) + await _ensure_visible(sender, receiver_idents) + + collectors = [_DataCollector(room, sender_identity="sender") for room in receivers] + received_futures = [ + _expect_event( + room, + "data_received", + predicate=lambda packet: ( + packet.participant is not None and packet.participant.identity == "sender" + ), + ) + for room in receivers + ] + + payload = b"hello everyone" + await sender.local_participant.publish_data(payload) + + await asyncio.gather(*(_await_event(f) for f in received_futures)) + for ident, collector in zip(receiver_idents, collectors): + assert collector.payloads() == [payload], f"{ident} payloads mismatch" + + await asyncio.gather(sender.disconnect(), *(r.disconnect() for r in receivers)) + + +@skip_if_no_credentials() +@pytest.mark.asyncio +async def test_data_topic_passthrough() -> None: + """Topic field is preserved end-to-end and observable by every receiver.""" + room_name = unique_room_name("py-dc-topic") + + sender = rtc.Room() + r1 = rtc.Room() + r2 = rtc.Room() + + await _connect(sender, "sender", room_name) + await _connect(r1, "r1", room_name) + await _connect(r2, "r2", room_name) + await _ensure_all_connected([sender, r1, r2]) + await _ensure_visible(sender, ["r1", "r2"]) + + r1_collector = _DataCollector(r1, sender_identity="sender") + r2_collector = _DataCollector(r2, sender_identity="sender") + + # Send three messages: two on "chat", one on "telemetry". + messages = [ + (b"chat-1", "chat"), + (b"telemetry-1", "telemetry"), + (b"chat-2", "chat"), + ] + + def _all_received(collector: _DataCollector) -> bool: + return len(collector.packets) >= len(messages) + + for payload, topic in messages: + await sender.local_participant.publish_data(payload, topic=topic) + + await _wait_until( + lambda: _all_received(r1_collector) and _all_received(r2_collector), + message="receivers did not get all topic messages", + ) + + expected_pairs = [(payload, topic) for payload, topic in messages] + for collector, ident in [(r1_collector, "r1"), (r2_collector, "r2")]: + got = list(zip(collector.payloads(), collector.topics())) + assert got == expected_pairs, f"{ident} mismatch: expected {expected_pairs}, got {got}" + + # Also verify `chat`-only filtering at the consumer side works as expected. + chat_only_r1 = [p for p in r1_collector.packets if p.topic == "chat"] + assert [p.data for p in chat_only_r1] == [b"chat-1", b"chat-2"] + + await asyncio.gather(sender.disconnect(), r1.disconnect(), r2.disconnect()) + + +@skip_if_no_credentials() +@pytest.mark.asyncio +async def test_data_targeted_with_topic() -> None: + """Targeted send carries the topic; non-targets receive nothing.""" + room_name = unique_room_name("py-dc-targeted-topic") + + sender = rtc.Room() + target = rtc.Room() + other = rtc.Room() + + await _connect(sender, "sender", room_name) + await _connect(target, "target", room_name) + await _connect(other, "other", room_name) + await _ensure_all_connected([sender, target, other]) + await _ensure_visible(sender, ["target", "other"]) + + target_collector = _DataCollector(target, sender_identity="sender") + other_collector = _DataCollector(other, sender_identity="sender") + + target_got = _expect_event( + target, + "data_received", + predicate=lambda packet: ( + packet.participant is not None and packet.participant.identity == "sender" + ), + ) + + payload = b"private ping" + topic = "private" + await sender.local_participant.publish_data( + payload, destination_identities=["target"], topic=topic + ) + + await _await_event(target_got) + assert target_collector.payloads() == [payload] + assert target_collector.topics() == [topic] + await _assert_no_data(other, other_collector) + + await asyncio.gather(sender.disconnect(), target.disconnect(), other.disconnect()) From 7c92e21b5f38b5a7ded2e78ebdeb06f937aceb03 Mon Sep 17 00:00:00 2001 From: cloudwebrtc Date: Fri, 22 May 2026 18:39:19 +0800 Subject: [PATCH 2/2] add more shared code. --- .../tests/test_change_video_quality.py | 90 ++------ livekit-rtc/tests/test_dc.py | 198 +++++------------- .../tests/test_e2ee_per_participant.py | 18 +- livekit-rtc/tests/test_e2ee_shared_key.py | 18 +- livekit-rtc/tests/utils.py | 167 ++++++++++++++- 5 files changed, 233 insertions(+), 258 deletions(-) diff --git a/livekit-rtc/tests/test_change_video_quality.py b/livekit-rtc/tests/test_change_video_quality.py index debe0587..9e2b2e2a 100644 --- a/livekit-rtc/tests/test_change_video_quality.py +++ b/livekit-rtc/tests/test_change_video_quality.py @@ -42,20 +42,25 @@ import os import sys import time -from typing import Any, Callable, Optional, Tuple +from typing import Optional, Tuple import numpy as np import pytest from livekit import rtc from livekit.rtc._proto.track_publication_pb2 import VideoQuality -from livekit.rtc.room import EventTypes -from utils import create_token, skip_if_no_credentials, unique_room_name # type: ignore[import-not-found] +from utils import ( # type: ignore[import-not-found] + await_event, + create_token, + ensure_rooms_all_connected, + ensure_track_subscribed, + expect_room_event, + skip_if_no_credentials, + unique_room_name, +) -WAIT_TIMEOUT = 30.0 -WAIT_INTERVAL = 0.1 PUBLISH_WIDTH = 1280 PUBLISH_HEIGHT = 720 PUBLISH_FPS = 15 @@ -74,73 +79,6 @@ ] -async def _wait_until( - predicate: Callable[[], bool], - *, - timeout: float = WAIT_TIMEOUT, - interval: float = WAIT_INTERVAL, - message: str = "condition not met", -) -> None: - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout - while loop.time() < deadline: - if predicate(): - return - await asyncio.sleep(interval) - raise AssertionError(f"timeout waiting: {message}") - - -async def _ensure_all_connected(rooms: list[rtc.Room]) -> None: - await _wait_until( - lambda: all(r.connection_state == rtc.ConnectionState.CONN_CONNECTED for r in rooms), - message="not all rooms reached CONN_CONNECTED", - ) - - -async def _ensure_track_subscribed(room: rtc.Room, track_sid: str) -> rtc.RemoteTrackPublication: - holder: dict[str, rtc.RemoteTrackPublication] = {} - - def _has_subscribed() -> bool: - for participant in room.remote_participants.values(): - pub = participant.track_publications.get(track_sid) - if pub is not None and pub.subscribed and pub.track is not None: - holder["pub"] = pub - return True - return False - - await _wait_until( - _has_subscribed, - message=f"room did not subscribe to track {track_sid}", - ) - return holder["pub"] - - -def _expect_event( - room: rtc.Room, - event: EventTypes, - predicate: Optional[Callable[..., bool]] = None, -) -> asyncio.Future: - loop = asyncio.get_running_loop() - fut: asyncio.Future = loop.create_future() - - def _on_event(*args: Any, **kwargs: Any) -> None: - if fut.done(): - return - if predicate is None or predicate(*args, **kwargs): - fut.set_result(args) - room.off(event, _on_event) - - room.on(event, _on_event) - return fut - - -async def _await_event(fut: asyncio.Future, timeout: float = WAIT_TIMEOUT) -> None: - try: - await asyncio.wait_for(fut, timeout=timeout) - except asyncio.TimeoutError as e: - raise AssertionError("timed out waiting for event") from e - - def _make_rolling_i420(width: int, height: int, t: float) -> rtc.VideoFrame: """Build a 1280x720 I420 frame containing 8 vertical color bars that scroll horizontally over time, so the encoder always sees motion.""" @@ -259,7 +197,7 @@ async def test_simulcast_quality_layers( sender, receiver = rtc.Room(), rtc.Room() await sender.connect(url, create_token("sender", room_name)) await receiver.connect(url, create_token("receiver", room_name)) - await _ensure_all_connected([sender, receiver]) + await ensure_rooms_all_connected([sender, receiver]) source = rtc.VideoSource(PUBLISH_WIDTH, PUBLISH_HEIGHT) track = rtc.LocalVideoTrack.create_video_track(f"{mode}-{codec_name}", source) @@ -282,13 +220,13 @@ async def test_simulcast_quality_layers( stream: Optional[rtc.VideoStream] = None try: - track_published = _expect_event( + track_published = expect_room_event( receiver, "track_published", predicate=lambda pub, _p: pub.kind == rtc.TrackKind.KIND_VIDEO, ) local_pub = await sender.local_participant.publish_track(track, options) - await _await_event(track_published) + await await_event(track_published) print( f"[{codec_name}] local_pub: sid={local_pub.sid} " @@ -296,7 +234,7 @@ async def test_simulcast_quality_layers( f"mime_type={local_pub.mime_type} " f"{local_pub.width}x{local_pub.height}" ) - remote_pub = await _ensure_track_subscribed(receiver, local_pub.sid) + remote_pub = await ensure_track_subscribed(receiver, local_pub.sid) assert remote_pub.track is not None print( f"[{codec_name}] remote_pub: sid={remote_pub.sid} " diff --git a/livekit-rtc/tests/test_dc.py b/livekit-rtc/tests/test_dc.py index 418802c5..aff64e3f 100644 --- a/livekit-rtc/tests/test_dc.py +++ b/livekit-rtc/tests/test_dc.py @@ -12,116 +12,22 @@ from __future__ import annotations import asyncio -import os -import uuid -from typing import Callable, Optional +from typing import Optional import pytest -from livekit import api, rtc -from livekit.rtc.room import EventTypes +from livekit import rtc - -WAIT_TIMEOUT = 20.0 -WAIT_INTERVAL = 0.1 - - -def skip_if_no_credentials(): - required_vars = ["LIVEKIT_URL", "LIVEKIT_API_KEY", "LIVEKIT_API_SECRET"] - missing = [var for var in required_vars if not os.getenv(var)] - return pytest.mark.skipif( - bool(missing), reason=f"Missing environment variables: {', '.join(missing)}" - ) - - -def create_token(identity: str, room_name: str) -> str: - return ( - api.AccessToken() - .with_identity(identity) - .with_name(identity) - .with_grants( - api.VideoGrants( - room_join=True, - room=room_name, - ) - ) - .to_jwt() - ) - - -def unique_room_name(base: str) -> str: - return f"{base}-{uuid.uuid4().hex[:8]}" - - -async def _wait_until( - predicate: Callable[[], bool], - *, - timeout: float = WAIT_TIMEOUT, - interval: float = WAIT_INTERVAL, - message: str = "condition not met", -) -> None: - loop = asyncio.get_event_loop() - deadline = loop.time() + timeout - while loop.time() < deadline: - if predicate(): - return - await asyncio.sleep(interval) - raise AssertionError(f"timeout waiting: {message}") - - -async def _connect(room: rtc.Room, identity: str, room_name: str) -> str: - token = create_token(identity, room_name) - url = os.environ["LIVEKIT_URL"] - await room.connect(url, token) - return token - - -async def _ensure_all_connected(rooms: list[rtc.Room]) -> None: - await _wait_until( - lambda: all(r.connection_state == rtc.ConnectionState.CONN_CONNECTED for r in rooms), - message="not all rooms reached CONN_CONNECTED", - ) - - -async def _ensure_visible(observer: rtc.Room, identities: list[str]) -> None: - """Wait until `observer` sees every identity in `identities` as a remote participant. - - Targeted publishes resolve identities at publish time, so we must let the - sender's room state catch up before sending.""" - - def _all_visible() -> bool: - seen = {p.identity for p in observer.remote_participants.values()} - return all(ident in seen for ident in identities) - - await _wait_until( - _all_visible, - message=f"not all identities visible to {observer.local_participant.identity}: {identities}", - ) - - -def _expect_event( - room: rtc.Room, - event: EventTypes, - predicate: Optional[Callable[..., bool]] = None, -) -> asyncio.Future: - loop = asyncio.get_event_loop() - fut: asyncio.Future = loop.create_future() - - def _on_event(*args, **kwargs) -> None: - if fut.done(): - return - if predicate is None or predicate(*args, **kwargs): - fut.set_result(args) - - room.on(event, _on_event) - return fut - - -async def _await_event(fut: asyncio.Future, timeout: float = WAIT_TIMEOUT) -> None: - try: - await asyncio.wait_for(fut, timeout=timeout) - except asyncio.TimeoutError as e: - raise AssertionError("timed out waiting for event") from e +from utils import ( # type: ignore[import-not-found] + await_event, + connect_room, + ensure_participants_visible, + ensure_rooms_all_connected, + expect_room_event, + skip_if_no_credentials, + unique_room_name, + wait_until, +) class _DataCollector: @@ -158,7 +64,7 @@ async def _assert_no_data( ) -@skip_if_no_credentials() +@skip_if_no_credentials() # type: ignore[untyped-decorator] @pytest.mark.asyncio async def test_data_one_to_one() -> None: """sender targets a single identity; only that identity receives.""" @@ -168,16 +74,16 @@ async def test_data_one_to_one() -> None: receiver = rtc.Room() bystander = rtc.Room() - await _connect(sender, "sender", room_name) - await _connect(receiver, "receiver", room_name) - await _connect(bystander, "bystander", room_name) - await _ensure_all_connected([sender, receiver, bystander]) - await _ensure_visible(sender, ["receiver", "bystander"]) + await connect_room("sender", room_name, room=sender) + await connect_room("receiver", room_name, room=receiver) + await connect_room("bystander", room_name, room=bystander) + await ensure_rooms_all_connected([sender, receiver, bystander]) + await ensure_participants_visible(sender, ["receiver", "bystander"]) receiver_collector = _DataCollector(receiver, sender_identity="sender") bystander_collector = _DataCollector(bystander, sender_identity="sender") - receiver_got = _expect_event( + receiver_got = expect_room_event( receiver, "data_received", predicate=lambda packet: ( @@ -188,14 +94,14 @@ async def test_data_one_to_one() -> None: payload = b"hello receiver" await sender.local_participant.publish_data(payload, destination_identities=["receiver"]) - await _await_event(receiver_got) + await await_event(receiver_got) assert receiver_collector.payloads() == [payload] await _assert_no_data(bystander, bystander_collector) await asyncio.gather(sender.disconnect(), receiver.disconnect(), bystander.disconnect()) -@skip_if_no_credentials() +@skip_if_no_credentials() # type: ignore[untyped-decorator] @pytest.mark.asyncio async def test_data_one_to_many_targeted() -> None: """sender targets a subset of identities; only that subset receives.""" @@ -206,25 +112,25 @@ async def test_data_one_to_many_targeted() -> None: r2 = rtc.Room() excluded = rtc.Room() - await _connect(sender, "sender", room_name) - await _connect(r1, "r1", room_name) - await _connect(r2, "r2", room_name) - await _connect(excluded, "excluded", room_name) - await _ensure_all_connected([sender, r1, r2, excluded]) - await _ensure_visible(sender, ["r1", "r2", "excluded"]) + await connect_room("sender", room_name, room=sender) + await connect_room("r1", room_name, room=r1) + await connect_room("r2", room_name, room=r2) + await connect_room("excluded", room_name, room=excluded) + await ensure_rooms_all_connected([sender, r1, r2, excluded]) + await ensure_participants_visible(sender, ["r1", "r2", "excluded"]) r1_collector = _DataCollector(r1, sender_identity="sender") r2_collector = _DataCollector(r2, sender_identity="sender") excluded_collector = _DataCollector(excluded, sender_identity="sender") - r1_got = _expect_event( + r1_got = expect_room_event( r1, "data_received", predicate=lambda packet: ( packet.participant is not None and packet.participant.identity == "sender" ), ) - r2_got = _expect_event( + r2_got = expect_room_event( r2, "data_received", predicate=lambda packet: ( @@ -235,7 +141,7 @@ async def test_data_one_to_many_targeted() -> None: payload = b"hello selected" await sender.local_participant.publish_data(payload, destination_identities=["r1", "r2"]) - await asyncio.gather(_await_event(r1_got), _await_event(r2_got)) + await asyncio.gather(await_event(r1_got), await_event(r2_got)) assert r1_collector.payloads() == [payload] assert r2_collector.payloads() == [payload] await _assert_no_data(excluded, excluded_collector) @@ -245,7 +151,7 @@ async def test_data_one_to_many_targeted() -> None: ) -@skip_if_no_credentials() +@skip_if_no_credentials() # type: ignore[untyped-decorator] @pytest.mark.asyncio async def test_data_broadcast() -> None: """Empty `destination_identities` broadcasts to every other participant.""" @@ -255,15 +161,15 @@ async def test_data_broadcast() -> None: receivers = [rtc.Room() for _ in range(3)] receiver_idents = [f"r{i}" for i in range(len(receivers))] - await _connect(sender, "sender", room_name) + await connect_room("sender", room_name, room=sender) for room, ident in zip(receivers, receiver_idents): - await _connect(room, ident, room_name) - await _ensure_all_connected([sender, *receivers]) - await _ensure_visible(sender, receiver_idents) + await connect_room(ident, room_name, room=room) + await ensure_rooms_all_connected([sender, *receivers]) + await ensure_participants_visible(sender, receiver_idents) collectors = [_DataCollector(room, sender_identity="sender") for room in receivers] received_futures = [ - _expect_event( + expect_room_event( room, "data_received", predicate=lambda packet: ( @@ -276,14 +182,14 @@ async def test_data_broadcast() -> None: payload = b"hello everyone" await sender.local_participant.publish_data(payload) - await asyncio.gather(*(_await_event(f) for f in received_futures)) + await asyncio.gather(*(await_event(f) for f in received_futures)) for ident, collector in zip(receiver_idents, collectors): assert collector.payloads() == [payload], f"{ident} payloads mismatch" await asyncio.gather(sender.disconnect(), *(r.disconnect() for r in receivers)) -@skip_if_no_credentials() +@skip_if_no_credentials() # type: ignore[untyped-decorator] @pytest.mark.asyncio async def test_data_topic_passthrough() -> None: """Topic field is preserved end-to-end and observable by every receiver.""" @@ -293,11 +199,11 @@ async def test_data_topic_passthrough() -> None: r1 = rtc.Room() r2 = rtc.Room() - await _connect(sender, "sender", room_name) - await _connect(r1, "r1", room_name) - await _connect(r2, "r2", room_name) - await _ensure_all_connected([sender, r1, r2]) - await _ensure_visible(sender, ["r1", "r2"]) + await connect_room("sender", room_name, room=sender) + await connect_room("r1", room_name, room=r1) + await connect_room("r2", room_name, room=r2) + await ensure_rooms_all_connected([sender, r1, r2]) + await ensure_participants_visible(sender, ["r1", "r2"]) r1_collector = _DataCollector(r1, sender_identity="sender") r2_collector = _DataCollector(r2, sender_identity="sender") @@ -315,7 +221,7 @@ def _all_received(collector: _DataCollector) -> bool: for payload, topic in messages: await sender.local_participant.publish_data(payload, topic=topic) - await _wait_until( + await wait_until( lambda: _all_received(r1_collector) and _all_received(r2_collector), message="receivers did not get all topic messages", ) @@ -332,7 +238,7 @@ def _all_received(collector: _DataCollector) -> bool: await asyncio.gather(sender.disconnect(), r1.disconnect(), r2.disconnect()) -@skip_if_no_credentials() +@skip_if_no_credentials() # type: ignore[untyped-decorator] @pytest.mark.asyncio async def test_data_targeted_with_topic() -> None: """Targeted send carries the topic; non-targets receive nothing.""" @@ -342,16 +248,16 @@ async def test_data_targeted_with_topic() -> None: target = rtc.Room() other = rtc.Room() - await _connect(sender, "sender", room_name) - await _connect(target, "target", room_name) - await _connect(other, "other", room_name) - await _ensure_all_connected([sender, target, other]) - await _ensure_visible(sender, ["target", "other"]) + await connect_room("sender", room_name, room=sender) + await connect_room("target", room_name, room=target) + await connect_room("other", room_name, room=other) + await ensure_rooms_all_connected([sender, target, other]) + await ensure_participants_visible(sender, ["target", "other"]) target_collector = _DataCollector(target, sender_identity="sender") other_collector = _DataCollector(other, sender_identity="sender") - target_got = _expect_event( + target_got = expect_room_event( target, "data_received", predicate=lambda packet: ( @@ -365,7 +271,7 @@ async def test_data_targeted_with_topic() -> None: payload, destination_identities=["target"], topic=topic ) - await _await_event(target_got) + await await_event(target_got) assert target_collector.payloads() == [payload] assert target_collector.topics() == [topic] await _assert_no_data(other, other_collector) diff --git a/livekit-rtc/tests/test_e2ee_per_participant.py b/livekit-rtc/tests/test_e2ee_per_participant.py index 8062662e..ed5da639 100644 --- a/livekit-rtc/tests/test_e2ee_per_participant.py +++ b/livekit-rtc/tests/test_e2ee_per_participant.py @@ -26,6 +26,7 @@ from utils import ( # type: ignore[import-not-found] assert_eventually, create_token, + publish_dummy_video, skip_if_no_credentials, unique_room_name, ) @@ -68,23 +69,6 @@ def set_key_index_on_all_cryptors(room: rtc.Room, key_index: int) -> None: cryptor.set_key_index(key_index) -async def publish_dummy_video(source: rtc.VideoSource, stop_event: asyncio.Event) -> None: - """Continuously publish frames until stop_event is set.""" - pixel_count = WIDTH * HEIGHT - frame_idx = 0 - while not stop_event.is_set(): - fill = frame_idx % 256 - pixel = bytes((255, fill, (fill + 85) % 256, (fill + 170) % 256)) - buf = pixel * pixel_count - frame = rtc.VideoFrame(WIDTH, HEIGHT, rtc.VideoBufferType.ARGB, buf) - source.capture_frame(frame) - frame_idx += 1 - try: - await asyncio.wait_for(stop_event.wait(), timeout=1.0 / FRAME_RATE) - except asyncio.TimeoutError: - pass - - @pytest.mark.asyncio @skip_if_no_credentials() # type: ignore[untyped-decorator] async def test_e2ee_per_participant() -> None: diff --git a/livekit-rtc/tests/test_e2ee_shared_key.py b/livekit-rtc/tests/test_e2ee_shared_key.py index 4334bb78..230793f9 100644 --- a/livekit-rtc/tests/test_e2ee_shared_key.py +++ b/livekit-rtc/tests/test_e2ee_shared_key.py @@ -26,6 +26,7 @@ from utils import ( # type: ignore[import-not-found] assert_eventually, create_token, + publish_dummy_video, skip_if_no_credentials, unique_room_name, ) @@ -47,23 +48,6 @@ def make_e2ee_options() -> rtc.E2EEOptions: return options -async def publish_dummy_video(source: rtc.VideoSource, stop_event: asyncio.Event) -> None: - """Continuously publish frames until stop_event is set.""" - pixel_count = WIDTH * HEIGHT - frame_idx = 0 - while not stop_event.is_set(): - fill = frame_idx % 256 - pixel = bytes((255, fill, (fill + 85) % 256, (fill + 170) % 256)) - buf = pixel * pixel_count - frame = rtc.VideoFrame(WIDTH, HEIGHT, rtc.VideoBufferType.ARGB, buf) - source.capture_frame(frame) - frame_idx += 1 - try: - await asyncio.wait_for(stop_event.wait(), timeout=1.0 / FRAME_RATE) - except asyncio.TimeoutError: - pass - - @pytest.mark.asyncio @skip_if_no_credentials() # type: ignore[untyped-decorator] async def test_e2ee_shared_key() -> None: diff --git a/livekit-rtc/tests/utils.py b/livekit-rtc/tests/utils.py index 07f51c84..a6b099ff 100644 --- a/livekit-rtc/tests/utils.py +++ b/livekit-rtc/tests/utils.py @@ -19,15 +19,18 @@ import asyncio import os import uuid -from typing import Callable, TypeVar +from typing import Any, Callable, Optional, TypeVar import pytest -from livekit import api +from livekit import api, rtc +from livekit.rtc.room import EventTypes T = TypeVar("T") _REQUIRED_ENV_VARS = ("LIVEKIT_URL", "LIVEKIT_API_KEY", "LIVEKIT_API_SECRET") +DEFAULT_WAIT_TIMEOUT = 30.0 +DEFAULT_WAIT_INTERVAL = 0.1 def skip_if_no_credentials() -> pytest.MarkDecorator: @@ -74,3 +77,163 @@ async def assert_eventually( return last_result await asyncio.sleep(interval) raise AssertionError(f"{message} (last result: {last_result})") + + +async def wait_until( + predicate: Callable[[], bool], + *, + timeout: float = DEFAULT_WAIT_TIMEOUT, + interval: float = DEFAULT_WAIT_INTERVAL, + message: str = "condition not met", +) -> None: + """Poll `predicate()` until it returns True or `timeout` elapses.""" + loop = asyncio.get_running_loop() + deadline = loop.time() + timeout + while loop.time() < deadline: + if predicate(): + return + await asyncio.sleep(interval) + raise AssertionError(f"timeout waiting: {message}") + + +async def ensure_rooms_all_connected( + rooms: list[rtc.Room], + *, + timeout: float = DEFAULT_WAIT_TIMEOUT, +) -> None: + """Wait until every `Room` in `rooms` reaches CONN_CONNECTED.""" + await wait_until( + lambda: all(r.connection_state == rtc.ConnectionState.CONN_CONNECTED for r in rooms), + timeout=timeout, + message="not all rooms reached CONN_CONNECTED", + ) + + +async def ensure_participants_visible( + observer: rtc.Room, + identities: list[str], + *, + timeout: float = DEFAULT_WAIT_TIMEOUT, +) -> None: + """Wait until `observer` sees every identity in `identities` as a remote participant.""" + + def _all_visible() -> bool: + seen = {p.identity for p in observer.remote_participants.values()} + return all(ident in seen for ident in identities) + + await wait_until( + _all_visible, + timeout=timeout, + message=( + f"not all identities visible to {observer.local_participant.identity}: {identities}" + ), + ) + + +def expect_room_event( + room: rtc.Room, + event: EventTypes, + predicate: Optional[Callable[..., bool]] = None, +) -> "asyncio.Future[tuple[Any, ...]]": + """Register a one-shot handler for `event` returning a `Future` resolved with the event args. + + `predicate` (if given) filters events; the handler is unregistered after the future resolves. + """ + loop = asyncio.get_running_loop() + fut: "asyncio.Future[tuple[Any, ...]]" = loop.create_future() + + def _on_event(*args: Any, **kwargs: Any) -> None: + if fut.done(): + return + if predicate is None or predicate(*args, **kwargs): + fut.set_result(args) + room.off(event, _on_event) + + room.on(event, _on_event) + return fut + + +async def await_event( + fut: "asyncio.Future[Any]", + timeout: float = DEFAULT_WAIT_TIMEOUT, +) -> None: + """Await a future from `expect_room_event` with a descriptive timeout failure.""" + try: + await asyncio.wait_for(fut, timeout=timeout) + except asyncio.TimeoutError as e: + raise AssertionError("timed out waiting for event") from e + + +async def connect_room( + identity: str, + room_name: str, + *, + room: Optional[rtc.Room] = None, + options: Optional[rtc.RoomOptions] = None, +) -> rtc.Room: + """Build a token for `identity`/`room_name` and connect. + + If `room` is provided it is connected in place; otherwise a fresh `rtc.Room` + is created. Returns the connected room. + """ + if room is None: + room = rtc.Room() + url = os.environ["LIVEKIT_URL"] + token = create_token(identity, room_name) + if options is None: + await room.connect(url, token) + else: + await room.connect(url, token, options=options) + return room + + +async def ensure_track_subscribed( + room: rtc.Room, + track_sid: str, + *, + timeout: float = DEFAULT_WAIT_TIMEOUT, +) -> rtc.RemoteTrackPublication: + """Wait until some remote participant in `room` has subscribed to `track_sid`.""" + holder: dict[str, rtc.RemoteTrackPublication] = {} + + def _has_subscribed() -> bool: + for participant in room.remote_participants.values(): + pub = participant.track_publications.get(track_sid) + if pub is not None and pub.subscribed and pub.track is not None: + holder["pub"] = pub + return True + return False + + await wait_until( + _has_subscribed, + timeout=timeout, + message=f"room did not subscribe to track {track_sid}", + ) + return holder["pub"] + + +async def publish_dummy_video( + source: rtc.VideoSource, + stop_event: asyncio.Event, + *, + width: int = 320, + height: int = 180, + fps: int = 15, +) -> None: + """Continuously capture ARGB frames into `source` until `stop_event` is set. + + Pixel values vary frame-to-frame so encryption tests see distinct ciphertexts. + """ + pixel_count = width * height + frame_idx = 0 + while not stop_event.is_set(): + fill = frame_idx % 256 + pixel = bytes((255, fill, (fill + 85) % 256, (fill + 170) % 256)) + buf = pixel * pixel_count + frame = rtc.VideoFrame(width, height, rtc.VideoBufferType.ARGB, buf) + source.capture_frame(frame) + frame_idx += 1 + try: + await asyncio.wait_for(stop_event.wait(), timeout=1.0 / fps) + except asyncio.TimeoutError: + pass