bayex.domain module

bayex.domain module#

Real

Continuous real-valued domain with clipping.

Integer

Discrete integer-valued domain with rounding and clipping.

class bayex.domain.Domain(dtype)

Bases: object

sample(key, shape)
transform(x)
class bayex.domain.Integer(lower, upper)

Bases: Domain

Discrete integer-valued domain with rounding and clipping.

Represents a parameter that can take integer values within [lower, upper].

sample(key, shape)

Samples integers uniformly from the domain.

Parameters:
  • key (Array) – JAX PRNGKey.

  • shape (Tuple) – Desired output shape.

Returns:

Sampled values clipped to valid integer range.

transform(x)

Rounds and clips values to the integer domain.

Parameters:

x (Array) – Input values.

Returns:

Rounded and clipped values as float32.

class bayex.domain.ParamSpace(space)

Bases: object

Internal class that manages a collection of named parameter domains.

This utility encapsulates logic for sampling, transforming, and handling structured parameter inputs defined by a mapping of variable names to Domain instances (e.g., Real, Integer).

Example

>>> space = ParamSpace({
...     "x1": Real(0.0, 1.0),
...     "x2": Integer(1, 5)
... })
>>> key = jax.random.PRNGKey(0)
>>> samples = space.sample_tree(key, (128,))
>>> xs = space.transform_tree(samples)

Notes

This class is intended for internal use by the optimizer and should not be exposed as part of the public API.

sample_params(key, shape)
Return type:

dict

to_array(tree)

Transforms a batch of parameter values into a 2D array suitable for GP input.

Applies each domain’s .transform() to its corresponding parameter values.

Parameters:

tree (dict) – A dictionary of parameter name → array of raw values.

Return type:

Array

Returns:

A JAX array of shape (batch_size, num_params) with transformed values.

to_dict(xs)

Converts a stacked parameter matrix back into named parameter trees.

Typically used after optimization in transformed space.

Parameters:

xs (Array) – A 2D JAX array of shape (batch_size, num_params), with each column corresponding to a parameter.

Return type:

dict

Returns:

A dictionary mapping parameter names to individual 1D arrays.

class bayex.domain.Real(lower, upper)

Bases: Domain

Continuous real-valued domain with clipping.

Represents a parameter that can take real values within [lower, upper].

sample(key, shape)

Samples uniformly from the domain.

Parameters:
  • key (Array) – JAX PRNGKey.

  • shape (Tuple) – Desired output shape.

Returns:

Sampled values clipped to the domain.

transform(x)

Clips values to the domain range [lower, upper].

Parameters:

x (Array) – Input values.

Returns:

Clipped values within bounds.