Surrogate Control
surrogate_control
¶
Surrogate control server for asynchronous requests and retraining.
Classes:
Name | Description |
---|---|
ControlSettings |
Configuration of the surrogate control |
CallInfo |
Logger Info object for control calls |
UpdateInfo |
Logger Info object for surrogate updates |
SurrogateControl |
Surrogate control server for asynchronous requests and retraining |
SurrogateLogger |
Logger for the surrogate control during runs |
surrogate.surrogate_control.ControlSettings
dataclass
¶
Configuration of the surrogate control.
Attributes:
Name | Type | Description |
---|---|---|
port |
str
|
Port to serve the UMBridge control server to, only used in |
name |
str
|
Name of the UMBridge control server |
minimum_num_training_points |
int
|
Number of training points that need to be provided before the surrogate is first used |
update_interval_rule |
Callable
|
Callable determining the number of training points, given the current number, after which the surrogate model is next retrained |
variance_threshold |
float
|
Threshold of the variance in the surrogate model (absolute or relative), below which the surrogate mean is used as predictor. Otherwise, a simulation model run is triggered |
overwrite_checkpoint |
bool
|
Whether to overwrite checkpoints. If not, the checkpoint names are ID-ed with increasing numbers. Defaults to True. |
surrogate.surrogate_control.CallInfo
dataclass
¶
Logger Info object for control calls.
Attributes:
Name | Type | Description |
---|---|---|
parameters |
list[list[float]]
|
Parameters the control was called with |
surrogate_result |
np.ndarray
|
Mean prediction from surrogate |
simulation_result |
list[list[float]]
|
Result requested from UMBridge simulation model server |
variance |
np.ndarray
|
Variance prediction from surrogate |
surrogate_used |
bool
|
Whether surrogate has been used for prediction |
num_training_points |
int
|
Overall number of training points generated so far |
surrogate.surrogate_control.UpdateInfo
dataclass
¶
Logger Info object for surrogate updates, meaning that new training data has been provided.
Attributes:
Name | Type | Description |
---|---|---|
new_fit |
bool
|
Whether the surrogate has been retrained with the new data |
num_updates |
int
|
Number of surrogate retrains performed so far |
next_update |
int
|
Number of training samples after which the next retrain is scheduled |
scale |
float
|
Scale parameter of the trained surrogate kernel (for GPs) |
correlation_length |
float | np.ndarray
|
Correlation length per dimension parameter of the trained surrogate kernel (for GPs) |
surrogate.surrogate_control.SurrogateControl
¶
Bases: ub.Model
Surrogate control server for asynchronous requests and retraining.
This is the main component of the surrogate workflow. The control server is an UM-Bridge server,
which can be called by a client to request evaluation of the surrogate for a given input
parameter set. Internally, the server connects to a simulation model, which is also assumed to
be an UM-Bridge server. A call to the server is dispatched either to the surrogate or the
simulation model, depending on the variance of the surrogate prediction. See the documentation
of the __call__
method for further details. In addition, the server hosts a background thread
that collects new training data whenerver the simulation model is invoked. The data is used to
retrain the surrogate model asynchronously. Synchronization between server requests, training
data collection, and surrogate retraining is ensured by threading locks.
Methods:
Name | Description |
---|---|
__call__ |
Call method according to UM-Bridge interface |
get_input_sizes |
UM-Bridge method to specify dimension of the input parameters |
get_output_sizes |
UM-Bridge method to specify dimension of the output |
supports_evaluate |
UM-Bridge flags indicating that the server can be called for evaluation |
update_surrogate_model_daemon |
Daemon thread for updating the surrogate model |
__init__
¶
__init__(
control_settings: ControlSettings,
logger_settings: utilities.LoggerSettings,
surrogate_model: surrogate_model.BaseSurrogateModel,
simulation_model: Callable,
) -> None
Constructor.
Initializes all data structures based on the provided configuration. Additionally, start a daemon process to asynchronously update the surrogate when new data is obtained from the simulation model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
control_settings
|
ControlSettings
|
Configuration of the control server |
required |
logger_settings
|
utilities.LoggerSettings
|
Configuration of the logger |
required |
surrogate_model
|
surrogate_model.BaseSurrogateModel
|
Surrogate model to be used |
required |
simulation_model
|
Callable
|
Simulation model to be used |
required |
get_input_sizes
¶
UMBridge method to specify dimension of the input parameters.
get_output_sizes
¶
UMBridge method to specify dimension of the output.
supports_evaluate
¶
UMBridge flags indicating that the server can be called for evaluation.
__call__
¶
Call method according to UMBridge interface.
An evaluation request for a given parameter set is requested as follows. Firstly, the surrogate is invoked, returning mean and variance for the estimation at the given parameter. If the variance is too large, the simulation model is invoked, and the result is returned along with a variance of zero. Otherwise, mean and variance of the surrogate prediction are returned. Whenever the simulation model is invoked, it automatically generates a new training sample for the surrogate. This sample is queued and used for retraining by the daemon thread of the control server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters
|
list[list[float]]
|
Parameter set for which evaluation is requested |
required |
config
|
dict[str, Any]
|
Configuration for the request, passed on to the simulation model, which is also assumed to be an UMBridge server |
required |
Returns:
Type | Description |
---|---|
list[list[float]]
|
list[list[float]]: Result of the request (surrogate or simulation result) in UMBridge format |
update_surrogate_model_daemon
¶
Daemon thread for updating the surrogate model.
The control server hosts a background or daemon thread. This thread checks if new training data has been generated by the simulation model and transferred to a specific update queue. The daemon thread scrapes the new data and retrains the surrogate if a sufficient number of samples, provided by the user-specified update rule, is available. Access to the queue and the surrogate object is synchronized with the processe's evaluation request via threading locks.
_init_surrogate_model_update_thread
¶
Start the daemon thread for surrogate updates.
_call_surrogate
¶
Invoke surrogate, synchronizing with daemon thread.
_retrain_surrogate
¶
Retrain surrogate, synchronizing with request thread.
_queue_training_data
¶
Insert new training data into update queue, synchronizing with daemon thread.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters
|
list[list[float]]
|
input of the data sample |
required |
result
|
list[list[float]]
|
output of the data sample |
required |
_tap_training_data
¶
Transfer training data from queue to surrogate, synchronizing with request thread.
surrogate.surrogate_control.SurrogateLogger
¶
Bases: utilities.BaseLogger
Logger for the surrogate control during runs.
The logger processes two types of events:
1. Request from a client, provided as a CallInfo
object
2. Update of the surrogate model, provided as an UpdateInfo
object
__init__
¶
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger_settings
|
utilities.LoggerSettings
|
Configuration of the logger |
required |
print_header
¶
Print info banner explaining the abbreviations used during logging.
log_control_call_info
¶
Log info from a call to the control server.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
call_info
|
CallInfo
|
CallInfo object to process |
required |
log_surrogate_update_info
¶
Log info from an update of the surrogate.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
update_info
|
UpdateInfo
|
UpdateInfo object to process |
required |