bayex.optimizer module

bayex.optimizer module#

Optimizer

Bayesian optimizer using Gaussian Processes and acquisition functions.

class bayex.optimizer.Optimizer(domain, acq='EI', maximize=False)

Bases: object

Bayesian optimizer using Gaussian Processes and acquisition functions.

This class manages the optimization loop for expensive black-box functions by modeling them with a Gaussian Process and selecting samples via acquisition functions such as EI, PI, UCB, or LCB.

fit(opt_state, y, new_params)

Updates optimizer state with a new observation.

Parameters:
  • opt_state – Current optimizer state.

  • y – New objective value.

  • new_params – Parameters that produced y.

Returns:

Updated OptimizerState.

init(ys, params, noise_scale=-8.0)

Initializes the optimizer state from initial data.

Parameters:
  • ys (Array) – Objective values for the initial parameters.

  • params (dict) – Dict of parameter arrays (same keys as domain).

Returns:

Initialized OptimizerState.

sample(key, state, size=10000)

Samples new parameters using the acquisition function.

Parameters:
  • key – JAX PseudoRandom key for random sampling.

  • opt_state – Current optimizer state.

  • size – Number of samples to draw.

  • has_prior – If True, also return GP predictions.

Returns:

Sampled parameters (dict), and optionally (xs_samples, means, stds).

class bayex.optimizer.OptimizerState(params: dict, ys: Array | ndarray, best_score: float, best_params: dict, mask: Array, gp_params: GPParams)

Bases: NamedTuple

Container for the state of the Bayesian optimizer.

params

Dictionary mapping parameter names to their corresponding padded JAX arrays of observed values.

Type:

dict

ys

Array of objective values associated with the observed parameters. Includes padding.

Type:

jax.Array or np.ndarray

best_score

Best observed objective value so far.

Type:

float

best_params

Parameter configuration corresponding to the best_score.

Type:

dict

mask

Boolean array indicating which entries in params and ys are valid (i.e., not padding).

Type:

jax.Array

gp_params

Parameters of the Gaussian Process fitted to the observations.

Type:

GPParams

best_params: dict

Alias for field number 3

best_score: float

Alias for field number 2

count(value, /)

Return number of occurrences of value.

gp_params: GPParams

Alias for field number 5

index(value, start=0, stop=9223372036854775807, /)

Return first index of value.

Raises ValueError if the value is not present.

mask: Array

Alias for field number 4

params: dict

Alias for field number 0

ys: Union[Array, ndarray]

Alias for field number 1