diff --git a/sdk/src/opendecree/_watcher_base.py b/sdk/src/opendecree/_watcher_base.py index eee17c4..9714920 100644 --- a/sdk/src/opendecree/_watcher_base.py +++ b/sdk/src/opendecree/_watcher_base.py @@ -20,13 +20,21 @@ class _WatchedFieldBase(Generic[T]): """Common state and helpers shared by WatchedField and AsyncWatchedField.""" - def __init__(self, path: str, type_: type[T], default: T) -> None: + def __init__( + self, + path: str, + type_: type[T], + default: T, + *, + on_callback_error: Callable[[Exception], None] | None = None, + ) -> None: self._path = path self._type = type_ self._default = default self._value: T = default self._is_set = False self._callbacks: list[Callable[[T, T], None]] = [] + self._on_callback_error = on_callback_error @property def path(self) -> str: @@ -62,5 +70,8 @@ def _fire_callbacks(self, old: T, new: T) -> None: for cb in self._callbacks: try: cb(old, new) - except Exception: - _logger.exception("Error in on_change callback for %s", self._path) + except Exception as exc: + if self._on_callback_error is not None: + self._on_callback_error(exc) + else: + _logger.exception("Error in on_change callback for %s", self._path) diff --git a/sdk/src/opendecree/async_watcher.py b/sdk/src/opendecree/async_watcher.py index 09bc2b3..4721b76 100644 --- a/sdk/src/opendecree/async_watcher.py +++ b/sdk/src/opendecree/async_watcher.py @@ -21,7 +21,7 @@ import asyncio import logging import random -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Callable from typing import Any, TypeVar import grpc.aio @@ -47,8 +47,15 @@ class AsyncWatchedField(_WatchedFieldBase[T]): Updated automatically by the watcher's asyncio task. """ - def __init__(self, path: str, type_: type[T], default: T) -> None: - super().__init__(path, type_, default) + def __init__( + self, + path: str, + type_: type[T], + default: T, + *, + on_callback_error: Callable[[Exception], None] | None = None, + ) -> None: + super().__init__(path, type_, default, on_callback_error=on_callback_error) self._change_queue: asyncio.Queue[Change | None] = asyncio.Queue() @property @@ -109,7 +116,14 @@ def __init__( self._task: asyncio.Task | None = None # type: ignore[type-arg] self._stopped = False - def field(self, path: str, type_: type[T], *, default: T) -> AsyncWatchedField[T]: + def field( + self, + path: str, + type_: type[T], + *, + default: T, + on_callback_error: Callable[[Exception], None] | None = None, + ) -> AsyncWatchedField[T]: """Register a field to watch. Must be called before the watcher is started (before __aenter__). @@ -118,13 +132,16 @@ def field(self, path: str, type_: type[T], *, default: T) -> AsyncWatchedField[T path: Dot-separated field path (e.g., "payments.fee"). type_: Python type to convert values to (str, int, float, bool, timedelta). default: Default value when the field is null or not set. + on_callback_error: Optional hook called with the exception when an + on_change callback raises. If not set, the exception is logged. + The hook may re-raise to terminate the watcher's background task. Returns: An AsyncWatchedField that tracks the live value. """ if self._task is not None: raise RuntimeError("Cannot register fields after watcher has started") - watched = AsyncWatchedField(path, type_, default) + watched = AsyncWatchedField(path, type_, default, on_callback_error=on_callback_error) self._fields[path] = watched return watched diff --git a/sdk/src/opendecree/watcher.py b/sdk/src/opendecree/watcher.py index f7c2cc2..32a295d 100644 --- a/sdk/src/opendecree/watcher.py +++ b/sdk/src/opendecree/watcher.py @@ -20,7 +20,7 @@ import random import threading import time -from collections.abc import Iterator +from collections.abc import Callable, Iterator from typing import Any, TypeVar import grpc @@ -46,8 +46,15 @@ class WatchedField(_WatchedFieldBase[T]): Attributes are updated automatically by the watcher's background thread. """ - def __init__(self, path: str, type_: type[T], default: T) -> None: - super().__init__(path, type_, default) + def __init__( + self, + path: str, + type_: type[T], + default: T, + *, + on_callback_error: Callable[[Exception], None] | None = None, + ) -> None: + super().__init__(path, type_, default, on_callback_error=on_callback_error) self._lock = threading.Lock() self._change_queue: queue.Queue[Change] = queue.Queue() @@ -113,7 +120,14 @@ def __init__(self, stub: Any, pb2: Any, tenant_id: str, timeout: float) -> None: self._stream: grpc.Future | None = None self._stop_event = threading.Event() - def field(self, path: str, type_: type[T], *, default: T) -> WatchedField[T]: + def field( + self, + path: str, + type_: type[T], + *, + default: T, + on_callback_error: Callable[[Exception], None] | None = None, + ) -> WatchedField[T]: """Register a field to watch. Must be called before the watcher is started (before __enter__). @@ -122,13 +136,16 @@ def field(self, path: str, type_: type[T], *, default: T) -> WatchedField[T]: path: Dot-separated field path (e.g., "payments.fee"). type_: Python type to convert values to (str, int, float, bool, timedelta). default: Default value when the field is null or not set. + on_callback_error: Optional hook called with the exception when an + on_change callback raises. If not set, the exception is logged. + The hook may re-raise to terminate the watcher's background loop. Returns: A WatchedField that tracks the live value. """ if self._thread is not None: raise RuntimeError("Cannot register fields after watcher has started") - watched = WatchedField(path, type_, default) + watched = WatchedField(path, type_, default, on_callback_error=on_callback_error) self._fields[path] = watched return watched diff --git a/sdk/tests/test_async_watcher.py b/sdk/tests/test_async_watcher.py index c149112..8c0decb 100644 --- a/sdk/tests/test_async_watcher.py +++ b/sdk/tests/test_async_watcher.py @@ -101,6 +101,45 @@ def bad_cb(old: int, new: int) -> None: f._update("2", change) # should not raise assert f.value == 2 + def test_on_callback_error_hook_is_called(self): + errors: list[Exception] = [] + f = AsyncWatchedField("x", int, 0, on_callback_error=errors.append) + f._load_initial("1") + + @f.on_change + def bad_cb(old: int, new: int) -> None: + raise ValueError("boom") + + change = Change(field_path="x", old_value="1", new_value="2", version=1) + f._update("2", change) + + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + assert str(errors[0]) == "boom" + assert f.value == 2 + + def test_on_callback_error_hook_via_field_method(self): + errors: list[Exception] = [] + stub = MagicMock() + pb2 = MagicMock() + mock_resp = MagicMock() + mock_resp.config.values = [] + stub.GetConfig = AsyncMock(return_value=mock_resp) + + w = AsyncConfigWatcher(stub, pb2, "t1", timeout=5.0) + f = w.field("x", int, default=0, on_callback_error=errors.append) + f._load_initial("1") + + @f.on_change + def bad_cb(old: int, new: int) -> None: + raise RuntimeError("fail") + + change = Change(field_path="x", old_value="1", new_value="2", version=1) + f._update("2", change) + + assert len(errors) == 1 + assert isinstance(errors[0], RuntimeError) + # --- AsyncConfigWatcher unit tests --- diff --git a/sdk/tests/test_watcher.py b/sdk/tests/test_watcher.py index 2f9f107..d609381 100644 --- a/sdk/tests/test_watcher.py +++ b/sdk/tests/test_watcher.py @@ -121,6 +121,49 @@ def bad_cb(old: int, new: int) -> None: f._update("2", change) assert f.value == 2 + def test_on_callback_error_hook_is_called(self): + errors: list[Exception] = [] + f = WatchedField("x", int, 0, on_callback_error=errors.append) + f._load_initial("1") + + @f.on_change + def bad_cb(old: int, new: int) -> None: + raise ValueError("boom") + + from opendecree.types import Change + + change = Change(field_path="x", old_value="1", new_value="2", version=1) + f._update("2", change) + + assert len(errors) == 1 + assert isinstance(errors[0], ValueError) + assert str(errors[0]) == "boom" + assert f.value == 2 + + def test_on_callback_error_hook_via_field_method(self): + errors: list[Exception] = [] + stub = MagicMock() + pb2 = MagicMock() + mock_resp = MagicMock() + mock_resp.config.values = [] + stub.GetConfig.return_value = mock_resp + + w = ConfigWatcher(stub, pb2, "t1", timeout=5.0) + f = w.field("x", int, default=0, on_callback_error=errors.append) + f._load_initial("1") + + @f.on_change + def bad_cb(old: int, new: int) -> None: + raise RuntimeError("fail") + + from opendecree.types import Change + + change = Change(field_path="x", old_value="1", new_value="2", version=1) + f._update("2", change) + + assert len(errors) == 1 + assert isinstance(errors[0], RuntimeError) + # --- ConfigWatcher unit tests ---