Parametric Derivatives¶
derivator
¶
Components for computing derivatives of the Eikonax solver.
This module contains two main components. Firstly, the
PartialDerivator
evaluates the partial derivatives of the
global Eikonax update operator \(\mathbf{G}\) w.r.t. the parameter tensor field \(\mathbf{M}\) and the
corresponding solution vector \(\mathbf{u}\) obtained from a forward solve. The
DerivativeSolver
component makes use of the fixed
point/adjoint property of the Eikonax solver to evaluate total parametric derivatives.
Classes:
Name | Description |
---|---|
PartialDerivatorData |
Settings for initialization of partial derivator |
PartialDerivator |
Component for computing partial derivatives of the Godunov Update operator |
DerivativeSolver |
Main component for obtaining gradients from partial derivatives |
eikonax.derivator.PartialDerivatorData
dataclass
¶
Settings for initialization of partial derivator.
See the Forward Solver documentation for more detailed explanations.
Attributes:
Name | Type | Description |
---|---|---|
softminmax_order |
int
|
Order of the the soft minmax function for differentiable transformation of the update parameters |
softminmax_cutoff |
Real
|
Cut-off in for minmax transformation, beyond which zero sensitivity is assumed. |
eikonax.derivator.PartialDerivator
¶
Bases: eqx.Module
Component for computing partial derivatives of the Godunov Update operator.
Given a tensor field \(M\) and a solution vector \(u\), the partial derivator computes the partial derivatives of the global Eikonax update operator with respect to the solution vector, \(\mathbf{G}_u(\mathbf{u}, \mathbf{m})\), and the tensor field, \(\mathbf{G}_M(\mathbf{u}, \mathbf{M})\). All derivatives are computed on the vertex level, exploiting the locality of interactions in the update operator (only adjacent simplices are considered). Therefore, we can indeed assemble the complete derivative operators as parse data structures, not just Jacobian-vector or vector-Jacobian products, within a single pass over the computational mesh. Atomic functions on the vertex level are differentiated with Jax.
Info
For the computation of the derivatives, so-called 'self-updates' are disabled. These updates occur when a vertex does not receive a lower update value from any direction than the value it currently has. At the correct solution point, this case cannot occur due to the causality of the update stencil.
Methods:
Name | Description |
---|---|
compute_partial_derivatives |
Compute the partial derivatives of the Godunov update operator with respect to the solution vector and the parameter tensor field, given a state for both variables |
__init__
¶
__init__(mesh_data: corefunctions.MeshData, derivator_data: PartialDerivatorData, initial_sites: corefunctions.InitialSites) -> None
Constructor for the partial derivator object.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh_data
|
corefunctions.MeshData
|
Mesh data object also utilized for the Eikonax solver, contains adjacency data for every vertex. |
required |
derivator_data
|
PartialDerivatorData
|
Settings for initialization of the derivator. |
required |
initial_sites
|
corefunctions.InitialSites
|
Locations and values at source points |
required |
compute_partial_derivatives
¶
compute_partial_derivatives(solution_vector: jtFloat[jax.Array | npt.NDArray, num_vertices], tensor_field: jtFloat[jax.Array | npt.NDArray, 'num_simplices dim dim']) -> tuple[tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]], tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]]
Compute the partial derivatives of the Godunov update operator.
This method provided the main interface for computing the partial derivatives of the global Eikonax update operator with respect to the solution vector and the parameter tensor field. The updates are computed locally for each vertex, such that the resulting data structures are sparse. Subsequently, further zero entries are removed to reduce the memory footprint. The derivatives computed in this component can be utilized to compute the total parametric derivative via a fix point equation, given that the provided solution vector is that fix point. The computation of partial derivatives is possible with a single pass over the mesh, since the solution of the Eikonax equation, and therefore causality within the Godunov update scheme, is known.
Note
The derivator expects the metric tensor field as used in the inner product for the
update stencil of the eikonal equation. This is the INVERSE of the conductivity
tensor, which is the actual tensor field in the eikonal equation. The
Tensorfield
component provides the inverse tensor
field.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Current solution |
required |
tensor_field
|
jax.Array
|
Parameter field |
required |
Returns:
Type | Description |
---|---|
tuple[tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]], tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]]
|
tuple[tuple[jax.Array, jax.Array, jax.Array], tuple[jax.Array, jax.Array, jax.Array]]: Partial derivatives with respect to the solution vector and the parameter tensor field. Both quantities are given as arrays over all local contributions, making them sparse in the global context. |
_compress_partial_derivative_solution
¶
_compress_partial_derivative_solution(partial_derivative_solution: jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2']) -> tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]]
Compress the partial derivative data with respect to the solution vector.
Compression consists of two steps:
- Remove zero entries in the sensitivity vector
- Set the sensitivity vector to zero at the initial sites, but keep them for later computations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
partial_derivative_solution
|
jax.Array
|
Raw data from partial derivative computation,
with shape |
required |
Returns:
Type | Description |
---|---|
jtInt[jax.Array, num_sol_values]
|
tuple[jax.Array, jax.Array, jax.Array]: Compressed data, represented as rows, |
jtInt[jax.Array, num_sol_values]
|
columns and values for initialization in sparse matrix. Shape depends on number of non-zero entries |
_compress_partial_derivative_parameter
¶
_compress_partial_derivative_parameter(partial_derivative_parameter: jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']) -> tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]
Compress the partial derivative data with respect to the parameter field.
Compression consists of two steps:
- Remove tensor components from the sensitivity data, if all entries are zero
- Set the sensitivity vector to zero at the initial sites, but keep them for later computations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
partial_derivative_parameter
|
jax.Array
|
Raw data from partial derivative
computation, with shape |
required |
Returns:
Type | Description |
---|---|
tuple[jtInt[jax.Array, num_param_values], jtInt[jax.Array, num_param_values], jtFloat[jax.Array, 'num_param_values dim dim']]
|
tuple[jax.Array, jax.Array, jax.Array]: Compressed data, represented as rows, columns and values to be further processes for sparse matrix assembly. Shape depends on number of non-zero entries |
_compute_global_partial_derivatives
¶
_compute_global_partial_derivatives(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim']) -> tuple[jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']]
Compute partial derivatives of the global update operator.
The method is a jitted and vectorized call to the
_compute_vertex_partial_derivative
method.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter tensor field |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'num_vertices max_num_adjacent_simplices dim dim']]
|
tuple[jax.Array, jax.Array]: Raw data for partial derivatives, with shapes (N, num_adjacent_simplices, 2) and (N, num_adjacent_simplices, dim, dim), N depends on the number of identical update paths for the vertices in the mesh. |
_compute_vertex_partial_derivatives
¶
_compute_vertex_partial_derivatives(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices dim dim']]
Compute partial derivatives for the update of a single vertex.
The method computes candidates for all respective subterms through calls to further methods. These candidates are filtered for feasibility by means of JAX filters. The sofmin function (and its gradient) is applied to the directions of all optimal updates to ensure differentiability, other contributions are discarded. Lasty, the evaluated contributions are combined according to the form of the "total differential" for the partial derivatives.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter tensor field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the vertex under consideration |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices dim dim']]
|
tuple[jax.Array, jax.Array]: Partial derivatives for the given vertex |
_compute_vertex_partial_derivative_candidates
¶
_compute_vertex_partial_derivative_candidates(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 'max_num_adjacent_simplices 4']) -> tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
Compute partial derivatives corresponding to potential update candidates for a vertex.
Update candidates and corresponding derivatives are computed for all adjacent simplices, and for all possible update parameters per simplex.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Global parameter field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the given vertex |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
|
tuple[jax.Array, jax.Array]: Candidates for partial derivatives |
_compute_partial_derivative_candidates_from_adjacent_simplex
¶
_compute_partial_derivative_candidates_from_adjacent_simplex(solution_vector: jtFloat[jax.Array, num_vertices], tensor_field: jtFloat[jax.Array, 'num_simplices dim dim'], adjacency_data: jtInt[jax.Array, 4]) -> tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
Compute partial derivatives for all update candidates within an adjacent simplex.
The update candidates are evaluated according to the different candidates for the optimization parameters \(\lambda\). Contributions are combined to the form of the involved total differentials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_vector
|
jax.Array
|
Global solution vector |
required |
tensor_field
|
jax.Array
|
Flobal parameter field |
required |
adjacency_data
|
jax.Array
|
Adjacency data for the given vertex and simplex |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
|
tuple[jax.Array, jax.Array]: Derivative candidate from the given simplex |
_filter_candidates
staticmethod
¶
_filter_candidates(vertex_update_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4'], grad_update_solution_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], grad_update_parameter_candidates: jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']) -> tuple[jtFloat[jax.Array, ''], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
Mask irrelevant derivative candidates so that they are discarded later.
Values are masked by setting them to zero or infinity, depending on the routine in which they are utilized later. Partial derivatives are only relevant if the corresponding update corresponds to an optimal path.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
vertex_update_candidates
|
jax.Array
|
Update candidates for a given vertex |
required |
grad_update_solution_candidates
|
jax.Array
|
Partial derivative candidates w.r.t. the solution vector |
required |
grad_update_parameter_candidates
|
jax.Array
|
Partial derivative candidates w.r.t. the parameter field |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, ''], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 2'], jtFloat[jax.Array, 'max_num_adjacent_simplices 4 dim dim']]
|
tuple[jax.Array, jax.Array, jax.Array]: Optimal update value, masked partial derivatives |
_compute_lambda_grad
¶
_compute_lambda_grad(solution_values: jtFloat[jax.Array, 2], parameter_tensor: jtFloat[jax.Array, 'dim dim'], edges: tuple[jtFloat[jax.Array, dim], jtFloat[jax.Array, dim], jtFloat[jax.Array, dim]]) -> tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
Compute the partial derivatives of update parameters for a single vertex.
This method evaluates the partial derivatives of the update parameters with respect to the current solution vector and the given parameter field, for a single triangle.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution_values
|
jax.Array
|
Current solution values at the opposite vertices of the considered triangle |
required |
parameter_tensor
|
jax.Array
|
Parameter tensor for the given triangle |
required |
edges
|
tuple[jax.Array, jax.Array, jax.Array]
|
Edges of the considered triangle |
required |
Returns:
Type | Description |
---|---|
tuple[jtFloat[jax.Array, '4 2'], jtFloat[jax.Array, '4 dim dim']]
|
tuple[jax.Array, jax.Array]: Jacobians of the update parameters w.r.t. the solution vector and the parameter tensor |
eikonax.derivator.DerivativeSolver
¶
Main component for obtaining gradients from partial derivatives.
The Eikonax PartialDerivator
computes partial
derivatives of the global update operator with respect
to the solution vector, \(\mathbf{G}_u\), and the parameter tensor field, \(\mathbf{G}_M\).
Now we exploit the facto that the obtained solution candidate from a forward solve
\(\mathbf{u}\in\mathbb{R}^{N_V}\) is, up to a given accuracy, is a fixed point of the
global update operator. We further consider the scenario of \(\mathbf{M}(\mathbf{m})\) being
dependent on some parameter \(\mathbf{m}\in\mathbb{R}^M\). This means we can write \(\mathbf{u}\) as
a function of \(\mathbf{m}\), obeying the relation
To obtain the Jacobian \(\mathbf{J} = \frac{d\mathbf{u}}{d\mathbf{m}}\in\mathbb{R}^{N_V\times M}\), we simply differentiate the fixed point relation,
\(\mathbf{G}_u\) and \(\mathbf{G}_M\) are provided by the
PartialDerivator
, whereas
\(\frac{d\mathbf{M}}{d\mathbf{m}}\) is computed as the Jacobian of the
TensorField
component.
We are typically not interested in the full Jacobian, but rather in the gradient of some cost functional \(l:\mathbb{R}^{N_V}\to\mathbb{R},\ l=l(\mathbf{u}(\mathbf{m}))\) with respect to \(\mathbf{m}\). The gradient is given as
We can identify \(\mathbf{v}\) as the adjoint variable, which is obtained by solving the linear discrete adjoint equation,
Now comes the catch: Through the strict causality of the Godunov update operator, we can find a unique and consistent ordering of vertex indices, such that the solution at a vertex \(i\) is only informed by the solution at a vertex \(j\), if \(j\) occurs before \(i\) in that ordering. The matrix \(\mathbf{G}_u\) has to form a directed, acyclic graph. This means that there is an orthogonal permutation matrix \(\mathbf{P}\) such that for \(\bar{\mathbf{G}}_u = \mathbf{P}\mathbf{G}_u\mathbf{P}^T\) an entry \((\bar{\mathbf{G}}_u)_{ij}\) is only non-zero if \(i > j\). In total, we can write
where \(\bar{\mathbf{A}}\) is an upper triangular matrix with unit diagonal. Hence, it is invertible through simple back-substitution.
The DerivativeSolver
component does exactly this: It sets up the matrices \(\mathbf{P}\) and
\(\bar{\mathbf{A}}\), permutates inputs/outputs, and solves the sparse linear system through
back-substitution.
Speedy gradients
Given a solution vector \(\mathbf{u}\), Eikonax computes derivatives with linear complexity. Even more, for a given evaluation point, we can evaluate an arbitrary number of gradients through simple backsubstitution. All matrices need to be assembled only once.
Change in tooling
In the DerivativeSolver
, we leave JAX and fall back to the numpy
/scipy
stack. While
the sequential solver operation should not be mush slower on the CPU, we have to transfer
the data back grom the offloading device. We plan to implement a GPU-compatible solver with
CuPy in a future version, or in JAX as soon as it offers the necessary
linear algebra tools.
Methods:
Name | Description |
---|---|
solve |
Solve the linear system for the adjoint variable |
sparse_system_matrix
property
¶
Get system matrix \(\bar{\mathbf{A}}\in\mathbb{R}^{{N_V}\times {N_V}}\).
sparse_permutation_matrix
property
¶
Get permutation matrix \(\mathbf{P}\in\mathbb{R}^{{N_V}\times {N_V}}\).
__init__
¶
__init__(solution: jtFloat[jax.Array | npt.NDArray, num_vertices], sparse_partial_update_solution: tuple[jtInt[jax.Array, num_sol_values], jtInt[jax.Array, num_sol_values], jtFloat[jax.Array, num_sol_values]]) -> None
Constructor for the derivative solver.
Initializes the causality-inspired permutation matrix \(\mathbf{P}\), and afterwards the permuted system matrix \(\bar{\mathbf{A}}\), which is triangular.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution
|
jax.Array | npt.NDArray
|
Obtained solution of the Eikonal equation |
required |
sparse_partial_update_solution
|
tuple[jax.Array, jax.Array, jax.Array]
|
Sparse representation of the partial derivative G_u, containing row inds, column inds and values. These structures might contain redundances, which are automatically removed through summation in the sparse matrix assembly later. |
required |
solve
¶
solve(right_hand_side: jtFloat[jax.Array | npt.NDArray, num_vertices]) -> jtFloat[npt.NDArray, num_parameters]
Solve the linear system for the parametric gradient.
Following the notation from the class docstring,this method solves the linear system for the adjoint variable \(\mathbf{v}\). Given a right-hand-side \(\mathbf{l}_u\), this is a three- step process:
- Permute the right hand side \(\bar{l}_u = \mathbf{P}l_u\)
- Solve the linear system \(\bar{\mathbf{A}}\bar{\mathbf{v}} = \bar{l}_u\)
- Permute solution back to the original ordering \(\mathbf{v} = \mathbf{P}^T\bar{\mathbf{v}}\)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
right_hand_side
|
jax.Array | npt.NDArray
|
RHS for the linear system solve |
required |
Returns:
Type | Description |
---|---|
jtFloat[npt.NDArray, num_parameters]
|
np.ndarray: Solution of the linear system solve, corresponding to the adjoint in an optimization context. |
_assemble_permutation_matrix
¶
Construct permutation matrix \(\mathbf{P}\) for index ordering.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
solution
|
npt.NDArray
|
Obtained solution of the eikonal equation |
required |
Returns:
Type | Description |
---|---|
sp.sparse.csc_matrix
|
sp.sparse.csc_matrix: Sparse permutation matrix |
_assemble_system_matrix
¶
_assemble_system_matrix(sparse_partial_update_solution: tuple[jtInt[npt.NDArray, num_sol_values], jtInt[npt.NDArray, num_sol_values], jtFloat[npt.NDArray, num_sol_values]], num_points: int) -> sp.sparse.csc_matrix
Assemble system matrix \(\bar{\mathbf{A}}\) for gradient solver.
Before invoking this method, the permutation matrix \(\mathbf{P}\) must be initialized.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sparse_partial_update_solution
|
tuple[npt.NDArray, npt.NDArray, npt.NDArray]
|
Sparse representation of the partial derivative \(G_u\), containing row inds, column inds and values. These structures might contain redundances, which are automatically removed through summation in the sparse matrix assembly. |
required |
num_points
|
int
|
Number of mesh points |
required |
Returns:
Type | Description |
---|---|
sp.sparse.csc_matrix
|
sp.sparse.csc_matrix: Sparse representation of the permuted system matrix |
eikonax.derivator.compute_eikonax_jacobian
¶
compute_eikonax_jacobian(derivative_solver: DerivativeSolver, partial_derivative_parameter: sp.sparse.coo_matrix) -> npt.NDArray
Compute Jacobian from concatenation of gradients, computed with unit vector RHS.
Warning
This method should only be used for small problems.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
derivative_solver
|
DerivativeSolver
|
Initialized derivative solver object |
required |
partial_derivative_parameter
|
sp.sparse.coo_matrix
|
Partial derivative of the global update operator with respect to the parameter tensor field |
required |
Returns:
Type | Description |
---|---|
npt.NDArray
|
npt.NDArray: (Dense) Jacobian matrix |
eikonax.derivator.compute_eikonax_hessian
¶
Compute Hessian matrix.
Not implemented yet