diff --git a/livekit-rtc/livekit/rtc/data_track.py b/livekit-rtc/livekit/rtc/data_track.py index a56f1098..a28846b2 100644 --- a/livekit-rtc/livekit/rtc/data_track.py +++ b/livekit-rtc/livekit/rtc/data_track.py @@ -15,7 +15,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import AsyncIterator, Optional +from typing import AsyncIterator, NoReturn, Optional from ._ffi_client import FfiClient, FfiHandle from ._proto import ffi_pb2 as proto_ffi @@ -231,7 +231,10 @@ async def __anext__(self) -> DataTrackFrame: if self._closed: raise StopAsyncIteration - self._send_read_request() + eos = self._send_read_request() + if eos is not None: + self._handle_eos(eos) + event: proto_ffi.FfiEvent = await self._queue.get() stream_event = event.data_track_stream_event detail = stream_event.WhichOneof("detail") @@ -246,18 +249,34 @@ async def __anext__(self) -> DataTrackFrame: user_timestamp=user_ts, ) elif detail == "eos": - self._close() - if stream_event.eos.HasField("error"): - raise SubscribeDataTrackError(stream_event.eos.error.message) - raise StopAsyncIteration + self._handle_eos(stream_event.eos) else: self._close() raise StopAsyncIteration - def _send_read_request(self) -> None: + def _send_read_request(self) -> Optional[proto_data_track.DataTrackStreamEOS]: req = proto_ffi.FfiRequest() req.data_track_stream_read.stream_handle = self._ffi_handle.handle - FfiClient.instance.request(req) + resp = FfiClient.instance.request(req) + return self._read_response_eos(resp.data_track_stream_read) + + @staticmethod + def _read_response_eos( + read_response: proto_data_track.DataTrackStreamReadResponse, + ) -> Optional[proto_data_track.DataTrackStreamEOS]: + try: + if not read_response.HasField("eos"): + return None + except ValueError: + return None + + return getattr(read_response, "eos") + + def _handle_eos(self, eos: proto_data_track.DataTrackStreamEOS) -> NoReturn: + self._close() + if eos.HasField("error"): + raise SubscribeDataTrackError(eos.error.message) + raise StopAsyncIteration def _close(self) -> None: if not self._closed: