Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51690][SS] Change the protocol of ListState.put()/get()/appendList() from Arrow to simple custom protocol #50488

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 15 additions & 31 deletions python/pyspark/sql/streaming/list_state_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Dict, Iterator, List, Union, Tuple
from typing import Any, Dict, Iterator, List, Union, Tuple

from pyspark.sql.streaming.stateful_processor_api_client import StatefulProcessorApiClient
from pyspark.sql.types import StructType, TYPE_CHECKING
from pyspark.sql.types import StructType
from pyspark.errors import PySparkRuntimeError
import uuid

if TYPE_CHECKING:
from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike

__all__ = ["ListStateClient"]


Expand All @@ -38,9 +35,9 @@ def __init__(
self.schema = self._stateful_processor_api_client._parse_string_schema(schema)
else:
self.schema = schema
# A dictionary to store the mapping between list state name and a tuple of pandas DataFrame
# A dictionary to store the mapping between list state name and a tuple of data batch
# and the index of the last row that was read.
self.pandas_df_dict: Dict[str, Tuple["PandasDataFrameLike", int]] = {}
self.data_batch_dict: Dict[str, Tuple[Any, int]] = {}

def exists(self, state_name: str) -> bool:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
Expand All @@ -67,9 +64,9 @@ def exists(self, state_name: str) -> bool:
def get(self, state_name: str, iterator_id: str) -> Tuple:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage

if iterator_id in self.pandas_df_dict:
if iterator_id in self.data_batch_dict:
# If the state is already in the dictionary, return the next row.
pandas_df, index = self.pandas_df_dict[iterator_id]
data_batch, index = self.data_batch_dict[iterator_id]
else:
# If the state is not in the dictionary, fetch the state from the server.
get_call = stateMessage.ListStateGet(iteratorId=iterator_id)
Expand All @@ -85,33 +82,20 @@ def get(self, state_name: str, iterator_id: str) -> Tuple:
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status == 0:
iterator = self._stateful_processor_api_client._read_arrow_state()
# We need to exhaust the iterator here to make sure all the arrow batches are read,
# even though there is only one batch in the iterator. Otherwise, the stream might
# block further reads since it thinks there might still be some arrow batches left.
# We only need to read the first batch in the iterator because it's guaranteed that
# there would only be one batch sent from the JVM side.
data_batch = None
for batch in iterator:
if data_batch is None:
data_batch = batch
if data_batch is None:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError("Error getting next list state row.")
pandas_df = data_batch.to_pandas()
data_batch = self._stateful_processor_api_client._read_list_state()
index = 0
else:
raise StopIteration()

new_index = index + 1
if new_index < len(pandas_df):
if new_index < len(data_batch):
# Update the index in the dictionary.
self.pandas_df_dict[iterator_id] = (pandas_df, new_index)
self.data_batch_dict[iterator_id] = (data_batch, new_index)
else:
# If the index is at the end of the DataFrame, remove the state from the dictionary.
self.pandas_df_dict.pop(iterator_id, None)
pandas_row = pandas_df.iloc[index]
return tuple(pandas_row)
# If the index is at the end of the data batch, remove the state from the dictionary.
self.data_batch_dict.pop(iterator_id, None)
row = data_batch[index]
return tuple(row)

def append_value(self, state_name: str, value: Tuple) -> None:
import pyspark.sql.streaming.proto.StateMessage_pb2 as stateMessage
Expand Down Expand Up @@ -143,7 +127,7 @@ def append_list(self, state_name: str, values: List[Tuple]) -> None:

self._stateful_processor_api_client._send_proto_message(message.SerializeToString())

self._stateful_processor_api_client._send_arrow_state(self.schema, values)
self._stateful_processor_api_client._send_list_state(self.schema, values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we add a TODO for other places we might change this in the future ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I can double confirm with benchmarking and address the other parts. This part I did the benchmark and we have to be backed by the number.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, we use the same framework for MapState as well

response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
Expand All @@ -160,7 +144,7 @@ def put(self, state_name: str, values: List[Tuple]) -> None:

self._stateful_processor_api_client._send_proto_message(message.SerializeToString())

self._stateful_processor_api_client._send_arrow_state(self.schema, values)
self._stateful_processor_api_client._send_list_state(self.schema, values)
response_message = self._stateful_processor_api_client._receive_proto_message()
status = response_message[0]
if status != 0:
Expand Down
20 changes: 20 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,26 @@ def _send_arrow_state(self, schema: StructType, state: List[Tuple]) -> None:
def _read_arrow_state(self) -> Any:
return self.serializer.load_stream(self.sockfile)

def _send_list_state(self, schema: StructType, state: List[Tuple]) -> None:
for value in state:
bytes = self._serialize_to_bytes(schema, value)
length = len(bytes)
write_int(length, self.sockfile)
self.sockfile.write(bytes)

write_int(-1, self.sockfile)
self.sockfile.flush()

def _read_list_state(self) -> List[Any]:
data_array = []
while True:
length = read_int(self.sockfile)
if length < 0:
break
bytes = self.sockfile.read(length)
data_array.append(self._deserialize_from_bytes(bytes))
return data_array

# Parse a string schema into a StructType schema. This method will perform an API call to
# JVM side to parse the schema string.
def _parse_string_schema(self, schema: str) -> StructType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.arrow.vector.ipc.ArrowStreamReader

import org.apache.spark.internal.Logging
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector}
Expand Down Expand Up @@ -57,4 +58,24 @@ class TransformWithStateInPandasDeserializer(deserializer: ExpressionEncoder.Des
reader.close(false)
rows.toSeq
}

def readListElements(stream: DataInputStream, listStateInfo: ListStateInfo): Seq[Row] = {
val rows = new scala.collection.mutable.ArrayBuffer[Row]

var endOfLoop = false
while (!endOfLoop) {
val size = stream.readInt()
if (size < 0) {
endOfLoop = true
} else {
val bytes = new Array[Byte](size)
stream.read(bytes, 0, size)
val newRow = PythonSQLUtils.toJVMRow(bytes, listStateInfo.schema,
listStateInfo.deserializer)
rows.append(newRow)
}
}

rows.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ class TransformWithStateInPandasStateServer(
sendResponse(2, s"state $stateName doesn't exist")
}
case ListStateCall.MethodCase.LISTSTATEPUT =>
val rows = deserializer.readArrowBatches(inputStream)
val rows = deserializer.readListElements(inputStream, listStateInfo)
listStateInfo.listState.put(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.LISTSTATEGET =>
Expand All @@ -475,20 +475,18 @@ class TransformWithStateInPandasStateServer(
}
if (!iteratorOption.get.hasNext) {
sendResponse(2, s"List state $stateName doesn't contain any value.")
return
} else {
sendResponse(0)
sendIteratorForListState(iteratorOption.get)
}
sendIteratorAsArrowBatches(iteratorOption.get, listStateInfo.schema,
arrowStreamWriterForTest) { data => listStateInfo.serializer(data)}
case ListStateCall.MethodCase.APPENDVALUE =>
val byteArray = message.getAppendValue.getValue.toByteArray
val newRow = PythonSQLUtils.toJVMRow(byteArray, listStateInfo.schema,
listStateInfo.deserializer)
listStateInfo.listState.appendValue(newRow)
sendResponse(0)
case ListStateCall.MethodCase.APPENDLIST =>
val rows = deserializer.readArrowBatches(inputStream)
val rows = deserializer.readListElements(inputStream, listStateInfo)
listStateInfo.listState.appendList(rows.toArray)
sendResponse(0)
case ListStateCall.MethodCase.CLEAR =>
Expand All @@ -499,6 +497,28 @@ class TransformWithStateInPandasStateServer(
}
}

private def sendIteratorForListState(iter: Iterator[Row]): Unit = {
// Only write a single batch in each GET request. Stops writing row if rowCount reaches
// the arrowTransformWithStateInPandasMaxRecordsPerBatch limit. This is to handle a case
// when there are multiple state variables, user tries to access a different state variable
// while the current state variable is not exhausted yet.
var rowCount = 0
while (iter.hasNext && rowCount < arrowTransformWithStateInPandasMaxRecordsPerBatch) {
val data = iter.next()

// Serialize the value row as a byte array
val valueBytes = PythonSQLUtils.toPyRow(data)
val lenBytes = valueBytes.length

outputStream.writeInt(lenBytes)
outputStream.write(valueBytes)

rowCount += 1
}
outputStream.writeInt(-1)
outputStream.flush()
}

private[sql] def handleMapStateRequest(message: MapStateCall): Unit = {
val stateName = message.getStateName
if (!mapStates.contains(stateName)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
listStateMap, iteratorMap, mapStateMap, keyValueIteratorMap, expiryTimerIter, listTimerMap)
when(transformWithStateInPandasDeserializer.readArrowBatches(any))
.thenReturn(Seq(getIntegerRow(1)))
when(transformWithStateInPandasDeserializer.readListElements(any, any))
.thenReturn(Seq(getIntegerRow(1)))
}

test("set handle state") {
Expand Down Expand Up @@ -260,8 +262,10 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
.setListStateGet(ListStateGet.newBuilder().setIteratorId(iteratorId).build()).build()
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
verify(arrowStreamWriter).writeRow(any)
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
// 1 for row, 1 for end of the data, 1 for proto response
verify(outputStream, times(3)).writeInt(any)
// 1 for sending an actual row, 1 for sending proto message
verify(outputStream, times(2)).write(any[Array[Byte]])
}

test("list state get - iterator in map with multiple batches") {
Expand All @@ -278,15 +282,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
// First call should send 2 records.
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
// maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response
verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
// maxRecordsPerBatch times for rows, 1 for sending proto message
verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
// Second call should send the remaining 2 records.
stateServer.handleListStateRequest(message)
verify(listState, times(0)).get()
// Since Mockito's verify counts the total number of calls, the expected number of writeRow call
// should be 2 * maxRecordsPerBatch.
verify(arrowStreamWriter, times(2 * maxRecordsPerBatch)).writeRow(any)
verify(arrowStreamWriter, times(2)).finalizeCurrentArrowBatch()
// Since Mockito's verify counts the total number of calls, the expected number of writeInt
// and write should be accumulated from the prior count; the number of calls are the same
// with prior one.
// maxRecordsPerBatch times for rows, 1 for end of the data, 1 for proto response
verify(outputStream, times(maxRecordsPerBatch * 2 + 4)).writeInt(any)
// maxRecordsPerBatch times for rows, 1 for sending proto message
verify(outputStream, times(maxRecordsPerBatch * 2 + 2)).write(any[Array[Byte]])
}

test("list state get - iterator not in map") {
Expand All @@ -302,17 +311,20 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
when(listState.get()).thenReturn(Iterator(getIntegerRow(1), getIntegerRow(2), getIntegerRow(3)))
stateServer.handleListStateRequest(message)
verify(listState).get()

// Verify that only maxRecordsPerBatch (2) rows are written to the output stream while still
// having 1 row left in the iterator.
verify(arrowStreamWriter, times(maxRecordsPerBatch)).writeRow(any)
verify(arrowStreamWriter).finalizeCurrentArrowBatch()
// maxRecordsPerBatch (2) for rows, 1 for end of the data, 1 for proto response
verify(outputStream, times(maxRecordsPerBatch + 2)).writeInt(any)
// 2 for rows, 1 for proto message
verify(outputStream, times(maxRecordsPerBatch + 1)).write(any[Array[Byte]])
}

test("list state put") {
val message = ListStateCall.newBuilder().setStateName(stateName)
.setListStatePut(ListStatePut.newBuilder().build()).build()
stateServer.handleListStateRequest(message)
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
verify(transformWithStateInPandasDeserializer).readListElements(any, any)
verify(listState).put(any)
}

Expand All @@ -328,7 +340,7 @@ class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with Befo
val message = ListStateCall.newBuilder().setStateName(stateName)
.setAppendList(AppendList.newBuilder().build()).build()
stateServer.handleListStateRequest(message)
verify(transformWithStateInPandasDeserializer).readArrowBatches(any)
verify(transformWithStateInPandasDeserializer).readListElements(any, any)
verify(listState).appendList(any)
}

Expand Down