bayex.domain module#
Continuous real-valued domain with clipping. |
|
Discrete integer-valued domain with rounding and clipping. |
- class bayex.domain.Domain(dtype)
Bases:
object- sample(key, shape)
- transform(x)
- class bayex.domain.Integer(lower, upper)
Bases:
DomainDiscrete integer-valued domain with rounding and clipping.
Represents a parameter that can take integer values within [lower, upper].
- sample(key, shape)
Samples integers uniformly from the domain.
- Parameters:
key (
Array) – JAX PRNGKey.shape (
Tuple) – Desired output shape.
- Returns:
Sampled values clipped to valid integer range.
- transform(x)
Rounds and clips values to the integer domain.
- Parameters:
x (
Array) – Input values.- Returns:
Rounded and clipped values as float32.
- class bayex.domain.ParamSpace(space)
Bases:
objectInternal class that manages a collection of named parameter domains.
This utility encapsulates logic for sampling, transforming, and handling structured parameter inputs defined by a mapping of variable names to Domain instances (e.g., Real, Integer).
Example
>>> space = ParamSpace({ ... "x1": Real(0.0, 1.0), ... "x2": Integer(1, 5) ... }) >>> key = jax.random.PRNGKey(0) >>> samples = space.sample_tree(key, (128,)) >>> xs = space.transform_tree(samples)
Notes
This class is intended for internal use by the optimizer and should not be exposed as part of the public API.
- sample_params(key, shape)
- Return type:
dict
- to_array(tree)
Transforms a batch of parameter values into a 2D array suitable for GP input.
Applies each domain’s .transform() to its corresponding parameter values.
- Parameters:
tree (
dict) – A dictionary of parameter name → array of raw values.- Return type:
Array- Returns:
A JAX array of shape (batch_size, num_params) with transformed values.
- to_dict(xs)
Converts a stacked parameter matrix back into named parameter trees.
Typically used after optimization in transformed space.
- Parameters:
xs (
Array) – A 2D JAX array of shape (batch_size, num_params), with each column corresponding to a parameter.- Return type:
dict- Returns:
A dictionary mapping parameter names to individual 1D arrays.
- class bayex.domain.Real(lower, upper)
Bases:
DomainContinuous real-valued domain with clipping.
Represents a parameter that can take real values within [lower, upper].
- sample(key, shape)
Samples uniformly from the domain.
- Parameters:
key (
Array) – JAX PRNGKey.shape (
Tuple) – Desired output shape.
- Returns:
Sampled values clipped to the domain.
- transform(x)
Clips values to the domain range [lower, upper].
- Parameters:
x (
Array) – Input values.- Returns:
Clipped values within bounds.