From aa0e04081a37d4090706d8f307e215c9edfea209 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 21:16:19 -0800 Subject: [PATCH] [FEAT] connect: `with_columns_renamed` --- .../src/translation/logical_plan.rs | 5 +++ .../logical_plan/with_columns_renamed.rs | 45 +++++++++++++++++++ tests/connect/test_with_columns_renamed.py | 24 ++++++++++ 3 files changed, 74 insertions(+) create mode 100644 src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs create mode 100644 tests/connect/test_with_columns_renamed.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index b6097d17ad..4073334ff4 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -7,6 +7,7 @@ use tracing::warn; use crate::translation::logical_plan::{ aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation, project::project, range::range, read::read, to_df::to_df, with_columns::with_columns, + with_columns_renamed::with_columns_renamed, }; mod aggregate; @@ -18,6 +19,7 @@ mod range; mod read; mod to_df; mod with_columns; +mod with_columns_renamed; pub struct Plan { pub builder: LogicalPlanBuilder, @@ -76,6 +78,9 @@ pub async fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::LocalRelation(l) => { local_relation(l).wrap_err("Failed to apply local_relation to logical plan") } + RelType::WithColumnsRenamed(w) => with_columns_renamed(*w) + .await + .wrap_err("Failed to apply with_columns_renamed to logical plan"), RelType::Read(r) => read(r) .await .wrap_err("Failed to apply read to logical plan"), diff --git a/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs new file mode 100644 index 0000000000..d5b4e4ace7 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns_renamed.rs @@ -0,0 +1,45 @@ +use daft_dsl::col; +use eyre::{bail, Context}; + +use crate::translation::Plan; + +pub async fn with_columns_renamed( + with_columns_renamed: spark_connect::WithColumnsRenamed, +) -> eyre::Result { + let spark_connect::WithColumnsRenamed { + input, + rename_columns_map, + renames, + } = with_columns_renamed; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = Box::pin(crate::translation::to_logical_plan(*input)).await?; + + // todo: do we want to implement this directly into daft? + + // Convert the rename mappings into expressions + let rename_exprs = if !rename_columns_map.is_empty() { + // Use rename_columns_map if provided (legacy format) + rename_columns_map + .into_iter() + .map(|(old_name, new_name)| col(old_name.as_str()).alias(new_name.as_str())) + .collect() + } else { + // Use renames if provided (new format) + renames + .into_iter() + .map(|rename| col(rename.col_name.as_str()).alias(rename.new_col_name.as_str())) + .collect() + }; + + // Apply the rename expressions to the plan + plan.builder = plan + .builder + .select(rename_exprs) + .wrap_err("Failed to apply rename expressions to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_with_columns_renamed.py b/tests/connect/test_with_columns_renamed.py new file mode 100644 index 0000000000..4e330d43b7 --- /dev/null +++ b/tests/connect/test_with_columns_renamed.py @@ -0,0 +1,24 @@ +from __future__ import annotations + + +def test_with_columns_renamed(spark_session): + # Test withColumnRenamed + df = spark_session.range(5) + renamed_df = df.withColumnRenamed("id", "number") + + collected = renamed_df.collect() + assert len(collected) == 5 + assert "number" in renamed_df.columns + assert "id" not in renamed_df.columns + assert [row["number"] for row in collected] == list(range(5)) + + # todo: this fails but is this expected or no? + # # Test withColumnsRenamed + # df = spark_session.range(2) + # renamed_df = df.withColumnsRenamed({"id": "number", "id": "character"}) + + # collected = renamed_df.collect() + # assert len(collected) == 2 + # assert set(renamed_df.columns) == {"number", "character"} + # assert "id" not in renamed_df.columns + # assert [(row["number"], row["character"]) for row in collected] == [(0, 0), (1, 1)]