Source code for coronalyze.core.statistics

"""Statistical functions for masked arrays and small-sample corrections.

Implements JAX-native masked statistics and the Mawet et al. (2014)
small-sample penalty for high-contrast imaging SNR calculations.
"""

import jax.numpy as jnp


[docs] def masked_mean(values: jnp.ndarray, mask: jnp.ndarray) -> float: """Compute the mean of masked values. Args: values: 1D array of values. mask: Boolean mask (True for valid values). Returns: Mean of valid values, or 0 if no valid values. """ count = jnp.sum(mask) masked_sum = jnp.sum(values * mask) return masked_sum / jnp.maximum(count, 1.0)
[docs] def masked_std( values: jnp.ndarray, mask: jnp.ndarray, mean: float | None = None, ) -> float: """Compute the standard deviation of masked values. Uses Bessel's correction (N-1 denominator) for unbiased estimation. Args: values: 1D array of values. mask: Boolean mask (True for valid values). mean: Pre-computed mean. If None, computed from masked values. Returns: Standard deviation of valid values. """ if mean is None: mean = masked_mean(values, mask) count = jnp.sum(mask) residuals = (values - mean) * mask variance = jnp.sum(residuals**2) / jnp.maximum(count - 1, 1.0) return jnp.sqrt(variance)
[docs] def small_sample_penalty(n: int | jnp.ndarray) -> float: """Compute the Mawet et al. (2014) small-sample statistics correction. At small angular separations, fewer reference apertures are available, which inflates the noise estimate. This penalty factor accounts for the additional uncertainty. Reference: Mawet et al. (2014) ApJ Equation 9: sigma_corrected = sigma * sqrt(1 + 1/n) Args: n: Number of reference apertures. Returns: Correction factor sqrt(1 + 1/n). """ return jnp.sqrt(1 + 1 / jnp.maximum(n, 1.0))