[Store] Add generic tensor file I/O over fsspec#2237
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new module mooncake/store_file_io.py that adds file and filesystem-aware serialization and deserialization support (supporting both standard PyTorch and SafeTensors formats) to MooncakeDistributedStore via dynamic patching. It also includes comprehensive unit tests for these features. The review feedback highlights several important improvements: ensuring tensors are contiguous before saving with safetensors to prevent failures, respecting the map_location parameter when loading safetensors files, normalizing path separators to avoid directory creation failures on Windows, and improving fallback tensor name matching by extracting the base name of the file path.
| def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes: | ||
| if format_name == "safetensors": | ||
| from safetensors.torch import save as safetensors_save | ||
|
|
||
| return safetensors_save({tensor_name: tensor}) |
There was a problem hiding this comment.
safetensors requires tensors to be contiguous. If a non-contiguous tensor (e.g., a sliced or transposed KV cache tensor) is passed to safetensors_save, it will raise a ValueError and fail the save operation. Ensuring the tensor is contiguous before saving prevents this failure.
| def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes: | |
| if format_name == "safetensors": | |
| from safetensors.torch import save as safetensors_save | |
| return safetensors_save({tensor_name: tensor}) | |
| def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes: | |
| if format_name == "safetensors": | |
| from safetensors.torch import save as safetensors_save | |
| if hasattr(tensor, "is_contiguous") and not tensor.is_contiguous(): | |
| tensor = tensor.contiguous() | |
| return safetensors_save({tensor_name: tensor}) |
| if format_name == "safetensors": | ||
| from safetensors.torch import load as safetensors_load | ||
|
|
||
| loaded_tensors = safetensors_load(payload) | ||
| return _pick_tensor_entry(loaded_tensors, tensor_name, store_key) |
There was a problem hiding this comment.
The map_location parameter is ignored when loading safetensors files, meaning the loaded tensors will always remain on the default device (usually CPU). This creates an inconsistency with the torch format and can lead to silent performance degradation or device mismatch errors if the caller expected the tensor to be loaded directly to a GPU. Moving the tensor to map_location if specified resolves this.
| if format_name == "safetensors": | |
| from safetensors.torch import load as safetensors_load | |
| loaded_tensors = safetensors_load(payload) | |
| return _pick_tensor_entry(loaded_tensors, tensor_name, store_key) | |
| if format_name == "safetensors": | |
| from safetensors.torch import load as safetensors_load | |
| loaded_tensors = safetensors_load(payload) | |
| tensor = _pick_tensor_entry(loaded_tensors, tensor_name, store_key) | |
| if map_location is not None and hasattr(tensor, "to"): | |
| tensor = tensor.to(map_location) | |
| return tensor |
| storage_options: dict[str, Any] | None, | ||
| ) -> None: | ||
| fs, path = _open_fs_target(file_name, filesystem, storage_options) | ||
| parent_dir = posixpath.dirname(path) |
There was a problem hiding this comment.
Using posixpath.dirname directly on path can fail to extract the parent directory on Windows if the path contains backslashes (\\). This will cause fs.makedirs to be skipped, potentially leading to a FileNotFoundError when writing the file. Normalizing the path separators to forward slashes resolves this issue.
| parent_dir = posixpath.dirname(path) | |
| parent_dir = posixpath.dirname(str(path).replace("\\", "/")) |
| if fallback_name and fallback_name in loaded_tensors: | ||
| return loaded_tensors[fallback_name] |
There was a problem hiding this comment.
When loading a tensor without specifying a key or tensor_name, fallback_name defaults to the full file path or URL. This will not match the key stored inside the safetensors file (which is usually just the base name of the tensor), leading to a warning. Extracting the base name of the file (without extension) as an additional fallback avoids this warning and improves usability.
if fallback_name:
if fallback_name in loaded_tensors:
return loaded_tensors[fallback_name]
base_name = posixpath.splitext(posixpath.basename(str(fallback_name).replace("\\", "/")))[0]
if base_name in loaded_tensors:
return loaded_tensors[base_name]There was a problem hiding this comment.
Pull request overview
Adds filesystem-agnostic tensor persistence APIs to the Python MooncakeDistributedStore client by routing reads/writes through fsspec, and expands the wheel test suite to cover standard torch files, file URIs, and an in-memory fsspec backend.
Changes:
- Introduces
mooncake.store_file_ioto patchMooncakeDistributedStorewithsave_tensor_to_file/load_tensor_from_fileplus safetensors helpers. - Calls the patch hook at
mooncakepackage import time to expose the new APIs automatically. - Extends
test_safetensor_functions.pywith new persistence scenarios (torch.pt,file://URIs,memory://, KV-cache aliases).
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| mooncake-wheel/tests/test_safetensor_functions.py | Adds new tests for generic tensor file I/O across local paths, file URIs, and an in-memory fsspec backend. |
| mooncake-wheel/mooncake/store_file_io.py | Implements fsspec-backed read/write + (de)serialization and monkey-patches MooncakeDistributedStore with new persistence methods. |
| mooncake-wheel/mooncake/init.py | Auto-enables the patching hook at import time so the new methods are available by default. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| def test_save_and_load_tensor_with_memory_filesystem(self): | ||
| """Test generic tensor persistence through a non-local fsspec backend.""" | ||
| import fsspec |
| if format_name == "safetensors": | ||
| from safetensors.torch import load as safetensors_load | ||
|
|
||
| loaded_tensors = safetensors_load(payload) | ||
| return _pick_tensor_entry(loaded_tensors, tensor_name, store_key) | ||
|
|
||
| import torch | ||
|
|
||
| return torch.load(io.BytesIO(payload), map_location=map_location) | ||
|
|
| def _read_bytes( | ||
| file_name: os.PathLike[str] | str, | ||
| filesystem: str | None, | ||
| storage_options: dict[str, Any] | None, | ||
| ) -> bytes: | ||
| fs, path = _open_fs_target(file_name, filesystem, storage_options) | ||
| with fs.open(path, "rb") as handle: | ||
| return handle.read() | ||
|
|
||
|
|
||
| def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes: | ||
| if format_name == "safetensors": | ||
| from safetensors.torch import save as safetensors_save | ||
|
|
||
| return safetensors_save({tensor_name: tensor}) | ||
|
|
||
| import torch | ||
|
|
||
| buffer = io.BytesIO() | ||
| torch.save(tensor, buffer) | ||
| return buffer.getvalue() |
| def _open_fs_target( | ||
| file_name: os.PathLike[str] | str, | ||
| filesystem: str | None, | ||
| storage_options: dict[str, Any] | None, | ||
| ): | ||
| import fsspec | ||
|
|
||
| target_url = _build_target_url(file_name, filesystem) | ||
| options = dict(storage_options or {}) | ||
| fs, path = fsspec.core.url_to_fs(target_url, **options) | ||
| return fs, path | ||
|
|
|
|
||
| patch_store_file_io_support() |
| time.sleep(default_kv_lease_ttl / 1000) | ||
| self.store.remove(original_key) | ||
| self.store.remove(restored_key) |
|
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Description
Module
mooncake-transfer-engine)mooncake-store)mooncake-ep)mooncake-integration)mooncake-p2p-store)mooncake-wheel)mooncake-pg)mooncake-rl)Type of Change
How Has This Been Tested?
Checklist
./scripts/code_format.shbefore submitting.