diff --git a/botorch/acquisition/utils.py b/botorch/acquisition/utils.py index a930488680..f19ea2f84f 100644 --- a/botorch/acquisition/utils.py +++ b/botorch/acquisition/utils.py @@ -534,6 +534,7 @@ def get_optimal_samples( posterior_transform: ScalarizedPosteriorTransform | None = None, objective: MCAcquisitionObjective | None = None, return_transformed: bool = False, + options: dict | None = None, ) -> tuple[Tensor, Tensor]: """Draws sample paths from the posterior and maximizes the samples using GD. @@ -551,7 +552,8 @@ def get_optimal_samples( objective: An MCAcquisitionObjective, used to negate the objective or otherwise transform sample outputs. Cannot be combined with `posterior_transform`. return_transformed: If True, return the transformed samples. - + options: Options for generation of initial candidates, passed to + gen_batch_initial_conditions. Returns: The optimal input locations and corresponding outputs, x* and f*. @@ -576,6 +578,12 @@ def get_optimal_samples( sample_transform = None paths = get_matheron_path_model(model=model, sample_shape=torch.Size([num_optima])) + suggested_points = prune_inferior_points( + model=model, + X=model.train_inputs[0], + posterior_transform=posterior_transform, + objective=objective, + ) optimal_inputs, optimal_outputs = optimize_posterior_samples( paths=paths, bounds=bounds, @@ -583,5 +591,7 @@ def get_optimal_samples( num_restarts=num_restarts, sample_transform=sample_transform, return_transformed=return_transformed, + suggested_points=suggested_points, + options=options, ) return optimal_inputs, optimal_outputs diff --git a/botorch/optim/initializers.py b/botorch/optim/initializers.py index e5e81f3dcc..040e2b44b3 100644 --- a/botorch/optim/initializers.py +++ b/botorch/optim/initializers.py @@ -468,6 +468,128 @@ def gen_batch_initial_conditions( return batch_initial_conditions +def gen_optimal_input_initial_conditions( + acq_function: AcquisitionFunction, + bounds: Tensor, + q: int, + num_restarts: int, + raw_samples: int, + fixed_features: dict[int, float] | None = None, + options: dict[str, bool | float | int] | None = None, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, +): + r"""Generate a batch of initial conditions for random-restart optimziation of + information-theoretic acquisition functions (PES & JES), where sampled optimizers + of the posterior constitute good initial guesses for further optimization. A + fraction of initial samples (by default: 100%) are drawn as perturbations around + `acq.optimal_inputs`. On average, this drastically decreases the runtime of + acquisition function optimization and yields higher-valued candidates by acquisition + function value. See https://github.com/pytorch/botorch/pull/2751 for more info. + + Args: + acq_function: The acquisition function to be optimized. + bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`. + q: The number of candidates to consider. + num_restarts: The number of starting points for multistart acquisition + function optimization. + raw_samples: The number of raw samples to consider in the initialization + heuristic. Note: if `sample_around_best` is True (the default is False), + then `2 * raw_samples` samples are used. + fixed_features: A map `{feature_index: value}` for features that + should be fixed to a particular value during generation. + options: Options for initial condition generation. These contain all + settings for the standard heuristic initialization from + `gen_batch_initial_conditions`. In addition, they contain + `frac_random` (the fraction of points drawn fully at random as opposed + to around the drawn optimizers from the posterior). + `sample_around_best_sigma` dictates both the standard deviation of the + samples drawn from posterior maximizers, and the samples from previous + best (if enabled). + inequality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. + equality constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. + + Returns: + A `num_restarts x q x d` tensor of initial conditions. + """ + options = options or {} + device = bounds.device + if not hasattr(acq_function, "optimal_inputs"): + raise AttributeError( + "gen_optimal_input_initial_conditions can only be used with " + "an AcquisitionFunction that has an optimal_inputs attribute." + ) + frac_random: float = options.get("frac_random", 0.0) + if not 0 <= frac_random <= 1: + raise ValueError( + f"frac_random must take on values in (0,1). Value: {frac_random}" + ) + + batch_limit = options.get("batch_limit") + num_optima = acq_function.optimal_inputs.shape[:-1].numel() + suggestions = acq_function.optimal_inputs.reshape(num_optima, -1) + X = torch.empty(0, q, bounds.shape[1], dtype=bounds.dtype) + num_random = round(raw_samples * frac_random) + if num_random > 0: + X_rnd = sample_q_batches_from_polytope( + n=num_random, + q=q, + bounds=bounds, + n_burnin=options.get("n_burnin", 10000), + n_thinning=options.get("n_thinning", 32), + equality_constraints=equality_constraints, + inequality_constraints=inequality_constraints, + ) + X = torch.cat((X, X_rnd)) + + if num_random < raw_samples: + X_perturbed = sample_points_around_best( + acq_function=acq_function, + n_discrete_points=q * (raw_samples - num_random), + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + best_X=suggestions, + ) + X_perturbed = X_perturbed.view( + raw_samples - num_random, q, bounds.shape[-1] + ).cpu() + X = torch.cat((X, X_perturbed)) + + if options.get("sample_around_best", False): + X_best = sample_points_around_best( + acq_function=acq_function, + n_discrete_points=q * raw_samples, + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + ) + X_best = X_best.view(raw_samples, q, bounds.shape[-1]).cpu() + X = torch.cat((X, X_best)) + + X_rnd = fix_features(X, fixed_features=fixed_features).cpu() + with torch.no_grad(): + if batch_limit is None: + batch_limit = X.shape[0] + # Evaluate the acquisition function on `X_rnd` using `batch_limit` + # sized chunks. + acq_vals = torch.cat( + [ + acq_function(x_.to(device=device)).cpu() + for x_ in X.split(split_size=batch_limit, dim=0) + ], + dim=0, + ) + idx = boltzmann_sample( + function_values=acq_vals, + num_samples=num_restarts, + eta=options.get("eta", 2.0), + ) + return X[idx] + + def gen_one_shot_kg_initial_conditions( acq_function: qKnowledgeGradient, bounds: Tensor, @@ -1136,6 +1258,7 @@ def sample_points_around_best( best_pct: float = 5.0, subset_sigma: float = 1e-1, prob_perturb: float | None = None, + best_X: Tensor | None = None, ) -> Tensor | None: r"""Find best points and sample nearby points. @@ -1154,60 +1277,62 @@ def sample_points_around_best( An optional `n_discrete_points x d`-dim tensor containing the sampled points. This is None if no baseline points are found. """ - X = get_X_baseline(acq_function=acq_function) - if X is None: - return - with torch.no_grad(): - try: - posterior = acq_function.model.posterior(X) - except AttributeError: - warnings.warn( - "Failed to sample around previous best points.", - BotorchWarning, - stacklevel=3, - ) + if best_X is None: + X = get_X_baseline(acq_function=acq_function) + if X is None: return - mean = posterior.mean - while mean.ndim > 2: - # take average over batch dims - mean = mean.mean(dim=0) - try: - f_pred = acq_function.objective(mean) - # Some acquisition functions do not have an objective - # and for some acquisition functions the objective is None - except (AttributeError, TypeError): - f_pred = mean - if hasattr(acq_function, "maximize"): - # make sure that the optimiztaion direction is set properly - if not acq_function.maximize: - f_pred = -f_pred - try: - # handle constraints for EHVI-based acquisition functions - constraints = acq_function.constraints - if constraints is not None: - neg_violation = -torch.stack( - [c(mean).clamp_min(0.0) for c in constraints], dim=-1 - ).sum(dim=-1) - feas = neg_violation == 0 - if feas.any(): - f_pred[~feas] = float("-inf") - else: - # set objective equal to negative violation - f_pred = neg_violation - except AttributeError: - pass - if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: - # multi-objective - # find pareto set - is_pareto = is_non_dominated(f_pred) - best_X = X[is_pareto] - else: - if f_pred.shape[-1] == 1: - f_pred = f_pred.squeeze(-1) - n_best = max(1, round(X.shape[0] * best_pct / 100)) - # the view() is to ensure that best_idcs is not a scalar tensor - best_idcs = torch.topk(f_pred, n_best).indices.view(-1) - best_X = X[best_idcs] + with torch.no_grad(): + try: + posterior = acq_function.model.posterior(X) + except AttributeError: + warnings.warn( + "Failed to sample around previous best points.", + BotorchWarning, + stacklevel=3, + ) + return + mean = posterior.mean + while mean.ndim > 2: + # take average over batch dims + mean = mean.mean(dim=0) + try: + f_pred = acq_function.objective(mean) + # Some acquisition functions do not have an objective + # and for some acquisition functions the objective is None + except (AttributeError, TypeError): + f_pred = mean + if hasattr(acq_function, "maximize"): + # make sure that the optimiztaion direction is set properly + if not acq_function.maximize: + f_pred = -f_pred + try: + # handle constraints for EHVI-based acquisition functions + constraints = acq_function.constraints + if constraints is not None: + neg_violation = -torch.stack( + [c(mean).clamp_min(0.0) for c in constraints], dim=-1 + ).sum(dim=-1) + feas = neg_violation == 0 + if feas.any(): + f_pred[~feas] = float("-inf") + else: + # set objective equal to negative violation + f_pred = neg_violation + except AttributeError: + pass + if f_pred.ndim == mean.ndim and f_pred.shape[-1] > 1: + # multi-objective + # find pareto set + is_pareto = is_non_dominated(f_pred) + best_X = X[is_pareto] + else: + if f_pred.shape[-1] == 1: + f_pred = f_pred.squeeze(-1) + n_best = max(1, round(X.shape[0] * best_pct / 100)) + # the view() is to ensure that best_idcs is not a scalar tensor + best_idcs = torch.topk(f_pred, n_best).indices.view(-1) + best_X = X[best_idcs] + use_perturbed_sampling = best_X.shape[-1] >= 20 or prob_perturb is not None n_trunc_normal_points = ( n_discrete_points // 2 if use_perturbed_sampling else n_discrete_points diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 6f3a5876a9..8d56ddd0d4 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -20,6 +20,7 @@ AcquisitionFunction, OneShotAcquisitionFunction, ) +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( qHypervolumeKnowledgeGradient, @@ -33,6 +34,7 @@ gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, TGenInitialConditions, ) from botorch.optim.stopping import ExpMAStoppingCriterion @@ -174,6 +176,8 @@ def get_ic_generator(self) -> TGenInitialConditions: return gen_one_shot_kg_initial_conditions elif isinstance(self.acq_function, qHypervolumeKnowledgeGradient): return gen_one_shot_hvkg_initial_conditions + elif isinstance(self.acq_function, qJointEntropySearch): + return gen_optimal_input_initial_conditions return gen_batch_initial_conditions diff --git a/botorch/utils/sampling.py b/botorch/utils/sampling.py index 7066578b9d..c571009e72 100644 --- a/botorch/utils/sampling.py +++ b/botorch/utils/sampling.py @@ -999,10 +999,12 @@ def sparse_to_dense_constraints( def optimize_posterior_samples( paths: GenericDeterministicModel, bounds: Tensor, - raw_samples: int = 1024, - num_restarts: int = 20, + raw_samples: int = 2048, + num_restarts: int = 4, sample_transform: Callable[[Tensor], Tensor] | None = None, return_transformed: bool = False, + suggested_points: Tensor | None = None, + options: dict | None = None, ) -> tuple[Tensor, Tensor]: r"""Cheaply maximizes posterior samples by random querying followed by gradient-based optimization using SciPy's L-BFGS-B routine. @@ -1011,12 +1013,19 @@ def optimize_posterior_samples( paths: Random Fourier Feature-based sample paths from the GP bounds: The bounds on the search space. raw_samples: The number of samples with which to query the samples initially. + Raw samples are cheap to evaluate, so this should ideally be set much higher + than num_restarts. num_restarts: The number of points selected for gradient-based optimization. + Should be set low relative to the number of raw samples for time-efficiency. sample_transform: A callable transform of the sample outputs (e.g. MCAcquisitionObjective or ScalarizedPosteriorTransform.evaluate) used to negate the objective or otherwise transform the output. return_transformed: A boolean indicating whether to return the transformed or non-transformed samples. + suggested_points: Tensor of suggested input locations that are high-valued. + These are more densely evaluated during the sampling phase of optimization. + options: Options for generation of initial candidates, passed to + gen_batch_initial_conditions. Returns: A two-element tuple containing: @@ -1024,6 +1033,7 @@ def optimize_posterior_samples( - f_opt: A `num_optima x [batch_size] x m`-dim, optionally `num_optima x [batch_size] x 1`-dim, tensor of optimal outputs f*. """ + options = {} if options is None else options def path_func(x) -> Tensor: res = paths(x) @@ -1032,21 +1042,35 @@ def path_func(x) -> Tensor: return res.squeeze(-1) - candidate_set = unnormalize( - SobolEngine(dimension=bounds.shape[1], scramble=True).draw(n=raw_samples), - bounds=bounds, - ) # queries all samples on all candidates - output shape # raw_samples * num_optima * num_models + frac_random = 1 if suggested_points is None else options.get("frac_random", 0.9) + candidate_set = draw_sobol_samples( + bounds=bounds, n=round(raw_samples * frac_random), q=1 + ).squeeze(-2) + if frac_random < 1: + perturbed_suggestions = sample_truncated_normal_perturbations( + X=suggested_points, + n_discrete_points=round(raw_samples * (1 - frac_random)), + sigma=options.get("sample_around_best_sigma", 1e-2), + bounds=bounds, + ) + candidate_set = torch.cat((candidate_set, perturbed_suggestions)) + candidate_queries = path_func(candidate_set) - argtop_k = torch.topk(candidate_queries, num_restarts, dim=-1).indices - X_top_k = candidate_set[argtop_k, :] + idx = boltzmann_sample( + function_values=candidate_queries.unsqueeze(-1), + num_samples=num_restarts, + eta=options.get("eta", 2.0), + replacement=False, + ) + ics = candidate_set[idx, :] # to avoid circular import, the import occurs here from botorch.generation.gen import gen_candidates_scipy X_top_k, f_top_k = gen_candidates_scipy( - X_top_k, + ics, path_func, lower_bounds=bounds[0], upper_bounds=bounds[1], @@ -1101,8 +1125,9 @@ def boltzmann_sample( eta *= temp_decrease weights = torch.exp(eta * norm_weights) + # squeeze in case of m = 1 (mono-output provided as batch_size x N x 1) return batched_multinomial( - weights=weights, num_samples=num_samples, replacement=replacement + weights=weights.squeeze(-1), num_samples=num_samples, replacement=replacement ) diff --git a/test/acquisition/test_utils.py b/test/acquisition/test_utils.py index b8115ba0af..065ceb6e12 100644 --- a/test/acquisition/test_utils.py +++ b/test/acquisition/test_utils.py @@ -33,6 +33,7 @@ UnsupportedError, ) from botorch.models import SingleTaskGP +from botorch.utils.test_helpers import get_fully_bayesian_model, get_model from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior from gpytorch.distributions import MultivariateNormal @@ -413,17 +414,14 @@ def test_project_to_sample_points(self): class TestGetOptimalSamples(BotorchTestCase): - def test_get_optimal_samples(self): - dims = 3 - dtype = torch.float64 + def _test_get_optimal_samples_base(self, model): + dims = model.train_inputs[0].shape[1] + dtype = model.train_targets.dtype + batch_shape = model.batch_shape for_testing_speed_kwargs = {"raw_samples": 20, "num_restarts": 2} num_optima = 7 - batch_shape = (3,) bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T - X = torch.rand(*batch_shape, 4, dims, dtype=dtype) - Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype) - model = SingleTaskGP(train_X=X, train_Y=Y) posterior_transform = ScalarizedPosteriorTransform( weights=torch.ones(1, dtype=dtype) ) @@ -438,6 +436,7 @@ def test_get_optimal_samples(self): num_optima=num_optima, **for_testing_speed_kwargs, ) + correct_X_shape = (num_optima,) + batch_shape + (dims,) correct_f_shape = (num_optima,) + batch_shape + (1,) self.assertEqual(X_opt_def.shape, correct_X_shape) @@ -519,6 +518,22 @@ def test_get_optimal_samples(self): **for_testing_speed_kwargs, ) + def test_optimal_samples(self): + dims = 3 + dtype = torch.float64 + X = torch.rand(4, dims, dtype=dtype) + Y = torch.sin(2 * 3.1415 * X).sum(dim=-1, keepdim=True).to(dtype) + model = get_model(train_X=X, train_Y=Y) + self._test_get_optimal_samples_base(model) + fully_bayesian_model = get_fully_bayesian_model( + train_X=X, + train_Y=Y, + num_models=4, + standardize_model=True, + infer_noise=True, + ) + self._test_get_optimal_samples_base(fully_bayesian_model) + class TestPreferenceUtils(BotorchTestCase): def test_repeat_to_match_aug_dim(self): diff --git a/test/optim/test_initializers.py b/test/optim/test_initializers.py index 187a09d7f3..5d2117a069 100644 --- a/test/optim/test_initializers.py +++ b/test/optim/test_initializers.py @@ -13,6 +13,7 @@ import torch from botorch.acquisition.analytic import PosteriorMean from botorch.acquisition.fixed_feature import FixedFeatureAcquisitionFunction +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import ( qExpectedImprovement, @@ -34,6 +35,7 @@ gen_batch_initial_conditions, gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, gen_value_function_initial_conditions, initialize_q_batch, initialize_q_batch_nonneg, @@ -47,6 +49,7 @@ ) from botorch.sampling.normal import IIDNormalSampler from botorch.utils.sampling import manual_seed, unnormalize +from botorch.utils.test_helpers import get_model from botorch.utils.testing import ( _get_max_violation_of_bounds, _get_max_violation_of_constraints, @@ -1074,6 +1077,110 @@ def test_gen_one_shot_kg_initial_conditions(self): ) self.assertTrue(torch.all(ics[..., -n_value:, :] == 1)) + def test_gen_optimal_input_initial_conditions(self): + num_restarts = 10 + raw_samples = 16 + q = 3 + for dtype in (torch.float, torch.double): + model = get_model( + torch.rand(4, 2, dtype=dtype), torch.rand(4, 1, dtype=dtype) + ) + optimal_inputs = torch.rand(5, 2, dtype=dtype) + optimal_outputs = torch.rand(5, 1, dtype=dtype) + jes = qJointEntropySearch( + model=model, + optimal_inputs=optimal_inputs, + optimal_outputs=optimal_outputs, + ) + bounds = torch.tensor([[0, 0], [1, 1]], device=self.device, dtype=dtype) + # base case + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + ) + self.assertEqual(ics.shape, torch.Size([num_restarts, q, 2])) + + # since we do sample_around best, this should generate enough points + # despite num_restarts being larger than raw_samples + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=15, + raw_samples=8, + options={"frac_random": 0.2, "sample_around_best": True}, + ) + self.assertEqual(ics.shape, torch.Size([15, q, 2])) + + # test option error + with self.assertRaises(ValueError): + gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=1, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 2.0}, + ) + + ei = qExpectedImprovement(model, 99.9) + with self.assertRaisesRegex( + AttributeError, + "gen_optimal_input_initial_conditions can only be used with " + "an AcquisitionFunction that has an optimal_inputs attribute.", + ): + gen_optimal_input_initial_conditions( + acq_function=ei, + bounds=bounds, + q=1, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 2.0}, + ) + # test generation logic + random_ics = torch.rand(raw_samples // 2, q, 2) + suggested_ics = torch.rand(raw_samples // 2 * q, 2) + with ExitStack() as es: + mock_random_ics = es.enter_context( + mock.patch( + "botorch.optim.initializers.sample_q_batches_from_polytope", + return_value=random_ics, + ) + ) + mock_suggested_ics = es.enter_context( + mock.patch( + "botorch.optim.initializers.sample_points_around_best", + return_value=suggested_ics, + ) + ) + mock_choose = es.enter_context( + mock.patch( + "torch.multinomial", + return_value=torch.arange(0, 10), + ) + ) + + ics = gen_optimal_input_initial_conditions( + acq_function=jes, + bounds=bounds, + q=q, + num_restarts=num_restarts, + raw_samples=raw_samples, + options={"frac_random": 0.5}, + ) + + mock_suggested_ics.assert_called_once() + mock_random_ics.assert_called_once() + mock_choose.assert_called_once() + + expected_result = torch.cat( + (random_ics, suggested_ics.view(raw_samples // 2, q, 2)[0:2]) + ) + self.assertTrue(torch.equal(ics, expected_result)) + class TestGenOneShotHVKGInitialConditions(BotorchTestCase): def test_gen_one_shot_hvkg_initial_conditions(self): diff --git a/test/optim/test_optimize.py b/test/optim/test_optimize.py index 8d8be47ea0..95f476b79a 100644 --- a/test/optim/test_optimize.py +++ b/test/optim/test_optimize.py @@ -18,6 +18,7 @@ AcquisitionFunction, OneShotAcquisitionFunction, ) +from botorch.acquisition.joint_entropy_search import qJointEntropySearch from botorch.acquisition.knowledge_gradient import qKnowledgeGradient from botorch.acquisition.monte_carlo import qExpectedImprovement from botorch.acquisition.multi_objective.hypervolume_knowledge_gradient import ( @@ -32,6 +33,7 @@ from botorch.optim.initializers import ( gen_one_shot_hvkg_initial_conditions, gen_one_shot_kg_initial_conditions, + gen_optimal_input_initial_conditions, ) from botorch.optim.optimize import ( _combine_initial_conditions, @@ -2068,6 +2070,15 @@ def test_get_ic_generator(self): ic_generator = opt_inputs.get_ic_generator() self.assertIs(ic_generator, gen_one_shot_kg_initial_conditions) + acqf = qJointEntropySearch( + model=m1, optimal_inputs=torch.rand(5, 3), optimal_outputs=torch.rand(5, 1) + ) + opt_inputs = OptimizeAcqfInputs( + acq_function=acqf, bounds=bounds, q=1, num_restarts=1, **kwargs + ) + ic_generator = opt_inputs.get_ic_generator() + self.assertIs(ic_generator, gen_optimal_input_initial_conditions) + def my_gen(): pass diff --git a/test/utils/test_sampling.py b/test/utils/test_sampling.py index 62dd0b5bbd..a56122a394 100644 --- a/test/utils/test_sampling.py +++ b/test/utils/test_sampling.py @@ -578,9 +578,13 @@ def test_optimize_posterior_samples(self): dims = 2 dtype = torch.float64 eps = 1e-4 - for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4} - nums_optima = (1, 7) - batch_shapes = ((), (2,), (3, 2)) + for_testing_speed_kwargs = { + "raw_samples": 64, + "num_restarts": 2, + "options": {"eta": 10}, + } + nums_optima = (1, 5) + batch_shapes = ((), (3,)) posterior_transforms = ( None, ScalarizedPosteriorTransform(weights=-torch.ones(1, dtype=dtype)), @@ -589,16 +593,19 @@ def test_optimize_posterior_samples(self): nums_optima, batch_shapes, posterior_transforms ): bounds = torch.tensor([[0, 1]] * dims, dtype=dtype).T - X = torch.rand(*batch_shape, 4, dims, dtype=dtype) + X = torch.rand(*batch_shape, 3, dims, dtype=dtype) Y = torch.pow(X - 0.5, 2).sum(dim=-1, keepdim=True) # having a noiseless model all but guarantees that the found optima # will be better than the observations - model = SingleTaskGP(X, Y, torch.full_like(Y, eps)) + model = SingleTaskGP( + train_X=X, train_Y=Y, train_Yvar=torch.full_like(Y, eps) + ) model.covar_module.lengthscale = 0.5 paths = get_matheron_path_model( model=model, sample_shape=torch.Size([num_optima]) ) + X_opt, f_opt = optimize_posterior_samples( paths=paths, bounds=bounds, @@ -616,8 +623,6 @@ def test_optimize_posterior_samples(self): self.assertTrue(torch.all(X_opt >= bounds[0])) self.assertTrue(torch.all(X_opt <= bounds[1])) - # Check that the all found optima are larger than the observations - # This is not 100% deterministic, but just about. Y_queries = paths(X) # this is when we negate, so the values should be smaller if posterior_transform: @@ -642,7 +647,7 @@ def test_optimize_posterior_samples_multi_objective(self): dims = 2 dtype = torch.float64 eps = 1e-4 - for_testing_speed_kwargs = {"raw_samples": 128, "num_restarts": 4} + for_testing_speed_kwargs = {"raw_samples": 64, "num_restarts": 2} num_optima = 5 batch_shape = (3,)