Skip to content

Commit fe95395

Browse files
committed
Support splitting OPA filter requests in batch mode by batch size
1 parent 993a088 commit fe95395

File tree

6 files changed

+376
-256
lines changed

6 files changed

+376
-256
lines changed

plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import io.airlift.configuration.Config;
1717
import io.airlift.configuration.ConfigDescription;
1818
import jakarta.validation.constraints.NotNull;
19+
import jakarta.validation.constraints.Positive;
1920

2021
import java.net.URI;
2122
import java.util.Optional;
@@ -31,6 +32,7 @@ public class OpaConfig
3132
private Optional<URI> opaRowFiltersUri = Optional.empty();
3233
private Optional<URI> opaColumnMaskingUri = Optional.empty();
3334
private Optional<URI> opaBatchColumnMaskingUri = Optional.empty();
35+
private Optional<Integer> opaBatchSize = Optional.empty();
3436

3537
@NotNull
3638
public URI getOpaUri()
@@ -46,6 +48,19 @@ public OpaConfig setOpaUri(@NotNull URI opaUri)
4648
return this;
4749
}
4850

51+
@Config("opa.policy.batch-size")
52+
@ConfigDescription("Size of a single batch for OPA requests")
53+
public OpaConfig setOpaBatchSize(Integer batchSize)
54+
{
55+
this.opaBatchSize = Optional.ofNullable(batchSize);
56+
return this;
57+
}
58+
59+
public Optional<@Positive Integer> getOpaBatchSize()
60+
{
61+
return this.opaBatchSize;
62+
}
63+
4964
@NotNull
5065
public Optional<URI> getOpaBatchUri()
5166
{

plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import static com.google.common.collect.ImmutableList.toImmutableList;
5050
import static com.google.common.collect.ImmutableMap.toImmutableMap;
5151
import static com.google.common.collect.ImmutableSet.toImmutableSet;
52+
import static com.google.common.collect.ImmutableSetMultimap.flatteningToImmutableSetMultimap;
53+
import static com.google.common.collect.Lists.partition;
5254
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
5355
import static com.google.common.net.MediaType.JSON_UTF_8;
5456
import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler;
@@ -66,6 +68,7 @@ public class OpaHttpClient
6668
private final boolean logRequests;
6769
private final boolean logResponses;
6870
private static final Logger log = Logger.get(OpaHttpClient.class);
71+
private final Optional<Integer> opaBatchSize;
6972

7073
@Inject
7174
public OpaHttpClient(
@@ -79,6 +82,7 @@ public OpaHttpClient(
7982
this.executor = requireNonNull(executor, "executor is null");
8083
this.logRequests = config.getLogRequests();
8184
this.logResponses = config.getLogResponses();
85+
this.opaBatchSize = config.getOpaBatchSize();
8286
}
8387

8488
public <T> FluentFuture<T> submitOpaRequest(OpaQueryInput input, URI uri, JsonCodec<T> deserializer)
@@ -158,25 +162,30 @@ public <T> Set<T> batchFilterFromOpa(Collection<T> items, Function<List<T>, OpaQ
158162

159163
public <K, V> Map<K, Set<V>> parallelBatchFilterFromOpa(Map<K, ? extends Collection<V>> items, BiFunction<K, List<V>, OpaQueryInput> requestBuilder, URI uri, JsonCodec<? extends OpaBatchQueryResult> deserializer)
160164
{
161-
List<Map.Entry<K, ImmutableList<V>>> parallelRequestItems = items.entrySet()
165+
List<Map.Entry<K, List<V>>> parallelRequestItems = items.entrySet()
162166
.stream()
163167
.filter(entry -> !entry.getValue().isEmpty())
164-
.map(entry -> Map.entry(entry.getKey(), ImmutableList.copyOf(entry.getValue())))
168+
.flatMap(entry -> partition(ImmutableList.copyOf(entry.getValue()), this.opaBatchSize.orElse(entry.getValue().size()))
169+
.stream()
170+
.map(partition -> Map.entry(entry.getKey(), partition)))
165171
.collect(toImmutableList());
172+
166173
return parallelRequest(
167-
parallelRequestItems,
168-
entry -> requestBuilder.apply(entry.getKey(), entry.getValue()),
169-
(entry, result) ->
170-
Optional.of(requireNonNullElse(result.result(), ImmutableList.<Integer>of()))
171-
.flatMap(indices -> indices.isEmpty() ? Optional.empty() : Optional.of(indices))
172-
.map(indices -> indices.stream()
173-
.map(index -> entry.getValue().get(index))
174-
.collect(toImmutableSet()))
175-
.map(values -> Map.entry(entry.getKey(), values)),
176-
uri,
177-
deserializer)
174+
parallelRequestItems,
175+
entry -> requestBuilder.apply(entry.getKey(), entry.getValue()),
176+
(entry, result) ->
177+
Optional.of(requireNonNullElse(result.result(), ImmutableList.<Integer>of()))
178+
.flatMap(indices -> indices.isEmpty() ? Optional.empty() : Optional.of(indices))
179+
.map(indices -> indices.stream().map(index -> entry.getValue().get(index)).collect(toImmutableSet()))
180+
.map(values -> Map.entry(entry.getKey(), values)),
181+
uri,
182+
deserializer)
183+
.stream()
184+
.collect(flatteningToImmutableSetMultimap(Map.Entry::getKey, entry -> entry.getValue().stream()))
185+
.asMap()
186+
.entrySet()
178187
.stream()
179-
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
188+
.collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableSet.copyOf(entry.getValue())));
180189
}
181190

182191
private <T> T parseOpaResponse(FullJsonResponseHandler.JsonResponse<T> response, URI uri)

plugin/trino-opa/src/test/java/io/trino/plugin/opa/RequestTestUtilities.java

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,20 @@ private RequestTestUtilities() {}
3434

3535
private static final JsonMapper jsonMapper = new JsonMapper();
3636

37+
public static JsonNode toJsonNode(String item)
38+
{
39+
try {
40+
return jsonMapper.readTree(item);
41+
}
42+
catch (IOException e) {
43+
throw new IllegalStateException("Cannot parse to a JsonNode", e);
44+
}
45+
}
46+
3747
public static void assertStringRequestsEqual(Set<String> expectedRequests, Collection<JsonNode> actualRequests, String extractPath)
3848
{
3949
Set<JsonNode> parsedExpectedRequests = expectedRequests.stream()
40-
.map(expectedRequest -> {
41-
try {
42-
return jsonMapper.readTree(expectedRequest);
43-
}
44-
catch (IOException e) {
45-
throw new AssertionError("Cannot parse expected request", e);
46-
}
47-
})
50+
.map(RequestTestUtilities::toJsonNode)
4851
.collect(toImmutableSet());
4952
Set<JsonNode> extractedActualRequests = actualRequests.stream().map(node -> node.at(extractPath)).collect(toImmutableSet());
5053
assertThat(extractedActualRequests).containsExactlyInAnyOrderElementsOf(parsedExpectedRequests);

plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public static Map<String, String> opaConfigToDict(OpaConfig config)
9292
config.getOpaRowFiltersUri().ifPresent(rowFiltersUri -> configBuilder.put("opa.policy.row-filters-uri", rowFiltersUri.toString()));
9393
config.getOpaColumnMaskingUri().ifPresent(columnMaskingUri -> configBuilder.put("opa.policy.column-masking-uri", columnMaskingUri.toString()));
9494
config.getOpaBatchColumnMaskingUri().ifPresent(batchColumnMaskingUri -> configBuilder.put("opa.policy.batch-column-masking-uri", batchColumnMaskingUri.toString()));
95+
config.getOpaBatchSize().ifPresent(batchSize -> configBuilder.put("opa.policy.batch-size", batchSize.toString()));
9596
return configBuilder.buildOrThrow();
9697
}
9798

0 commit comments

Comments
 (0)