Solver¶
main
¶
This file contains the main classes that need to be used for solving Eikonal Equations.
The classes from this file should be used as an easy interface to the implemented methods in this project.
fimjax.main.Mesh
¶
dataclass for holding everything that belongs to the mesh.
Note that this only supports keyword arguments due to Chex.
Use it like:
mesh = Mesh(points=points, elements=elements)
N: Number of elements in a mesh (i.e. triangles) d_e: Number of points in an element M: Number of points in a mesh d: Dimension of the underlying space
Attributes:
Name | Type | Description |
---|---|---|
points |
np.ndarray
|
[M, d] array of points |
elements |
np.ndarray
|
[N, d_e] array of indices into points that correspond to elements in a mesh |
points_triangle |
np.ndarray
|
[N, d_e, d] like elements, but with the points instead of indices |
__init__
¶
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
points
|
np.ndarray
|
[M, d] array of points |
required |
elements
|
np.ndarray
|
[N] array of |
required |
fimjax.main.InitialValues
¶
Initial values for an Eikonal PDE.
Attributes:
Name | Type | Description |
---|---|---|
locations |
np.ndarray
|
[X] array of indices into a mesh that point to the initial value locations |
values |
np.ndarray
|
[X] array of values of the initial values |
fimjax.main.FIMSolution
¶
Holds the information on a FIMSolution.
Attributes:
Name | Type | Description |
---|---|---|
solution |
np.ndarray
|
solution |
iterations |
int
|
number of iterations |
has_converged |
bool
|
flag if the solution has converged |
has_converged_after |
int
|
number of iterations needed for convergence |
fimjax.main.ITERATION_SCHEME
¶
Bases: StrEnum
Iteration schemes that can be used for the solver function in FIM.
fimjax.main.Solver
¶
Main class to interact with the Eikonal Solver.
This class should be used as an OOP interface to the JAX code that is used to compute solutions and parametric derivatives for the eikonal equation. Further information on the different iteration schemes can be found in the corresponding functions.
get_solver_function
¶
get_solver_function(type: ITERATION_SCHEME = ITERATION_SCHEME.FOR, local_update_function: callable = None)
Provides the solution function as a jax function for composition.
Returns:
Name | Type | Description |
---|---|---|
callable |
solution function. |
solve
¶
solve(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iter: int, local_update_function: callable = _update_all_triangles) -> FIMSolution
Uses the FIM algorithm to solve an eikonal equation.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
metrics
|
np.ndarray
|
metrics tensor |
required |
initial_values
|
InitialValues
|
initial values |
required |
iter
|
int
|
number of iterations |
required |
local_update_function
|
optional
|
function to use for local updates. Defaults to _update_all_triangles. |
_update_all_triangles
|
Returns:
Name | Type | Description |
---|---|---|
FIMSolution |
FIMSolution
|
FIMSolution object |
value_and_vjp
¶
value_and_vjp(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iter: int, adjoint_vector: np.ndarray, local_update_function: callable = _update_all_triangles) -> np.ndarray
Calculates the value and the vector-jacobian product for the FIM.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
metrics
|
np.ndarray
|
metrics tensor |
required |
initial_values
|
InitialValues
|
initial values |
required |
iter
|
int
|
number of iterations |
required |
adjoint_vector
|
np.ndarray
|
adjoint vector |
required |
local_update_function
|
optional
|
function to use for local updates. Defaults to _update_all_triangles. |
_update_all_triangles
|
Returns:
Type | Description |
---|---|
np.ndarray
|
np.ndarray: adjoint vector |
fimjax.main._update_all_triangles
¶
Performs one Jacobi update.
Calculates the solution to all update direction and picks the smallest one for each point.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
D
|
np.ndarray
|
[N, d, d] array with metric tensor field |
required |
solution
|
np.ndarray
|
[M] array solution before iteration |
required |
Returns:
Type | Description |
---|---|
np.ndarray
|
[M] new solution after one iteration |
fimjax.main._compute_fim_for
¶
_compute_fim_for(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iters: int, local_update_function: callable) -> FIMSolution
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Uses Checkpointing at every iterations to provide a more memory efficient AD.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
iters
|
int
|
number of iterations |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution along with a flag whether FIM has converged. |
fimjax.main._compute_fim_checkpointed_while
¶
_compute_fim_checkpointed_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, checkpoints: int, local_update_function: callable) -> FIMSolution
Uses the Jacobi update until convergence is obtained.
Uses a checkpointed while loop from equinox to make jitting and AD possible, because XLA needs static bounds on memory, that are normally not possible with an unknown number of iterations. For more information on how this works, please see the documentation of Equinox' checkpointed_while_loop.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
checkpoints
|
int
|
number of checkpoints used in forward pass for AD |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution object |
fimjax.main._compute_fim_while
¶
_compute_fim_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> FIMSolution
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Due to XLA needing static memory bounds this functions is not jittable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution object |
fimjax.main._compute_fim_fixed_point
¶
_compute_fim_fixed_point(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> np.ndarray
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Due to XLA needing static memory bounds this functions is not jittable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
np.ndarray
|
solution object |