Skip to content

Commit abfd602

Browse files
committed
Support batch-size for batch mode
1 parent 3bd1f53 commit abfd602

File tree

5 files changed

+202
-104
lines changed

5 files changed

+202
-104
lines changed

plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaConfig.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import io.airlift.configuration.Config;
1717
import io.airlift.configuration.ConfigDescription;
1818
import jakarta.validation.constraints.NotNull;
19+
import jakarta.validation.constraints.Positive;
1920

2021
import java.net.URI;
2122
import java.util.Optional;
@@ -31,6 +32,7 @@ public class OpaConfig
3132
private Optional<URI> opaRowFiltersUri = Optional.empty();
3233
private Optional<URI> opaColumnMaskingUri = Optional.empty();
3334
private Optional<URI> opaBatchColumnMaskingUri = Optional.empty();
35+
private Optional<Integer> opaBatchSize = Optional.empty();
3436

3537
@NotNull
3638
public URI getOpaUri()
@@ -46,6 +48,19 @@ public OpaConfig setOpaUri(@NotNull URI opaUri)
4648
return this;
4749
}
4850

51+
@Config("opa.policy.batch-size")
52+
@ConfigDescription("Size of a single batch for OPA requests")
53+
public OpaConfig setOpaBatchSize(Integer batchSize)
54+
{
55+
this.opaBatchSize = Optional.ofNullable(batchSize);
56+
return this;
57+
}
58+
59+
public Optional<@Positive Integer> getOpaBatchSize()
60+
{
61+
return this.opaBatchSize;
62+
}
63+
4964
@NotNull
5065
public Optional<URI> getOpaBatchUri()
5166
{

plugin/trino-opa/src/main/java/io/trino/plugin/opa/OpaHttpClient.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
import static com.google.common.collect.ImmutableList.toImmutableList;
5050
import static com.google.common.collect.ImmutableMap.toImmutableMap;
5151
import static com.google.common.collect.ImmutableSet.toImmutableSet;
52+
import static com.google.common.collect.ImmutableSetMultimap.flatteningToImmutableSetMultimap;
53+
import static com.google.common.collect.Lists.partition;
5254
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
5355
import static com.google.common.net.MediaType.JSON_UTF_8;
5456
import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler;
@@ -66,6 +68,7 @@ public class OpaHttpClient
6668
private final boolean logRequests;
6769
private final boolean logResponses;
6870
private static final Logger log = Logger.get(OpaHttpClient.class);
71+
private final Optional<Integer> opaBatchSize;
6972

7073
@Inject
7174
public OpaHttpClient(
@@ -79,6 +82,7 @@ public OpaHttpClient(
7982
this.executor = requireNonNull(executor, "executor is null");
8083
this.logRequests = config.getLogRequests();
8184
this.logResponses = config.getLogResponses();
85+
this.opaBatchSize = config.getOpaBatchSize();
8286
}
8387

8488
public <T> FluentFuture<T> submitOpaRequest(OpaQueryInput input, URI uri, JsonCodec<T> deserializer)
@@ -158,11 +162,14 @@ public <T> Set<T> batchFilterFromOpa(Collection<T> items, Function<List<T>, OpaQ
158162

159163
public <K, V> Map<K, Set<V>> parallelBatchFilterFromOpa(Map<K, ? extends Collection<V>> items, BiFunction<K, List<V>, OpaQueryInput> requestBuilder, URI uri, JsonCodec<? extends OpaBatchQueryResult> deserializer)
160164
{
161-
List<Map.Entry<K, ImmutableList<V>>> parallelRequestItems = items.entrySet()
165+
List<Map.Entry<K, List<V>>> parallelRequestItems = items.entrySet()
162166
.stream()
163167
.filter(entry -> !entry.getValue().isEmpty())
164-
.map(entry -> Map.entry(entry.getKey(), ImmutableList.copyOf(entry.getValue())))
168+
.flatMap(entry -> partition(ImmutableList.copyOf(entry.getValue()), this.opaBatchSize.orElse(entry.getValue().size()))
169+
.stream()
170+
.map(partition -> Map.entry(entry.getKey(), partition)))
165171
.collect(toImmutableList());
172+
166173
return parallelRequest(
167174
parallelRequestItems,
168175
entry -> requestBuilder.apply(entry.getKey(), entry.getValue()),
@@ -176,7 +183,11 @@ public <K, V> Map<K, Set<V>> parallelBatchFilterFromOpa(Map<K, ? extends Collect
176183
uri,
177184
deserializer)
178185
.stream()
179-
.collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
186+
.collect(flatteningToImmutableSetMultimap(Map.Entry::getKey, entry -> entry.getValue().stream()))
187+
.asMap()
188+
.entrySet()
189+
.stream()
190+
.collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableSet.copyOf(entry.getValue())));
180191
}
181192

182193
private <T> T parseOpaResponse(FullJsonResponseHandler.JsonResponse<T> response, URI uri)

plugin/trino-opa/src/test/java/io/trino/plugin/opa/TestHelpers.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ public static Map<String, String> opaConfigToDict(OpaConfig config)
9292
config.getOpaRowFiltersUri().ifPresent(rowFiltersUri -> configBuilder.put("opa.policy.row-filters-uri", rowFiltersUri.toString()));
9393
config.getOpaColumnMaskingUri().ifPresent(columnMaskingUri -> configBuilder.put("opa.policy.column-masking-uri", columnMaskingUri.toString()));
9494
config.getOpaBatchColumnMaskingUri().ifPresent(batchColumnMaskingUri -> configBuilder.put("opa.policy.batch-column-masking-uri", batchColumnMaskingUri.toString()));
95+
config.getOpaBatchSize().ifPresent(batchUri -> configBuilder.put("opa.policy.batch-size", batchUri.toString()));
9596
return configBuilder.buildOrThrow();
9697
}
9798

0 commit comments

Comments
 (0)