Skip to content

Commit d212568

Browse files
authored
Rework request channel to receive initial payload as additional parameter (#125)
1 parent 17ce10d commit d212568

File tree

19 files changed

+112
-172
lines changed

19 files changed

+112
-172
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ RSocket interface contains 5 methods:
3131
* Request-Stream:
3232

3333
`fun requestStream(payload: Payload): Flow<Payload>`
34-
* Request-Channel:
34+
* Request-Channel:
3535

36-
`fun requestChannel(payloads: Flow<Payload>): Flow<Payload>`
36+
`fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload>`
3737
* Metadata-Push:
3838

3939
`suspend fun metadataPush(metadata: ByteReadPacket)`

benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt

+5-2
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {
5252
it.release()
5353
payloadsFlow
5454
}
55-
requestChannel { it.flowOn(requestStrategy) }
55+
requestChannel { init, payloads ->
56+
init.release()
57+
payloads.flowOn(requestStrategy)
58+
}
5659
}
5760
}
5861
client = runBlocking {
@@ -80,6 +83,6 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {
8083

8184
override suspend fun doRequestStream(): Flow<Payload> = client.requestStream(payloadCopy()).flowOn(requestStrategy)
8285

83-
override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadsFlow).flowOn(requestStrategy)
86+
override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadCopy(), payloadsFlow).flowOn(requestStrategy)
8487

8588
}

examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt

+3-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ fun main(): Unit = runBlocking {
2525
val server = LocalServer()
2626
RSocketServer().bind(server) {
2727
RSocketRequestHandler {
28-
requestChannel { request ->
28+
requestChannel { init, request ->
29+
println("Init with: ${init.data.readText()}")
2930
request.flowOn(PrefetchStrategy(3, 0)).take(3).flatMapConcat { payload ->
3031
val data = payload.data.readText()
3132
flow {
@@ -50,7 +51,7 @@ fun main(): Unit = runBlocking {
5051
println("Client: No") //no print
5152
}
5253

53-
val response = rSocket.requestChannel(request)
54+
val response = rSocket.requestChannel(Payload("Init"), request)
5455
response.collect {
5556
val data = it.data.readText()
5657
println("Client receives: $data")

examples/multiplatform-chat/build.gradle.kts

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ val kotlinxSerializationVersion: String by rootProject
2828
kotlin {
2929
jvm("serverJvm")
3030
jvm("clientJvm")
31-
js("clientJs", LEGACY) {
31+
js("clientJs", IR) {
3232
browser {
3333
binaries.executable()
3434
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ interface RSocket : Cancellable {
4242
notImplemented("Request Stream")
4343
}
4444

45-
fun requestChannel(payloads: Flow<Payload>): Flow<Payload> {
45+
fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> {
46+
initPayload.release()
4647
notImplemented("Request Channel")
4748
}
4849
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt

+7-7
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ import io.rsocket.kotlin.payload.*
2121
import kotlinx.coroutines.*
2222
import kotlinx.coroutines.flow.*
2323

24-
class RSocketRequestHandlerBuilder internal constructor() {
24+
public class RSocketRequestHandlerBuilder internal constructor() {
2525
private var metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null
2626
private var fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null
2727
private var requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null
2828
private var requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null
29-
private var requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null
29+
private var requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null
3030

3131
public fun metadataPush(block: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)) {
3232
check(metadataPush == null) { "Metadata Push handler already configured" }
@@ -48,7 +48,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
4848
requestStream = block
4949
}
5050

51-
public fun requestChannel(block: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)) {
51+
public fun requestChannel(block: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)) {
5252
check(requestChannel == null) { "Request Channel handler already configured" }
5353
requestChannel = block
5454
}
@@ -58,7 +58,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
5858
}
5959

6060
@Suppress("FunctionName")
61-
fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
61+
public fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
6262
val builder = RSocketRequestHandlerBuilder()
6363
builder.configure()
6464
return builder.build(Job(parentJob))
@@ -70,7 +70,7 @@ private class RSocketRequestHandler(
7070
private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null,
7171
private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null,
7272
private val requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null,
73-
private val requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null,
73+
private val requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null,
7474
) : RSocket {
7575
override suspend fun metadataPush(metadata: ByteReadPacket): Unit =
7676
metadataPush?.invoke(this, metadata) ?: super.metadataPush(metadata)
@@ -84,7 +84,7 @@ private class RSocketRequestHandler(
8484
override fun requestStream(payload: Payload): Flow<Payload> =
8585
requestStream?.invoke(this, payload) ?: super.requestStream(payload)
8686

87-
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> =
88-
requestChannel?.invoke(this, payloads) ?: super.requestChannel(payloads)
87+
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
88+
requestChannel?.invoke(this, initPayload, payloads) ?: super.requestChannel(initPayload, payloads)
8989

9090
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ private class ReconnectableRSocket(
112112
emitAll(currentRSocket(payload).requestStream(payload))
113113
}
114114

115-
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = flow {
116-
emitAll(currentRSocket().requestChannel(payloads))
115+
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> = flow {
116+
emitAll(currentRSocket(initPayload).requestChannel(initPayload, payloads))
117117
}
118118

119119
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ internal class RSocketRequester(
5252

5353
override fun requestStream(payload: Payload): Flow<Payload> = RequestStreamRequesterFlow(payload, this, state)
5454

55-
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = RequestChannelRequesterFlow(payloads, this, state)
55+
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
56+
RequestChannelRequesterFlow(initPayload, payloads, this, state)
5657

5758
fun createStream(): Int {
5859
checkAvailable()

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt

+4-12
Original file line numberDiff line numberDiff line change
@@ -60,31 +60,23 @@ internal class RSocketResponder(
6060
val response = requestOrCancel(streamId) {
6161
requestHandler.requestStream(initFrame.payload)
6262
} ?: return@launchCancelable
63-
response.collectLimiting(
64-
streamId,
65-
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
66-
)
67-
send(CompletePayloadFrame(streamId))
63+
response.collectLimiting(streamId, initFrame.initialRequest)
6864
}.invokeOnCompletion {
6965
initFrame.release()
7066
}
7167
}
7268

7369
fun handleRequestChannel(initFrame: RequestFrame): Unit = with(state) {
7470
val streamId = initFrame.streamId
75-
val receiver = createReceiverFor(streamId, initFrame)
71+
val receiver = createReceiverFor(streamId)
7672

7773
val request = RequestChannelResponderFlow(streamId, receiver, state)
7874

7975
launchCancelable(streamId) {
8076
val response = requestOrCancel(streamId) {
81-
requestHandler.requestChannel(request)
77+
requestHandler.requestChannel(initFrame.payload, request)
8278
} ?: return@launchCancelable
83-
response.collectLimiting(
84-
streamId,
85-
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
86-
)
87-
send(CompletePayloadFrame(streamId))
79+
response.collectLimiting(streamId, initFrame.initialRequest)
8880
}.invokeOnCompletion {
8981
initFrame.release()
9082
receiver.closeReceivedElements()

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt

+6-3
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,8 @@ internal class RSocketState(
5252
prioritizer.sendPrioritized(frame)
5353
}
5454

55-
fun createReceiverFor(streamId: Int, initFrame: RequestFrame? = null): ReceiveChannel<RequestFrame> {
55+
fun createReceiverFor(streamId: Int): ReceiveChannel<RequestFrame> {
5656
val receiver = SafeChannel<RequestFrame>(Channel.UNLIMITED)
57-
initFrame?.let(receiver::offer) //used only in RequestChannel on responder side
5857
receivers[streamId] = receiver
5958
return receiver
6059
}
@@ -94,11 +93,15 @@ internal class RSocketState(
9493

9594
suspend inline fun Flow<Payload>.collectLimiting(
9695
streamId: Int,
97-
limitingCollector: LimitingFlowCollector,
96+
initialRequest: Int,
97+
crossinline onStart: () -> Unit = {},
9898
): Unit = coroutineScope {
99+
val limitingCollector = LimitingFlowCollector(this@RSocketState, streamId, initialRequest)
99100
limits[streamId] = limitingCollector
100101
try {
102+
onStart()
101103
collect(limitingCollector)
104+
send(CompletePayloadFrame(streamId))
102105
} catch (e: Throwable) {
103106
limits.remove(streamId)
104107
//if isn't active, then, that stream was cancelled, and so no need for error frame

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt

+8-5
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,30 @@
1616

1717
package io.rsocket.kotlin.internal.flow
1818

19+
import io.rsocket.kotlin.frame.*
1920
import io.rsocket.kotlin.internal.*
2021
import io.rsocket.kotlin.payload.*
2122
import kotlinx.atomicfu.*
2223
import kotlinx.coroutines.*
2324
import kotlinx.coroutines.flow.*
2425

25-
internal abstract class LimitingFlowCollector(initial: Int) : FlowCollector<Payload> {
26+
internal class LimitingFlowCollector(
27+
private val state: RSocketState,
28+
private val streamId: Int,
29+
initial: Int,
30+
) : FlowCollector<Payload> {
2631
private val requests = atomic(initial)
2732
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)
2833

29-
abstract suspend fun emitValue(value: Payload)
30-
3134
fun updateRequests(n: Int) {
3235
if (n <= 0) return
3336
requests.getAndAdd(n)
3437
awaiter.getAndSet(null)?.resumeSafely()
3538
}
3639

37-
final override suspend fun emit(value: Payload): Unit = value.closeOnError {
40+
override suspend fun emit(value: Payload): Unit = value.closeOnError {
3841
useRequest()
39-
emitValue(value)
42+
state.send(NextPayloadFrame(streamId, value))
4043
}
4144

4245
private suspend fun useRequest() {

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt

+20-27
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@ package io.rsocket.kotlin.internal.flow
1919
import io.rsocket.kotlin.*
2020
import io.rsocket.kotlin.frame.*
2121
import io.rsocket.kotlin.internal.*
22-
import io.rsocket.kotlin.internal.cancelConsumed
2322
import io.rsocket.kotlin.payload.*
2423
import kotlinx.atomicfu.*
2524
import kotlinx.coroutines.*
26-
import kotlinx.coroutines.channels.*
2725
import kotlinx.coroutines.flow.*
2826

29-
@OptIn(ExperimentalStreamsApi::class)
27+
@OptIn(ExperimentalStreamsApi::class, ExperimentalCoroutinesApi::class)
3028
internal class RequestChannelRequesterFlow(
29+
private val initPayload: Payload,
3130
private val payloads: Flow<Payload>,
3231
private val requester: RSocketRequester,
3332
private val state: RSocketState,
@@ -40,31 +39,25 @@ internal class RequestChannelRequesterFlow(
4039

4140
val strategy = currentCoroutineContext().requestStrategy()
4241
val initialRequest = strategy.firstRequest()
43-
val streamId = requester.createStream()
44-
val receiverDeferred = CompletableDeferred<ReceiveChannel<RequestFrame>?>()
45-
val request = launchCancelable(streamId) {
46-
payloads.collectLimiting(
47-
streamId,
48-
RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, initialRequest)
49-
)
50-
if (receiverDeferred.isCompleted && !receiverDeferred.isCancelled) send(CompletePayloadFrame(streamId))
51-
}
52-
request.invokeOnCompletion {
53-
if (receiverDeferred.isCompleted) {
54-
@OptIn(ExperimentalCoroutinesApi::class)
55-
if (it != null && it !is CancellationException) receiverDeferred.getCompleted()?.cancelConsumed(it)
56-
} else {
57-
if (it == null) receiverDeferred.complete(null)
58-
else receiverDeferred.completeExceptionally(it.cause ?: it)
42+
initPayload.closeOnError {
43+
val streamId = requester.createStream()
44+
val receiver = createReceiverFor(streamId)
45+
val request = launchCancelable(streamId) {
46+
payloads.collectLimiting(streamId, 0) {
47+
send(RequestChannelFrame(streamId, initialRequest, initPayload))
48+
}
49+
}
50+
51+
request.invokeOnCompletion {
52+
if (it != null && it !is CancellationException) receiver.cancelConsumed(it)
53+
}
54+
try {
55+
collectStream(streamId, receiver, strategy, collector)
56+
} catch (e: Throwable) {
57+
if (e is CancellationException) request.cancel(e)
58+
else request.cancel("Receiver failed", e)
59+
throw e
5960
}
60-
}
61-
try {
62-
val receiver = receiverDeferred.await() ?: return
63-
collectStream(streamId, receiver, strategy, collector)
64-
} catch (e: Throwable) {
65-
if (e is CancellationException) request.cancel(e)
66-
else request.cancel("Receiver failed", e)
67-
throw e
6861
}
6962
}
7063
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlowCollector.kt

-42
This file was deleted.

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamResponderFlowCollector.kt

-31
This file was deleted.

0 commit comments

Comments
 (0)