Skip to content

Iteration Schemes

iterations

Implementation of several different iteration techniques.

We provide different iteration techniques, because they are handled differently when it comes to algorithmic differentiation or JAX Just-In-Time compilation.

fimjax.iterations.UNDEFINED_VALUE module-attribute

UNDEFINED_VALUE = 10000000000.0

fimjax.iterations.PHI_ITER_COUNT module-attribute

PHI_ITER_COUNT = 1

fimjax.iterations.Mesh

dataclass for holding everything that belongs to the mesh. Note that this only supports keyword arguments due to Chex. Use it like: mesh = Mesh(points=points, elements=elements)

N: Number of elements in a mesh (i.e. triangles) d_e: Number of points in an element M: Number of points in a mesh d: Dimension of the underlying space

Attributes:

Name Type Description
points np.ndarray

[M, d] array of points

elements np.ndarray

[N, d_e] array of indices into points that correspond to elements in a mesh

points_triangle np.ndarray

[N, d_e, d] like elements, but with the points instead of indices

__init__

__init__(points: np.ndarray, elements: np.ndarray)

Constructor.

Parameters:

Name Type Description Default
points np.ndarray

[M, d] array of points

required
elements np.ndarray

[N] array of

required

fimjax.iterations.InitialValues

Initial values for an Eikonal PDE.

Attributes:

Name Type Description
locations np.ndarray

[X] array of indices into a mesh that point to the initial value locations

values np.ndarray

[X] array of values of the initial values

fimjax.iterations.FIMSolution

Holds the information on a FIMSolution.

Attributes:

Name Type Description
solution np.ndarray

solution

iterations int

number of iterations

has_converged bool

flag if the solution has converged

has_converged_after int

number of iterations needed for convergence

fimjax.iterations._compute_fim_for

_compute_fim_for(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iters: int, local_update_function: callable) -> FIMSolution

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Uses Checkpointing at every iterations to provide a more memory efficient AD.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
iters int

number of iterations

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
FIMSolution

solution along with a flag whether FIM has converged.

fimjax.iterations._compute_fim_while

_compute_fim_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> FIMSolution

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Due to XLA needing static memory bounds this functions is not jittable.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
FIMSolution

solution object

fimjax.iterations._compute_fim_checkpointed_while

_compute_fim_checkpointed_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, checkpoints: int, local_update_function: callable) -> FIMSolution

Uses the Jacobi update until convergence is obtained.

Uses a checkpointed while loop from equinox to make jitting and AD possible, because XLA needs static bounds on memory, that are normally not possible with an unknown number of iterations. For more information on how this works, please see the documentation of Equinox' checkpointed_while_loop.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required
checkpoints int

number of checkpoints used in forward pass for AD

required

Returns:

Type Description
FIMSolution

solution object

fimjax.iterations.fixed_point

fixed_point(f, D, phi_guess)

fimjax.iterations.fixed_point_fwd

fixed_point_fwd(f, D, phi_init)

fimjax.iterations.fixed_point_rev

fixed_point_rev(f, res, phi_star_bar)

fimjax.iterations._compute_fim_fixed_point

_compute_fim_fixed_point(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> np.ndarray

Uses the Jacobi update with a fixed number of iterations to compute the FIM.

Due to XLA needing static memory bounds this functions is not jittable.

Parameters:

Name Type Description Default
mesh Mesh

Mesh

required
initial_value_locations

position of the initial values in the mesh

required
initial_values InitialValues

initial values

required
metrics np.ndarray

[N, d, d] array corresponding to the metric tensor field

required
local_update_function callable

function that calculates the local updates, see _update_all_triangles for more information

required

Returns:

Type Description
np.ndarray

solution object