diff --git a/src/google/adk/auth/auth_credential.py b/src/google/adk/auth/auth_credential.py index 4a2add823c..747c21c987 100644 --- a/src/google/adk/auth/auth_credential.py +++ b/src/google/adk/auth/auth_credential.py @@ -87,6 +87,7 @@ class OAuth2Auth(BaseModelWithConfig): expires_at: int | None = None expires_in: int | None = None audience: str | None = None + prompt: str | None = None code_verifier: str | None = None code_challenge_method: str | None = None token_endpoint_auth_method: ( diff --git a/src/google/adk/auth/auth_handler.py b/src/google/adk/auth/auth_handler.py index 8e8f5d340b..36c82a7e84 100644 --- a/src/google/adk/auth/auth_handler.py +++ b/src/google/adk/auth/auth_handler.py @@ -197,7 +197,7 @@ def generate_auth_uri( ) params = { "access_type": "offline", - "prompt": "consent", + "prompt": auth_credential.oauth2.prompt or "consent", } if auth_credential.oauth2.audience: params["audience"] = auth_credential.oauth2.audience diff --git a/tests/unittests/auth/test_auth_handler.py b/tests/unittests/auth/test_auth_handler.py index c19a5d93fd..135fc492d5 100644 --- a/tests/unittests/auth/test_auth_handler.py +++ b/tests/unittests/auth/test_auth_handler.py @@ -66,6 +66,8 @@ def create_authorization_url(self, url, **kwargs): params = f"client_id={self.client_id}&scope={self.scope}" if kwargs.get("audience"): params += f"&audience={kwargs.get('audience')}" + if kwargs.get("prompt"): + params += f"&prompt={kwargs.get('prompt')}" return f"{url}?{params}", "mock_state" def fetch_token( @@ -250,6 +252,25 @@ def test_generate_auth_uri_with_audience_and_prompt( result = handler.generate_auth_uri() assert "audience=test_audience" in result.oauth2.auth_uri + assert "prompt=consent" in result.oauth2.auth_uri + + @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) + def test_generate_auth_uri_with_custom_prompt( + self, openid_auth_scheme, oauth2_credentials + ): + """Test generating an auth URI with a custom prompt override.""" + oauth2_credentials.oauth2.prompt = "none" + exchanged = oauth2_credentials.model_copy(deep=True) + + config = AuthConfig( + auth_scheme=openid_auth_scheme, + raw_auth_credential=oauth2_credentials, + exchanged_auth_credential=exchanged, + ) + handler = AuthHandler(config) + result = handler.generate_auth_uri() + + assert "prompt=none" in result.oauth2.auth_uri @patch("google.adk.auth.auth_handler.OAuth2Session", MockOAuth2Session) def test_generate_auth_uri_openid(