Skip to content

Solver

main

This file contains the main classes that need to be used for solving Eikonal Equations.

The classes from this file should be used as an easy interface to the implemented methods in this project.

fimjax.main.Mesh

dataclass for holding everything that belongs to the mesh. Note that this only supports keyword arguments due to Chex. Use it like: mesh = Mesh(points=points, elements=elements)

N: Number of elements in a mesh (i.e. triangles) d_e: Number of points in an element M: Number of points in a mesh d: Dimension of the underlying space

Attributes:

Name Type Description
points np.ndarray

[M, d] array of points

elements np.ndarray

[N, d_e] array of indices into points that correspond to elements in a mesh

points_triangle np.ndarray

[N, d_e, d] like elements, but with the points instead of indices

__init__

__init__(points: np.ndarray, elements: np.ndarray)

Constructor.

Parameters:

Name Type Description Default
points np.ndarray

[M, d] array of points

required
elements np.ndarray

[N] array of

required

fimjax.main.InitialValues

Initial values for an Eikonal PDE.

Attributes:

Name Type Description
locations np.ndarray

[X] array of indices into a mesh that point to the initial value locations

values np.ndarray

[X] array of values of the initial values

fimjax.main.FIMSolution

Holds the information on a FIMSolution.

Attributes:

Name Type Description
solution np.ndarray

solution

iterations int

number of iterations

has_converged bool

flag if the solution has converged

has_converged_after int

number of iterations needed for convergence

fimjax.main.ITERATION_SCHEME

Bases: StrEnum

Iteration schemes that can be used for the solver function in FIM.

fimjax.main.Solver

Main class to interact with the Eikonal Solver.

This class should be used as an OOP interface to the JAX code that is used to compute solutions and parametric derivatives for the eikonal equation. Further information on the different iteration schemes can be found in the corresponding functions.

get_solver_function

get_solver_function(type: ITERATION_SCHEME = ITERATION_SCHEME.FOR, local_update_function: callable = None)

Provides the solution function as a jax function for composition.

Returns:

Name Type Description
callable

solution function.

solve

solve(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iter: int, local_update_function: callable = _update_all_triangles) -> FIMSolution

Uses the FIM algorithm to solve an eikonal equation.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
metrics np.ndarray

metrics tensor

required
initial_values InitialValues

initial values

required
iter int

number of iterations

required
local_update_function optional

function to use for local updates. Defaults to _update_all_triangles.

_update_all_triangles

Returns:

Name Type Description
FIMSolution FIMSolution

FIMSolution object

value_and_vjp

value_and_vjp(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iter: int, adjoint_vector: np.ndarray, local_update_function: callable = _update_all_triangles) -> np.ndarray

Calculates the value and the vector-jacobian product for the FIM.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
metrics np.ndarray

metrics tensor

required
initial_values InitialValues

initial values

required
iter int

number of iterations

required
adjoint_vector np.ndarray

adjoint vector

required
local_update_function optional

function to use for local updates. Defaults to _update_all_triangles.

_update_all_triangles

Returns:

Type Description
np.ndarray

np.ndarray: adjoint vector

fimjax.main._update_all_triangles

_update_all_triangles(mesh: Mesh, D: np.ndarray, solution: np.ndarray) -> np.ndarray

Performs one Jacobi update.

Calculates the solution to all update direction and picks the smallest one for each point.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
D np.ndarray

[N, d, d] array with metric tensor field

required
solution np.ndarray

[M] array solution before iteration

required

Returns:

Type Description
np.ndarray

[M] new solution after one iteration

fimjax.main._compute_fim_for

_compute_fim_for(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iters: int, local_update_function: callable) -> FIMSolution

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Uses Checkpointing at every iterations to provide a more memory efficient AD.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
iters int

number of iterations

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
FIMSolution

solution along with a flag whether FIM has converged.

fimjax.main._compute_fim_checkpointed_while

_compute_fim_checkpointed_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, checkpoints: int, local_update_function: callable) -> FIMSolution

Uses the Jacobi update until convergence is obtained.

Uses a checkpointed while loop from equinox to make jitting and AD possible, because XLA needs static bounds on memory, that are normally not possible with an unknown number of iterations. For more information on how this works, please see the documentation of Equinox' checkpointed_while_loop.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required
checkpoints int

number of checkpoints used in forward pass for AD

required

Returns:

Type Description
FIMSolution

solution object

fimjax.main._compute_fim_while

_compute_fim_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> FIMSolution

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Due to XLA needing static memory bounds this functions is not jittable.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
FIMSolution

solution object

fimjax.main._compute_fim_fixed_point

_compute_fim_fixed_point(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> np.ndarray

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Due to XLA needing static memory bounds this functions is not jittable.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
np.ndarray

solution object