diff --git a/simvue/api/objects/alert/base.py b/simvue/api/objects/alert/base.py index 8d1258a0..13a8e960 100644 --- a/simvue/api/objects/alert/base.py +++ b/simvue/api/objects/alert/base.py @@ -235,7 +235,9 @@ def get_status(self, run_id: str) -> typing.Literal["ok", "critical"]: ) _url: URL = self.url / f"status/{run_id}" - _response = sv_get(url=f"{_url}", headers=self._headers) + _response = sv_get( + url=f"{_url}", headers=self._headers, verify=self._user_config.server_verify + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/alert/fetch.py b/simvue/api/objects/alert/fetch.py index 2d6d503f..217f744a 100644 --- a/simvue/api/objects/alert/fetch.py +++ b/simvue/api/objects/alert/fetch.py @@ -169,6 +169,7 @@ def get( f"{_url}", headers=_config.headers, params=_params | kwargs, + verify=_config.server_verify, ) _label: str = cls.__name__.lower() diff --git a/simvue/api/objects/alert/user.py b/simvue/api/objects/alert/user.py index 3e9dedac..40a337ff 100644 --- a/simvue/api/objects/alert/user.py +++ b/simvue/api/objects/alert/user.py @@ -149,6 +149,7 @@ def set_status(self, run_id: str, status: typing.Literal["ok", "critical"]) -> N url=self.url / "status" / run_id, data={"status": status}, headers=self._headers, + verify=self._user_config.server_verify, ) get_json_from_response( diff --git a/simvue/api/objects/artifact/base.py b/simvue/api/objects/artifact/base.py index a1ef4356..c482932d 100644 --- a/simvue/api/objects/artifact/base.py +++ b/simvue/api/objects/artifact/base.py @@ -8,6 +8,7 @@ import http import io import logging +import pathlib import typing import pydantic @@ -101,6 +102,7 @@ def attach_to_run(self, run_id: str, category: Category) -> None: url=f"{_run_artifacts_url}", headers=self._headers, json={"category": category}, + verify=self._user_config.server_verify, ) get_json_from_response( @@ -146,6 +148,7 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: params={}, is_json=False, timeout=timeout, + verify=self.storage_ca_cert, files={"file": file}, data=_fields, ) @@ -157,6 +160,7 @@ def _upload(self, file: io.BytesIO, timeout: int, file_size: int) -> None: headers={}, is_json=False, timeout=timeout, + verify=self.storage_ca_cert, data=file, ) @@ -189,6 +193,15 @@ def _get( **kwargs, ) + @property + def storage_ca_cert(self) -> str | bool: + """Return current storage CA certificate.""" + _ca_cert: pathlib.Path | bool = ( + self._user_config.server.certificates.storage_ca_cert + ) + + return f"{_ca_cert}" if isinstance(_ca_cert, pathlib.Path) else _ca_cert + @property def checksum(self) -> str: """Retrieve the checksum for this artifact. @@ -327,7 +340,9 @@ def get_category(self, run_id: str) -> Category: URL(self._user_config.server.url) / f"runs/{run_id}/artifacts/{self._identifier}" ) - _response = sv_get(url=_run_url, header=self._headers) + _response = sv_get( + url=_run_url, header=self._headers, verify=self._user_config.server_verify + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK, http.HTTPStatus.NOT_FOUND], @@ -367,6 +382,7 @@ def download_content(self) -> Generator[bytes]: _response = sv_get( f"{self.download_url}", timeout=_timeout, + verify=self.storage_ca_cert, headers=None, ) diff --git a/simvue/api/objects/artifact/fetch.py b/simvue/api/objects/artifact/fetch.py index b15f7f77..dfb4859c 100644 --- a/simvue/api/objects/artifact/fetch.py +++ b/simvue/api/objects/artifact/fetch.py @@ -133,7 +133,10 @@ def from_run( ) _url = URL(f"{_config.server.url}") / f"runs/{run_id}/artifacts" _response = sv_get( - url=f"{_url}", params={"category": category}, headers=_config.headers + url=f"{_url}", + params={"category": category}, + headers=_config.headers, + verify=_config.server_verify, ) _json_response = get_json_from_response( expected_type=list, @@ -196,7 +199,10 @@ def from_name( ) _url = URL(f"{_config.server.url}") / f"runs/{run_id}/artifacts" _response = sv_get( - url=f"{_url}", params={"name": name}, headers=_config.headers + url=f"{_url}", + params={"name": name}, + headers=_config.headers, + verify=_config.server_verify, ) _json_response = get_json_from_response( expected_type=list, @@ -275,6 +281,7 @@ def get( _url, headers=_config.headers, params=_params | kwargs, + verify=_config.server_verify, ) _label: str = cls.__name__.lower() _label = _label.replace("base", "") diff --git a/simvue/api/objects/base.py b/simvue/api/objects/base.py index 4a729571..6fe31a24 100644 --- a/simvue/api/objects/base.py +++ b/simvue/api/objects/base.py @@ -658,6 +658,7 @@ def _post_batch( headers=self._headers | {"Content-Type": "application/msgpack"}, params=self._params or {}, data=batch_data, + verify=self._user_config.server_verify, is_json=True, ) @@ -698,6 +699,7 @@ def _post_single( params=self._params or {}, data=data or kwargs, is_json=is_json, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.FORBIDDEN: @@ -735,7 +737,11 @@ def _put(self, **kwargs) -> dict[str, typing.Any]: _ = kwargs.pop(key, None) _response = sv_put( - url=f"{self.url}", headers=self._headers, data=kwargs, is_json=True + url=f"{self.url}", + headers=self._headers, + data=kwargs, + is_json=True, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.FORBIDDEN: @@ -769,7 +775,12 @@ def delete(self, **kwargs) -> dict[str, typing.Any]: if not self.url: raise RuntimeError(f"Identifier for instance of {self.label()} Unknown") - _response = sv_delete(url=f"{self.url}", headers=self._headers, params=kwargs) + _response = sv_delete( + url=f"{self.url}", + headers=self._headers, + params=kwargs, + verify=self._user_config.server_verify, + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK, http.HTTPStatus.NO_CONTENT], @@ -789,7 +800,10 @@ def _get( raise RuntimeError(f"Identifier for instance of {self.label()} Unknown") _response = sv_get( - url=f"{url or self.url}", headers=self._headers, params=kwargs + url=f"{url or self.url}", + headers=self._headers, + params=kwargs, + verify=self._user_config.server_verify, ) if _response.status_code == http.HTTPStatus.NOT_FOUND: diff --git a/simvue/api/objects/events.py b/simvue/api/objects/events.py index 066c635e..5c9e7355 100644 --- a/simvue/api/objects/events.py +++ b/simvue/api/objects/events.py @@ -193,6 +193,7 @@ def histogram( _response = sv_get( url=_url, headers=self._headers, + verify=self._user_config.server_verify, params={ "run": self._run_id, "window": window, diff --git a/simvue/api/objects/folder.py b/simvue/api/objects/folder.py index 03943336..6734e7d7 100644 --- a/simvue/api/objects/folder.py +++ b/simvue/api/objects/folder.py @@ -295,7 +295,10 @@ def _set_favourite(self, *, starred: bool) -> dict: """Set starred status.""" _url = self.url / "starred" _response = sv_put( - f"{_url}", headers=self._user_config.headers, data={"starred": starred} + f"{_url}", + headers=self._user_config.headers, + data={"starred": starred}, + verify=self._user_config.server_verify, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/grids.py b/simvue/api/objects/grids.py index de0f156f..33c24f38 100644 --- a/simvue/api/objects/grids.py +++ b/simvue/api/objects/grids.py @@ -101,6 +101,7 @@ def attach_metric_for_run(self, run_id: str, metric_name: str) -> None: _response = sv_put( url=f"{self.run_data_url(run_id)}", headers=self._headers, + verify=self._user_config.server_verify, json={"metric": metric_name}, ) @@ -240,6 +241,7 @@ def get_run_metric_values( _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'values'}", headers=self._headers, + verify=self._user_config.server_verify, params={"step": step}, ) @@ -271,6 +273,7 @@ def get_run_metric_span(self, *, run_id: str, metric_name: str) -> dict: """ _response = sv_get( url=f"{self.run_metric_url(run_id, metric_name) / 'span'}", + verify=self._user_config.server_verify, headers=self._headers, ) @@ -476,6 +479,7 @@ def _log_values(self, metrics: list[GridMetricSet]) -> None: url=f"{self._user_config.server.url}/{self.run_grids_endpoint(self._run_id)}", headers=self._headers | {"Content-Type": "application/msgpack"}, data=msgpack.packb(metrics, use_bin_type=True), + verify=self._user_config.server_verify, is_json=False, params={}, ) diff --git a/simvue/api/objects/metrics.py b/simvue/api/objects/metrics.py index 12b20f2c..4dc8f983 100644 --- a/simvue/api/objects/metrics.py +++ b/simvue/api/objects/metrics.py @@ -159,7 +159,12 @@ def get( def span(self, run_ids: list[str]) -> dict[str, int | float]: """Returns the metrics span for the given runs""" _url = self._base_url / "span" - _response = sv_get(url=f"{_url}", headers=self._headers, json=run_ids) + _response = sv_get( + url=f"{_url}", + headers=self._headers, + json=run_ids, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -171,7 +176,10 @@ def names(self, run_ids: list[str]) -> list[str]: """Returns the metric names for the given runs""" _url = self._base_url / "names" _response = sv_get( - url=f"{_url}", headers=self._headers, params={"runs": json.dumps(run_ids)} + url=f"{_url}", + headers=self._headers, + params={"runs": json.dumps(run_ids)}, + verify=self._user_config.server_verify, ) return get_json_from_response( response=_response, diff --git a/simvue/api/objects/run.py b/simvue/api/objects/run.py index b4d894f8..f319ae15 100644 --- a/simvue/api/objects/run.py +++ b/simvue/api/objects/run.py @@ -582,7 +582,10 @@ def _set_favourite(self, *, starred: bool) -> dict: """Set starred status.""" _url = self.url / "starred" _response = sv_put( - f"{_url}", headers=self._user_config.headers, data={"starred": starred} + f"{_url}", + headers=self._user_config.headers, + data={"starred": starred}, + verify=self._user_config.server_verify, ) return get_json_from_response( expected_status=[http.HTTPStatus.OK], @@ -670,7 +673,12 @@ def send_heartbeat(self) -> dict[str, typing.Any] | None: _url = self._base_url _url /= f"{self._identifier}/heartbeat" - _response = sv_put(f"{_url}", headers=self._headers, data={}) + _response = sv_put( + f"{_url}", + headers=self._headers, + data={}, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -709,7 +717,11 @@ def abort_trigger(self) -> bool: if self._offline or not self._identifier: return False - _response = sv_get(f"{self._abort_url}", headers=self._headers) + _response = sv_get( + f"{self._abort_url}", + headers=self._headers, + verify=self._user_config.server_verify, + ) _json_response = get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], @@ -729,7 +741,11 @@ def artifacts(self) -> list[dict[str, typing.Any]]: if self._offline or not self._artifact_url: return [] - _response = sv_get(url=self._artifact_url, headers=self._headers) + _response = sv_get( + url=self._artifact_url, + headers=self._headers, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, @@ -750,7 +766,11 @@ def grids(self) -> list[dict[str, str]]: if self._offline or not self._grid_url: return [] - _response = sv_get(url=self._grid_url, headers=self._headers) + _response = sv_get( + url=self._grid_url, + headers=self._headers, + verify=self._user_config.server_verify, + ) return get_json_from_response( response=_response, @@ -778,7 +798,10 @@ def abort(self, reason: str) -> dict[str, typing.Any]: raise RuntimeError("Cannot abort run, no endpoint defined") _response = sv_put( - f"{self._abort_url}", headers=self._headers, data={"reason": reason} + f"{self._abort_url}", + headers=self._headers, + data={"reason": reason}, + verify=self._user_config.server_verify, ) return get_json_from_response( diff --git a/simvue/api/objects/stats.py b/simvue/api/objects/stats.py index a121d4db..bc78803d 100644 --- a/simvue/api/objects/stats.py +++ b/simvue/api/objects/stats.py @@ -137,7 +137,9 @@ def whoami(self) -> dict[str, str]: server response for 'whomai' query. """ _url: URL = URL(self._user_config.server.url) / "whoami" - _response = sv_get(url=f"{_url}", headers=self._headers) + _response = sv_get( + url=f"{_url}", headers=self._headers, verify=self._user_config.server_verify + ) return get_json_from_response( response=_response, expected_status=[http.HTTPStatus.OK], diff --git a/simvue/api/objects/storage/fetch.py b/simvue/api/objects/storage/fetch.py index 1468e9f3..b5aee965 100644 --- a/simvue/api/objects/storage/fetch.py +++ b/simvue/api/objects/storage/fetch.py @@ -100,6 +100,7 @@ def get( _url, headers=_class_instance._headers, params={"start": offset, "count": count} | kwargs, + verify=_class_instance._user_config.server_verify, ) _label: str = _class_instance.__class__.__name__.lower() _label = _label.replace("base", "") diff --git a/simvue/api/request.py b/simvue/api/request.py index 5a568588..540fdf85 100644 --- a/simvue/api/request.py +++ b/simvue/api/request.py @@ -65,6 +65,7 @@ def post( data: typing.Any, is_json: bool = True, timeout: int | None = None, + verify: str | bool = True, files: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP POST with retries @@ -101,6 +102,7 @@ def post( data=data_sent, timeout=timeout, files=files, + verify=verify, ) if response.status_code == http.HTTPStatus.UNPROCESSABLE_ENTITY: @@ -135,6 +137,7 @@ def put( data: dict[str, typing.Any] | None = None, json: dict[str, typing.Any] | None = None, is_json: bool = True, + verify: bool | str = True, timeout: int = DEFAULT_API_TIMEOUT, ) -> requests.Response: """HTTP PUT with retries @@ -168,7 +171,12 @@ def put( logging.debug(f"PUT: {url}\n\tdata={data_sent}\n\tjson={json}") response = requests.put( - url, headers=headers, data=data_sent, timeout=timeout, json=json + url, + headers=headers, + data=data_sent, + timeout=timeout, + json=json, + verify=verify, ) if response.status_code in RETRY_STATUSES: @@ -196,6 +204,7 @@ def get( headers: dict[str, str] | None = None, params: dict[str, str | int | float | None] | None = None, timeout: int = DEFAULT_API_TIMEOUT, + verify: str | bool = True, json: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP GET @@ -218,7 +227,7 @@ def get( """ logging.debug(f"GET: {url}\n\tparams={params}") response = requests.get( - url, headers=headers, timeout=timeout, params=params, json=json + url, headers=headers, timeout=timeout, params=params, json=json, verify=verify ) if response.status_code in RETRY_STATUSES: @@ -245,6 +254,7 @@ def delete( url: str, headers: dict[str, str], timeout: int = DEFAULT_API_TIMEOUT, + verify: str | bool = True, params: dict[str, typing.Any] | None = None, ) -> requests.Response: """HTTP DELETE @@ -266,7 +276,9 @@ def delete( response from executing DELETE """ logging.debug(f"DELETE: {url}\n\tparams={params}") - response = requests.delete(url, headers=headers, timeout=timeout, params=params) + response = requests.delete( + url, headers=headers, timeout=timeout, params=params, verify=verify + ) if response.status_code in RETRY_STATUSES: raise RetryableHTTPError( diff --git a/simvue/config/parameters.py b/simvue/config/parameters.py index 67f67d79..ffa790dc 100644 --- a/simvue/config/parameters.py +++ b/simvue/config/parameters.py @@ -21,6 +21,29 @@ logger = logging.getLogger(__file__) +class CertificateSpecifications(pydantic.BaseModel): + storage_ca_cert: pydantic.FilePath | bool = True + server_ca_cert: pydantic.FilePath | bool = True + client_cert: pydantic.FilePath | None = None + client_key: pydantic.SecretStr | None = None + + @pydantic.model_validator(mode="before") + @classmethod + def check_for_cert_env( + cls, values: dict[str, pathlib.Path | None | str] + ) -> dict[str, pathlib.Path | None | str]: + """Check for CA certificate for storage specification in environment.""" + if ( + _env_ca_cert := os.environ.get("SIMVUE_STORAGE_CA_CERTIFICATE") + ) is not None: + values["storage_ca_cert"] = _env_ca_cert + if (_env_ca_cert := os.environ.get("SIMVUE_SERVER_CA_CERTIFICATE")) is not None: + values["server_ca_cert"] = _env_ca_cert + if _env_client_cert := os.environ.get("SIMVUE_SERVER_CLIENT_CERTIFICATE"): + values["client_cert"] = _env_ca_cert + return values + + class ServerSpecifications(pydantic.BaseModel): model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict( extra="forbid", @@ -29,6 +52,9 @@ class ServerSpecifications(pydantic.BaseModel): url: pydantic.AnyHttpUrl | None token: pydantic.SecretStr | None env: dict[str, str] | None = None + certificates: CertificateSpecifications = pydantic.Field( + default_factory=CertificateSpecifications + ) @pydantic.field_validator("url") @classmethod @@ -76,7 +102,7 @@ class DefaultRunSpecifications(pydantic.BaseModel): name: str | None = None description: str | None = None tags: list[str] | None = None - folder: str = pydantic.Field("/", pattern=sv_models.FOLDER_REGEX) + folder: str = pydantic.Field(default="/", pattern=sv_models.FOLDER_REGEX) metadata: dict[str, str | int | float | bool] | None = None mode: typing.Literal["offline", "disabled", "online"] = "online" record_shell_vars: list[str] | None = None diff --git a/simvue/config/user.py b/simvue/config/user.py index 067298e7..09c48aee 100644 --- a/simvue/config/user.py +++ b/simvue/config/user.py @@ -69,6 +69,12 @@ class SimvueConfiguration(pydantic.BaseModel): current_profile: str | None = None _server_version: semver.Version | None = None + @property + def server_verify(self) -> str | bool: + """Return current server CA certificate.""" + _ca_cert: pathlib.Path | bool = self.server.certificates.server_ca_cert + return f"{_ca_cert}" if isinstance(_ca_cert, pathlib.Path) else _ca_cert + @property def server_version(self) -> semver.Version: """Retrieve current Server version.""" @@ -111,7 +117,11 @@ def _load_pyproject_configs(cls) -> dict | None: @classmethod @functools.lru_cache def _check_server( - cls, token: str, url: str, mode: typing.Literal["offline", "online", "disabled"] + cls, + token: str, + url: str, + verify: str | bool, + mode: typing.Literal["offline", "online", "disabled"], ) -> semver.Version | None: if mode in ("offline", "disabled"): return None @@ -122,7 +132,7 @@ def _check_server( } try: _url = URL(url) / "version" - _response = sv_get(f"{_url}", headers) + _response = sv_get(f"{_url}", headers, verify=verify) if _response.status_code == http.HTTPStatus.UNAUTHORIZED: raise AssertionError("Unauthorised token") @@ -168,7 +178,10 @@ def check_valid_server(self) -> Self: raise ValueError("No token provided.") self._server_version = self._check_server( - self.server.token.get_secret_value(), self.server.url, self.run.mode + token=self.server.token.get_secret_value(), + url=self.server.url, + verify=self.server_verify, + mode=self.run.mode, ) return self diff --git a/simvue/sender/actions.py b/simvue/sender/actions.py index a01949b1..5721c1d4 100644 --- a/simvue/sender/actions.py +++ b/simvue/sender/actions.py @@ -963,6 +963,7 @@ def _single_item_upload( _response: requests.Response = sv_put( url=f"{_local_config.server.url}/runs/{_online_id}/heartbeat", headers=_local_config.headers, + verify=_local_config.server_verify, ) try: