Source code for coronalyze.core.pca
"""Principal Component Analysis (PCA/KLIP) for PSF subtraction.
Implements the efficient 'Snapshot Method' for SVD on large image cubes,
particularly suited for coronagraphy (where pixels >> frames).
All functions are JIT-compilable and fully differentiable.
"""
import functools
import jax
import jax.numpy as jnp
[docs]
@functools.partial(jax.jit, static_argnames=["n_modes"])
def get_pca_basis(
ref_cube: jnp.ndarray,
n_modes: int,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Compute PCA basis vectors from a reference cube (Snapshot Method).
Uses eigendecomposition of the NxN covariance matrix instead of full SVD,
which is O(N²xP) vs O(P²xN) for pixels P >> frames N.
Args:
ref_cube: Reference image cube of shape (N_frames, Height, Width).
n_modes: Number of principal components (modes) to keep.
Returns:
basis: The top n_modes eigen-images, shape (n_modes, Height*Width).
mean_ref: The mean of the reference cube, shape (Height, Width).
"""
n_frames, ny, nx = ref_cube.shape
flat_refs = ref_cube.reshape(n_frames, -1)
# Center the data
mean_ref = jnp.mean(flat_refs, axis=0)
centered = flat_refs - mean_ref
# Covariance Matrix (N_frames x N_frames)
cov = jnp.dot(centered, centered.T)
# Eigen Decomposition (Symmetric/Hermitian)
vals, vecs = jnp.linalg.eigh(cov)
# Sort and Truncate (eigh returns ascending order)
vals = vals[: -n_modes - 1 : -1]
vecs = vecs[:, : -n_modes - 1 : -1]
# Project back to Image Space (KL transform)
# Basis = (Vectors^T . Centered_Data) / sqrt(Eigenvalues)
normalization = 1.0 / jnp.sqrt(jnp.maximum(vals, 1e-12))
basis = jnp.dot(vecs.T, centered) * normalization[:, None]
return basis, mean_ref.reshape(ny, nx)
[docs]
@jax.jit
def pca_subtract(
science_image: jnp.ndarray,
basis: jnp.ndarray,
mean_ref: jnp.ndarray,
) -> jnp.ndarray:
"""Project science image onto PCA basis and subtract the model.
Args:
science_image: 2D science image, shape (Height, Width).
basis: PCA basis from get_pca_basis, shape (n_modes, Height*Width).
mean_ref: Mean reference from get_pca_basis, shape (Height, Width).
Returns:
Residual image with PSF model subtracted, shape (Height, Width).
"""
ny, nx = science_image.shape
flat_sci = science_image.reshape(-1)
flat_mean = mean_ref.reshape(-1)
# Center target
centered_sci = flat_sci - flat_mean
# Coefficients = Basis . Image
coeffs = jnp.dot(basis, centered_sci)
# Model = Coefficients . Basis
model = jnp.dot(coeffs, basis)
# Residual
residual = centered_sci - model
return residual.reshape(ny, nx)