Skip to content

Tensor Algebra

linalg

Linear algebra utilities and sparse data structures for Eikonax.

This module provides specialized sparse data structures and utilities tailored to the Eikonax solver's derivative computation workflow. Since the derivative operators are inherently sparse due to the local nature of the Godunov update scheme, these structures efficiently store only the non-zero entries along with their connectivity information.

The module defines three main sparse container classes:

  1. EikonaxSparseMatrix: A COO-like (coordinate format) sparse matrix representation compatible with JAX arrays. Stores row indices, column indices, and values separately along with the matrix shape.

  2. DerivatorSparseTensor: Specialized tensor for storing partial derivatives \(\mathbf{G}_M\) computed by the PartialDerivator. Pairs derivative values with simplex adjacency information for each vertex.

  3. TensorfieldSparseTensor: Container for the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) computed by the TensorField. Stores derivative values together with parameter indices that indicate which global parameters each simplex depends on.

The key operation in this module is the tensor contraction contract_derivative_tensors, which efficiently combines the derivatives from the Eikonax solver and the tensor field to compute the total parametric derivative \(\mathbf{G}_m = \mathbf{G}_M \frac{d\mathbf{M}}{d\mathbf{m}}\) required for gradient computation.

Classes:

Name Description
EikonaxSparseMatrix

COO-style sparse matrix for JAX-compatible derivative storage

DerivatorSparseTensor

Sparse tensor for partial derivatives w.r.t. metric tensors

TensorfieldSparseTensor

Sparse tensor for tensor field Jacobian

Functions:

Name Description
convert_to_scipy_sparse

Convert EikonaxSparseMatrix to SciPy sparse array

contract_derivative_tensors

Contract DerivatorSparseTensor with TensorfieldSparseTensor

eikonax.linalg.EikonaxSparseMatrix

Bases: eqx.Module

COO-style sparse matrix representation compatible with JAX.

This class stores a sparse matrix in coordinate (COO) format using JAX arrays. Unlike standard SciPy sparse matrices, this representation is JAX-compatible and can be used in JIT-compiled functions and automatic differentiation. The matrix is defined by three parallel arrays storing row indices, column indices, and corresponding values, along with the matrix shape.

Attributes:

Name Type Description
row_inds jax.Array

Row indices of non-zero entries with shape (num_entries,).

col_inds jax.Array

Column indices of non-zero entries with shape (num_entries,).

values jax.Array

Values of non-zero entries with shape (num_entries,).

shape tuple[int, int]

Shape of the matrix (num_rows, num_cols).

eikonax.linalg.DerivatorSparseTensor

Bases: eqx.Module

Sparse tensor for partial derivatives from the PartialDerivator.

This container stores the partial derivatives \(\mathbf{G}_M\) of the Eikonax update operator with respect to the metric tensor field. For each vertex, it stores derivative contributions from all adjacent simplices in a dense local format, along with indices identifying which simplices are adjacent. This structure efficiently represents the sparse global tensor while maintaining JAX compatibility.

The derivative values capture how changes in the metric tensor of each adjacent simplex affect the update value at each vertex. The adjacency data maps these local contributions to the global simplex numbering.

Attributes:

Name Type Description
derivative_values jax.Array

Partial derivative tensors with shape (num_vertices, max_num_neighbors, dim, dim). For each vertex and adjacent simplex, stores the (dim x dim) derivative of the update w.r.t. that simplex's metric tensor.

adjacent_simplex_data jax.Array

Global simplex indices with shape (num_vertices, max_num_neighbors). Maps local neighbor index to global simplex index. Entries of -1 indicate padding for vertices with fewer than max_num_neighbors adjacent simplices.

eikonax.linalg.TensorfieldSparseTensor

Bases: eqx.Module

Sparse tensor for the tensor field Jacobian.

This container stores the Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) computed by the TensorField component. For each simplex, it stores how the metric tensor depends on the global parameters, along with indices indicating which global parameters affect each simplex.

Since most tensor field parameterizations are local (each simplex depends on only a small subset of global parameters), this sparse representation is much more memory-efficient than storing the full dense Jacobian tensor.

Attributes:

Name Type Description
derivative_values jax.Array

Jacobian values with shape (num_simplices, dim, dim, num_parameters_mapped). For each simplex, stores how the (dim x dim) metric tensor changes with respect to the relevant parameters.

parameter_inds jax.Array

Global parameter indices with shape (num_simplices, num_parameters_mapped). Maps local parameter index to global parameter index for each simplex.

num_parameters_global int

Total number of global parameters in the full parameter vector. Used to determine output matrix dimensions in contractions.

eikonax.linalg.convert_to_scipy_sparse

convert_to_scipy_sparse(eikonax_sparse_matrix: EikonaxSparseMatrix) -> sp.coo_array

Convert EikonaxSparseMatrix to SciPy sparse COO array.

This function transforms the JAX-compatible EikonaxSparseMatrix format into a standard SciPy coo_array. The conversion involves:

  1. Extracting row indices, column indices, and values as NumPy arrays
  2. Filtering out zero entries to reduce memory footprint
  3. Constructing a SciPy COO array and summing duplicate entries

This conversion is necessary when interfacing with SciPy's sparse linear algebra routines, such as in the DerivativeSolver which requires CSC format for the triangular solve.

Parameters:

Name Type Description Default
eikonax_sparse_matrix EikonaxSparseMatrix

Sparse matrix in JAX-compatible format.

required

Returns:

Type Description
sp.coo_array

scipy.sparse.coo_array: SciPy sparse array in COO format with zeros removed and duplicate entries summed.

eikonax.linalg.contract_derivative_tensors

contract_derivative_tensors(
    derivative_sparse_tensor: DerivatorSparseTensor,
    tensorfield_sparse_tensor: TensorfieldSparseTensor,
) -> EikonaxSparseMatrix

Contract derivative tensors to compute total parametric derivative.

This function performs the key operation in Eikonax's gradient computation pipeline: combining the partial derivatives from the solver (\(\mathbf{G}_M\)) with the tensor field Jacobian (\(\frac{d\mathbf{M}}{d\mathbf{m}}\)) to obtain the total parametric derivative \(\mathbf{G}_m = \mathbf{G}_M \frac{d\mathbf{M}}{d\mathbf{m}}\).

The contraction is performed efficiently by exploiting the sparse structure of both inputs:

  1. For each vertex, iterate over adjacent simplices
  2. Extract the (dim x dim) derivative matrix for that vertex-simplex pair from derivative_sparse_tensor
  3. Contract (via Einstein summation) with the corresponding (dim x dim x num_params) Jacobian tensor from tensorfield_sparse_tensor
  4. Map the result to the appropriate global parameter indices
  5. Assemble all contributions into a sparse matrix of shape (num_vertices, num_parameters)

The operation is vectorized over all vertices using JAX's vmap and JIT-compiled for performance.

Parameters:

Name Type Description Default
derivative_sparse_tensor DerivatorSparseTensor

Partial derivatives \(\mathbf{G}_M\) from the PartialDerivator.

required
tensorfield_sparse_tensor TensorfieldSparseTensor

Tensor field Jacobian \(\frac{d\mathbf{M}}{d\mathbf{m}}\) from the TensorField.

required

Returns:

Name Type Description
EikonaxSparseMatrix EikonaxSparseMatrix

Total parametric derivative \(\mathbf{G}_m\) with shape (num_vertices, num_parameters_global). This can be transposed and multiplied with the adjoint vector to compute parametric gradients.

eikonax.linalg._contract_vertex_tensors

_contract_vertex_tensors(
    derivator_tensor: jtReal[jax.Array, "max_num_neighbors dim dim"],
    adjacent_simplex_data: jtInt[jax.Array, max_num_neighbors],
    tensorfield_data: jtReal[jax.Array, "num_simplices dim dim num_parameters_mapped"],
    parameter_inds: jtInt[jax.Array, "num_simplices num_parameters_mapped"],
) -> tuple[
    jtReal[jax.Array, "max_num_neighbors num_parameters_mapped"],
    jtInt[jax.Array, "max_num_neighbors num_parameters_mapped"],
]

Contract derivatives for a single vertex with tensor field Jacobian.

This helper function performs the tensor contraction for a single vertex. For each adjacent simplex, it:

  1. Extracts the (dim x dim) derivative matrix from the derivator tensor
  2. Extracts the (dim x dim x num_params) Jacobian tensor for that simplex
  3. Contracts them via Einstein summation: result[k] = sum_ij derivator[i,j] * jacobian[i,j,k]
  4. Maps the result to the global parameter indices for that simplex

Invalid simplices (indicated by simplex_ind == -1) are handled by filtering: their contributions are set to zero and their column indices set to -1.

This function is called in a vectorized manner by contract_derivative_tensors via jax.vmap.

Parameters:

Name Type Description Default
derivator_tensor jax.Array

Partial derivatives for one vertex with shape (max_num_neighbors, dim, dim).

required
adjacent_simplex_data jax.Array

Global simplex indices for one vertex with shape (max_num_neighbors,). Entries of -1 indicate padding.

required
tensorfield_data jax.Array

Global tensor field Jacobian with shape (num_simplices, dim, dim, num_parameters_mapped).

required
parameter_inds jax.Array

Global parameter indices for all simplices with shape (num_simplices, num_parameters_mapped).

required

Returns:

Type Description
tuple[jtReal[jax.Array, 'max_num_neighbors num_parameters_mapped'], jtInt[jax.Array, 'max_num_neighbors num_parameters_mapped']]

tuple[jax.Array, jax.Array]: Contracted derivative values with shape (max_num_neighbors, num_parameters_mapped) and corresponding column indices with the same shape. Entries corresponding to invalid simplices are zeroed/set to -1.