Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions sdk/src/opendecree/_watcher_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,21 @@
class _WatchedFieldBase(Generic[T]):
"""Common state and helpers shared by WatchedField and AsyncWatchedField."""

def __init__(self, path: str, type_: type[T], default: T) -> None:
def __init__(
self,
path: str,
type_: type[T],
default: T,
*,
on_callback_error: Callable[[Exception], None] | None = None,
) -> None:
self._path = path
self._type = type_
self._default = default
self._value: T = default
self._is_set = False
self._callbacks: list[Callable[[T, T], None]] = []
self._on_callback_error = on_callback_error

@property
def path(self) -> str:
Expand Down Expand Up @@ -62,5 +70,8 @@ def _fire_callbacks(self, old: T, new: T) -> None:
for cb in self._callbacks:
try:
cb(old, new)
except Exception:
_logger.exception("Error in on_change callback for %s", self._path)
except Exception as exc:
if self._on_callback_error is not None:
self._on_callback_error(exc)
else:
_logger.exception("Error in on_change callback for %s", self._path)
27 changes: 22 additions & 5 deletions sdk/src/opendecree/async_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import asyncio
import logging
import random
from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Callable
from typing import Any, TypeVar

import grpc.aio
Expand All @@ -47,8 +47,15 @@ class AsyncWatchedField(_WatchedFieldBase[T]):
Updated automatically by the watcher's asyncio task.
"""

def __init__(self, path: str, type_: type[T], default: T) -> None:
super().__init__(path, type_, default)
def __init__(
self,
path: str,
type_: type[T],
default: T,
*,
on_callback_error: Callable[[Exception], None] | None = None,
) -> None:
super().__init__(path, type_, default, on_callback_error=on_callback_error)
self._change_queue: asyncio.Queue[Change | None] = asyncio.Queue()

@property
Expand Down Expand Up @@ -109,7 +116,14 @@ def __init__(
self._task: asyncio.Task | None = None # type: ignore[type-arg]
self._stopped = False

def field(self, path: str, type_: type[T], *, default: T) -> AsyncWatchedField[T]:
def field(
self,
path: str,
type_: type[T],
*,
default: T,
on_callback_error: Callable[[Exception], None] | None = None,
) -> AsyncWatchedField[T]:
"""Register a field to watch.

Must be called before the watcher is started (before __aenter__).
Expand All @@ -118,13 +132,16 @@ def field(self, path: str, type_: type[T], *, default: T) -> AsyncWatchedField[T
path: Dot-separated field path (e.g., "payments.fee").
type_: Python type to convert values to (str, int, float, bool, timedelta).
default: Default value when the field is null or not set.
on_callback_error: Optional hook called with the exception when an
on_change callback raises. If not set, the exception is logged.
The hook may re-raise to terminate the watcher's background task.

Returns:
An AsyncWatchedField that tracks the live value.
"""
if self._task is not None:
raise RuntimeError("Cannot register fields after watcher has started")
watched = AsyncWatchedField(path, type_, default)
watched = AsyncWatchedField(path, type_, default, on_callback_error=on_callback_error)
self._fields[path] = watched
return watched

Expand Down
27 changes: 22 additions & 5 deletions sdk/src/opendecree/watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import random
import threading
import time
from collections.abc import Iterator
from collections.abc import Callable, Iterator
from typing import Any, TypeVar

import grpc
Expand All @@ -46,8 +46,15 @@ class WatchedField(_WatchedFieldBase[T]):
Attributes are updated automatically by the watcher's background thread.
"""

def __init__(self, path: str, type_: type[T], default: T) -> None:
super().__init__(path, type_, default)
def __init__(
self,
path: str,
type_: type[T],
default: T,
*,
on_callback_error: Callable[[Exception], None] | None = None,
) -> None:
super().__init__(path, type_, default, on_callback_error=on_callback_error)
self._lock = threading.Lock()
self._change_queue: queue.Queue[Change] = queue.Queue()

Expand Down Expand Up @@ -113,7 +120,14 @@ def __init__(self, stub: Any, pb2: Any, tenant_id: str, timeout: float) -> None:
self._stream: grpc.Future | None = None
self._stop_event = threading.Event()

def field(self, path: str, type_: type[T], *, default: T) -> WatchedField[T]:
def field(
self,
path: str,
type_: type[T],
*,
default: T,
on_callback_error: Callable[[Exception], None] | None = None,
) -> WatchedField[T]:
"""Register a field to watch.

Must be called before the watcher is started (before __enter__).
Expand All @@ -122,13 +136,16 @@ def field(self, path: str, type_: type[T], *, default: T) -> WatchedField[T]:
path: Dot-separated field path (e.g., "payments.fee").
type_: Python type to convert values to (str, int, float, bool, timedelta).
default: Default value when the field is null or not set.
on_callback_error: Optional hook called with the exception when an
on_change callback raises. If not set, the exception is logged.
The hook may re-raise to terminate the watcher's background loop.

Returns:
A WatchedField that tracks the live value.
"""
if self._thread is not None:
raise RuntimeError("Cannot register fields after watcher has started")
watched = WatchedField(path, type_, default)
watched = WatchedField(path, type_, default, on_callback_error=on_callback_error)
self._fields[path] = watched
return watched

Expand Down
39 changes: 39 additions & 0 deletions sdk/tests/test_async_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,45 @@ def bad_cb(old: int, new: int) -> None:
f._update("2", change) # should not raise
assert f.value == 2

def test_on_callback_error_hook_is_called(self):
errors: list[Exception] = []
f = AsyncWatchedField("x", int, 0, on_callback_error=errors.append)
f._load_initial("1")

@f.on_change
def bad_cb(old: int, new: int) -> None:
raise ValueError("boom")

change = Change(field_path="x", old_value="1", new_value="2", version=1)
f._update("2", change)

assert len(errors) == 1
assert isinstance(errors[0], ValueError)
assert str(errors[0]) == "boom"
assert f.value == 2

def test_on_callback_error_hook_via_field_method(self):
errors: list[Exception] = []
stub = MagicMock()
pb2 = MagicMock()
mock_resp = MagicMock()
mock_resp.config.values = []
stub.GetConfig = AsyncMock(return_value=mock_resp)

w = AsyncConfigWatcher(stub, pb2, "t1", timeout=5.0)
f = w.field("x", int, default=0, on_callback_error=errors.append)
f._load_initial("1")

@f.on_change
def bad_cb(old: int, new: int) -> None:
raise RuntimeError("fail")

change = Change(field_path="x", old_value="1", new_value="2", version=1)
f._update("2", change)

assert len(errors) == 1
assert isinstance(errors[0], RuntimeError)


# --- AsyncConfigWatcher unit tests ---

Expand Down
43 changes: 43 additions & 0 deletions sdk/tests/test_watcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,49 @@ def bad_cb(old: int, new: int) -> None:
f._update("2", change)
assert f.value == 2

def test_on_callback_error_hook_is_called(self):
errors: list[Exception] = []
f = WatchedField("x", int, 0, on_callback_error=errors.append)
f._load_initial("1")

@f.on_change
def bad_cb(old: int, new: int) -> None:
raise ValueError("boom")

from opendecree.types import Change

change = Change(field_path="x", old_value="1", new_value="2", version=1)
f._update("2", change)

assert len(errors) == 1
assert isinstance(errors[0], ValueError)
assert str(errors[0]) == "boom"
assert f.value == 2

def test_on_callback_error_hook_via_field_method(self):
errors: list[Exception] = []
stub = MagicMock()
pb2 = MagicMock()
mock_resp = MagicMock()
mock_resp.config.values = []
stub.GetConfig.return_value = mock_resp

w = ConfigWatcher(stub, pb2, "t1", timeout=5.0)
f = w.field("x", int, default=0, on_callback_error=errors.append)
f._load_initial("1")

@f.on_change
def bad_cb(old: int, new: int) -> None:
raise RuntimeError("fail")

from opendecree.types import Change

change = Change(field_path="x", old_value="1", new_value="2", version=1)
f._update("2", change)

assert len(errors) == 1
assert isinstance(errors[0], RuntimeError)


# --- ConfigWatcher unit tests ---

Expand Down