Skip to content

Parametric Derivatives

derivator

Components for computing derivatives of the Eikonax solver.

This module contains two main components. Firstly, the PartialDerivator evaluates the partial derivatives of the global Eikonax update operator \(\mathbf{G}\) w.r.t. the parameter tensor field \(\mathbf{M}\) and the corresponding solution vector \(\mathbf{u}\) obtained from a forward solve. The DerivativeSolver component makes use of the fixed point/adjoint property of the Eikonax solver to evaluate total parametric derivatives.

Classes:

Name Description
PartialDerivatorData

Settings for initialization of partial derivator

PartialDerivator

Component for computing partial derivatives of the Godunov Update operator

DerivativeSolver

Main component for obtaining gradients from partial derivatives

eikonax.derivator.PartialDerivatorData dataclass

Settings for initialization of partial derivator.

See the Forward Solver documentation for more detailed explanations.

Attributes:

Name Type Description
softminmax_order int

Order of the the soft minmax function for differentiable transformation of the update parameters

softminmax_cutoff Real

Cut-off in for minmax transformation, beyond which zero sensitivity is assumed.

eikonax.derivator.PartialDerivator

Bases: eqx.Module

Component for computing partial derivatives of the Godunov Update operator.

Given a tensor field \(M\) and a solution vector \(u\), the partial derivator computes the partial derivatives of the global Eikonax update operator with respect to the solution vector, \(\mathbf{G}_u(\mathbf{u}, \mathbf{m})\), and the tensor field, \(\mathbf{G}_M(\mathbf{u}, \mathbf{M})\). All derivatives are computed on the vertex level, exploiting the locality of interactions in the update operator (only adjacent simplices are considered). Therefore, we can indeed assemble the complete derivative operators as parse data structures, not just Jacobian-vector or vector-Jacobian products, within a single pass over the computational mesh. Atomic functions on the vertex level are differentiated with Jax.

Info

For the computation of the derivatives, so-called 'self-updates' are disabled. These updates occur when a vertex does not receive a lower update value from any direction than the value it currently has. At the correct solution point, this case cannot occur due to the causality of the update stencil.

Methods:

Name Description
compute_partial_derivatives

Compute the partial derivatives of the Godunov update operator with respect to the solution vector and the parameter tensor field, given a state for both variables

__init__

__init__(mesh_data: corefunctions.MeshData, derivator_data: PartialDerivatorData, initial_sites: corefunctions.InitialSites) -> None

Constructor for the partial derivator object.

Parameters:

Name Type Description Default
mesh_data corefunctions.MeshData

Mesh data object also utilized for the Eikonax solver, contains adjacency data for every vertex.

required
derivator_data PartialDerivatorData

Settings for initialization of the derivator.

required
initial_sites corefunctions.InitialSites

Locations and values at source points

required

compute_partial_derivatives

compute_partial_derivatives(solution_vector: jtFloat[jax.Array | npt.NDArray, num_vertices], tensor_field: jtFloat[jax.Array | npt.NDArray, 'num_simplices dim dim']) -> tuple[tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]], tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]]

Compute the partial derivatives of the Godunov update operator.

This method provided the main interface for computing the partial derivatives of the global Eikonax update operator with respect to the solution vector and the parameter tensor field. The updates are computed locally for each vertex, such that the resulting data structures are sparse. Subsequently, further zero entries are removed to reduce the memory footprint. The derivatives computed in this component can be utilized to compute the total parametric derivative via a fix point equation, given that the provided solution vector is that fix point. The computation of partial derivatives is possible with a single pass over the mesh, since the solution of the Eikonax equation, and therefore causality within the Godunov update scheme, is known.

Note

The derivator expects the metric tensor field as used in the inner product for the update stencil of the eikonal equation. This is the INVERSE of the conductivity tensor, which is the actual tensor field in the eikonal equation. The Tensorfield component provides the inverse tensor field.

Parameters:

Name Type Description Default
solution_vector jax.Array

Current solution

required
tensor_field jax.Array

Parameter field

required

Returns:

Type Description
tuple[tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]], tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]]

tuple[tuple[jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array, jax.Array]]: Partial derivatives with respect to the solution vector and the parameter tensor field. Both quantities are given as arrays over all local contributions, making them sparse in the global context.

_compress_partial_derivative_solution

_compress_partial_derivative_solution(partial_derivative_solution: jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2']) -> tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]]

Compress the partial derivative data with respect to the solution vector.

Compression consists of two steps:

  1. Remove zero entries in the sensitivity vector
  2. Set the sensitivity vector to zero at the initial sites, but keep them for later computations.

Parameters:

Name Type Description Default
partial_derivative_solution jax.Array

Raw data from partial derivative computation, with shape (N, num_adjacent_simplices, 2), N depends on the number of identical update paths for the vertices in the mesh.

required

Returns:

Type Description
jtInt[jax.Array, num_sol_values]

tuple[jax.Array, jax.Array, jax.Array]: Compressed data, represented as rows,

jtInt[jax.Array, num_sol_values]

columns and values for initialization in sparse matrix. Shape depends on number of non-zero entries

_compress_partial_derivative_parameter

_compress_partial_derivative_parameter(partial_derivative_parameter: jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']) -> tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]

Compress the partial derivative data with respect to the parameter field.

Compression consists of two steps:

  1. Remove tensor components from the sensitivity data, if all entries are zero
  2. Set the sensitivity vector to zero at the initial sites, but keep them for later computations.

Parameters:

Name Type Description Default
partial_derivative_parameter jax.Array

Raw data from partial derivative computation, with shape (N, num_adjacent_simplices, dim, dim), N depends on the number of identical update paths for the vertices in the mesh.

required

Returns:

Type Description
tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]

tuple[jax.Array, jax.Array, jax.Array]: Compressed data, represented as rows, columns and values to be further processes for sparse matrix assembly. Shape depends on number of non-zero entries

_compute_global_partial_derivatives

_compute_global_partial_derivatives(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim']) -> tuple[jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']]

Compute partial derivatives of the global update operator.

The method is a jitted and vectorized call to the _compute_vertex_partial_derivative method.

Parameters:

Name Type Description Default
solution_vector jax.Array

Global solution vector

required
tensor_field jax.Array

Global parameter tensor field

required

Returns:

Type Description
tuple[jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']]

tuple[jax.Array, jax.Array]: Raw data for partial derivatives, with shapes (N, num_adjacent_simplices, 2) and (N, num_adjacent_simplices, dim, dim), N depends on the number of identical update paths for the vertices in the mesh.

_compute_vertex_partial_derivatives

_compute_vertex_partial_derivatives(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices dim dim']]

Compute partial derivatives for the update of a single vertex.

The method computes candidates for all respective subterms through calls to further methods. These candidates are filtered for feasibility by means of JAX filters. The sofmin function (and its gradient) is applied to the directions of all optimal updates to ensure differentiability, other contributions are discarded. Lasty, the evaluated contributions are combined according to the form of the "total differential" for the partial derivatives.

Parameters:

Name Type Description Default
solution_vector jax.Array

Global solution vector

required
tensor_field jax.Array

Global parameter tensor field

required
adjacency_data jax.Array

Adjacency data for the vertex under consideration

required

Returns:

Type Description
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices dim dim']]

tuple[jax.Array, jax.Array]: Partial derivatives for the given vertex

_compute_vertex_partial_derivative_candidates

_compute_vertex_partial_derivative_candidates(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]

Compute partial derivatives corresponding to potential update candidates for a vertex.

Update candidates and corresponding derivatives are computed for all adjacent simplices, and for all possible update parameters per simplex.

Parameters:

Name Type Description Default
solution_vector jax.Array

Global solution vector

required
tensor_field jax.Array

Global parameter field

required
adjacency_data jax.Array

Adjacency data for the given vertex

required

Returns:

Type Description
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]

tuple[jax.Array, jax.Array]: Candidates for partial derivatives

_compute_partial_derivative_candidates_from_adjacent_simplex

_compute_partial_derivative_candidates_from_adjacent_simplex(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 4]) -> tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]

Compute partial derivatives for all update candidates within an adjacent simplex.

The update candidates are evaluated according to the different candidates for the optimization parameters \(\lambda\). Contributions are combined to the form of the involved total differentials.

Parameters:

Name Type Description Default
solution_vector jax.Array

Global solution vector

required
tensor_field jax.Array

Flobal parameter field

required
adjacency_data jax.Array

Adjacency data for the given vertex and simplex

required

Returns:

Type Description
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]

tuple[jax.Array, jax.Array]: Derivative candidate from the given simplex

_filter_candidates staticmethod

_filter_candidates(vertex_update_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4'], grad_update_solution_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], grad_update_parameter_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']) -> tuple[jtFloat[jax.Array, ''], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]

Mask irrelevant derivative candidates so that they are discarded later.

Values are masked by setting them to zero or infinity, depending on the routine in which they are utilized later. Partial derivatives are only relevant if the corresponding update corresponds to an optimal path.

Parameters:

Name Type Description Default
vertex_update_candidates jax.Array

Update candidates for a given vertex

required
grad_update_solution_candidates jax.Array

Partial derivative candidates w.r.t. the solution vector

required
grad_update_parameter_candidates jax.Array

Partial derivative candidates w.r.t. the parameter field

required

Returns:

Type Description
tuple[jtFloat[jax.Array, ''], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]

tuple[jax.Array, jax.Array, jax.Array]: Optimal update value, masked partial derivatives

_compute_lambda_grad

_compute_lambda_grad(solution_values: jtFloat[jax.Array, 2], parameter_tensor: jtFloat[jax.Array, 'dim dim'], edges: tuple[jtFloat[jax.Array, dim], jtFloat[jax.Array, dim], jtFloat[jax.Array, dim]]) -> tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]

Compute the partial derivatives of update parameters for a single vertex.

This method evaluates the partial derivatives of the update parameters with respect to the current solution vector and the given parameter field, for a single triangle.

Parameters:

Name Type Description Default
solution_values jax.Array

Current solution values at the opposite vertices of the considered triangle

required
parameter_tensor jax.Array

Parameter tensor for the given triangle

required
edges tuple[jax.Array, jax.Array, jax.Array]

Edges of the considered triangle

required

Returns:

Type Description
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]

tuple[jax.Array, jax.Array]: Jacobians of the update parameters w.r.t. the solution vector and the parameter tensor

eikonax.derivator.DerivativeSolver

Main component for obtaining gradients from partial derivatives.

The Eikonax PartialDerivator computes partial derivatives of the global update operator with respect to the solution vector, \(\mathbf{G}_u\), and the parameter tensor field, \(\mathbf{G}_M\). Now we exploit the facto that the obtained solution candidate from a forward solve \(\mathbf{u}\in\mathbb{R}^{N_V}\) is, up to a given accuracy, is a fixed point of the global update operator. We further consider the scenario of \(\mathbf{M}(\mathbf{m})\) being dependent on some parameter \(\mathbf{m}\in\mathbb{R}^M\). This means we can write \(\mathbf{u}\) as a function of \(\mathbf{m}\), obeying the relation

\[ \mathbf{u}(\mathbf{m}) = \mathbf{G}(\mathbf{u}(\mathbf{m}), \mathbf{M}(\mathbf{m})) \]

To obtain the Jacobian \(\mathbf{J} = \frac{d\mathbf{u}}{d\mathbf{m}}\in\mathbb{R}^{N_V\times M}\), we simply differentiate the fixed point relation,

\[ \mathbf{J} = \mathbf{G}_u\mathbf{J} + \overbrace{\mathbf{G}_M\frac{d\mathbf{M}(\mathbf{m})}{d\mathbf{m}}}^{\mathbf{G}_m} \quad\Leftrightarrow\quad \mathbf{J} = (\mathbf{I} - \mathbf{G}_u)^{-1}\mathbf{G}_m \]

\(\mathbf{G}_u\) and \(\mathbf{G}_M\) are provided by the PartialDerivator, whereas \(\frac{d\mathbf{M}}{d\mathbf{m}}\) is computed as the Jacobian of the TensorField component.

We are typically not interested in the full Jacobian, but rather in the gradient of some cost functional \(l:\mathbb{R}^{N_V}\to\mathbb{R},\ l=l(\mathbf{u}(\mathbf{m}))\) with respect to \(\mathbf{m}\). The gradient is given as

\[ \mathbf{g}(\mathbf{m}) = \frac{d l}{d\mathbf{m}} = \overbrace{\frac{d l}{d\mathbf{u}}}^{l_u^T}\mathbf{J} = \overbrace{l_u^T(\mathbf{I} - \mathbf{G}_u)^{-1}}^{\mathbf{v}^T}\mathbf{G}_m. \]

We can identify \(\mathbf{v}\) as the adjoint variable, which is obtained by solving the linear discrete adjoint equation,

\[ (\mathbf{I} - \mathbf{G}_u)^T\mathbf{v} = l_u. \]

Now comes the catch: Through the strict causality of the Godunov update operator, we can find a unique and consistent ordering of vertex indices, such that the solution at a vertex \(i\) is only informed by the solution at a vertex \(j\), if \(j\) occurs before \(i\) in that ordering. The matrix \(\mathbf{G}_u\) has to form a directed, acyclic graph. This means that there is an orthogonal permutation matrix \(\mathbf{P}\) such that for \(\bar{\mathbf{G}}_u = \mathbf{P}\mathbf{G}_u\mathbf{P}^T\) an entry \((\bar{\mathbf{G}}_u)_{ij}\) is only non-zero if \(i > j\). In total, we can write

\[ \begin{align*} (\mathbf{I}-\mathbf{G}_u)^T\mathbf{v} = \mathbf{l}_u &\Leftrightarrow \mathbf{P}(\mathbf{I}-\mathbf{G}_u)^T\mathbf{P}^T \overbrace{\mathbf{P}\mathbf{v}}^{\bar{\mathbf{v}}} = \overbrace{\mathbf{P}\mathbf{l}_u}^{\bar{\mathbf{l}}_u} \nonumber \\ & \Leftrightarrow \overbrace{(\mathbf{I} - \bar{\mathbf{G}}_u)^T}^{\bar{\mathbf{A}}}\bar{v} =\bar{\mathbf{l}}_u \end{align*} \]

where \(\bar{\mathbf{A}}\) is an upper triangular matrix with unit diagonal. Hence, it is invertible through simple back-substitution.

The DerivativeSolver component does exactly this: It sets up the matrices \(\mathbf{P}\) and \(\bar{\mathbf{A}}\), permutates inputs/outputs, and solves the sparse linear system through back-substitution.

Speedy gradients

Given a solution vector \(\mathbf{u}\), Eikonax computes derivatives with linear complexity. Even more, for a given evaluation point, we can evaluate an arbitrary number of gradients through simple backsubstitution. All matrices need to be assembled only once.

Change in tooling

In the DerivativeSolver, we leave JAX and fall back to the numpy/scipy stack. While the sequential solver operation should not be mush slower on the CPU, we have to transfer the data back grom the offloading device. We plan to implement a GPU-compatible solver with CuPy in a future version, or in JAX as soon as it offers the necessary linear algebra tools.

Methods:

Name Description
solve

Solve the linear system for the adjoint variable

sparse_system_matrix property

sparse_system_matrix: sp.sparse.csc_matrix

Get system matrix \(\bar{\mathbf{A}}\in\mathbb{R}^{{N_V}\times {N_V}}\).

sparse_permutation_matrix property

sparse_permutation_matrix: sp.sparse.csc_matrix

Get permutation matrix \(\mathbf{P}\in\mathbb{R}^{{N_V}\times {N_V}}\).

__init__

__init__(solution: jtFloat[jax.Array | npt.NDArray, num_vertices], sparse_partial_update_solution: tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]]) -> None

Constructor for the derivative solver.

Initializes the causality-inspired permutation matrix \(\mathbf{P}\), and afterwards the permuted system matrix \(\bar{\mathbf{A}}\), which is triangular.

Parameters:

Name Type Description Default
solution jax.Array | npt.NDArray

Obtained solution of the Eikonal equation

required
sparse_partial_update_solution tuple[jax.Array, jax.Array, jax.Array]

Sparse representation of the partial derivative G_u, containing row inds, column inds and values. These structures might contain redundances, which are automatically removed through summation in the sparse matrix assembly later.

required

solve

solve(right_hand_side: jtFloat[jax.Array | npt.NDArray, num_vertices]) -> jtFloat[npt.NDArray, num_parameters]

Solve the linear system for the parametric gradient.

Following the notation from the class docstring,this method solves the linear system for the adjoint variable \(\mathbf{v}\). Given a right-hand-side \(\mathbf{l}_u\), this is a three- step process:

  1. Permute the right hand side \(\bar{l}_u = \mathbf{P}l_u\)
  2. Solve the linear system \(\bar{\mathbf{A}}\bar{\mathbf{v}} = \bar{l}_u\)
  3. Permute solution back to the original ordering \(\mathbf{v} = \mathbf{P}^T\bar{\mathbf{v}}\)

Parameters:

Name Type Description Default
right_hand_side jax.Array | npt.NDArray

RHS for the linear system solve

required

Returns:

Type Description
jtFloat[npt.NDArray, num_parameters]

np.ndarray: Solution of the linear system solve, corresponding to the adjoint in an optimization context.

_assemble_permutation_matrix

_assemble_permutation_matrix(solution: jtFloat[npt.NDArray, num_vertices]) -> sp.sparse.csc_matrix

Construct permutation matrix \(\mathbf{P}\) for index ordering.

Parameters:

Name Type Description Default
solution npt.NDArray

Obtained solution of the eikonal equation

required

Returns:

Type Description
sp.sparse.csc_matrix

sp.sparse.csc_matrix: Sparse permutation matrix

_assemble_system_matrix

_assemble_system_matrix(sparse_partial_update_solution: tuple[jtInt[npt.NDArray, num_sol_values], jtInt[npt.NDArray, num_sol_values], jtFloat[npt.NDArray, num_sol_values]], num_points: int) -> sp.sparse.csc_matrix

Assemble system matrix \(\bar{\mathbf{A}}\) for gradient solver.

Before invoking this method, the permutation matrix \(\mathbf{P}\) must be initialized.

Parameters:

Name Type Description Default
sparse_partial_update_solution tuple[npt.NDArray, npt.NDArray, npt.NDArray]

Sparse representation of the partial derivative \(G_u\), containing row inds, column inds and values. These structures might contain redundances, which are automatically removed through summation in the sparse matrix assembly.

required
num_points int

Number of mesh points

required

Returns:

Type Description
sp.sparse.csc_matrix

sp.sparse.csc_matrix: Sparse representation of the permuted system matrix

eikonax.derivator.compute_eikonax_jacobian

compute_eikonax_jacobian(derivative_solver: DerivativeSolver, partial_derivative_parameter: sp.sparse.coo_matrix) -> npt.NDArray

Compute Jacobian from concatenation of gradients, computed with unit vector RHS.

Warning

This method should only be used for small problems.

Parameters:

Name Type Description Default
derivative_solver DerivativeSolver

Initialized derivative solver object

required
partial_derivative_parameter sp.sparse.coo_matrix

Partial derivative of the global update operator with respect to the parameter tensor field

required

Returns:

Type Description
npt.NDArray

npt.NDArray: (Dense) Jacobian matrix

eikonax.derivator.compute_eikonax_hessian

compute_eikonax_hessian() -> None

Compute Hessian matrix.

Not implemented yet