Skip to content

Parameter Tensor Field

tensorfield

Composable and differentiable parameter tensor fields.

This module provides ABCs and implementations for the creation of differentiable parameter fields used in Eikonax. Recall that for the Eikonax solver, and particularly parameteric derivatives, we require an input tensor field \(\mathbf{M}: \mathbb{R}^M \times \mathbb{N}_0 \to \mathbb{S}_+^{d\times d}\). This means that the tensor field is a mapping \(\mathbf{M}(\mathbf{m},s)\) that assigns, given a global parameter vector \(\mathbf{m}\), an s.p.d tensor to every simplex \(s\) in the mesh. To allow for sufficient flexibility in the choice of tensor field, we implement it as a composition of two main components.

  1. AbstractVectorToSimplicesMap provides the interface for a mapping from the global parameter vector \(\mathbf{m}\) to the local parameter values \(\mathbf{m}_s\) required to assemble the tensor \(\mathbf{M}_s\) for simplex \(s\).
  2. AbstractSimplexTensor provides the interface for the assembly of the local tensor \(\mathbf{M}_s\), given the local contributions \(\mathbf{m}_s\) and a simplex s.

Concrete implementations of both components are used to initialize the TensorField object, which vectorizes and differentiates them using JAX, to provide the mapping \(\mathbf{M}(\mathbf{m})\) and its Jacobian tensor \(\frac{d \mathbf{M}}{d \mathbf{m}}\).

Classes:

Name Description
AbstractVectorToSimplicesMap

ABC interface contract for vector-to-simplices maps

LinearScalarMap

Simple one-to-one map from global to simplex parameters

AbstractSimplexTensor

ABC interface contract for assembly of the tensor field

LinearScalarSimplexTensor

SimplexTensor implementation relying on one parameter per simplex

InvLinearScalarSimplexTensor

SimplexTensor implementation relying on one parameter per simplex

TensorField

Tensor field component

eikonax.tensorfield.AbstractVectorToSimplicesMap

Bases: eqx.Module

ABC interface contract for vector-to-simplices maps.

Every component derived from this class needs to implement the map method, which maps returns the relevant parameters for a given simplex from the global parameter vector.

Note

Eikonax assumes that the mapping from global to local parameters is linear, so that a parametric derivatives does not have to be provided.

Methods:

Name Description
map

Interface for vector-so-simplex mapping

map abstractmethod

map(simplex_ind: jtInt[jax.Array, ''], parameters: jtReal[jax.Array, num_parameters]) -> jtReal[jax.Array, ...]

Interface for vector-so-simplex mapping.

For the given simplex_ind, return those parameters from the global parameter vector that are relevant for the simplex. This methods need to be broadcastable over simplex_ind by JAX (with vmap).

Parameters:

Name Type Description Default
simplex_ind jax.Array

Index of the simplex under consideration

required
parameters jax.Array

Global parameter vector

required

Raises:

Type Description
NotImplementedError

ABC error indicating that the method needs to be implemented in subclasses

Returns:

Type Description
jtReal[jax.Array, ...]

jax.Array: Relevant parameters for the simplex

eikonax.tensorfield.LinearScalarMap

Bases: AbstractVectorToSimplicesMap

Simple one-to-one map from global to simplex parameters.

Every simplex takes exactly one parameter \(m_s\), which is sorted in the global parameter in the same order as the simplices, meaning that \(m_s = \mathbf{m}[s]\).

map

map(simplex_ind: jtInt[jax.Array, ''], parameters: jtReal[jax.Array, num_parameters]) -> jtReal[jax.Array, '']

Return relevant parameters for a given simplex.

Parameters:

Name Type Description Default
simplex_ind jax.Array

Index of the simplex under consideration

required
parameters jax.Array

Global parameter vector

required

Returns:

Type Description
jtReal[jax.Array, '']

jax.Array: relevant parameter (only one)

eikonax.tensorfield.AbstractSimplexTensor

Bases: eqx.Module

ABC interface contract for assembly of the tensor field.

SimplexTensor components assemble the tensor field for a given simplex and a set of parameters for that simplex. The relevant parameters are provided by the VectorToSimplicesMap component from the global parameter vector.

Note

Tis class provides the metric tensor as used in the inner product for the update stencil of the eikonal equation. This is the INVERSE of the conductivity tensor, which is the actual tensor field in the eikonal equation.

Methods:

Name Description
assemble

Assemble the tensor field for a given simplex and parameters

derivative

Parametric derivative of the assemble method

__check_init__

__check_init__() -> None

Check that dimension is initialized correctly in subclasses.

assemble abstractmethod

assemble(simplex_ind: jtInt[jax.Array, ''], parameters: jtFloat[jax.Array, num_parameters_local]) -> jtFloat[jax.Array, 'dim dim']

Assemble the tensor field for given simplex and parameters.

Given a parameter array of size \(m_s\), the methods returns a tensor of size \(d\times d\). The method needs to be broadcastable over simplex_ind by JAX (with vmap).

Parameters:

Name Type Description Default
simplex_ind jax.Array

Index of the simplex under consideration

required
parameters jax.Array

Parameters for the simplex

required

Raises:

Type Description
NotImplementedError

ABC error indicating that the method needs to be implemented in subclasses

Returns:

Type Description
jtFloat[jax.Array, 'dim dim']

jax.Array: Tensor field for the simplex under consideration

derivative abstractmethod

derivative(simplex_ind: jtInt[jax.Array, ''], parameters: jtFloat[jax.Array, num_parameters_local]) -> jtFloat[jax.Array, 'dim dim num_parameters_local']

Parametric derivative of the assemble method.

Given a parameter array of size \(m_s\), the methods returns a Jacobian tensor of size \(d\times d\times m_s\). The method needs to be broadcastable over simplex_ind by JAX (with vmap).

Parameters:

Name Type Description Default
simplex_ind jax.Array

Index of the simplex under consideration

required
parameters jax.Array

Parameters for the simplex

required

Raises:

Type Description
NotImplementedError

ABC error indicating that the method needs to be implemented in subclasses

Returns:

Type Description
jtFloat[jax.Array, 'dim dim num_parameters_local']

jax.Array: Jacobian tensor for the simplex under consideration

eikonax.tensorfield.LinearScalarSimplexTensor

Bases: AbstractSimplexTensor

SimplexTensor implementation relying on one parameter per simplex.

Given a scalar parameter \(m_s\), the tensor field is assembled as \(m_s \cdot \mathbf{I}\), where \(\mathbf{I}\) is the identity matrix.

Methods:

Name Description
assemble

Assemble the tensor field for a parameter vector

derivative

Parametric derivative of the assemble method

__init__

__init__(dimension: int) -> None

Constructor.

Parameters:

Name Type Description Default
dimension int

Dimension of the tensor field

required

assemble

assemble(_simplex_ind: jtInt[jax.Array, ''], parameters: jtFloat[jax.Array, '']) -> jtFloat[jax.Array, 'dim dim']

Assemble tensor for given simplex.

the parameters argument is a scalar here, and _simplex_ind is not used.

Parameters:

Name Type Description Default
_simplex_ind jax.Array

Index of simplex under consideration (not used)

required
parameters jax.Array

Parameter (scalar) for tensor assembly

required

Returns:

Type Description
jtFloat[jax.Array, 'dim dim']

jax.Array: Tensor for the simplex

derivative

derivative(_simplex_ind: jtInt[jax.Array, ''], _parameters: jtFloat[jax.Array, '']) -> jtFloat[jax.Array, 'dim dim num_parameters_local']

Parametric derivative of the assemble method.

Parameters:

Name Type Description Default
_simplex_ind jax.Array

Index of simplex under consideration (not used)

required
_parameters jax.Array

Parameter (scalar) for tensor assembly

required

Returns:

Type Description
jtFloat[jax.Array, 'dim dim num_parameters_local']

jax.Array: Jacobian tensor for the simplex under consideration

eikonax.tensorfield.InvLinearScalarSimplexTensor

Bases: AbstractSimplexTensor

SimplexTensor implementation relying on one parameter per simplex.

Given a scalar parameter \(m_s\), the tensor field is assembled as \(\frac{1}{m_s} \cdot \mathbf{I}\), where \(\mathbf{I}\) is the identity matrix.

Methods:

Name Description
assemble

Assemble the tensor field for a parameter vector

derivative

Parametric derivative of the assemble method

__init__

__init__(dimension: int) -> None

Constructor.

Parameters:

Name Type Description Default
dimension int

Dimension of the tensor field

required

assemble

assemble(_simplex_ind: jtInt[jax.Array, ''], parameters: jtFloat[jax.Array, '']) -> jtFloat[jax.Array, 'dim dim']

Assemble tensor for given simplex.

The parameters argument is a scalar here, and _simplex_ind is not used.

Parameters:

Name Type Description Default
_simplex_ind jax.Array

Index of simplex under consideration (not used)

required
parameters jax.Array

Parameter (scalar) for tensor assembly

required

Returns:

Type Description
jtFloat[jax.Array, 'dim dim']

jax.Array: Tensor for the simplex

derivative

derivative(_simplex_ind: jtInt[jax.Array, ''], parameters: jtFloat[jax.Array, '']) -> jtFloat[jax.Array, 'dim dim num_parameters_local']

Parametric derivative of the assemble method.

Parameters:

Name Type Description Default
_simplex_ind jax.Array

Index of simplex under consideration (not used)

required
parameters jax.Array

Parameter (scalar) for tensor assembly

required

Returns:

Type Description
jtFloat[jax.Array, 'dim dim num_parameters_local']

jax.Array: Jacobian tensor for the simplex under consideration

eikonax.tensorfield.TensorField

Bases: eqx.Module

Tensor field component.

Tensor fields combine the functionality of vector-to-simplices maps and simplex tensors according to the composition over inheritance principle. They constitute the full mapping \(\mathbf{M}(\mathbf{m})\) from the global parameter vector to the tensor field over all mesh faces (simplices). In addition, they provide the parametric derivative \(\frac{d\mathbf{M}}{\mathbf{m}}\) of that mapping, and assemble the full parameter-to-solution partial Jacobian \(\mathbf{G}_m\) from a given partial derivative of the solution vector w.r.t. the tensor field \(\mathbf{G}_M\). This introduces some degree of coupling to the eikonax solver, but is the simplest interface for computation of the total derivative according to the chain rule. More detailed explanations are given in the assemble_jacobian method.

Methods:

Name Description
assemble_field

Assemble the tensor field for the given parameter vector

assemble_jacobian

Assemble the parametric derivative of a solution vector for a given parameter vector and derivative of the solution vector w.r.t. the tensor field

__init__

__init__(num_simplices: int, vector_to_simplices_map: AbstractVectorToSimplicesMap, simplex_tensor: AbstractSimplexTensor) -> None

Constructor.

Takes information about the mesh simplices, a vector-to-simplices map, and a simplex tensor map.

Parameters:

Name Type Description Default
num_simplices int

Number of simplices in the mesh

required
vector_to_simplices_map AbstractVectorToSimplicesMap

Mapping from global to simplex parameters

required
simplex_tensor AbstractSimplexTensor

Tensor field assembly for a given simplex

required

assemble_field

assemble_field(parameter_vector: jtFloat[jax.Array | npt.NDArray, num_parameters_global]) -> jtFloat[jax.Array, 'num_simplex dim dim']

Assemble global tensor field from global parameter vector.

This method simply chains calls to the vector-to-simplices map and the simplex tensor objects, vectorized over all simplices.

Parameters:

Name Type Description Default
parameter_vector jax.Array | npt.NDArray

Global parameter vector

required

Returns:

Type Description
jtFloat[jax.Array, 'num_simplex dim dim']

jax.Array: Global tensor field

assemble_jacobian

assemble_jacobian(number_of_vertices: int, derivative_solution_tensor: tuple[jtInt[jax.Array, num_values], jtInt[jax.Array, num_values], jtFloat[jax.Array, 'num_values dim dim']], parameter_vector: jtFloat[jax.Array | npt.NDArray, num_parameters_global]) -> sp.sparse.coo_matrix

Assemble partial derivative of the Eikonax solution vector w.r.t. parameters.

The total derivative of the Update operator w.r.t. the global parameter vector is given by the chain rule of differentiation, \(\mathbf{G}_m = \mathbf{G}_M \frac{d\mathbf{M}}{d\mathbf{m}}\) The Eikonax PartialDerivator component evaluates the derivative of the solution vector w.r.t. the tensor field. The tensor field assembles the Jacobian tensor of the tensor field w.r.t. to the global parameter vector, and chains it with the solution-to-tensor derivative in a vectorized form. All computations are done in a sparse matrix format. Consider given a solution-to-tensor derivative of \(\mathbf{G}_M\) of shape \(N \times K \times d \times d\), where \(N\) is the number of vertices, \(K\) is the number of simplices, and \(d\) is the physical dimension of the tensor field. This method internally assembles the tensor-to-parameter derivative \(\frac{d\mathbf{M}}{d\mathbf{m}}\) of shape \(K \times d \times d \times M\), where \(M\) is the number of parameters. The total derivative is then computed as a tensor product of \(\mathbf{G}_M\) and \(\frac{d\mathbf{M}}{d\mathbf{m}}\) over their last and first three dimensions,respectively. The output is a sparse matrix of shape N x M, returned as a scipy COO matrix. The assembly is rather involved, so we handle it internally in this component, at the expense of introducing some additional coupling to the Eikonax Derivator

Will be changed

The PartialDerivator returns a compressed representation of \(\mathbf{G}_M\), which is hard to handle with standardized tensor product operations. Reducing the compression might allow for a more transparent interface at this point.

Parameters:

Name Type Description Default
number_of_vertices int

Number of vertices in the mesh

required
derivative_solution_tensor tuple[jax.Array, jax.Array, jax.Array]

Solution-tensor derivative of shape N x K x D x D. Provided as a tuple of row indices, simplex indices, and values, already in sparsified format. The row indices are the indices of the relevant vertices, and can be seen as one half of the index set of the resulting sparse matrix. For each row index, the corresponding simplex index indicates the simplex whose tensor values influence the solution at that vertex by means of the derivative.

required
parameter_vector jax.Array

Global parameter vector

required

Returns:

Type Description
sp.sparse.coo_matrix

sp.sparse.coo_matrix: Sparse derivative of the Eikonax solution vector w.r.t. the global parameter vector, of shape N x M

_assemble_jacobian

_assemble_jacobian(simplex_inds: jtFloat[jax.Array, num_values], derivative_solution_tensor_values: jtFloat[jax.Array, num_values], parameter_vector: jtFloat[jax.Array, num_parameters_global]) -> tuple[jtFloat[jax.Array, ...], jtInt[jax.Array, ...]]

Compute the partial derivative \(\frac{d\mathbf{M}}{d\mathbf{m}}\).

Simplex-level derivatives are computed for all provided simplex_inds to match the solution-tensor derivatives obtained from the Eikonax derivator.

Parameters:

Name Type Description Default
simplex_inds jax.Array

Indices of simplices under consideration

required
derivative_solution_tensor_values jax.Array

Solution-tensor derivative values

required
parameter_vector jax.Array

Global parameter vector

required

Returns:

Type Description
tuple[jtFloat[jax.Array, ...], jtInt[jax.Array, ...]]

tuple[jax.Array, jax.Array]: Values and column indices of the Jacobian