Skip to content

Surrogate Control

surrogate_control

Surrogate control server for asynchronous requests and retraining.

Classes:

Name Description
ControlSettings

Configuration of the surrogate control

CallInfo

Logger Info object for control calls

UpdateInfo

Logger Info object for surrogate updates

SurrogateControl

Surrogate control server for asynchronous requests and retraining

SurrogateLogger

Logger for the surrogate control during runs

surrogate.surrogate_control.ControlSettings dataclass

Configuration of the surrogate control.

Attributes:

Name Type Description
port str

Port to serve the UMBridge control server to, only used in run_server script

name str

Name of the UMBridge control server

minimum_num_training_points int

Number of training points that need to be provided before the surrogate is first used

update_interval_rule Callable

Callable determining the number of training points, given the current number, after which the surrogate model is next retrained

variance_threshold float

Threshold of the variance in the surrogate model (absolute or relative), below which the surrogate mean is used as predictor. Otherwise, a simulation model run is triggered

overwrite_checkpoint bool

Whether to overwrite checkpoints. If not, the checkpoint names are ID-ed with increasing numbers. Defaults to True.

surrogate.surrogate_control.CallInfo dataclass

Logger Info object for control calls.

Attributes:

Name Type Description
parameters list[list[float]]

Parameters the control was called with

surrogate_result np.ndarray

Mean prediction from surrogate

simulation_result list[list[float]]

Result requested from UMBridge simulation model server

variance np.ndarray

Variance prediction from surrogate

surrogate_used bool

Whether surrogate has been used for prediction

num_training_points int

Overall number of training points generated so far

surrogate.surrogate_control.UpdateInfo dataclass

Logger Info object for surrogate updates, meaning that new training data has been provided.

Attributes:

Name Type Description
new_fit bool

Whether the surrogate has been retrained with the new data

num_updates int

Number of surrogate retrains performed so far

next_update int

Number of training samples after which the next retrain is scheduled

scale float

Scale parameter of the trained surrogate kernel (for GPs)

correlation_length float | np.ndarray

Correlation length per dimension parameter of the trained surrogate kernel (for GPs)

surrogate.surrogate_control.SurrogateControl

Bases: ub.Model

Surrogate control server for asynchronous requests and retraining.

This is the main component of the surrogate workflow. The control server is an UM-Bridge server, which can be called by a client to request evaluation of the surrogate for a given input parameter set. Internally, the server connects to a simulation model, which is also assumed to be an UM-Bridge server. A call to the server is dispatched either to the surrogate or the simulation model, depending on the variance of the surrogate prediction. See the documentation of the __call__ method for further details. In addition, the server hosts a background thread that collects new training data whenerver the simulation model is invoked. The data is used to retrain the surrogate model asynchronously. Synchronization between server requests, training data collection, and surrogate retraining is ensured by threading locks.

Methods:

Name Description
__call__

Call method according to UM-Bridge interface

get_input_sizes

UM-Bridge method to specify dimension of the input parameters

get_output_sizes

UM-Bridge method to specify dimension of the output

supports_evaluate

UM-Bridge flags indicating that the server can be called for evaluation

update_surrogate_model_daemon

Daemon thread for updating the surrogate model

__init__

__init__(
    control_settings: ControlSettings,
    logger_settings: utilities.LoggerSettings,
    surrogate_model: surrogate_model.BaseSurrogateModel,
    simulation_model: Callable,
) -> None

Constructor.

Initializes all data structures based on the provided configuration. Additionally, start a daemon process to asynchronously update the surrogate when new data is obtained from the simulation model.

Parameters:

Name Type Description Default
control_settings ControlSettings

Configuration of the control server

required
logger_settings utilities.LoggerSettings

Configuration of the logger

required
surrogate_model surrogate_model.BaseSurrogateModel

Surrogate model to be used

required
simulation_model Callable

Simulation model to be used

required

get_input_sizes

get_input_sizes(_config: dict[str, Any]) -> list[int]

UMBridge method to specify dimension of the input parameters.

get_output_sizes

get_output_sizes(_config: dict[str, Any]) -> list[int]

UMBridge method to specify dimension of the output.

supports_evaluate

supports_evaluate() -> bool

UMBridge flags indicating that the server can be called for evaluation.

__call__

__call__(parameters: list[list[float]], config: dict[str, Any]) -> list[list[float]]

Call method according to UMBridge interface.

An evaluation request for a given parameter set is requested as follows. Firstly, the surrogate is invoked, returning mean and variance for the estimation at the given parameter. If the variance is too large, the simulation model is invoked, and the result is returned along with a variance of zero. Otherwise, mean and variance of the surrogate prediction are returned. Whenever the simulation model is invoked, it automatically generates a new training sample for the surrogate. This sample is queued and used for retraining by the daemon thread of the control server.

Parameters:

Name Type Description Default
parameters list[list[float]]

Parameter set for which evaluation is requested

required
config dict[str, Any]

Configuration for the request, passed on to the simulation model, which is also assumed to be an UMBridge server

required

Returns:

Type Description
list[list[float]]

list[list[float]]: Result of the request (surrogate or simulation result) in UMBridge format

update_surrogate_model_daemon

update_surrogate_model_daemon() -> None

Daemon thread for updating the surrogate model.

The control server hosts a background or daemon thread. This thread checks if new training data has been generated by the simulation model and transferred to a specific update queue. The daemon thread scrapes the new data and retrains the surrogate if a sufficient number of samples, provided by the user-specified update rule, is available. Access to the queue and the surrogate object is synchronized with the processe's evaluation request via threading locks.

_init_surrogate_model_update_thread

_init_surrogate_model_update_thread() -> threading.Thread

Start the daemon thread for surrogate updates.

_call_surrogate

_call_surrogate(parameters: list[list[float]]) -> np.ndarray

Invoke surrogate, synchronizing with daemon thread.

_retrain_surrogate

_retrain_surrogate() -> None

Retrain surrogate, synchronizing with request thread.

_queue_training_data

_queue_training_data(parameters: list[list[float]], result: list[list[float]]) -> None

Insert new training data into update queue, synchronizing with daemon thread.

Parameters:

Name Type Description Default
parameters list[list[float]]

input of the data sample

required
result list[list[float]]

output of the data sample

required

_tap_training_data

_tap_training_data() -> None

Transfer training data from queue to surrogate, synchronizing with request thread.

_get_checkpoint_id

_get_checkpoint_id() -> int

Get ID for a checkpoint.

surrogate.surrogate_control.SurrogateLogger

Bases: utilities.BaseLogger

Logger for the surrogate control during runs.

The logger processes two types of events: 1. Request from a client, provided as a CallInfo object 2. Update of the surrogate model, provided as an UpdateInfo object

__init__

__init__(logger_settings: utilities.LoggerSettings) -> None

Constructor.

Parameters:

Name Type Description Default
logger_settings utilities.LoggerSettings

Configuration of the logger

required

print_header

print_header() -> None

Print info banner explaining the abbreviations used during logging.

log_control_call_info

log_control_call_info(call_info: CallInfo) -> None

Log info from a call to the control server.

Parameters:

Name Type Description Default
call_info CallInfo

CallInfo object to process

required

log_surrogate_update_info

log_surrogate_update_info(update_info: UpdateInfo) -> None

Log info from an update of the surrogate.

Parameters:

Name Type Description Default
update_info UpdateInfo

UpdateInfo object to process

required