Forward Solver¶
solver
¶
Eikonax forward solver.
Classes:
Name | Description |
---|---|
SolverData |
Settings for the initialization of the Eikonax Solver. |
Solution |
Eikonax solution object, returned by the solver. |
Solver |
Eikonax solver class. |
eikonax.solver.SolverData
dataclass
¶
Settings for the initialization of the Eikonax Solver.
See the Forward Solver documentation for more detailed explanations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loop_type
|
str
|
Type of loop for iterations, options are 'jitted_for', 'jitted_while', 'nonjitted_while'. |
required |
max_value
|
Real
|
Maximum value for the initialization of the solution vector. |
required |
use_soft_update
|
bool
|
Flag for using soft minmax approximation for optimization parameters |
required |
softminmax_order
|
int | None
|
Order of the soft minmax approximation for optimization
parameters. Only required if |
required |
softminmax_cutoff
|
Real | None
|
Cutoff distance from [0,1] for the soft minmax function.
Only required if |
required |
max_num_iterations
|
int
|
Maximum number of iterations after which to terminate the solver. Required for all loop types |
required |
tolerance
|
Real
|
Absolute difference between iterates in supremum norm, after which to terminate solver. Required for while loop types |
None
|
log_interval
|
int
|
Iteration interval after which log info is written. Required for non-jitted while loop type. |
None
|
eikonax.solver.Solution
dataclass
¶
Eikonax solution object, returned by the solver.
See the Forward Solver documentation for more detailed explanations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
values
|
jax.Array
|
Actual solution vector. |
required |
num_iterations
|
int
|
Number of iterations performed in the solve. |
required |
tolerance
|
float | jax.Array
|
Tolerance from last two iterates, or entire tolerance history |
None
|
eikonax.solver.Solver
¶
Bases: eqx.Module
Eikonax solver class.
The solver class is the main component for computing the solution \(u\) of the Eikonal equation for given geometry \(\Omega\) of dimension \(d\), tensor field \(\mathbf{M}\), and initial sites \(\Gamma\),
On the discrete level, we assume that the eikonal equation is solved on a triangulation formed by \(N_V\) vertices and \(N_S\) associated triangles. This means that for a tensor field \(\mathbf{M}\in\mathbb{R}^{N_S\times d\times d}\), the solver computes a solution vector \(\mathbf{u}\in\mathbb{R}^{N_V}\). through iteraive updates
where global update function \(\mathbf{G}\) is derived from Godunov-type upwinding principles. The solver can either be run with a fixed number of iterations, or until a user-defined tolerance for the difference between two consecutive iterates in supremum norm is undercut.
The Eikonax solver works on the vertex level,
meaning that it considers updates from all adjacent triangles to a vertex, instead of all
updates for all vertices per triangle. This allows to establish causality in the final solution,
which is important for the efficient computation of parametric derivatives.
The solver class is mainly a wrapper around different loop constructs, which call vectorized
forms of the methods implemented in the corefunctions
module. These
loop constructs evolve around the loop functionality provided by JAX.
Methods:
Name | Description |
---|---|
run |
Main interface for Eikonax runs. |
__init__
¶
__init__(mesh_data: corefunctions.MeshData, solver_data: SolverData, initial_sites: corefunctions.InitialSites, logger: logging.Logger | None = None) -> None
Constructor of the solver class.
The constructor initializes all data structures that are re-used in many-query scenarios, such as the solution of inverse problems.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh_data
|
corefunctions.MeshData
|
Vertex-based mesh data. |
required |
solver_data
|
SolverData
|
Settings for the solver. |
required |
initial_sites
|
corefunctions.InitialSites
|
vertices and values for source points |
required |
logger
|
logging.Logger | None
|
Logger object, only required for non-jitted while loops. Defaults to None. |
None
|
run
¶
Main interface for conducting solver runs.
The method initializes the solution vector and dispatches to the run method for the selected loop type.
Note
The derivator expects the metric tensor field as used in the inner product for the
update stencil of the eikonal equation. This is the INVERSE of the conductivity
tensor, which is the actual tensor field in the eikonal equation. The
Tensorfield
component provides the inverse tensor
field.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor_field
|
jax.Array
|
Parameter field for which to solve the Eikonal equation. Provides an anisotropy tensor for each simplex of the mesh. |
required |
Raises:
Type | Description |
---|---|
ValueError
|
Checks that the chosen loop type is valid. |
Returns:
Name | Type | Description |
---|---|---|
Solution |
Solution
|
Eikonax solution object. |
_run_jitted_for_loop
¶
_run_jitted_for_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, float]
Solver run with jitted for loop for iterations.
The method constructs a JAX-type for loop with fixed number of iterations. For every
iteration, a new solution vector is computed from the
_compute_global_update
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
initial_guess
|
jax.Array
|
Initial solution vector |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, num_vertices], int, float]
|
tuple[jax.Array, int, float]: Solution values, number of iterations, tolerance |
_run_jitted_while_loop
¶
_run_jitted_while_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, float]
Solver run with jitted while loop for iterations.
The iterator is tolerance-based, terminating after a user-defined tolerance for the
difference between two consecutive iterates in supremum norm is undercut. For every
iteration, a new solution vector is computed from the
_compute_global_update
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
initial_guess
|
jax.Array
|
Initial solution vector |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Raises:
Type | Description |
---|---|
ValueError
|
Checks that tolerance has been provided by the user |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, num_vertices], int, float]
|
tuple[jax.Array, int, float]: Solution values, number of iterations, tolerance |
_run_nonjitted_while_loop
¶
_run_nonjitted_while_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, jtFloat[jax.Array, ...]]
Solver run with standard Python while loop for iterations.
While being less performant, the Python while loop allows for logging of infos between
iterations. The iterator is tolerance-based, terminating after a user-defined tolerance for
the difference between two consecutive iterates in supremum norm is undercut. For every
iteration, a new solution vector is computed from the
_compute_global_update
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
initial_guess
|
jax.Array
|
Initial solution vector |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Raises:
Type | Description |
---|---|
ValueError
|
Checks that tolerance has been provided by the user |
ValueError
|
Checks that log interval has been provided by the user |
ValueError
|
Checks that logger object has been provided by the user |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, num_vertices], int, jtFloat[jax.Array, ...]]
|
tuple[jax.Array, int, int]: Solution values, number of iterations, tolerance vector over all iterations |
_compute_global_update
¶
_compute_global_update(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> jtFloat[jax.Array, num_vertices]
Given a current state and tensor field, compute a new solution vector.
This method is basically a vectorized call to the
_compute_vertex_update
method, evaluated
over all vertices of the mesh.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Current state |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Returns:
Type | Description |
---|---|
jtFloat[jax.Array, num_vertices]
|
jax.Array: New iterate |
_compute_vertex_update
¶
_compute_vertex_update(old_solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> jtFloat[jax.Array, '']
Compute the update value for a single vertex.
This method links to the main logic of the solver routine, based on functions in the
corefunctions
module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
old_solution_vector
|
jax.Array
|
Current state |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
adjacency_data
|
jax.Array
|
Info on all adjacent triangles and respective vertices for the current vertex |
required |
Returns:
Type | Description |
---|---|
jtFloat[jax.Array, '']
|
jax.Array: Optimal update value for the current vertex |