coronalyze.core
===============

.. py:module:: coronalyze.core

.. autoapi-nested-parse::

   Core JAX-based analysis primitives for coronalyze.

   This module contains pure JAX mathematical functions with no external dependencies.
   All functions are JIT-compilable and differentiable.



Submodules
----------

.. toctree::
   :maxdepth: 1

   /autoapi/coronalyze/core/geometry/index
   /autoapi/coronalyze/core/matched_filter/index
   /autoapi/coronalyze/core/modeling/index
   /autoapi/coronalyze/core/pca/index
   /autoapi/coronalyze/core/photometry/index
   /autoapi/coronalyze/core/snr/index
   /autoapi/coronalyze/core/statistics/index


Classes
-------

.. autoapisummary::

   coronalyze.core.SNREstimator


Functions
---------

.. autoapisummary::

   coronalyze.core.calculate_n_apertures
   coronalyze.core.generate_aperture_coords
   coronalyze.core.get_center
   coronalyze.core.radial_distance
   coronalyze.core.inject_planet
   coronalyze.core.make_simple_disk
   coronalyze.core.subtract_disk
   coronalyze.core.subtract_star
   coronalyze.core.get_pca_basis
   coronalyze.core.pca_subtract
   coronalyze.core.aperture_photometry
   coronalyze.core.aperture_solid_angle
   coronalyze.core.circular_aperture_mask
   coronalyze.core.flux_map
   coronalyze.core.make_aperture_kernel
   coronalyze.core.soft_aperture_mask
   coronalyze.core.calculate_ccd_snr
   coronalyze.core.exposure_time_for_snr
   coronalyze.core.snr
   coronalyze.core.snr_estimator
   coronalyze.core.snr_map
   coronalyze.core.masked_mean
   coronalyze.core.masked_std
   coronalyze.core.small_sample_penalty


Package Contents
----------------

.. py:function:: calculate_n_apertures(radius, fwhm, exclusion_buffer = 0.5)

   Calculate the number of reference apertures at a given radius.

   Uses the Mawet et al. (2014) formula with exclusion buffer correction
   to ensure no overlap with the planet aperture on either side.

   This function provides the canonical calculation used by the SNR module,
   ensuring consistency between visualization and computation.

   Args:
       radius: Radial distance from center in pixels.
       fwhm: Full width at half maximum in pixels.
       exclusion_buffer: Gap between planet and first/last reference aperture
           in units of angular step (default 0.5). Creates a gap on both sides.

   Returns:
       Number of valid reference apertures.

   Example::

       from coronalyze.core.geometry import calculate_n_apertures
       n = calculate_n_apertures(radius=20, fwhm=5.0)
       print(f"{n} reference apertures at r=20px")


.. py:function:: generate_aperture_coords(center, radius, planet_angle, n_apertures, max_apertures = 200, fwhm = None, exclusion_buffer = 0.5)

   Generate coordinates for reference apertures at a given radius.

   Uses a fixed-size array with masking for JAX compatibility (static shapes).
   Apertures are distributed evenly around the annulus, excluding the planet position.

   Matches VIP's clockwise rotation and angle formula from planet position.

   Args:
       center: Image center (cy, cx) in pixels.
       radius: Radial distance from center in pixels.
       planet_angle: Angle of the planet position in radians.
       n_apertures: Actual number of valid apertures to use.
       max_apertures: Maximum buffer size for static array shape.
       fwhm: Full width half maximum for VIP-style angle calculation. If None,
             uses uniform distribution.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.0). Prevents PSF wing leakage.

   Returns:
       Tuple of (y_coords, x_coords, mask) where:
           - y_coords: Y coordinates of aperture centers (size max_apertures)
           - x_coords: X coordinates of aperture centers (size max_apertures)
           - mask: Boolean mask indicating valid apertures


.. py:function:: get_center(shape)

   Get the center coordinates of an image (0-indexed geometric center).

   For an image of size N, the geometric center is at (N-1)/2.
   This matches the convention used in modeling.py for planet injection.

   Args:
       shape: Image shape (ny, nx).

   Returns:
       Center coordinates (cy, cx).


.. py:function:: radial_distance(shape, center = None)

   Calculate radial distance from center for each pixel.

   Args:
       shape: Image shape (ny, nx).
       center: Center coordinates (cy, cx). If None, uses image center.

   Returns:
       2D array of radial distances in pixels.


.. py:function:: inject_planet(image, psf_template, flux, pos, order = 3)

   Inject a fake planet into an image using cubic spline shifts.

   The PSF template is shifted from the image center to the target position
   with sub-pixel precision, scaled by the flux, and added to the image.

   Args:
       image: 2D image to inject into, shape (ny, nx).
       psf_template: 2D PSF template centered at image center, same shape as image.
       flux: Flux scaling factor for the injected planet.
       pos: Target position (y, x) in pixels.
       order: Interpolation order for sub-pixel shifting (default: 3 = cubic).

   Returns:
       Image with injected planet, same shape as input.


.. py:function:: make_simple_disk(shape, radius, inclination_deg, width, flux = 1.0, pa_deg = 0.0)

   Generate a simple, optically thin Gaussian ring/disk.

   Analytically projects the disk to avoid interpolation artifacts.
   Flux is normalized so the total integrated flux equals the specified value.

   Args:
       shape: Output image shape (ny, nx).
       radius: Ring radius in pixels.
       inclination_deg: Disk inclination (0 = face-on, 90 = edge-on).
       width: Gaussian width (sigma) of the ring in pixels.
       flux: Total integrated flux of the disk. Default 1.0.
       pa_deg: Position angle of major axis, measured East of North (degrees).

   Returns:
       2D disk image with total flux normalized to the specified value.


.. py:function:: subtract_disk(residual, disk_model, scale = 1.0)

   Subtract disk model from residual image.

   Disk subtraction is typically a separate modeling task from stellar
   speckle subtraction. Call this after subtract_star when analyzing
   systems with circumstellar disks.

   Args:
       residual: Image after stellar subtraction (from subtract_star).
       disk_model: Disk model expectation (electrons).
       scale: Multiplicative scaling factor for the disk model.
           Adjust when disk model brightness doesn't match observation.

   Returns:
       Residual image with disk contribution removed.

   Example::

       # Two-step subtraction
       residual = subtract_star(observation, star_model)
       residual = subtract_disk(residual, disk_model)


.. py:function:: subtract_star(science, star_model, scale = 1.0)

   Subtract stellar PSF model from observation.

   This is the fundamental operation for "perfect" RDI when you have
   a noiseless stellar PSF expectation (e.g., from coronagraphoto).

   Args:
       science: Observed image (electrons).
       star_model: Noiseless stellar PSF expectation (electrons).
       scale: Multiplicative scaling factor for the model before subtraction.
           Use values != 1.0 when the reference brightness differs from
           the science image (e.g., different exposure times or stellar flux).

   Returns:
       Residual image containing noise + planet signal.

   Example::

       residual = subtract_star(observation, star_expectation)
       # With scaling for brightness mismatch:
       residual = subtract_star(observation, star_expectation, scale=0.95)


.. py:function:: get_pca_basis(ref_cube, n_modes)

   Compute PCA basis vectors from a reference cube (Snapshot Method).

   Uses eigendecomposition of the NxN covariance matrix instead of full SVD,
   which is O(N²xP) vs O(P²xN) for pixels P >> frames N.

   Args:
       ref_cube: Reference image cube of shape (N_frames, Height, Width).
       n_modes: Number of principal components (modes) to keep.

   Returns:
       basis: The top n_modes eigen-images, shape (n_modes, Height*Width).
       mean_ref: The mean of the reference cube, shape (Height, Width).


.. py:function:: pca_subtract(science_image, basis, mean_ref)

   Project science image onto PCA basis and subtract the model.

   Args:
       science_image: 2D science image, shape (Height, Width).
       basis: PCA basis from get_pca_basis, shape (n_modes, Height*Width).
       mean_ref: Mean reference from get_pca_basis, shape (Height, Width).

   Returns:
       Residual image with PSF model subtracted, shape (Height, Width).


.. py:function:: aperture_photometry(image, center, radius)

   Perform circular aperture photometry on an image.

   Args:
       image: 2D image array.
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.

   Returns:
       Total flux within the aperture.


.. py:function:: aperture_solid_angle(radius_pixels, pixel_scale_arcsec)

   Calculate the solid angle of a circular aperture.

   Args:
       radius_pixels: Aperture radius in pixels.
       pixel_scale_arcsec: Pixel scale in arcseconds per pixel.

   Returns:
       Solid angle in arcsec^2.


.. py:function:: circular_aperture_mask(shape, center, radius)

   Create a circular aperture mask.

   Args:
       shape: Image shape (ny, nx).
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.

   Returns:
       Boolean mask array with True inside the aperture.


.. py:function:: flux_map(image, kernel)

   Generate a flux map via convolution with an aperture kernel.

   Each pixel in the output represents the integrated flux that would be
   measured if an aperture were centered at that position.

   Args:
       image: 2D image array.
       kernel: Aperture kernel from make_aperture_kernel.

   Returns:
       2D flux map with same shape as input image.


.. py:function:: make_aperture_kernel(radius, soft = False, sharpness = 10.0)

   Create a circular aperture kernel for convolution-based photometry.

   This function is cached (up to 32 configurations) to avoid repeated
   kernel creation overhead when called with the same parameters.

   Args:
       radius: Aperture radius in pixels.
       soft: If True, use sigmoid-based soft edge for differentiability.
             If False, use hard binary mask.
       sharpness: Steepness of sigmoid transition (only used if soft=True).

   Returns:
       2D aperture kernel. Not normalized (sum of counts, not average).


.. py:function:: soft_aperture_mask(shape, center, radius, sharpness = 10.0)

   Create a soft (differentiable) circular aperture mask.

   Uses a sigmoid function to create a smooth transition at the aperture
   edge, enabling gradient-based optimization through the mask.

   Args:
       shape: Image shape (ny, nx).
       center: Center of aperture (y, x) in pixels.
       radius: Radius of aperture in pixels.
       sharpness: Steepness of the sigmoid transition.

   Returns:
       Soft mask array with values in [0, 1].


.. py:class:: SNREstimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Bases: :py:obj:`equinox.Module`


   Standard SNR estimator implementing Mawet et al. (2014).

   Pre-computes the aperture kernel to allow efficient JIT-compilation
   in iterative pipelines. This class is an Equinox module, meaning it
   can be passed into JIT-compiled functions as a PyTree.

   The SNR is calculated using small-sample statistics correction:
       SNR = (x_planet - x_bg_mean) / (sigma_bg * sqrt(1 + 1/n_bg))

   Example::

       # High-performance pipeline usage
       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def process_cube(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)

       snrs = process_cube(image_cube, planet_positions)

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


   .. py:attribute:: kernel
      :type:  jax.numpy.ndarray


   .. py:attribute:: fwhm
      :type:  float


   .. py:attribute:: exclusion_buffer
      :type:  float


   .. py:attribute:: max_apertures
      :type:  int


   .. py:attribute:: order
      :type:  int


   .. py:method:: __call__(image, positions, validity_map = None)

      Calculate SNR for a list of candidate positions.

      Args:
          image: 2D science image.
          positions: (N, 2) array of (y, x) coordinates.
          validity_map: Optional 2D mask (1=valid, 0=invalid). Used to exclude
              known companions, bad pixels, or edge regions. Off-chip apertures
              are automatically excluded via boundary handling.

      Returns:
          (N,) array of SNR values.



   .. py:method:: map(image, validity_map = None)

      Generate a full SNR detection map for the image.

      This is computationally expensive O(N²) but useful for
      generating detection maps.

      Args:
          image: 2D science image.
          validity_map: Optional 2D mask (1=valid, 0=invalid).

      Returns:
          2D array of SNR values matching image shape.



.. py:function:: calculate_ccd_snr(signal, background_noise, read_noise = 0.0, dark_current = 0.0)

   Calculate signal-to-noise ratio using the CCD equation.

   Uses the standard CCD equation for SNR:
       SNR = S / sqrt(S + B + R^2 + D)

   where:
       S = signal (electrons)
       B = background noise (electrons)
       R = read noise (electrons)
       D = dark current (electrons)

   This is distinct from Mawet SNR which uses spatial aperture statistics.

   Args:
       signal: Source signal in electrons.
       background_noise: Background noise in electrons (from sky, zodi, etc.).
       read_noise: Read noise in electrons (per pixel, summed over aperture).
       dark_current: Dark current in electrons (summed over aperture).

   Returns:
       Signal-to-noise ratio.


.. py:function:: exposure_time_for_snr(target_snr, signal_rate, background_rate, read_noise = 0.0, dark_rate = 0.0)

   Calculate exposure time needed to achieve a target SNR.

   Solves the CCD equation for exposure time:
       SNR = S*t / sqrt(S*t + B*t + R^2 + D*t)

   This is a quadratic equation in t.

   Args:
       target_snr: Desired signal-to-noise ratio.
       signal_rate: Source signal rate in electrons/second.
       background_rate: Background rate in electrons/second.
       read_noise: Read noise in electrons (constant, not per second).
       dark_rate: Dark current rate in electrons/second.

   Returns:
       Required exposure time in seconds.


.. py:function:: snr(image, positions, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5, validity_map = None)

   Calculate SNR at specific positions using Mawet et al. (2014).

   This is a convenience wrapper around snr_estimator() for simple use cases.
   For iterative pipelines, use snr_estimator() to avoid repeated kernel creation.

   Args:
       image: 2D science image.
       positions: (N, 2) array of (y, x) coordinates.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.
       validity_map: Optional 2D mask (1=valid, 0=invalid) to exclude
           known companions, bad pixels, or edge regions.

   Returns:
       (N,) array of SNR values.

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


.. py:function:: snr_estimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Create a JIT-ready SNR estimator with pre-computed kernel.

   This factory creates an SNREstimator instance that can be efficiently
   used in JIT-compiled pipelines. The aperture kernel is computed once
   at build time.

   Args:
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges for differentiability.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size for static shapes.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       SNREstimator instance ready for use.

   Example::

       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def pipeline(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)


.. py:function:: snr_map(image, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Generate a 2D map of SNR values using Mawet et al. (2014).

   Computes SNR at every pixel position. This is computationally expensive
   O(N²) but useful for generating detection maps.

   Args:
       image: 2D science image.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness.
       fast: Use bilinear interpolation for speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       2D array of SNR values, same shape as input image.


.. py:function:: masked_mean(values, mask)

   Compute the mean of masked values.

   Args:
       values: 1D array of values.
       mask: Boolean mask (True for valid values).

   Returns:
       Mean of valid values, or 0 if no valid values.


.. py:function:: masked_std(values, mask, mean = None)

   Compute the standard deviation of masked values.

   Uses Bessel's correction (N-1 denominator) for unbiased estimation.

   Args:
       values: 1D array of values.
       mask: Boolean mask (True for valid values).
       mean: Pre-computed mean. If None, computed from masked values.

   Returns:
       Standard deviation of valid values.


.. py:function:: small_sample_penalty(n)

   Compute the Mawet et al. (2014) small-sample statistics correction.

   At small angular separations, fewer reference apertures are available,
   which inflates the noise estimate. This penalty factor accounts for
   the additional uncertainty.

   Reference: Mawet et al. (2014) ApJ
              Equation 9: sigma_corrected = sigma * sqrt(1 + 1/n)

   Args:
       n: Number of reference apertures.

   Returns:
       Correction factor sqrt(1 + 1/n).


