coronalyze
==========

.. py:module:: coronalyze

.. autoapi-nested-parse::

   coronalyze: JAX-based post-processing for coronagraphic direct imaging.

   This library provides analysis tools for coronagraphic observations,
   designed as a companion to coronagraphoto.

   Primary SNR API (Mawet et al. 2014):
       - snr(): Calculate SNR at positions
       - snr_map(): Generate 2D SNR detection map
       - snr_estimator(): Factory for JIT-ready SNREstimator objects

   For experimental matched-filter SNR, see coronalyze.core.matched_filter.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/coronalyze/analysis/index
   /autoapi/coronalyze/core/index
   /autoapi/coronalyze/datasets/index
   /autoapi/coronalyze/interfaces/index
   /autoapi/coronalyze/pipelines/index


Classes
-------

.. autoapisummary::

   coronalyze.SNREstimator


Functions
---------

.. autoapisummary::

   coronalyze.get_perfect_residuals
   coronalyze.get_photon_noise_map
   coronalyze.simulate_observation
   coronalyze.aperture_photometry
   coronalyze.aperture_solid_angle
   coronalyze.calculate_n_apertures
   coronalyze.circular_aperture_mask
   coronalyze.flux_map
   coronalyze.generate_aperture_coords
   coronalyze.get_center
   coronalyze.get_pca_basis
   coronalyze.inject_planet
   coronalyze.make_aperture_kernel
   coronalyze.make_simple_disk
   coronalyze.masked_mean
   coronalyze.masked_std
   coronalyze.pca_subtract
   coronalyze.radial_distance
   coronalyze.small_sample_penalty
   coronalyze.soft_aperture_mask
   coronalyze.subtract_disk
   coronalyze.subtract_star
   coronalyze.calculate_ccd_snr
   coronalyze.exposure_time_for_snr
   coronalyze.snr
   coronalyze.snr_estimator
   coronalyze.snr_map
   coronalyze.calculate_yield_snr
   coronalyze.klip_subtract


Package Contents
----------------

.. py:function:: get_perfect_residuals(observation, expectation_model)

   Calculate residuals assuming perfect subtraction of static structure.

   This simulates 'Faux-RDI': we subtract the exact expectation value of
   the star and disk. The residual contains only the fundamental noise
   (photon + read noise) and any signals not in the model (planets).

   Args:
       observation: The noisy data (Data = Poisson(Model) + ReadNoise).
       expectation_model: The noiseless expectation of background sources
           (stellar PSF + disk, excluding planets).

   Returns:
       Residual image containing noise + unmodeled signals (planets).


.. py:function:: get_photon_noise_map(expectation_rate, exposure_time, read_noise = 0.0)

   Calculate the theoretical 1-sigma noise map in rate units.

   Properly converts between rate and count units to combine photon noise
   with read noise. Returns noise in the same units as input (counts/sec).

   Formula: Sigma_rate = sqrt(Rate * t + RN^2) / t

   Args:
       expectation_rate: Expected count rate image (counts/sec).
       exposure_time: Integration time in seconds.
       read_noise: Read noise in electrons (per pixel). Default 0.

   Returns:
       1-sigma noise map in rate units (counts/sec), same shape as input.


.. py:function:: simulate_observation(clean_signal, background_model, exposure_time = 1.0, rng_key = None)

   Generate a noisy realization of the scene for yield tests.

   Applies Poisson noise to the combined scene (signal + background).
   Returns data in rate units (counts/sec).

   Args:
       clean_signal: Planet image (counts/sec).
       background_model: Star + Disk expectation (counts/sec).
       exposure_time: Integration time (seconds).
       rng_key: JAX PRNG Key. If None, uses key 0.

   Returns:
       Noisy image in units of counts/sec (Poisson noise added in counts,
       then divided back by exposure time).


.. py:function:: aperture_photometry(image, center, radius)

   Perform circular aperture photometry on an image.

   Args:
       image: 2D image array.
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.

   Returns:
       Total flux within the aperture.


.. py:function:: aperture_solid_angle(radius_pixels, pixel_scale_arcsec)

   Calculate the solid angle of a circular aperture.

   Args:
       radius_pixels: Aperture radius in pixels.
       pixel_scale_arcsec: Pixel scale in arcseconds per pixel.

   Returns:
       Solid angle in arcsec^2.


.. py:function:: calculate_n_apertures(radius, fwhm, exclusion_buffer = 0.5)

   Calculate the number of reference apertures at a given radius.

   Uses the Mawet et al. (2014) formula with exclusion buffer correction
   to ensure no overlap with the planet aperture on either side.

   This function provides the canonical calculation used by the SNR module,
   ensuring consistency between visualization and computation.

   Args:
       radius: Radial distance from center in pixels.
       fwhm: Full width at half maximum in pixels.
       exclusion_buffer: Gap between planet and first/last reference aperture
           in units of angular step (default 0.5). Creates a gap on both sides.

   Returns:
       Number of valid reference apertures.

   Example::

       from coronalyze.core.geometry import calculate_n_apertures
       n = calculate_n_apertures(radius=20, fwhm=5.0)
       print(f"{n} reference apertures at r=20px")


.. py:function:: circular_aperture_mask(shape, center, radius)

   Create a circular aperture mask.

   Args:
       shape: Image shape (ny, nx).
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.

   Returns:
       Boolean mask array with True inside the aperture.


.. py:function:: flux_map(image, kernel)

   Generate a flux map via convolution with an aperture kernel.

   Each pixel in the output represents the integrated flux that would be
   measured if an aperture were centered at that position.

   Args:
       image: 2D image array.
       kernel: Aperture kernel from make_aperture_kernel.

   Returns:
       2D flux map with same shape as input image.


.. py:function:: generate_aperture_coords(center, radius, planet_angle, n_apertures, max_apertures = 200, fwhm = None, exclusion_buffer = 0.5)

   Generate coordinates for reference apertures at a given radius.

   Uses a fixed-size array with masking for JAX compatibility (static shapes).
   Apertures are distributed evenly around the annulus, excluding the planet position.

   Matches VIP's clockwise rotation and angle formula from planet position.

   Args:
       center: Image center (cy, cx) in pixels.
       radius: Radial distance from center in pixels.
       planet_angle: Angle of the planet position in radians.
       n_apertures: Actual number of valid apertures to use.
       max_apertures: Maximum buffer size for static array shape.
       fwhm: Full width half maximum for VIP-style angle calculation. If None,
             uses uniform distribution.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.0). Prevents PSF wing leakage.

   Returns:
       Tuple of (y_coords, x_coords, mask) where:
           - y_coords: Y coordinates of aperture centers (size max_apertures)
           - x_coords: X coordinates of aperture centers (size max_apertures)
           - mask: Boolean mask indicating valid apertures


.. py:function:: get_center(shape)

   Get the center coordinates of an image (0-indexed geometric center).

   For an image of size N, the geometric center is at (N-1)/2.
   This matches the convention used in modeling.py for planet injection.

   Args:
       shape: Image shape (ny, nx).

   Returns:
       Center coordinates (cy, cx).


.. py:function:: get_pca_basis(ref_cube, n_modes)

   Compute PCA basis vectors from a reference cube (Snapshot Method).

   Uses eigendecomposition of the NxN covariance matrix instead of full SVD,
   which is O(N²xP) vs O(P²xN) for pixels P >> frames N.

   Args:
       ref_cube: Reference image cube of shape (N_frames, Height, Width).
       n_modes: Number of principal components (modes) to keep.

   Returns:
       basis: The top n_modes eigen-images, shape (n_modes, Height*Width).
       mean_ref: The mean of the reference cube, shape (Height, Width).


.. py:function:: inject_planet(image, psf_template, flux, pos, order = 3)

   Inject a fake planet into an image using cubic spline shifts.

   The PSF template is shifted from the image center to the target position
   with sub-pixel precision, scaled by the flux, and added to the image.

   Args:
       image: 2D image to inject into, shape (ny, nx).
       psf_template: 2D PSF template centered at image center, same shape as image.
       flux: Flux scaling factor for the injected planet.
       pos: Target position (y, x) in pixels.
       order: Interpolation order for sub-pixel shifting (default: 3 = cubic).

   Returns:
       Image with injected planet, same shape as input.


.. py:function:: make_aperture_kernel(radius, soft = False, sharpness = 10.0)

   Create a circular aperture kernel for convolution-based photometry.

   This function is cached (up to 32 configurations) to avoid repeated
   kernel creation overhead when called with the same parameters.

   Args:
       radius: Aperture radius in pixels.
       soft: If True, use sigmoid-based soft edge for differentiability.
             If False, use hard binary mask.
       sharpness: Steepness of sigmoid transition (only used if soft=True).

   Returns:
       2D aperture kernel. Not normalized (sum of counts, not average).


.. py:function:: make_simple_disk(shape, radius, inclination_deg, width, flux = 1.0, pa_deg = 0.0)

   Generate a simple, optically thin Gaussian ring/disk.

   Analytically projects the disk to avoid interpolation artifacts.
   Flux is normalized so the total integrated flux equals the specified value.

   Args:
       shape: Output image shape (ny, nx).
       radius: Ring radius in pixels.
       inclination_deg: Disk inclination (0 = face-on, 90 = edge-on).
       width: Gaussian width (sigma) of the ring in pixels.
       flux: Total integrated flux of the disk. Default 1.0.
       pa_deg: Position angle of major axis, measured East of North (degrees).

   Returns:
       2D disk image with total flux normalized to the specified value.


.. py:function:: masked_mean(values, mask)

   Compute the mean of masked values.

   Args:
       values: 1D array of values.
       mask: Boolean mask (True for valid values).

   Returns:
       Mean of valid values, or 0 if no valid values.


.. py:function:: masked_std(values, mask, mean = None)

   Compute the standard deviation of masked values.

   Uses Bessel's correction (N-1 denominator) for unbiased estimation.

   Args:
       values: 1D array of values.
       mask: Boolean mask (True for valid values).
       mean: Pre-computed mean. If None, computed from masked values.

   Returns:
       Standard deviation of valid values.


.. py:function:: pca_subtract(science_image, basis, mean_ref)

   Project science image onto PCA basis and subtract the model.

   Args:
       science_image: 2D science image, shape (Height, Width).
       basis: PCA basis from get_pca_basis, shape (n_modes, Height*Width).
       mean_ref: Mean reference from get_pca_basis, shape (Height, Width).

   Returns:
       Residual image with PSF model subtracted, shape (Height, Width).


.. py:function:: radial_distance(shape, center = None)

   Calculate radial distance from center for each pixel.

   Args:
       shape: Image shape (ny, nx).
       center: Center coordinates (cy, cx). If None, uses image center.

   Returns:
       2D array of radial distances in pixels.


.. py:function:: small_sample_penalty(n)

   Compute the Mawet et al. (2014) small-sample statistics correction.

   At small angular separations, fewer reference apertures are available,
   which inflates the noise estimate. This penalty factor accounts for
   the additional uncertainty.

   Reference: Mawet et al. (2014) ApJ
              Equation 9: sigma_corrected = sigma * sqrt(1 + 1/n)

   Args:
       n: Number of reference apertures.

   Returns:
       Correction factor sqrt(1 + 1/n).


.. py:function:: soft_aperture_mask(shape, center, radius, sharpness = 10.0)

   Create a soft (differentiable) circular aperture mask.

   Uses a sigmoid function to create a smooth transition at the aperture
   edge, enabling gradient-based optimization through the mask.

   Args:
       shape: Image shape (ny, nx).
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.
       sharpness: Steepness of the sigmoid transition.

   Returns:
       Soft mask array with values in [0, 1].


.. py:function:: subtract_disk(residual, disk_model, scale = 1.0)

   Subtract disk model from residual image.

   Disk subtraction is typically a separate modeling task from stellar
   speckle subtraction. Call this after subtract_star when analyzing
   systems with circumstellar disks.

   Args:
       residual: Image after stellar subtraction (from subtract_star).
       disk_model: Disk model expectation (electrons).
       scale: Multiplicative scaling factor for the disk model.
           Adjust when disk model brightness doesn't match observation.

   Returns:
       Residual image with disk contribution removed.

   Example::

       # Two-step subtraction
       residual = subtract_star(observation, star_model)
       residual = subtract_disk(residual, disk_model)


.. py:function:: subtract_star(science, star_model, scale = 1.0)

   Subtract stellar PSF model from observation.

   This is the fundamental operation for "perfect" RDI when you have
   a noiseless stellar PSF expectation (e.g., from coronagraphoto).

   Args:
       science: Observed image (electrons).
       star_model: Noiseless stellar PSF expectation (electrons).
       scale: Multiplicative scaling factor for the model before subtraction.
           Use values != 1.0 when the reference brightness differs from
           the science image (e.g., different exposure times or stellar flux).

   Returns:
       Residual image containing noise + planet signal.

   Example::

       residual = subtract_star(observation, star_expectation)
       # With scaling for brightness mismatch:
       residual = subtract_star(observation, star_expectation, scale=0.95)


.. py:class:: SNREstimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Bases: :py:obj:`equinox.Module`


   Standard SNR estimator implementing Mawet et al. (2014).

   Pre-computes the aperture kernel to allow efficient JIT-compilation
   in iterative pipelines. This class is an Equinox module, meaning it
   can be passed into JIT-compiled functions as a PyTree.

   The SNR is calculated using small-sample statistics correction:
       SNR = (x_planet - x_bg_mean) / (sigma_bg * sqrt(1 + 1/n_bg))

   Example::

       # High-performance pipeline usage
       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def process_cube(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)

       snrs = process_cube(image_cube, planet_positions)

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


   .. py:attribute:: kernel
      :type:  jax.numpy.ndarray


   .. py:attribute:: fwhm
      :type:  float


   .. py:attribute:: exclusion_buffer
      :type:  float


   .. py:attribute:: max_apertures
      :type:  int


   .. py:attribute:: order
      :type:  int


   .. py:method:: __call__(image, positions, validity_map = None)

      Calculate SNR for a list of candidate positions.

      Args:
          image: 2D science image.
          positions: (N, 2) array of (y, x) coordinates.
          validity_map: Optional 2D mask (1=valid, 0=invalid). Used to exclude
              known companions, bad pixels, or edge regions. Off-chip apertures
              are automatically excluded via boundary handling.

      Returns:
          (N,) array of SNR values.



   .. py:method:: map(image, validity_map = None)

      Generate a full SNR detection map for the image.

      This is computationally expensive O(N²) but useful for
      generating detection maps.

      Args:
          image: 2D science image.
          validity_map: Optional 2D mask (1=valid, 0=invalid).

      Returns:
          2D array of SNR values matching image shape.



.. py:function:: calculate_ccd_snr(signal, background_noise, read_noise = 0.0, dark_current = 0.0)

   Calculate signal-to-noise ratio using the CCD equation.

   Uses the standard CCD equation for SNR:
       SNR = S / sqrt(S + B + R^2 + D)

   where:
       S = signal (electrons)
       B = background noise (electrons)
       R = read noise (electrons)
       D = dark current (electrons)

   This is distinct from Mawet SNR which uses spatial aperture statistics.

   Args:
       signal: Source signal in electrons.
       background_noise: Background noise in electrons (from sky, zodi, etc.).
       read_noise: Read noise in electrons (per pixel, summed over aperture).
       dark_current: Dark current in electrons (summed over aperture).

   Returns:
       Signal-to-noise ratio.


.. py:function:: exposure_time_for_snr(target_snr, signal_rate, background_rate, read_noise = 0.0, dark_rate = 0.0)

   Calculate exposure time needed to achieve a target SNR.

   Solves the CCD equation for exposure time:
       SNR = S*t / sqrt(S*t + B*t + R^2 + D*t)

   This is a quadratic equation in t.

   Args:
       target_snr: Desired signal-to-noise ratio.
       signal_rate: Source signal rate in electrons/second.
       background_rate: Background rate in electrons/second.
       read_noise: Read noise in electrons (constant, not per second).
       dark_rate: Dark current rate in electrons/second.

   Returns:
       Required exposure time in seconds.


.. py:function:: snr(image, positions, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5, validity_map = None)

   Calculate SNR at specific positions using Mawet et al. (2014).

   This is a convenience wrapper around snr_estimator() for simple use cases.
   For iterative pipelines, use snr_estimator() to avoid repeated kernel creation.

   Args:
       image: 2D science image.
       positions: (N, 2) array of (y, x) coordinates.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.
       validity_map: Optional 2D mask (1=valid, 0=invalid) to exclude
           known companions, bad pixels, or edge regions.

   Returns:
       (N,) array of SNR values.

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


.. py:function:: snr_estimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Create a JIT-ready SNR estimator with pre-computed kernel.

   This factory creates an SNREstimator instance that can be efficiently
   used in JIT-compiled pipelines. The aperture kernel is computed once
   at build time.

   Args:
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges for differentiability.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size for static shapes.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       SNREstimator instance ready for use.

   Example::

       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def pipeline(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)


.. py:function:: snr_map(image, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Generate a 2D map of SNR values using Mawet et al. (2014).

   Computes SNR at every pixel position. This is computationally expensive
   O(N²) but useful for generating detection maps.

   Args:
       image: 2D science image.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness.
       fast: Use bilinear interpolation for speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       2D array of SNR values, same shape as input image.


.. py:function:: calculate_yield_snr(science, planet_positions, fwhm, star_model = None, disk_model = None, reference_cube = None, n_modes = 5, method = 'star', star_scale = 1.0, disk_scale = 1.0, exclusion_buffer = 0.5, validity_map = None)

   End-to-end yield SNR calculation.

   Convenience function that performs subtraction and SNR calculation
   in one call. Selects the appropriate subtraction method based on
   the `method` argument.

   Args:
       science: Observed image (electrons).
       planet_positions: Planet positions as (N, 2) array of (y, x) coords.
       fwhm: PSF FWHM in pixels.
       star_model: Noiseless star expectation (required for 'star' method).
       disk_model: Optional noiseless disk expectation.
       reference_cube: Reference library for KLIP (required for 'klip'/'rdi').
       n_modes: Number of PCA modes (for klip method only).
       method: Subtraction method - "star", "rdi", or "klip".
       star_scale: Scaling factor for star model (default 1.0).
       disk_scale: Scaling factor for disk model (default 1.0).
       exclusion_buffer: Gap between test and reference apertures in units
           of angular step (default 0.5). Prevents PSF wing leakage.
       validity_map: Optional 2D mask (1=valid, 0=invalid) to exclude
           known companions, bad pixels, or edge regions.

   Returns:
       SNR values for each planet position.

   Example::

       # Fast yield calculation with static PSF
       snrs = calculate_yield_snr(
           science_image,
           planet_positions,
           fwhm=4.5,
           star_model=star_expectation,
           disk_model=disk_expectation,
           method="star"
       )


.. py:function:: klip_subtract(science, reference_cube, n_modes = 5)

   KLIP/PCA PSF subtraction.

   Uses PCA to build a stellar PSF model from a reference cube
   (e.g., images at different roll angles) and subtracts it.

   This is the most physically realistic stellar subtraction mode
   but also the slowest.

   Args:
       science: Science observation (electrons).
       reference_cube: Reference library of shape (n_frames, ny, nx).
           Typically star-only images at different roll angles.
       n_modes: Number of PCA modes to use for subtraction.

   Returns:
       Residual image after KLIP subtraction.

   Example::

       # ADI with KLIP
       residual = klip_subtract(science, roll_cube, n_modes=10)
       # If you also need disk subtraction, do it after:
       residual = subtract_disk(residual, disk_model)


