From d5bb146b2f36f14188a4111786ca54a0109ff601 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 15 May 2026 14:23:53 -0400 Subject: [PATCH 1/2] feat: inline encoding for Python aggregate and window UDFs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends the PythonLogicalCodec / PythonPhysicalCodec inline encoding introduced for scalar UDFs to also cover Python-defined aggregate and window UDFs. The cloudpickle tuple shape per family is: DFPYUDA (agg) (name, accumulator_factory, input_schema_bytes, return_schema_bytes, state_schema_bytes, volatility_str) DFPYUDW (window) (name, evaluator_factory, input_schema_bytes, return_schema_bytes, volatility_str) Same wire-framing as scalar (family magic + version byte + cloudpickle blob), same schema serde (arrow-rs native IPC), same cached cloudpickle handle. The agg state schema is encoded as a full IPC schema so the post-decode UDF reports the same names + nullability + metadata as the sender — relevant for accumulators whose StateFieldsArgs consumers key off names rather than positional DataType. Required restructuring two existing UDF impls so the codec can grab the Python callable directly: * udaf.rs: replaces create_udaf + AccumulatorFactoryFunction closure with a named PythonFunctionAggregateUDF that stores the Py accumulator factory. Synthesizes state_{i} field names when the Python constructor passes only Vec; from_parts preserves the full state schema on the decode side. * udwf.rs: renames MultiColumnWindowUDF -> PythonFunctionWindowUDF, drops the PartitionEvaluatorFactory PtrEq wrapper, stores the Py evaluator directly. PartialEq and Hash get the same pointer-identity fast path + debug-log exception handling already on PythonFunctionScalarUDF. User-facing surface: * AggregateUDF.name and WindowUDF.name properties (parallel to the ScalarUDF.name shipped in PR1). * Existing UDAF/UDWF construction paths are unchanged. The per-session with_python_udf_inlining toggle, sender-side context, strict refusal, and user-guide docs land in PRs 3-4 of this series. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/codec.rs | 315 ++++++++++++++++++++++++++++-- crates/core/src/udaf.rs | 181 +++++++++++++++-- crates/core/src/udwf.rs | 113 +++++++---- python/datafusion/expr.py | 25 +-- python/datafusion/ipc.py | 19 +- python/datafusion/user_defined.py | 20 ++ python/tests/test_pickle_expr.py | 116 ++++++++++- 7 files changed, 697 insertions(+), 92 deletions(-) diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index cc038edc9..65557ab27 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -66,12 +66,14 @@ //! | Layer + kind | Family prefix | //! | ----------------------------- | ------------- | //! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonLogicalCodec` agg | `DFPYUDA` | +//! | `PythonLogicalCodec` window | `DFPYUDW` | //! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` agg | `DFPYUDA` | +//! | `PythonPhysicalCodec` window | `DFPYUDW` | //! | User FFI extension codec | user-chosen | //! | Default codec | (none) | //! -//! Aggregate and window UDF families are reserved for follow-on work. -//! //! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported //! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. //! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape @@ -94,8 +96,8 @@ use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; use datafusion::logical_expr::{ - AggregateUDF, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, TypeSignature, - Volatility, WindowUDF, + AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; @@ -105,7 +107,10 @@ use pyo3::prelude::*; use pyo3::sync::PyOnceLock; use pyo3::types::{PyBytes, PyTuple}; +use crate::errors::to_datafusion_err; +use crate::udaf::PythonFunctionAggregateUDF; use crate::udf::PythonFunctionScalarUDF; +use crate::udwf::PythonFunctionWindowUDF; // Wire-format framing for inlined Python UDF payloads. // @@ -126,6 +131,16 @@ use crate::udf::PythonFunctionScalarUDF; /// volatility). pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; +/// Family prefix for an inlined Python aggregate UDF +/// (cloudpickled tuple of name, accumulator factory, input schema, +/// return type, state types schema, volatility). +pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA"; + +/// Family prefix for an inlined Python window UDF +/// (cloudpickled tuple of name, evaluator factory, input schema, +/// return type, volatility). +pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW"; + /// Wire-format version this build emits. pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; @@ -299,18 +314,30 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } self.inner.try_decode_udwf(name, buf) } } @@ -389,18 +416,30 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } self.inner.try_decode_udwf(name, buf) } } @@ -425,12 +464,8 @@ pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) }; Python::attach(|py| -> Result { - let py_version = current_python_version(py) - .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; - let bytes = encode_python_scalar_udf(py, py_udf) - .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; - write_wire_header(buf, PY_SCALAR_UDF_FAMILY, py_version); - buf.extend_from_slice(&bytes); + let bytes = encode_python_scalar_udf(py, py_udf).map_err(to_datafusion_err)?; + append_framed_payload(py, buf, PY_SCALAR_UDF_FAMILY, &bytes)?; Ok(true) }) } @@ -441,14 +476,11 @@ pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) /// `FunctionRegistry`). pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { Python::attach(|py| -> Result>> { - let py_version = current_python_version(py) - .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; - let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF", py_version)? + let Some(payload) = read_framed_payload(py, buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { return Ok(None); }; - let udf = decode_python_scalar_udf(py, payload) - .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + let udf = decode_python_scalar_udf(py, payload).map_err(to_datafusion_err)?; Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) }) } @@ -564,6 +596,11 @@ fn build_single_field_schema_bytes(field: &Field) -> PyResult> { schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) } +/// Emit a multi-field IPC schema blob. +fn build_schema_bytes(fields: Vec) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + /// Decode the per-arg `DataType`s the encoder wrote via /// [`build_input_schema_bytes`]. fn read_input_dtypes(bytes: &[u8]) -> PyResult> { @@ -624,6 +661,37 @@ fn current_python_version(py: Python<'_>) -> PyResult<(u8, u8)> { Ok((major, minor)) } +/// Stamp `buf` with the framing header for `family` plus the current +/// Python `(major, minor)`, then append `payload`. Bundles the +/// `current_python_version` lookup with the header write so each +/// encoder call site stays one line. +fn append_framed_payload( + py: Python<'_>, + buf: &mut Vec, + family: &[u8], + payload: &[u8], +) -> Result<()> { + let py_version = current_python_version(py).map_err(to_datafusion_err)?; + write_wire_header(buf, family, py_version); + buf.extend_from_slice(payload); + Ok(()) +} + +/// Inspect `buf`'s framing against `family` + the current Python +/// `(major, minor)`. Returns `Ok(None)` when `buf` does not carry +/// `family` (caller should delegate); `Ok(Some(payload))` when the +/// framing matches; `Err(_)` for a recognised family at the wrong +/// wire-format or Python version (see [`strip_wire_header`]). +fn read_framed_payload<'a>( + py: Python<'_>, + buf: &'a [u8], + family: &[u8], + kind: &str, +) -> Result> { + let py_version = current_python_version(py).map_err(to_datafusion_err)?; + strip_wire_header(buf, family, kind, py_version) +} + /// Cached handle to the `cloudpickle` module. /// /// The encode/decode helpers above would otherwise re-resolve the @@ -642,6 +710,186 @@ fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { .map(|cached| cached.bind(py).clone()) } +// ============================================================================= +// Shared Python window UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes, +// return_schema_bytes, volatility_str)`. The evaluator factory is the +// Python callable that produces a new evaluator instance per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node.inner().downcast_ref::() else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_window_udf(py, py_udf).map_err(to_datafusion_err)?; + append_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, &bytes)?; + Ok(true) + }) +} + +pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result>> { + Python::attach(|py| -> Result>> { + let Some(payload) = read_framed_payload(py, buf, PY_WINDOW_UDF_FAMILY, "window UDF")? + else { + return Ok(None); + }; + let udf = decode_python_window_udf(py, payload).map_err(to_datafusion_err)?; + Ok(Some(Arc::new(WindowUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult> { + let signature = WindowUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.evaluator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let evaluator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionWindowUDF")? + .data_type() + .clone(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionWindowUDF::new( + name, + evaluator, + input_types, + return_type, + volatility, + )) +} + +// ============================================================================= +// Shared Python aggregate UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes, +// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator +// factory is the Python callable that produces a new accumulator instance +// per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node.inner().downcast_ref::() else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_agg_udf(py, py_udf).map_err(to_datafusion_err)?; + append_framed_payload(py, buf, PY_AGG_UDF_FAMILY, &bytes)?; + Ok(true) + }) +} + +pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result>> { + Python::attach(|py| -> Result>> { + let Some(payload) = read_framed_payload(py, buf, PY_AGG_UDF_FAMILY, "aggregate UDF")? + else { + return Ok(None); + }; + let udf = decode_python_agg_udf(py, payload).map_err(to_datafusion_err)?; + Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult> { + let signature = AggregateUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let state_fields: Vec = udf + .state_fields_ref() + .iter() + .map(|f| f.as_ref().clone()) + .collect(); + let state_schema_bytes = build_schema_bytes(state_fields)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.accumulator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + PyBytes::new(py, &state_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let accumulator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let state_schema_bytes: Vec = tuple.get_item(4)?.extract()?; + let volatility_str: String = tuple.get_item(5)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionAggregateUDF")? + .data_type() + .clone(); + // Preserve the encoded state field metadata (names, nullability, + // arbitrary key/value attributes) so the post-decode UDF reports + // the same state schema as the sender's instance — important for + // accumulators whose `StateFieldsArgs` consumers key off names or + // nullability rather than positional `DataType`. + let state_schema = schema_from_ipc_bytes(&state_schema_bytes).map_err(arrow_to_py_err)?; + let state_fields: Vec = + state_schema.fields().iter().cloned().collect(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionAggregateUDF::from_parts( + name, + accumulator, + input_types, + return_type, + state_fields, + volatility, + )) +} + #[cfg(test)] mod wire_header_tests { use super::*; @@ -729,7 +977,7 @@ mod wire_header_tests { } #[test] - fn write_then_strip_round_trips_payload() { + fn write_then_strip_round_trips_scalar_payload() { let mut buf = Vec::new(); write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY); buf.extend_from_slice(b"scalar-payload"); @@ -739,4 +987,39 @@ mod wire_header_tests { .unwrap(); assert_eq!(payload, b"scalar-payload"); } + + #[test] + fn write_then_strip_round_trips_agg_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_AGG_UDF_FAMILY, TEST_PY); + buf.extend_from_slice(b"agg-payload"); + + let payload = strip_wire_header(&buf, PY_AGG_UDF_FAMILY, "aggregate UDF", TEST_PY) + .unwrap() + .unwrap(); + assert_eq!(payload, b"agg-payload"); + } + + #[test] + fn write_then_strip_round_trips_window_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_WINDOW_UDF_FAMILY, TEST_PY); + buf.extend_from_slice(b"window-payload"); + + let payload = strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF", TEST_PY) + .unwrap() + .unwrap(); + assert_eq!(payload, b"window-payload"); + } + + #[test] + fn strip_does_not_match_a_different_family() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY, TEST_PY); + buf.extend_from_slice(b"payload"); + assert!(matches!( + strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF", TEST_PY), + Ok(None) + )); + } } diff --git a/crates/core/src/udaf.rs b/crates/core/src/udaf.rs index 80ef51716..279ef744d 100644 --- a/crates/core/src/udaf.rs +++ b/crates/core/src/udaf.rs @@ -19,12 +19,13 @@ use std::ptr::NonNull; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf, + Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility, }; use datafusion_ffi::udaf::FFI_AggregateUDF; use datafusion_python_util::parse_volatility; @@ -144,15 +145,157 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_args| -> Result> { - let accum = Python::attach(|py| { - accum - .call0(py) - .map_err(|e| DataFusionError::Execution(format!("{e}"))) - })?; - Ok(Box::new(RustAccumulator::new(accum))) - }) +fn instantiate_accumulator(accum: &Py) -> Result> { + let instance = Python::attach(|py| { + accum + .call0(py) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) + })?; + Ok(Box::new(RustAccumulator::new(instance))) +} + +/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs. +/// Holds the Python accumulator factory directly so the codec can +/// downcast and cloudpickle it across process boundaries. +#[derive(Debug)] +pub(crate) struct PythonFunctionAggregateUDF { + name: String, + accumulator: Py, + signature: Signature, + return_type: DataType, + state_fields: Vec, +} + +impl PythonFunctionAggregateUDF { + fn new( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_types: Vec, + volatility: Volatility, + ) -> Self { + let signature = Signature::exact(input_types, volatility); + let state_fields = state_types + .into_iter() + .enumerate() + .map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true))) + .collect(); + Self { + name, + accumulator, + signature, + return_type, + state_fields, + } + } + + /// Stored Python callable that returns a fresh accumulator instance + /// per partition. Consumed by the codec to cloudpickle the factory + /// across process boundaries. + pub(crate) fn accumulator(&self) -> &Py { + &self.accumulator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } + + pub(crate) fn state_fields_ref(&self) -> &[FieldRef] { + &self.state_fields + } + + /// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted + /// by the codec. `state_fields` carries the full state schema + /// (names, data types, nullability, metadata) — the codec extracts + /// it from the IPC payload, so the post-decode state schema is + /// identical to the pre-encode one. Use [`Self::new`] when only + /// `Vec` is available (e.g. the Python constructor path, + /// where field names are synthesized). + pub(crate) fn from_parts( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_fields: Vec, + volatility: Volatility, + ) -> Self { + Self { + name, + accumulator, + signature: Signature::exact(input_types, volatility), + return_type, + state_fields, + } + } +} + +impl Eq for PythonFunctionAggregateUDF {} +impl PartialEq for PythonFunctionAggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + && self.state_fields == other.state_fields + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.accumulator.as_ptr() == other.accumulator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `AggregateUDFImpl` exposes a fallible + // `PartialEq`. + self.accumulator + .bind(py) + .eq(other.accumulator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udaf", + "PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionAggregateUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate callables. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + for f in &self.state_fields { + f.hash(state); + } + } +} + +impl AggregateUDFImpl for PythonFunctionAggregateUDF { + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + instantiate_accumulator(&self.accumulator) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(self.state_fields.clone()) + } } fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { @@ -190,14 +333,15 @@ impl PyAggregateUDF { state_type: PyArrowType>, volatility: &str, ) -> PyResult { - let function = create_udaf( - name, + let py_udf = PythonFunctionAggregateUDF::new( + name.to_string(), + accumulator, input_type.0, - Arc::new(return_type.0), + return_type.0, + state_type.0, parse_volatility(volatility)?, - to_rust_accumulator(accumulator), - Arc::new(state_type.0), ); + let function = AggregateUDF::new_from_impl(py_udf); Ok(Self { function }) } @@ -231,4 +375,9 @@ impl PyAggregateUDF { fn __repr__(&self) -> PyResult { Ok(format!("AggregateUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/crates/core/src/udwf.rs b/crates/core/src/udwf.rs index 40e6208c4..4a55f032c 100644 --- a/crates/core/src/udwf.rs +++ b/crates/core/src/udwf.rs @@ -24,10 +24,9 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; -use datafusion::logical_expr::ptr_eq::PtrEq; use datafusion::logical_expr::window_state::WindowAggState; use datafusion::logical_expr::{ - PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; @@ -197,15 +196,13 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator(evaluator: Py) -> PartitionEvaluatorFactory { - Arc::new(move || -> Result> { - let evaluator = Python::attach(|py| { - evaluator - .call0(py) - .map_err(|e| DataFusionError::Execution(e.to_string())) - })?; - Ok(Box::new(RustPartitionEvaluator::new(evaluator))) - }) +fn instantiate_partition_evaluator(evaluator: &Py) -> Result> { + let instance = Python::attach(|py| { + evaluator + .call0(py) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; + Ok(Box::new(RustPartitionEvaluator::new(instance))) } /// Represents an WindowUDF @@ -233,14 +230,14 @@ impl PyWindowUDF { volatility: &str, ) -> PyResult { let return_type = return_type.0; - let input_types = input_types.into_iter().map(|t| t.0).collect(); + let input_types: Vec = input_types.into_iter().map(|t| t.0).collect(); - let function = WindowUDF::from(MultiColumnWindowUDF::new( + let function = WindowUDF::from(PythonFunctionWindowUDF::new( name, + evaluator, input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator), )); Ok(Self { function }) } @@ -275,47 +272,94 @@ impl PyWindowUDF { fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } -#[derive(Hash, Eq, PartialEq)] -pub struct MultiColumnWindowUDF { +#[derive(Debug)] +pub(crate) struct PythonFunctionWindowUDF { name: String, + evaluator: Py, signature: Signature, return_type: DataType, - partition_evaluator_factory: PtrEq, } -impl std::fmt::Debug for MultiColumnWindowUDF { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish() - } -} - -impl MultiColumnWindowUDF { - pub fn new( +impl PythonFunctionWindowUDF { + pub(crate) fn new( name: impl Into, + evaluator: Py, input_types: Vec, return_type: DataType, volatility: Volatility, - partition_evaluator_factory: PartitionEvaluatorFactory, ) -> Self { let name = name.into(); let signature = Signature::exact(input_types, volatility); Self { name, + evaluator, signature, return_type, - partition_evaluator_factory: partition_evaluator_factory.into(), } } + + /// Stored Python callable that produces a fresh partition + /// evaluator instance per partition. Consumed by the codec to + /// cloudpickle the evaluator factory across process boundaries. + pub(crate) fn evaluator(&self) -> &Py { + &self.evaluator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } +} + +impl Eq for PythonFunctionWindowUDF {} +impl PartialEq for PythonFunctionWindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.evaluator.as_ptr() == other.evaluator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `WindowUDFImpl` exposes a fallible + // `PartialEq`. + self.evaluator + .bind(py) + .eq(other.evaluator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udwf", + "PythonFunctionWindowUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionWindowUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate evaluators. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + } } -impl WindowUDFImpl for MultiColumnWindowUDF { +impl WindowUDFImpl for PythonFunctionWindowUDF { fn name(&self) -> &str { &self.name } @@ -334,7 +378,6 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self, _partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result> { - let _ = _partition_evaluator_args; - (self.partition_evaluator_factory)() + instantiate_partition_evaluator(&self.evaluator) } } diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 645bd9c18..7e95bc127 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -449,14 +449,14 @@ def to_bytes(self, ctx: SessionContext | None = None) -> bytes: installed :class:`LogicalExtensionCodec`. When ``ctx`` is ``None``, the default codec is used. - Built-in functions and Python scalar UDFs travel inside the - returned bytes; the worker does not need to pre-register them. - UDFs imported via the FFI capsule protocol travel by name only - and must be registered on the worker. + Built-in functions and Python UDFs (scalar, aggregate, window) + travel inside the returned bytes; the worker does not need to + pre-register them. UDFs imported via the FFI capsule protocol + travel by name only and must be registered on the worker. .. warning:: Security Bytes returned here may embed a cloudpickled Python - callable (when the expression carries a Python scalar UDF). + callable (when the expression carries a Python UDF). Reconstructing them via :meth:`from_bytes` or :func:`pickle.loads` executes arbitrary Python on the receiver. Only accept payloads from trusted sources. @@ -526,7 +526,7 @@ def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: ``ctx`` is ``None`` the worker context installed via :func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker context is installed, the global :class:`SessionContext` is used - (sufficient for built-ins and Python scalar UDFs, plus any UDFs + (sufficient for built-ins and Python UDFs, plus any UDFs registered on the global context). .. warning:: Security @@ -560,12 +560,13 @@ def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]: Lets expressions be shipped to worker processes via :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions - and Python scalar UDFs travel inside the pickle bytes; only - FFI-capsule UDFs require pre-registration on the worker. The - worker's :class:`SessionContext` for resolving those references - is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling - back to the global :class:`SessionContext` if none has been - installed on the worker. + and Python UDFs (scalar, aggregate, window) travel inside the + pickle bytes; only FFI-capsule UDFs require pre-registration on + the worker. The worker's :class:`SessionContext` for resolving + those references is looked up via + :func:`datafusion.ipc.set_worker_ctx`, falling back to the + global :class:`SessionContext` if none has been installed on + the worker. .. warning:: Security :func:`pickle.loads` on the returned tuple executes diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py index 78b6873f7..8dd7fc463 100644 --- a/python/datafusion/ipc.py +++ b/python/datafusion/ipc.py @@ -35,16 +35,17 @@ def init_worker(): ctx.register_udaf(my_ffi_aggregate) set_worker_ctx(ctx) -Built-in functions and Python scalar UDFs travel inside the shipped -expression itself and do not need pre-registration on the worker. +Built-in functions and Python UDFs (scalar, aggregate, window) travel +inside the shipped expression itself and do not need pre-registration +on the worker. .. note:: Serialization model - Expressions containing Python scalar UDFs are serialized using - :mod:`cloudpickle`. The callable itself travels **by value** - (bytecode and closure cells inlined), but any names the callable - resolves via ``import`` are captured **by reference** and must be - importable on the receiving worker. + Expressions containing Python UDFs (scalar, aggregate, window) are + serialized using :mod:`cloudpickle`. The callable itself travels + **by value** (bytecode and closure cells inlined), but any names the + callable resolves via ``import`` are captured **by reference** and + must be importable on the receiving worker. The serialized payload is stamped with the sender's Python ``(major, minor)`` version. Loading on a different minor version @@ -97,8 +98,8 @@ def clear_worker_ctx() -> None: After clearing, expressions reconstructed in this worker fall back to the global :class:`SessionContext` — adequate for built-ins and Python - scalar UDFs, but anything imported via the FFI capsule protocol must - be registered on the global context to resolve. + UDFs (scalar, aggregate, window), but anything imported via the FFI + capsule protocol must be registered on the global context to resolve. Examples: >>> from datafusion import SessionContext diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index d79cf22e8..3eb50a094 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -441,6 +441,16 @@ def __init__( str(volatility), ) + @property + def name(self) -> str: + """Return the registered name of this UDAF. + + For UDAFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udaf.name + def __repr__(self) -> str: """Print a string representation of the Aggregate UDF.""" return self._udaf.__repr__() @@ -851,6 +861,16 @@ def __init__( name, func, input_types, return_type, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDWF. + + For UDWFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udwf.name + def __repr__(self) -> str: """Print a string representation of the Window UDF.""" return self._udwf.__repr__() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py index 5d8d9285f..3f5ef342c 100644 --- a/python/tests/test_pickle_expr.py +++ b/python/tests/test_pickle_expr.py @@ -17,10 +17,10 @@ """In-process pickle round-trip tests for :class:`Expr`. -Built-in functions and Python scalar UDFs travel with the pickled -expression and do not need worker-side pre-registration. The worker -context (:mod:`datafusion.ipc`) is only consulted for UDFs imported -via the FFI capsule protocol. +Built-in functions and Python UDFs (scalar, aggregate, window) travel +with the pickled expression and do not need worker-side pre-registration. +The worker context (:mod:`datafusion.ipc`) is only consulted for UDFs +imported via the FFI capsule protocol. """ from __future__ import annotations @@ -147,6 +147,114 @@ def test_multi_arg_udf_round_trip(self): assert "add_scaled" in decoded.canonical_name() +class TestAggregateUDFCodec: + """Python aggregate UDFs travel inline like scalar UDFs.""" + + def _build_aggregate_udf(self): + from datafusion import udaf + from datafusion.user_defined import Accumulator + + class CountAcc(Accumulator): + def __init__(self): + self._count = 0 + + def state(self): + return [pa.scalar(self._count, type=pa.int64())] + + def update(self, values): + self._count += len(values) + + def merge(self, states): + for s in states: + self._count += s[0].as_py() + + def evaluate(self): + return pa.scalar(self._count, type=pa.int64()) + + return udaf( + CountAcc, + [pa.int64()], + pa.int64(), + [pa.int64()], + "immutable", + name="count_all", + ) + + def test_agg_udf_self_contained_blob(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_agg_udf_decodes_into_fresh_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_evaluates_after_roundtrip(self): + """End-to-end: the decoded aggregate UDF runs and merges across + partitions, exercising the round-tripped state-field schema.""" + u = self._build_aggregate_udf() + e = u(col("a")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]}) + out = df.aggregate([], [decoded.alias("n")]).to_pydict() + assert out["n"] == [5] + + +class TestWindowUDFCodec: + """Python window UDFs travel inline like scalar UDFs.""" + + def _build_window_udf(self): + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class CountUpEvaluator(WindowEvaluator): + def evaluate_all(self, values, num_rows): + return pa.array(list(range(num_rows))) + + return udwf( + CountUpEvaluator, + [pa.int64()], + pa.int64(), + "immutable", + name="count_up", + ) + + def test_window_udf_self_contained_blob(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_window_udf_decodes_into_fresh_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_up" in decoded.canonical_name() + + def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_up" in decoded.canonical_name() + + class TestErrorPaths: def test_from_bytes_rejects_garbage(self): with pytest.raises(Exception): # noqa: B017 From 4a3237fc27cffa2b02c67b7a48c5711fdf284f9d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 19 May 2026 15:07:04 -0400 Subject: [PATCH 2/2] feat: restore pub UDAF/UDWF helpers and document inline encoding Re-export `to_rust_accumulator`, `to_rust_partition_evaluator`, and `PythonFunctionWindowUDF` (with a `MultiColumnWindowUDF` alias) by promoting `udaf` and `udwf` to `pub mod` so prior downstream Rust consumers keep their API surface after the inline-encoding refactor. Adds an end-to-end window UDF pickle round-trip test that runs the decoded evaluator over a real session, mirroring the aggregate test. Documents the cloudpickle-based shipping behavior of Python aggregate and window UDFs in the user-guide aggregations and windows pages. Co-Authored-By: Claude Opus 4.7 (1M context) --- crates/core/src/lib.rs | 4 +-- crates/core/src/udaf.rs | 13 ++++++++- crates/core/src/udwf.rs | 27 ++++++++++++++++--- .../common-operations/aggregations.rst | 18 +++++++++++++ .../user-guide/common-operations/windows.rst | 18 +++++++++++++ python/tests/test_pickle_expr.py | 17 ++++++++++++ 6 files changed, 91 insertions(+), 6 deletions(-) diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index e3551c937..94edc3a3a 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -59,11 +59,11 @@ mod array; #[cfg(feature = "substrait")] pub mod substrait; #[allow(clippy::borrow_deref_ref)] -mod udaf; +pub mod udaf; #[allow(clippy::borrow_deref_ref)] mod udf; pub mod udtf; -mod udwf; +pub mod udwf; #[cfg(feature = "mimalloc")] #[global_allocator] diff --git a/crates/core/src/udaf.rs b/crates/core/src/udaf.rs index 279ef744d..caf7b97bc 100644 --- a/crates/core/src/udaf.rs +++ b/crates/core/src/udaf.rs @@ -25,7 +25,7 @@ use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{ - Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, Signature, Volatility, }; use datafusion_ffi::udaf::FFI_AggregateUDF; use datafusion_python_util::parse_volatility; @@ -154,6 +154,17 @@ fn instantiate_accumulator(accum: &Py) -> Result> { Ok(Box::new(RustAccumulator::new(instance))) } +/// Wrap a Python accumulator factory in an `AccumulatorFactoryFunction`. +/// +/// Retained for downstream callers that previously consumed this +/// helper to build a [`AccumulatorFactoryFunction`] for `create_udaf` +/// or similar factory-based APIs. New in-crate code should construct +/// a [`PythonFunctionAggregateUDF`] directly so the codec can downcast +/// and ship it inline. +pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { + Arc::new(move |_args| instantiate_accumulator(&accum)) +} + /// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs. /// Holds the Python accumulator factory directly so the codec can /// downcast and cloudpickle it across process boundaries. diff --git a/crates/core/src/udwf.rs b/crates/core/src/udwf.rs index 4a55f032c..ebec8f3bd 100644 --- a/crates/core/src/udwf.rs +++ b/crates/core/src/udwf.rs @@ -26,7 +26,7 @@ use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; use datafusion::logical_expr::window_state::WindowAggState; use datafusion::logical_expr::{ - PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; @@ -205,6 +205,17 @@ fn instantiate_partition_evaluator(evaluator: &Py) -> Result) -> PartitionEvaluatorFactory { + Arc::new(move || instantiate_partition_evaluator(&evaluator)) +} + /// Represents an WindowUDF #[pyclass( from_py_object, @@ -279,16 +290,26 @@ impl PyWindowUDF { } } +/// `WindowUDFImpl` for Python-defined window UDFs. +/// +/// Holds the Python evaluator factory directly so the codec can +/// downcast and cloudpickle it across process boundaries. Replaces +/// the prior factory-erased `MultiColumnWindowUDF`; the old name is +/// kept as a type alias below for backward compatibility. #[derive(Debug)] -pub(crate) struct PythonFunctionWindowUDF { +pub struct PythonFunctionWindowUDF { name: String, evaluator: Py, signature: Signature, return_type: DataType, } +/// Backward-compatible alias for downstream crates that referenced the +/// previous struct name. New code should use [`PythonFunctionWindowUDF`]. +pub type MultiColumnWindowUDF = PythonFunctionWindowUDF; + impl PythonFunctionWindowUDF { - pub(crate) fn new( + pub fn new( name: impl Into, evaluator: Py, input_types: Vec, diff --git a/docs/source/user-guide/common-operations/aggregations.rst b/docs/source/user-guide/common-operations/aggregations.rst index f59b62ab4..8f218abd8 100644 --- a/docs/source/user-guide/common-operations/aggregations.rst +++ b/docs/source/user-guide/common-operations/aggregations.rst @@ -434,3 +434,21 @@ The available aggregate functions are: - :py:meth:`datafusion.expr.GroupingSet.cube` - :py:meth:`datafusion.expr.GroupingSet.grouping_sets` +User-Defined Aggregate Functions +-------------------------------- + +You can ship custom aggregations to the engine by subclassing +:py:class:`~datafusion.user_defined.Accumulator` and registering it via +:py:func:`~datafusion.udaf`. See :py:mod:`datafusion.user_defined` for +the accumulator interface and worked examples. + +.. note:: Serialization + + Python aggregate UDFs travel inline inside pickled or + :py:meth:`~datafusion.expr.Expr.to_bytes`-serialized expressions — + the accumulator class is captured by value via :mod:`cloudpickle`, + so worker processes do not need to pre-register the UDF. Any names + the accumulator resolves via ``import`` are captured **by reference** + and must be importable on the receiving worker. See + :py:mod:`datafusion.ipc` for the full IPC model and security caveats. + diff --git a/docs/source/user-guide/common-operations/windows.rst b/docs/source/user-guide/common-operations/windows.rst index d77881bcf..127f691b5 100644 --- a/docs/source/user-guide/common-operations/windows.rst +++ b/docs/source/user-guide/common-operations/windows.rst @@ -213,3 +213,21 @@ The possible window functions are: 3. Aggregate Functions - All :ref:`Aggregation Functions` can be used as window functions. + +User-Defined Window Functions +----------------------------- + +You can ship custom window functions to the engine by subclassing +:py:class:`~datafusion.user_defined.WindowEvaluator` and registering it +via :py:func:`~datafusion.udwf`. See :py:mod:`datafusion.user_defined` +for the evaluator interface and worked examples. + +.. note:: Serialization + + Python window UDFs travel inline inside pickled or + :py:meth:`~datafusion.expr.Expr.to_bytes`-serialized expressions — + the evaluator class is captured by value via :mod:`cloudpickle`, so + worker processes do not need to pre-register the UDF. Any names the + evaluator resolves via ``import`` are captured **by reference** and + must be importable on the receiving worker. See + :py:mod:`datafusion.ipc` for the full IPC model and security caveats. diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py index 3f5ef342c..5015efe5b 100644 --- a/python/tests/test_pickle_expr.py +++ b/python/tests/test_pickle_expr.py @@ -254,6 +254,23 @@ def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self): decoded = pickle.loads(blob) # noqa: S301 assert "count_up" in decoded.canonical_name() + def test_window_udf_evaluates_after_roundtrip(self): + """End-to-end: decoded window UDF runs and emits per-row values + produced by the round-tripped evaluator factory.""" + from datafusion.expr import WindowFrame + + u = self._build_window_udf() + e = u(col("a")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]}) + framed = ( + decoded.window_frame(WindowFrame("rows", None, None)).build().alias("c") + ) + out = df.select(framed).to_pydict() + assert out["c"] == [0, 1, 2, 3, 4] + class TestErrorPaths: def test_from_bytes_rejects_garbage(self):