Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,46 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
target_type = type_hints.get(param_name, param.annotation)
if target_type != inspect.Parameter.empty:

# Handle Optional[PydanticModel] types
# Handle Optional[PydanticModel] and Union[Model, Model, ...] types
if get_origin(param.annotation) is Union:
union_args = get_args(param.annotation)
# Find the non-None type in Optional[T] (which is Union[T, None])
# Separate None from the concrete types in the union.
non_none_types = [
arg for arg in union_args if arg is not type(None)
]
if len(non_none_types) == 1:
# Simple Optional[T] — unwrap and let the single-model path
# below handle the conversion.
target_type = non_none_types[0]
elif len(non_none_types) > 1 and isinstance(args[param_name], dict):
# Union[ModelA, ModelB, ...] — if the value is already an
# instance of one of the union members, leave it alone.
# Otherwise try each Pydantic branch in declaration order
# until one validates successfully.
if not any(
isinstance(args[param_name], t)
for t in non_none_types
if inspect.isclass(t)
):
pydantic_types = [
t
for t in non_none_types
if inspect.isclass(t) and issubclass(t, pydantic.BaseModel)
]
for candidate_type in pydantic_types:
try:
converted_args[param_name] = candidate_type.model_validate(
args[param_name]
)
break
except Exception:
continue
else:
logger.warning(
f"Failed to convert argument '{param_name}' to any of"
f' the Union Pydantic types: {pydantic_types}'
)
continue

# Check if the target type is a Pydantic model
if inspect.isclass(target_type) and issubclass(
Expand Down
111 changes: 111 additions & 0 deletions tests/unittests/tools/test_function_tool_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# Pydantic model conversion tests

from typing import Optional
from typing import Union
from unittest.mock import MagicMock

from google.adk.agents.invocation_context import InvocationContext
Expand All @@ -40,6 +41,14 @@ class PreferencesModel(pydantic.BaseModel):
notifications: bool = True


class CompanyModel(pydantic.BaseModel):
"""Test Pydantic model for company data."""

company_name: str
industry: str
employee_count: int


def sync_function_with_pydantic_model(user: UserModel) -> dict:
"""Sync function that takes a Pydantic model."""
return {
Expand Down Expand Up @@ -89,6 +98,17 @@ def function_with_mixed_args(
}


def function_with_union_entity(
entity: Union[UserModel, CompanyModel],
) -> str:
"""Function whose argument may be either of two Pydantic models."""
if isinstance(entity, UserModel):
return f"user:{entity.name}"
if isinstance(entity, CompanyModel):
return f"company:{entity.company_name}"
return f"unexpected:{type(entity)}"


def test_preprocess_args_with_dict_to_pydantic_conversion():
"""Test _preprocess_args converts dict to Pydantic model."""
tool = FunctionTool(sync_function_with_pydantic_model)
Expand Down Expand Up @@ -370,3 +390,94 @@ def place_order(orders: list[UserModel]) -> int:
result = await tool.run_async(args=args, tool_context=tool_context_mock)

assert result == 50


# ---------------------------------------------------------------------------
# Union[ModelA, ModelB] regression tests (issue #5799)
# ---------------------------------------------------------------------------


def test_preprocess_args_union_first_branch_converted():
"""A dict matching the first Union branch is converted to that model."""
tool = FunctionTool(function_with_union_entity)

processed = tool._preprocess_args({"entity": {"name": "Alice", "age": 30}})

assert isinstance(processed["entity"], UserModel), (
f"Expected UserModel, got {type(processed['entity'])}"
)
assert processed["entity"].name == "Alice"


def test_preprocess_args_union_second_branch_converted():
"""A dict matching the second Union branch is converted to that model."""
tool = FunctionTool(function_with_union_entity)

processed = tool._preprocess_args(
{
"entity": {
"company_name": "Acme",
"industry": "tech",
"employee_count": 100,
}
}
)

assert isinstance(processed["entity"], CompanyModel), (
f"Expected CompanyModel, got {type(processed['entity'])}"
)
assert processed["entity"].company_name == "Acme"


def test_preprocess_args_union_already_correct_type_unchanged():
"""If the arg is already a Pydantic instance it must not be re-converted."""
tool = FunctionTool(function_with_union_entity)

user = UserModel(name="Bob", age=25)
processed = tool._preprocess_args({"entity": user})

assert processed["entity"] is user


def test_preprocess_args_union_optional_backward_compat():
"""Optional[T] (Union[T, None]) conversion must still work."""
tool = FunctionTool(function_with_optional_pydantic_model)

processed = tool._preprocess_args({
"user": {"name": "Carol", "age": 40},
"preferences": {"theme": "light"},
})

assert isinstance(processed["user"], UserModel)
assert processed["user"].name == "Carol"
assert isinstance(processed["preferences"], PreferencesModel)


@pytest.mark.asyncio
async def test_run_async_union_end_to_end():
"""run_async end-to-end: Union dict arg must reach the function as a model."""
tool = FunctionTool(function_with_union_entity)

tool_context_mock = MagicMock(spec=ToolContext)
invocation_context_mock = MagicMock(spec=InvocationContext)
session_mock = MagicMock(spec=Session)
invocation_context_mock.session = session_mock
tool_context_mock.invocation_context = invocation_context_mock

result = await tool.run_async(
args={"entity": {"name": "Dave", "age": 20}},
tool_context=tool_context_mock,
)
assert result == "user:Dave"

result2 = await tool.run_async(
args={
"entity": {
"company_name": "Corp",
"industry": "finance",
"employee_count": 500,
}
},
tool_context=tool_context_mock,
)
assert result2 == "company:Corp"