Offline Training
offline_training
¶
Pretraining of a surrogate model on pseudo-random parameters.
Classes:
Name | Description |
---|---|
OfflineTrainingSettings |
Configuration of the offline training run. |
OfflineTrainer |
Class for pretraining of surrogate models. |
OfflineTrainingLogger |
Logger for information during pretraining. |
surrogate.offline_training.OfflineTrainingSettings
dataclass
¶
Configuration of the offline training run.
Attributes:
Name | Type | Description |
---|---|---|
num_offline_training_points |
int
|
Number of parameter samples to generate for training through Latin Hypercube Sampling. |
num_threads |
int
|
Number of parallel threads to use for pretraining. Only makes sense if the calls to the simulation model are dispatched to an actually parallel setup |
offline_model_config |
dict
|
Configuration of UMBridge calls to the simulation model server. |
lhs_bounds |
list
|
Dimension-wise bounds for the Latin Hypercube Sampling. |
lhs_seed |
list
|
Seed for the Latin Hypercube Sampling. |
checkpoint_save_name |
Path
|
Name of the checkpoint file to save the surrogate model and data to. |
surrogate.offline_training.OfflineTrainer
¶
Class for pretraining of surrogate models.
Implements simple pretraining without the asynchronous server. Input parameters are generated via Latin Hypercube Sampling on the domain of interest. Outputs are obtained from calls to a simulation model server.
Methods:
Name | Description |
---|---|
run |
Execute the pretraining. |
__init__
¶
__init__(
training_settings: OfflineTrainingSettings,
logger_settings: utilities.LoggerSettings,
surrogate_model: surrogate_model.BaseSurrogateModel,
simulation_model: Callable,
) -> None
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
training_settings
|
OfflineTrainingSettings
|
Configuration of the offline training run. |
required |
logger_settings
|
utilities.LoggerSettings
|
Configuration of the logger |
required |
surrogate_model
|
surrogate_model.BaseSurrogateModel
|
Surrogate model to train |
required |
simulation_model
|
Callable
|
Simulation model to request evaluations from to generate training data |
required |
surrogate.offline_training.OfflineTrainingLogger
¶
Bases: utilities.BaseLogger
Logger for information during pretraining.
The logger records to events: 1. Generation of training data sample (input and output) 2. Fitting of the surrogate
__init__
¶
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
logger_settings
|
utilities.LoggerSettings
|
Configuration of the logger |
required |
log_simulation_run
¶
Log information on generation of a training sample.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameter
|
float | Iterable
|
Input parameter |
required |
result
|
float
|
Simulation model result |
required |
log_surrogate_fit
¶
Log information on the fitting of the surrogate.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
scale
|
float
|
Log scale parameter of the surrogate model (for GPs) |
required |
correlation_length
|
float | Iterable
|
Log correlation length per dimension of the surrogate model (for GPs) |
required |