Skip to content

[Store] Add generic tensor file I/O over fsspec#2237

Open
stmatengss wants to merge 1 commit into
kvcache-ai:mainfrom
stmatengss:feature/generic-tensor-file-io
Open

[Store] Add generic tensor file I/O over fsspec#2237
stmatengss wants to merge 1 commit into
kvcache-ai:mainfrom
stmatengss:feature/generic-tensor-file-io

Conversation

@stmatengss
Copy link
Copy Markdown
Collaborator

Description

Module

  • Transfer Engine (mooncake-transfer-engine)
  • Mooncake Store (mooncake-store)
  • Mooncake EP (mooncake-ep)
  • Integration (mooncake-integration)
  • P2P Store (mooncake-p2p-store)
  • Python Wheel (mooncake-wheel)
  • PyTorch Backend (mooncake-pg)
  • Mooncake RL (mooncake-rl)
  • CI/CD
  • Docs
  • Other

Type of Change

  • Bug fix
  • New feature
  • Refactor
  • Breaking change
  • Documentation update
  • Other

How Has This Been Tested?

Checklist

  • I have performed a self-review of my own code.
  • I have formatted my own code using ./scripts/code_format.sh before submitting.
  • I have updated the documentation.
  • I have added tests to prove my changes are effective.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new module mooncake/store_file_io.py that adds file and filesystem-aware serialization and deserialization support (supporting both standard PyTorch and SafeTensors formats) to MooncakeDistributedStore via dynamic patching. It also includes comprehensive unit tests for these features. The review feedback highlights several important improvements: ensuring tensors are contiguous before saving with safetensors to prevent failures, respecting the map_location parameter when loading safetensors files, normalizing path separators to avoid directory creation failures on Windows, and improving fallback tensor name matching by extracting the base name of the file path.

Comment on lines +99 to +103
def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes:
if format_name == "safetensors":
from safetensors.torch import save as safetensors_save

return safetensors_save({tensor_name: tensor})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

safetensors requires tensors to be contiguous. If a non-contiguous tensor (e.g., a sliced or transposed KV cache tensor) is passed to safetensors_save, it will raise a ValueError and fail the save operation. Ensuring the tensor is contiguous before saving prevents this failure.

Suggested change
def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes:
if format_name == "safetensors":
from safetensors.torch import save as safetensors_save
return safetensors_save({tensor_name: tensor})
def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes:
if format_name == "safetensors":
from safetensors.torch import save as safetensors_save
if hasattr(tensor, "is_contiguous") and not tensor.is_contiguous():
tensor = tensor.contiguous()
return safetensors_save({tensor_name: tensor})

Comment on lines +143 to +147
if format_name == "safetensors":
from safetensors.torch import load as safetensors_load

loaded_tensors = safetensors_load(payload)
return _pick_tensor_entry(loaded_tensors, tensor_name, store_key)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The map_location parameter is ignored when loading safetensors files, meaning the loaded tensors will always remain on the default device (usually CPU). This creates an inconsistency with the torch format and can lead to silent performance degradation or device mismatch errors if the caller expected the tensor to be loaded directly to a GPU. Moving the tensor to map_location if specified resolves this.

Suggested change
if format_name == "safetensors":
from safetensors.torch import load as safetensors_load
loaded_tensors = safetensors_load(payload)
return _pick_tensor_entry(loaded_tensors, tensor_name, store_key)
if format_name == "safetensors":
from safetensors.torch import load as safetensors_load
loaded_tensors = safetensors_load(payload)
tensor = _pick_tensor_entry(loaded_tensors, tensor_name, store_key)
if map_location is not None and hasattr(tensor, "to"):
tensor = tensor.to(map_location)
return tensor

storage_options: dict[str, Any] | None,
) -> None:
fs, path = _open_fs_target(file_name, filesystem, storage_options)
parent_dir = posixpath.dirname(path)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using posixpath.dirname directly on path can fail to extract the parent directory on Windows if the path contains backslashes (\\). This will cause fs.makedirs to be skipped, potentially leading to a FileNotFoundError when writing the file. Normalizing the path separators to forward slashes resolves this issue.

Suggested change
parent_dir = posixpath.dirname(path)
parent_dir = posixpath.dirname(str(path).replace("\\", "/"))

Comment on lines +123 to +124
if fallback_name and fallback_name in loaded_tensors:
return loaded_tensors[fallback_name]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When loading a tensor without specifying a key or tensor_name, fallback_name defaults to the full file path or URL. This will not match the key stored inside the safetensors file (which is usually just the base name of the tensor), leading to a warning. Extracting the base name of the file (without extension) as an additional fallback avoids this warning and improves usability.

    if fallback_name:
        if fallback_name in loaded_tensors:
            return loaded_tensors[fallback_name]
        base_name = posixpath.splitext(posixpath.basename(str(fallback_name).replace("\\", "/")))[0]
        if base_name in loaded_tensors:
            return loaded_tensors[base_name]

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds filesystem-agnostic tensor persistence APIs to the Python MooncakeDistributedStore client by routing reads/writes through fsspec, and expands the wheel test suite to cover standard torch files, file URIs, and an in-memory fsspec backend.

Changes:

  • Introduces mooncake.store_file_io to patch MooncakeDistributedStore with save_tensor_to_file / load_tensor_from_file plus safetensors helpers.
  • Calls the patch hook at mooncake package import time to expose the new APIs automatically.
  • Extends test_safetensor_functions.py with new persistence scenarios (torch .pt, file:// URIs, memory://, KV-cache aliases).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 6 comments.

File Description
mooncake-wheel/tests/test_safetensor_functions.py Adds new tests for generic tensor file I/O across local paths, file URIs, and an in-memory fsspec backend.
mooncake-wheel/mooncake/store_file_io.py Implements fsspec-backed read/write + (de)serialization and monkey-patches MooncakeDistributedStore with new persistence methods.
mooncake-wheel/mooncake/init.py Auto-enables the patching hook at import time so the new methods are available by default.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


def test_save_and_load_tensor_with_memory_filesystem(self):
"""Test generic tensor persistence through a non-local fsspec backend."""
import fsspec
Comment on lines +143 to +152
if format_name == "safetensors":
from safetensors.torch import load as safetensors_load

loaded_tensors = safetensors_load(payload)
return _pick_tensor_entry(loaded_tensors, tensor_name, store_key)

import torch

return torch.load(io.BytesIO(payload), map_location=map_location)

Comment on lines +89 to +109
def _read_bytes(
file_name: os.PathLike[str] | str,
filesystem: str | None,
storage_options: dict[str, Any] | None,
) -> bytes:
fs, path = _open_fs_target(file_name, filesystem, storage_options)
with fs.open(path, "rb") as handle:
return handle.read()


def _serialize_tensor(tensor: Any, format_name: str, tensor_name: str) -> bytes:
if format_name == "safetensors":
from safetensors.torch import save as safetensors_save

return safetensors_save({tensor_name: tensor})

import torch

buffer = io.BytesIO()
torch.save(tensor, buffer)
return buffer.getvalue()
Comment on lines +58 to +69
def _open_fs_target(
file_name: os.PathLike[str] | str,
filesystem: str | None,
storage_options: dict[str, Any] | None,
):
import fsspec

target_url = _build_target_url(file_name, filesystem)
options = dict(storage_options or {})
fs, path = fsspec.core.url_to_fs(target_url, **options)
return fs, path

Comment on lines +5 to +6

patch_store_file_io_support()
Comment on lines +369 to +371
time.sleep(default_kv_lease_ttl / 1000)
self.store.remove(original_key)
self.store.remove(restored_key)
@codecov-commenter
Copy link
Copy Markdown

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants