Iteration Schemes¶
iterations
¶
Implementation of several different iteration techniques.
We provide different iteration techniques, because they are handled differently when it comes to algorithmic differentiation or JAX Just-In-Time compilation.
fimjax.iterations.Mesh
¶
dataclass for holding everything that belongs to the mesh.
Note that this only supports keyword arguments due to Chex.
Use it like:
mesh = Mesh(points=points, elements=elements)
N: Number of elements in a mesh (i.e. triangles) d_e: Number of points in an element M: Number of points in a mesh d: Dimension of the underlying space
Attributes:
Name | Type | Description |
---|---|---|
points |
np.ndarray
|
[M, d] array of points |
elements |
np.ndarray
|
[N, d_e] array of indices into points that correspond to elements in a mesh |
points_triangle |
np.ndarray
|
[N, d_e, d] like elements, but with the points instead of indices |
__init__
¶
Constructor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
points
|
np.ndarray
|
[M, d] array of points |
required |
elements
|
np.ndarray
|
[N] array of |
required |
fimjax.iterations.InitialValues
¶
Initial values for an Eikonal PDE.
Attributes:
Name | Type | Description |
---|---|---|
locations |
np.ndarray
|
[X] array of indices into a mesh that point to the initial value locations |
values |
np.ndarray
|
[X] array of values of the initial values |
fimjax.iterations.FIMSolution
¶
Holds the information on a FIMSolution.
Attributes:
Name | Type | Description |
---|---|---|
solution |
np.ndarray
|
solution |
iterations |
int
|
number of iterations |
has_converged |
bool
|
flag if the solution has converged |
has_converged_after |
int
|
number of iterations needed for convergence |
fimjax.iterations._compute_fim_for
¶
_compute_fim_for(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, iters: int, local_update_function: callable) -> FIMSolution
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Uses Checkpointing at every iterations to provide a more memory efficient AD.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
iters
|
int
|
number of iterations |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution along with a flag whether FIM has converged. |
fimjax.iterations._compute_fim_while
¶
_compute_fim_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> FIMSolution
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Due to XLA needing static memory bounds this functions is not jittable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution object |
fimjax.iterations._compute_fim_checkpointed_while
¶
_compute_fim_checkpointed_while(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, checkpoints: int, local_update_function: callable) -> FIMSolution
Uses the Jacobi update until convergence is obtained.
Uses a checkpointed while loop from equinox to make jitting and AD possible, because XLA needs static bounds on memory, that are normally not possible with an unknown number of iterations. For more information on how this works, please see the documentation of Equinox' checkpointed_while_loop.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
checkpoints
|
int
|
number of checkpoints used in forward pass for AD |
required |
Returns:
Type | Description |
---|---|
FIMSolution
|
solution object |
fimjax.iterations._compute_fim_fixed_point
¶
_compute_fim_fixed_point(mesh: Mesh, initial_values: InitialValues, metrics: np.ndarray, local_update_function: callable) -> np.ndarray
Uses the Jacobi update with a fixed number of iterations to compute the FIM.
Due to XLA needing static memory bounds this functions is not jittable.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
mesh
|
Mesh
|
Mesh |
required |
initial_value_locations
|
position of the initial values in the mesh |
required | |
initial_values
|
InitialValues
|
initial values |
required |
metrics
|
np.ndarray
|
[N, d, d] array corresponding to the metric tensor field |
required |
local_update_function
|
callable
|
function that calculates the local updates, see _update_all_triangles for more information |
required |
Returns:
Type | Description |
---|---|
np.ndarray
|
solution object |