Skip to content

Forward Solver

solver

Eikonax forward solver.

Classes:

Name Description
SolverData

Settings for the initialization of the Eikonax Solver.

Solution

Eikonax solution object, returned by the solver.

Solver

Eikonax solver class.

eikonax.solver.SolverData dataclass

Settings for the initialization of the Eikonax Solver.

See the Forward Solver documentation for more detailed explanations.

Parameters:

Name Type Description Default
loop_type str

Type of loop for iterations, options are 'jitted_for', 'jitted_while', 'nonjitted_while'.

required
max_value Real

Maximum value for the initialization of the solution vector.

required
use_soft_update bool

Flag for using soft minmax approximation for optimization parameters

required
softminmax_order int | None

Order of the soft minmax approximation for optimization parameters. Only required if use_soft_update is True.

required
softminmax_cutoff Real | None

Cutoff distance from [0,1] for the soft minmax function. Only required if use_soft_update is True.

required
max_num_iterations int

Maximum number of iterations after which to terminate the solver. Required for all loop types

required
tolerance Real

Absolute difference between iterates in supremum norm, after which to terminate solver. Required for while loop types

None
log_interval int

Iteration interval after which log info is written. Required for non-jitted while loop type.

None

eikonax.solver.Solution dataclass

Eikonax solution object, returned by the solver.

See the Forward Solver documentation for more detailed explanations.

Parameters:

Name Type Description Default
values jax.Array

Actual solution vector.

required
num_iterations int

Number of iterations performed in the solve.

required
tolerance float | jax.Array

Tolerance from last two iterates, or entire tolerance history

None

eikonax.solver.Solver

Bases: eqx.Module

Eikonax solver class.

The solver class is the main component for computing the solution \(u\) of the Eikonal equation for given geometry \(\Omega\) of dimension \(d\), tensor field \(\mathbf{M}\), and initial sites \(\Gamma\),

\[ \begin{gather*} \sqrt{\big(\nabla u(\mathbf{x}),\mathbf{M}(\mathbf{x})\nabla u(\mathbf{x})\big)} = 1,\quad \mathbf{x}\in\Omega, \\ \nabla u(\mathbf{x}) \cdot \mathbf{n}(\mathbf{x}) \geq 0,\quad \mathbf{x}\in\partial\Omega, \\ u(\mathbf{x}_0) = u_0,\quad \mathbf{x}_0 \in \Gamma. \end{gather*} \]

On the discrete level, we assume that the eikonal equation is solved on a triangulation formed by \(N_V\) vertices and \(N_S\) associated triangles. This means that for a tensor field \(\mathbf{M}\in\mathbb{R}^{N_S\times d\times d}\), the solver computes a solution vector \(\mathbf{u}\in\mathbb{R}^{N_V}\). through iteraive updates

\[ \mathbf{u}^{(j+1)} = \mathbf{G}(\mathbf{u}^{(j)}, \mathbf{M}), \]

where global update function \(\mathbf{G}\) is derived from Godunov-type upwinding principles. The solver can either be run with a fixed number of iterations, or until a user-defined tolerance for the difference between two consecutive iterates in supremum norm is undercut.

The Eikonax solver works on the vertex level, meaning that it considers updates from all adjacent triangles to a vertex, instead of all updates for all vertices per triangle. This allows to establish causality in the final solution, which is important for the efficient computation of parametric derivatives. The solver class is mainly a wrapper around different loop constructs, which call vectorized forms of the methods implemented in the corefunctions module. These loop constructs evolve around the loop functionality provided by JAX.

Methods:

Name Description
run

Main interface for Eikonax runs.

__init__

__init__(mesh_data: corefunctions.MeshData, solver_data: SolverData, initial_sites: corefunctions.InitialSites, logger: logging.Logger | None = None) -> None

Constructor of the solver class.

The constructor initializes all data structures that are re-used in many-query scenarios, such as the solution of inverse problems.

Parameters:

Name Type Description Default
mesh_data corefunctions.MeshData

Vertex-based mesh data.

required
solver_data SolverData

Settings for the solver.

required
initial_sites corefunctions.InitialSites

vertices and values for source points

required
logger logging.Logger | None

Logger object, only required for non-jitted while loops. Defaults to None.

None

run

run(tensor_field: jtFloat[jax.Array | npt.NDArray, 'num_simplices dim dim']) -> Solution

Main interface for conducting solver runs.

The method initializes the solution vector and dispatches to the run method for the selected loop type.

Note

The derivator expects the metric tensor field as used in the inner product for the update stencil of the eikonal equation. This is the INVERSE of the conductivity tensor, which is the actual tensor field in the eikonal equation. The Tensorfield component provides the inverse tensor field.

Parameters:

Name Type Description Default
tensor_field jax.Array

Parameter field for which to solve the Eikonal equation. Provides an anisotropy tensor for each simplex of the mesh.

required

Raises:

Type Description
ValueError

Checks that the chosen loop type is valid.

Returns:

Name Type Description
Solution Solution

Eikonax solution object.

_run_jitted_for_loop

_run_jitted_for_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, float]

Solver run with jitted for loop for iterations.

The method constructs a JAX-type for loop with fixed number of iterations. For every iteration, a new solution vector is computed from the _compute_global_update method.

Parameters:

Name Type Description Default
initial_guess jax.Array

Initial solution vector

required
tensor_field jax.Array

Parameter field

required

Returns:

Type Description
tuple[jtFloat[jax.Array, num_vertices], int, float]

tuple[jax.Array, int, float]: Solution values, number of iterations, tolerance

_run_jitted_while_loop

_run_jitted_while_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, float]

Solver run with jitted while loop for iterations.

The iterator is tolerance-based, terminating after a user-defined tolerance for the difference between two consecutive iterates in supremum norm is undercut. For every iteration, a new solution vector is computed from the _compute_global_update method.

Parameters:

Name Type Description Default
initial_guess jax.Array

Initial solution vector

required
tensor_field jax.Array

Parameter field

required

Raises:

Type Description
ValueError

Checks that tolerance has been provided by the user

Returns:

Type Description
tuple[jtFloat[jax.Array, num_vertices], int, float]

tuple[jax.Array, int, float]: Solution values, number of iterations, tolerance

_run_nonjitted_while_loop

_run_nonjitted_while_loop(initial_guess: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> tuple[jtFloat[jax.Array, num_vertices], int, jtFloat[jax.Array, ...]]

Solver run with standard Python while loop for iterations.

While being less performant, the Python while loop allows for logging of infos between iterations. The iterator is tolerance-based, terminating after a user-defined tolerance for the difference between two consecutive iterates in supremum norm is undercut. For every iteration, a new solution vector is computed from the _compute_global_update method.

Parameters:

Name Type Description Default
initial_guess jax.Array

Initial solution vector

required
tensor_field jax.Array

Parameter field

required

Raises:

Type Description
ValueError

Checks that tolerance has been provided by the user

ValueError

Checks that log interval has been provided by the user

ValueError

Checks that logger object has been provided by the user

Returns:

Type Description
tuple[jtFloat[jax.Array, num_vertices], int, jtFloat[jax.Array, ...]]

tuple[jax.Array, int, int]: Solution values, number of iterations, tolerance vector over all iterations

_compute_global_update

_compute_global_update(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim']) -> jtFloat[jax.Array, num_vertices]

Given a current state and tensor field, compute a new solution vector.

This method is basically a vectorized call to the _compute_vertex_update method, evaluated over all vertices of the mesh.

Parameters:

Name Type Description Default
solution_vector jax.Array

Current state

required
tensor_field jax.Array

Parameter field

required

Returns:

Type Description
jtFloat[jax.Array, num_vertices]

jax.Array: New iterate

_compute_vertex_update

_compute_vertex_update(old_solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_vertices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> jtFloat[jax.Array, '']

Compute the update value for a single vertex.

This method links to the main logic of the solver routine, based on functions in the corefunctions module.

Parameters:

Name Type Description Default
old_solution_vector jax.Array

Current state

required
tensor_field jax.Array

Parameter field

required
adjacency_data jax.Array

Info on all adjacent triangles and respective vertices for the current vertex

required

Returns:

Type Description
jtFloat[jax.Array, '']

jax.Array: Optimal update value for the current vertex