From baed4ea78846ccb17d983a2877ce39c73c27aabf Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Tue, 1 Apr 2025 16:18:57 +0900 Subject: [PATCH 1/5] Use custom protocol instead of Arrow for state interactions in ListState --- .../sql/streaming/list_state_client.py | 36 +++++++------------ .../stateful_processor_api_client.py | 10 ++++++ ...ansformWithStateInPandasDeserializer.scala | 21 +++++++++++ ...ransformWithStateInPandasStateServer.scala | 30 +++++++++++++--- 4 files changed, 68 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index cb618d1a691b3..3441b8b2af2f6 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -41,6 +41,7 @@ def __init__( # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame # and the index of the last row that was read. self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {} + self.data_batch_dict: Dict[str, Tuple[Any, int]] = {} def exists(self, state_name: str) -> bool: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -67,9 +68,9 @@ def exists(self, state_name: str) -> bool: def get(self, state_name: str, iterator_id: str) -> Tuple: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage - if iterator_id in self.pandas_df_dict: + if iterator_id in self.data_batch_dict: # If the state is already in the dictionary, return the next row. - pandas_df, index = self.pandas_df_dict[iterator_id] + data_batch, index = self.data_batch_dict[iterator_id] else: # If the state is not in the dictionary, fetch the state from the server. get_call = stateMessage.ListStateGet(iteratorId=iterator_id) @@ -85,33 +86,20 @@ def get(self, state_name: str, iterator_id: str) -> Tuple: response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status == 0: - iterator = self._stateful_processor_api_client._read_arrow_state() - # We need to exhaust the iterator here to make sure all the arrow batches are read, - # even though there is only one batch in the iterator. Otherwise, the stream might - # block further reads since it thinks there might still be some arrow batches left. - # We only need to read the first batch in the iterator because it's guaranteed that - # there would only be one batch sent from the JVM side. - data_batch = None - for batch in iterator: - if data_batch is None: - data_batch = batch - if data_batch is None: - # TODO(SPARK-49233): Classify user facing errors. - raise PySparkRuntimeError("Error getting next list state row.") - pandas_df = data_batch.to_pandas() + data_batch = self._stateful_processor_api_client._read_list_state() index = 0 else: raise StopIteration() new_index = index + 1 - if new_index < len(pandas_df): + if new_index < len(data_batch): # Update the index in the dictionary. - self.pandas_df_dict[iterator_id] = (pandas_df, new_index) + self.data_batch_dict[iterator_id] = (data_batch, new_index) else: - # If the index is at the end of the DataFrame, remove the state from the dictionary. - self.pandas_df_dict.pop(iterator_id, None) - pandas_row = pandas_df.iloc[index] - return tuple(pandas_row) + # If the index is at the end of the data batch, remove the state from the dictionary. + self.data_batch_dict.pop(iterator_id, None) + row = data_batch[index] + return tuple(row) def append_value(self, state_name: str, value: Tuple) -> None: import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage @@ -143,7 +131,7 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None: self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_arrow_state(self.schema, values) + self._stateful_processor_api_client._send_list_state(self.schema, values) response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: @@ -160,7 +148,7 @@ def put(self, state_name: str, values: List[Tuple]) -> None: self._stateful_processor_api_client._send_proto_message(message.SerializeToString()) - self._stateful_processor_api_client._send_arrow_state(self.schema, values) + self._stateful_processor_api_client._send_list_state(self.schema, values) response_message = self._stateful_processor_api_client._receive_proto_message() status = response_message[0] if status != 0: diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index 6fd56481bc612..df0b1aea8f439 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -455,6 +455,16 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: def _read_arrow_state(self) -> Any: return self.serializer.load_stream(self.sockfile) + def _read_list_state(self) -> List[Any]: + data_array = [] + while True: + length = read_int(self.sockfile) + if length < 0: + break + bytes = self.sockfile.read(length) + data_array.append(self._deserialize_from_bytes(bytes)) + return data_array + # Parse a string schema into a StructType schema. This method will perform an API call to # JVM side to parse the schema string. def _parse_string_schema(self, schema: str) -> StructType: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala index 1a8ffb35c0533..b38697aeb0219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasDeserializer.scala @@ -26,6 +26,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader import org.apache.spark.internal.Logging import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Des reader.close(false) rows.toSeq } + + def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo): Seq[Row] = { + val rows = new scala.collection.mutable.ArrayBuffer[Row] + + var endOfLoop = false + while (!endOfLoop) { + val size = stream.readInt() + if (size < 0) { + endOfLoop = true + } else { + val bytes = new Array[Byte](size) + stream.read(bytes, 0, size) + val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema, + listStateInfo.deserializer) + rows.append(newRow) + } + } + + rows.toSeq + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala index 35e56cd757983..cd7ce5536c74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala @@ -463,7 +463,7 @@ class TransformWithStateInPandasStateServer( sendResponse(2, s"state $stateName doesn't exist") } case ListStateCall.MethodCase.LISTSTATEPUT => - val rows = deserializer.readArrowBatches(inputStream) + val rows = deserializer.readListElements(inputStream, listStateInfo) listStateInfo.listState.put(rows.toArray) sendResponse(0) case ListStateCall.MethodCase.LISTSTATEGET => @@ -475,12 +475,10 @@ class TransformWithStateInPandasStateServer( } if (!iteratorOption.get.hasNext) { sendResponse(2, s"List state $stateName doesn't contain any value.") - return } else { sendResponse(0) + sendIteratorForListState(iteratorOption.get) } - sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema, - arrowStreamWriterForTest) { data => listStateInfo.serializer(data)} case ListStateCall.MethodCase.APPENDVALUE => val byteArray = message.getAppendValue.getValue.toByteArray val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema, @@ -488,7 +486,7 @@ class TransformWithStateInPandasStateServer( listStateInfo.listState.appendValue(newRow) sendResponse(0) case ListStateCall.MethodCase.APPENDLIST => - val rows = deserializer.readArrowBatches(inputStream) + val rows = deserializer.readListElements(inputStream, listStateInfo) listStateInfo.listState.appendList(rows.toArray) sendResponse(0) case ListStateCall.MethodCase.CLEAR => @@ -499,6 +497,28 @@ class TransformWithStateInPandasStateServer( } } + private def sendIteratorForListState(iter: Iterator[Row]): Unit = { + // Only write a single batch in each GET request. Stops writing row if rowCount reaches + // the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case + // when there are multiple state variables, user tries to access a different state variable + // while the current state variable is not exhausted yet. + var rowCount = 0 + while (iter.hasNext && rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) { + val data = iter.next() + + // Serialize the value row as a byte array + val valueBytes = PythonSQLUtils.toPyRow(data) + val lenBytes = valueBytes.length + + outputStream.writeInt(lenBytes) + outputStream.write(valueBytes) + + rowCount += 1 + } + outputStream.writeInt(-1) + outputStream.flush() + } + private[sql] def handleMapStateRequest(message: MapStateCall): Unit = { val stateName = message.getStateName if (!mapStates.contains(stateName)) { From 59223246f06f988cdfe824040456713552896190 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 2 Apr 2025 14:39:41 +0900 Subject: [PATCH 2/5] fix style + missing method --- python/pyspark/sql/streaming/list_state_client.py | 2 +- .../sql/streaming/stateful_processor_api_client.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index 3441b8b2af2f6..ba516297743e6 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Dict, Iterator, List, Union, Tuple +from typing import Any, Dict, Iterator, List, Union, Tuple from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient from pyspark.sql.types import StructType, TYPE_CHECKING diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py b/python/pyspark/sql/streaming/stateful_processor_api_client.py index df0b1aea8f439..72b190aee1cfe 100644 --- a/python/pyspark/sql/streaming/stateful_processor_api_client.py +++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py @@ -455,6 +455,16 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None: def _read_arrow_state(self) -> Any: return self.serializer.load_stream(self.sockfile) + def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None: + for value in state: + bytes = self._serialize_to_bytes(schema, value) + length = len(bytes) + write_int(length, self.sockfile) + self.sockfile.write(bytes) + + write_int(-1, self.sockfile) + self.sockfile.flush() + def _read_list_state(self) -> List[Any]: data_array = [] while True: From 3c1722c21746af2d1720aa6c38e9074be5816603 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Wed, 2 Apr 2025 15:49:00 +0900 Subject: [PATCH 3/5] Fix StateServer suite --- ...ormWithStateInPandasStateServerSuite.scala | 36 ++++++++++++------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala index 1f0aa72d27131..305a520f6af80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala @@ -103,6 +103,8 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, expiryTimerIter, listTimerMap) when(transformWithStateInPandasDeserializer.readArrowBatches(any)) .thenReturn(Seq(getIntegerRow(1))) + when(transformWithStateInPandasDeserializer.readListElements(any, any)) + .thenReturn(Seq(getIntegerRow(1))) } test("set handle state") { @@ -260,8 +262,10 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo .setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build() stateServer.handleListStateRequest(message) verify(listState, times(0)).get() - verify(arrowStreamWriter).writeRow(any) - verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // 1 for row, 1 for end of the data, 1 for proto response + verify(outputStream, times(3)).writeInt(any) + // 1 for sending an actual row, 1 for sending proto message + verify(outputStream, times(2)).write(any[Array[Byte]]) } test("list state get - iterator in map with multiple batches") { @@ -278,15 +282,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo // First call should send 2 records. stateServer.handleListStateRequest(message) verify(listState, times(0)).get() - verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) - verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response + verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any) + // maxRecordsPerBatch times for rows, 1 for sending proto message + verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]]) // Second call should send the remaining 2 records. stateServer.handleListStateRequest(message) verify(listState, times(0)).get() - // Since Mockito's verify counts the total number of calls, the expected number of writeRow call - // should be 2 * maxRecordsPerBatch. - verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any) - verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch() + // Since Mockito's verify counts the total number of calls, the expected number of writeInt + // and write should be accumulated from the prior count; the number of calls are the same + // with prior one. + // maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response + verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any) + // maxRecordsPerBatch times for rows, 1 for sending proto message + verify(outputStream, times(maxRecordsPerBatch * 2 + 2)).write(any[Array[Byte]]) } test("list state get - iterator not in map") { @@ -302,17 +311,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo when(listState.get()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3))) stateServer.handleListStateRequest(message) verify(listState).get() + // Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still // having 1 row left in the iterator. - verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any) - verify(arrowStreamWriter).finalizeCurrentArrowBatch() + // maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto response + verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any) + // 2 for rows, 1 for proto message + verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]]) } test("list state put") { val message = ListStateCall.newBuilder().setStateName(stateName) .setListStatePut(ListStatePut.newBuilder().build()).build() stateServer.handleListStateRequest(message) - verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(transformWithStateInPandasDeserializer).readListElements(any, any) verify(listState).put(any) } @@ -328,7 +340,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo val message = ListStateCall.newBuilder().setStateName(stateName) .setAppendList(AppendList.newBuilder().build()).build() stateServer.handleListStateRequest(message) - verify(transformWithStateInPandasDeserializer).readArrowBatches(any) + verify(transformWithStateInPandasDeserializer).readListElements(any, any) verify(listState).appendList(any) } From f68aee60f77e77ae65d70346c20dd658f879ded5 Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 4 Apr 2025 10:11:07 +0900 Subject: [PATCH 4/5] fix nit --- python/pyspark/sql/streaming/list_state_client.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index ba516297743e6..186fd328b0c5f 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -38,9 +38,8 @@ def __init__( self.schema = self._stateful_processor_api_client._parse_string_schema(schema) else: self.schema = schema - # A dictionary to store the mapping between list state name and a tuple of pandas DataFrame + # A dictionary to store the mapping between list state name and a tuple of data batch # and the index of the last row that was read. - self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {} self.data_batch_dict: Dict[str, Tuple[Any, int]] = {} def exists(self, state_name: str) -> bool: From db65e2cf08b2c9135b97d82f2ae1e217f1e9262a Mon Sep 17 00:00:00 2001 From: Jungtaek Lim Date: Fri, 4 Apr 2025 16:00:00 +0900 Subject: [PATCH 5/5] remove unused import --- python/pyspark/sql/streaming/list_state_client.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyspark/sql/streaming/list_state_client.py b/python/pyspark/sql/streaming/list_state_client.py index 186fd328b0c5f..66f2640c935e5 100644 --- a/python/pyspark/sql/streaming/list_state_client.py +++ b/python/pyspark/sql/streaming/list_state_client.py @@ -17,13 +17,10 @@ from typing import Any, Dict, Iterator, List, Union, Tuple from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient -from pyspark.sql.types import StructType, TYPE_CHECKING +from pyspark.sql.types import StructType from pyspark.errors import PySparkRuntimeError import uuid -if TYPE_CHECKING: - from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike - __all__ = ["ListStateClient"]