diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index a930488680..f19ea2f84f 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -534,6 +534,7 @@ def get_optimal_samples( posterior_transform: ScalarizedPosteriorTransform | None = None, objective: MCAcquisitionObjective | None = None, return_transformed: bool = False, + options: dict | None = None, ) -> tuple[Tensor, Tensor]: """Draws sample paths from the posterior and maximizes the samples using GD. @@ -551,7 +552,8 @@ def get_optimal_samples( objective: An MCAcquisitionObjective, used to negate the objective or otherwise transform sample outputs. Cannot be combined with `posterior_transform`. return_transformed: If True, return the transformed samples. - + options: Options for generation of initial candidates, passed to + gen_batch_initial_conditions. Returns: The optimal input locations and corresponding outputs, x* and f*. @@ -576,6 +578,12 @@ def get_optimal_samples( sample_transform = None paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima])) + suggested_points = prune_inferior_points( + model=model, + X=model.train_inputs[0], + posterior_transform=posterior_transform, + objective=objective, + ) optimal_inputs, optimal_outputs = optimize_posterior_samples( paths=paths, bounds=bounds, @@ -583,5 +591,7 @@ def get_optimal_samples( num_restarts=num_restarts, sample_transform=sample_transform, return_transformed=return_transformed, + suggested_points=suggested_points, + options=options, ) return optimal_inputs, optimal_outputs diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 7066578b9d..c571009e72 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -999,10 +999,12 @@ def sparse_to_dense_constraints( def optimize_posterior_samples( paths: GenericDeterministicModel, bounds: Tensor, - raw_samples: int = 1024, - num_restarts: int = 20, + raw_samples: int = 2048, + num_restarts: int = 4, sample_transform: Callable[[Tensor], Tensor] | None = None, return_transformed: bool = False, + suggested_points: Tensor | None = None, + options: dict | None = None, ) -> tuple[Tensor, Tensor]: r"""Cheaply maximizes posterior samples by random querying followed by gradient-based optimization using SciPy's L-BFGS-B routine. @@ -1011,12 +1013,19 @@ def optimize_posterior_samples( paths: Random Fourier Feature-based sample paths from the GP bounds: The bounds on the search space. raw_samples: The number of samples with which to query the samples initially. + Raw samples are cheap to evaluate, so this should ideally be set much higher + than num_restarts. num_restarts: The number of points selected for gradient-based optimization. + Should be set low relative to the number of raw samples for time-efficiency. sample_transform: A callable transform of the sample outputs (e.g. MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to negate the objective or otherwise transform the output. return_transformed: A boolean indicating whether to return the transformed or non-transformed samples. + suggested_points: Tensor of suggested input locations that are high-valued. + These are more densely evaluated during the sampling phase of optimization. + options: Options for generation of initial candidates, passed to + gen_batch_initial_conditions. Returns: A two-element tuple containing: @@ -1024,6 +1033,7 @@ def optimize_posterior_samples( - f_opt: A `num_optima x [batch_size] x m`-dim, optionally `num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*. """ + options = {} if options is None else options def path_func(x) -> Tensor: res = paths(x) @@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor: return res.squeeze(-1) - candidate_set = unnormalize( - SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples), - bounds=bounds, - ) # queries all samples on all candidates - output shape # raw_samples * num_optima * num_models + frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9) + candidate_set = draw_sobol_samples( + bounds=bounds, n=round(raw_samples * frac_random), q=1 + ).squeeze(-2) + if frac_random < 1: + perturbed_suggestions = sample_truncated_normal_perturbations( + X=suggested_points, + n_discrete_points=round(raw_samples * (1 - frac_random)), + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + ) + candidate_set = torch.cat((candidate_set, perturbed_suggestions)) + candidate_queries = path_func(candidate_set) - argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices - X_top_k = candidate_set[argtop_k, :] + idx = boltzmann_sample( + function_values=candidate_queries.unsqueeze(-1), + num_samples=num_restarts, + eta=options.get("eta", 2.0), + replacement=False, + ) + ics = candidate_set[idx, :] # to avoid circular import, the import occurs here from botorch.generation.gen import gen_candidates_scipy X_top_k, f_top_k = gen_candidates_scipy( - X_top_k, + ics, path_func, lower_bounds=bounds[0], upper_bounds=bounds[1], @@ -1101,8 +1125,9 @@ def boltzmann_sample( eta *= temp_decrease weights = torch.exp(eta * norm_weights) + # squeeze in case of m = 1 (mono-output provided as batch_size x N x 1) return batched_multinomial( - weights=weights, num_samples=num_samples, replacement=replacement + weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement ) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index b8115ba0af..065ceb6e12 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -33,6 +33,7 @@ UnsupportedError, ) from botorch.models import SingleTaskGP +from botorch.utils.test_helpers import get_fully_bayesian_model, get_model from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from gpytorch.distributions import MultivariateNormal @@ -413,17 +414,14 @@ def test_project_to_sample_points(self): class TestGetOptimalSamples(BotorchTestCase): - def test_get_optimal_samples(self): - dims = 3 - dtype = torch.float64 + def _test_get_optimal_samples_base(self, model): + dims = model.train_inputs[0].shape[1] + dtype = model.train_targets.dtype + batch_shape = model.batch_shape for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2} num_optima = 7 - batch_shape = (3,) bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T - X = torch.rand(*batch_shape, 4, dims, dtype=dtype) - Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype) - model = SingleTaskGP(train_X=X, train_Y=Y) posterior_transform = ScalarizedPosteriorTransform( weights=torch.ones(1, dtype=dtype) ) @@ -438,6 +436,7 @@ def test_get_optimal_samples(self): num_optima=num_optima, **for_testing_speed_kwargs, ) + correct_X_shape = (num_optima,) + batch_shape + (dims,) correct_f_shape = (num_optima,) + batch_shape + (1,) self.assertEqual(X_opt_def.shape, correct_X_shape) @@ -519,6 +518,22 @@ def test_get_optimal_samples(self): **for_testing_speed_kwargs, ) + def test_optimal_samples(self): + dims = 3 + dtype = torch.float64 + X = torch.rand(4, dims, dtype=dtype) + Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype) + model = get_model(train_X=X, train_Y=Y) + self._test_get_optimal_samples_base(model) + fully_bayesian_model = get_fully_bayesian_model( + train_X=X, + train_Y=Y, + num_models=4, + standardize_model=True, + infer_noise=True, + ) + self._test_get_optimal_samples_base(fully_bayesian_model) + class TestPreferenceUtils(BotorchTestCase): def test_repeat_to_match_aug_dim(self): diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index 62dd0b5bbd..a56122a394 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -578,9 +578,13 @@ def test_optimize_posterior_samples(self): dims = 2 dtype = torch.float64 eps = 1e-4 - for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4} - nums_optima = (1, 7) - batch_shapes = ((), (2,), (3, 2)) + for_testing_speed_kwargs = { + "raw_samples": 64, + "num_restarts": 2, + "options": {"eta": 10}, + } + nums_optima = (1, 5) + batch_shapes = ((), (3,)) posterior_transforms = ( None, ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)), @@ -589,16 +593,19 @@ def test_optimize_posterior_samples(self): nums_optima, batch_shapes, posterior_transforms ): bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T - X = torch.rand(*batch_shape, 4, dims, dtype=dtype) + X = torch.rand(*batch_shape, 3, dims, dtype=dtype) Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True) # having a noiseless model all but guarantees that the found optima # will be better than the observations - model = SingleTaskGP(X, Y, torch.full_like(Y, eps)) + model = SingleTaskGP( + train_X=X, train_Y=Y, train_Yvar=torch.full_like(Y, eps) + ) model.covar_module.lengthscale = 0.5 paths = get_matheron_path_model( model=model, sample_shape=torch.Size([num_optima]) ) + X_opt, f_opt = optimize_posterior_samples( paths=paths, bounds=bounds, @@ -616,8 +623,6 @@ def test_optimize_posterior_samples(self): self.assertTrue(torch.all(X_opt >= bounds[0])) self.assertTrue(torch.all(X_opt <= bounds[1])) - # Check that the all found optima are larger than the observations - # This is not 100% deterministic, but just about. Y_queries = paths(X) # this is when we negate, so the values should be smaller if posterior_transform: @@ -642,7 +647,7 @@ def test_optimize_posterior_samples_multi_objective(self): dims = 2 dtype = torch.float64 eps = 1e-4 - for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4} + for_testing_speed_kwargs = {"raw_samples": 64, "num_restarts": 2} num_optima = 5 batch_shape = (3,)