Tensor Algebra¶
linalg
¶
Linear algebra utilities and sparse data structures for Eikonax.
This module provides specialized sparse data structures and utilities tailored to the Eikonax solver's derivative computation workflow. Since the derivative operators are inherently sparse due to the local nature of the Godunov update scheme, these structures efficiently store only the non-zero entries along with their connectivity information.
The module defines three main sparse container classes:
-
EikonaxSparseMatrix: A COO-like (coordinate format) sparse matrix representation compatible with JAX arrays. Stores row indices, column indices, and values separately along with the matrix shape. -
DerivatorSparseTensor: Specialized tensor for storing partial derivatives \(\mathbf{G}_M\) computed by thePartialDerivator. Pairs derivative values with simplex adjacency information for each vertex. -
TensorfieldSparseTensor: Container for the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) computed by theTensorField. Stores derivative values together with parameter indices that indicate which global parameters each simplex depends on.
The key operation in this module is the tensor contraction
contract_derivative_tensors, which efficiently
combines the derivatives from the Eikonax solver and the tensor field to compute the total
parametric derivative \(\mathbf{G}_m = \mathbf{G}_M \frac{d\mathbf{M}}{d\mathbf{m}}\) required
for gradient computation.
Classes:
| Name | Description |
|---|---|
EikonaxSparseMatrix |
COO-style sparse matrix for JAX-compatible derivative storage |
DerivatorSparseTensor |
Sparse tensor for partial derivatives w.r.t. metric tensors |
TensorfieldSparseTensor |
Sparse tensor for tensor field Jacobian |
Functions:
| Name | Description |
|---|---|
convert_to_scipy_sparse |
Convert EikonaxSparseMatrix to SciPy sparse array |
contract_derivative_tensors |
Contract DerivatorSparseTensor with TensorfieldSparseTensor |
eikonax.linalg.EikonaxSparseMatrix
¶
Bases: eqx.Module
COO-style sparse matrix representation compatible with JAX.
This class stores a sparse matrix in coordinate (COO) format using JAX arrays. Unlike standard SciPy sparse matrices, this representation is JAX-compatible and can be used in JIT-compiled functions and automatic differentiation. The matrix is defined by three parallel arrays storing row indices, column indices, and corresponding values, along with the matrix shape.
Attributes:
| Name | Type | Description |
|---|---|---|
row_inds |
jax.Array
|
Row indices of non-zero entries with shape (num_entries,). |
col_inds |
jax.Array
|
Column indices of non-zero entries with shape (num_entries,). |
values |
jax.Array
|
Values of non-zero entries with shape (num_entries,). |
shape |
tuple[int, int]
|
Shape of the matrix (num_rows, num_cols). |
eikonax.linalg.DerivatorSparseTensor
¶
Bases: eqx.Module
Sparse tensor for partial derivatives from the PartialDerivator.
This container stores the partial derivatives \(\mathbf{G}_M\) of the Eikonax update operator with respect to the metric tensor field. For each vertex, it stores derivative contributions from all adjacent simplices in a dense local format, along with indices identifying which simplices are adjacent. This structure efficiently represents the sparse global tensor while maintaining JAX compatibility.
The derivative values capture how changes in the metric tensor of each adjacent simplex affect the update value at each vertex. The adjacency data maps these local contributions to the global simplex numbering.
Attributes:
| Name | Type | Description |
|---|---|---|
derivative_values |
jax.Array
|
Partial derivative tensors with shape (num_vertices, max_num_neighbors, dim, dim). For each vertex and adjacent simplex, stores the (dim x dim) derivative of the update w.r.t. that simplex's metric tensor. |
adjacent_simplex_data |
jax.Array
|
Global simplex indices with shape (num_vertices, max_num_neighbors). Maps local neighbor index to global simplex index. Entries of -1 indicate padding for vertices with fewer than max_num_neighbors adjacent simplices. |
eikonax.linalg.TensorfieldSparseTensor
¶
Bases: eqx.Module
Sparse tensor for the tensor field Jacobian.
This container stores the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) computed by the
TensorField component. For each simplex, it stores
how the metric tensor depends on the global parameters, along with indices indicating which
global parameters affect each simplex.
Since most tensor field parameterizations are local (each simplex depends on only a small subset of global parameters), this sparse representation is much more memory-efficient than storing the full dense Jacobian tensor.
Attributes:
| Name | Type | Description |
|---|---|---|
derivative_values |
jax.Array
|
Jacobian values with shape (num_simplices, dim, dim, num_parameters_mapped). For each simplex, stores how the (dim x dim) metric tensor changes with respect to the relevant parameters. |
parameter_inds |
jax.Array
|
Global parameter indices with shape (num_simplices, num_parameters_mapped). Maps local parameter index to global parameter index for each simplex. |
num_parameters_global |
int
|
Total number of global parameters in the full parameter vector. Used to determine output matrix dimensions in contractions. |
eikonax.linalg.convert_to_scipy_sparse
¶
Convert EikonaxSparseMatrix to SciPy sparse COO array.
This function transforms the JAX-compatible
EikonaxSparseMatrix format into a standard SciPy
coo_array. The conversion involves:
- Extracting row indices, column indices, and values as NumPy arrays
- Filtering out zero entries to reduce memory footprint
- Constructing a SciPy COO array and summing duplicate entries
This conversion is necessary when interfacing with SciPy's sparse linear algebra routines,
such as in the DerivativeSolver which requires CSC
format for the triangular solve.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eikonax_sparse_matrix
|
EikonaxSparseMatrix
|
Sparse matrix in JAX-compatible format. |
required |
Returns:
| Type | Description |
|---|---|
sp.coo_array
|
scipy.sparse.coo_array: SciPy sparse array in COO format with zeros removed and duplicate entries summed. |
eikonax.linalg.contract_derivative_tensors
¶
contract_derivative_tensors(
derivative_sparse_tensor: DerivatorSparseTensor,
tensorfield_sparse_tensor: TensorfieldSparseTensor,
) -> EikonaxSparseMatrix
Contract derivative tensors to compute total parametric derivative.
This function performs the key operation in Eikonax's gradient computation pipeline: combining the partial derivatives from the solver (\(\mathbf{G}_M\)) with the tensor field Jacobian (\(\frac{d\mathbf{M}}{d\mathbf{m}}\)) to obtain the total parametric derivative \(\mathbf{G}_m = \mathbf{G}_M \frac{d\mathbf{M}}{d\mathbf{m}}\).
The contraction is performed efficiently by exploiting the sparse structure of both inputs:
- For each vertex, iterate over adjacent simplices
- Extract the (dim x dim) derivative matrix for that vertex-simplex pair from
derivative_sparse_tensor - Contract (via Einstein summation) with the corresponding (dim x dim x num_params) Jacobian
tensor from
tensorfield_sparse_tensor - Map the result to the appropriate global parameter indices
- Assemble all contributions into a sparse matrix of shape (num_vertices, num_parameters)
The operation is vectorized over all vertices using JAX's vmap and JIT-compiled for
performance.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
derivative_sparse_tensor
|
DerivatorSparseTensor
|
Partial derivatives \(\mathbf{G}_M\)
from the |
required |
tensorfield_sparse_tensor
|
TensorfieldSparseTensor
|
Tensor field Jacobian
\(\frac{d\mathbf{M}}{d\mathbf{m}}\) from the
|
required |
Returns:
| Name | Type | Description |
|---|---|---|
EikonaxSparseMatrix |
EikonaxSparseMatrix
|
Total parametric derivative \(\mathbf{G}_m\) with shape (num_vertices, num_parameters_global). This can be transposed and multiplied with the adjoint vector to compute parametric gradients. |
eikonax.linalg._contract_vertex_tensors
¶
_contract_vertex_tensors(
derivator_tensor: jtReal[jax.Array, "max_num_neighbors dim dim"],
adjacent_simplex_data: jtInt[jax.Array, max_num_neighbors],
tensorfield_data: jtReal[jax.Array, "num_simplices dim dim num_parameters_mapped"],
parameter_inds: jtInt[jax.Array, "num_simplices num_parameters_mapped"],
) -> tuple[
jtReal[jax.Array, "max_num_neighbors num_parameters_mapped"],
jtInt[jax.Array, "max_num_neighbors num_parameters_mapped"],
]
Contract derivatives for a single vertex with tensor field Jacobian.
This helper function performs the tensor contraction for a single vertex. For each adjacent simplex, it:
- Extracts the (dim x dim) derivative matrix from the derivator tensor
- Extracts the (dim x dim x num_params) Jacobian tensor for that simplex
- Contracts them via Einstein summation: result[k] = sum_ij derivator[i,j] * jacobian[i,j,k]
- Maps the result to the global parameter indices for that simplex
Invalid simplices (indicated by simplex_ind == -1) are handled by filtering: their contributions are set to zero and their column indices set to -1.
This function is called in a vectorized manner by
contract_derivative_tensors via jax.vmap.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
derivator_tensor
|
jax.Array
|
Partial derivatives for one vertex with shape (max_num_neighbors, dim, dim). |
required |
adjacent_simplex_data
|
jax.Array
|
Global simplex indices for one vertex with shape (max_num_neighbors,). Entries of -1 indicate padding. |
required |
tensorfield_data
|
jax.Array
|
Global tensor field Jacobian with shape (num_simplices, dim, dim, num_parameters_mapped). |
required |
parameter_inds
|
jax.Array
|
Global parameter indices for all simplices with shape (num_simplices, num_parameters_mapped). |
required |
Returns:
| Type | Description |
|---|---|
tuple[jtReal[jax.Array, 'max_num_neighbors num_parameters_mapped'], jtInt[jax.Array, 'max_num_neighbors num_parameters_mapped']]
|
tuple[jax.Array, jax.Array]: Contracted derivative values with shape (max_num_neighbors, num_parameters_mapped) and corresponding column indices with the same shape. Entries corresponding to invalid simplices are zeroed/set to -1. |