From 215fd6ac6c938de598268664134fefb588bd5cf2 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 21:31:49 -0800 Subject: [PATCH] [FEAT] connect: add more (internally P2) column operations --- .../translation/expr/unresolved_function.rs | 2 + tests/connect/test_basic_column.py | 57 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index af70a1e8f2..74501ed873 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -32,6 +32,8 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq) .wrap_err("Failed to handle >= function"), + "and" => handle_binary_op(arguments, daft_dsl::Operator::And) + .wrap_err("Failed to handle and function"), "%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus) .wrap_err("Failed to handle % function"), "sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"), diff --git a/tests/connect/test_basic_column.py b/tests/connect/test_basic_column.py index 95dcb5cdd0..fb1471f8ad 100644 --- a/tests/connect/test_basic_column.py +++ b/tests/connect/test_basic_column.py @@ -44,3 +44,60 @@ def test_column_name(spark_session): # df = spark_session.range(10) # df_item = df.select(col("id")[0]) # assert df_item.toPandas()["id"].iloc[0] == 0, "getitem should return first element" + + +def test_column_astype(spark_session): + df = spark_session.range(10) + df_astype = df.select(col("id").astype(StringType())) + assert df_astype.schema.fields[0].dataType == StringType(), "astype should change data type" + + +def test_column_between(spark_session): + df = spark_session.range(10) + df_between = df.select(col("id").between(3, 6).alias("in_range")) + assert df_between.toPandas()["in_range"].tolist() == [False, False, False, True, True, True, True, False, False, False] + + +# TODO: Uncomment when string operations are implemented +# def test_column_string_ops(spark_session): +# df_str = spark_session.createDataFrame([("hello",), ("world",)], ["text"]) +# df_contains = df_str.select(col("text").contains("o").alias("has_o")) +# assert df_contains.toPandas()["has_o"].tolist() == [True, True] +# df_startswith = df_str.select(col("text").startswith("h").alias("starts_h")) +# assert df_startswith.toPandas()["starts_h"].tolist() == [True, False] +# df_endswith = df_str.select(col("text").endswith("d").alias("ends_d")) +# assert df_endswith.toPandas()["ends_d"].tolist() == [False, True] +# df_substr = df_str.select(col("text").substr(1, 2).alias("first_two")) +# assert df_substr.toPandas()["first_two"].tolist() == ["he", "wo"] + + +# TODO: Uncomment when struct operations are implemented +# def test_column_struct_ops(spark_session): +# df_struct = spark_session.createDataFrame([ +# ({"a": 1, "b": 2},), +# ({"a": 3, "b": 4},) +# ], ["data"]) +# df_getfield = df_struct.select(col("data").getField("a").alias("a_val")) +# assert df_getfield.toPandas()["a_val"].tolist() == [1, 3] +# df_dropfields = df_struct.select(col("data").dropFields("a").alias("no_a")) +# assert "a" not in df_dropfields.toPandas()["no_a"][0] +# df_withfield = df_struct.select(col("data").withField("c", col("data.a") + 10).alias("with_c")) +# assert df_withfield.toPandas()["with_c"][0]["c"] == 11 + + +# TODO: Uncomment when array operations are implemented +# def test_column_array_ops(spark_session): +# df_array = spark_session.createDataFrame([([1, 2, 3],), ([4, 5, 6],)], ["numbers"]) +# df_getitem = df_array.select(col("numbers").getItem(0).alias("first")) +# assert df_getitem.toPandas()["first"].tolist() == [1, 4] + + +# TODO: Uncomment when when/otherwise operations are implemented +# def test_column_case_when(spark_session): +# df = spark_session.range(10) +# df_case = df.select( +# col("id").when(col("id") < 5, "low") +# .otherwise("high") +# .alias("category") +# ) +# assert df_case.toPandas()["category"].tolist() == ["low"] * 5 + ["high"] * 5