coronalyze.core.snr
===================

.. py:module:: coronalyze.core.snr

.. autoapi-nested-parse::

   SNR calculations for high-contrast coronagraphic imaging.

   This module provides SNR estimation using the Mawet et al. (2014) method,
   which is the standard approach for exoplanet detection in coronagraphic images.

   Primary API:
       - snr(): Calculate SNR at positions
       - snr_map(): Generate 2D SNR detection map
       - snr_estimator(): Factory for JIT-ready SNREstimator objects

   Classes:
       - SNREstimator: Equinox module for efficient batch SNR computation

   For experimental matched-filter SNR, see coronalyze.core.matched_filter.

   Reference: Mawet et al. (2014), ApJ, 792, 97



Classes
-------

.. autoapisummary::

   coronalyze.core.snr.SNREstimator


Functions
---------

.. autoapisummary::

   coronalyze.core.snr.snr_estimator
   coronalyze.core.snr.snr
   coronalyze.core.snr.snr_map
   coronalyze.core.snr._snr_batch_core
   coronalyze.core.snr.calculate_ccd_snr
   coronalyze.core.snr.exposure_time_for_snr


Module Contents
---------------

.. py:class:: SNREstimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Bases: :py:obj:`equinox.Module`


   Standard SNR estimator implementing Mawet et al. (2014).

   Pre-computes the aperture kernel to allow efficient JIT-compilation
   in iterative pipelines. This class is an Equinox module, meaning it
   can be passed into JIT-compiled functions as a PyTree.

   The SNR is calculated using small-sample statistics correction:
       SNR = (x_planet - x_bg_mean) / (sigma_bg * sqrt(1 + 1/n_bg))

   Example::

       # High-performance pipeline usage
       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def process_cube(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)

       snrs = process_cube(image_cube, planet_positions)

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


   .. py:attribute:: kernel
      :type:  jax.numpy.ndarray


   .. py:attribute:: fwhm
      :type:  float


   .. py:attribute:: exclusion_buffer
      :type:  float


   .. py:attribute:: max_apertures
      :type:  int


   .. py:attribute:: order
      :type:  int


   .. py:method:: __call__(image, positions, validity_map = None)

      Calculate SNR for a list of candidate positions.

      Args:
          image: 2D science image.
          positions: (N, 2) array of (y, x) coordinates.
          validity_map: Optional 2D mask (1=valid, 0=invalid). Used to exclude
              known companions, bad pixels, or edge regions. Off-chip apertures
              are automatically excluded via boundary handling.

      Returns:
          (N,) array of SNR values.



   .. py:method:: map(image, validity_map = None)

      Generate a full SNR detection map for the image.

      This is computationally expensive O(N²) but useful for
      generating detection maps.

      Args:
          image: 2D science image.
          validity_map: Optional 2D mask (1=valid, 0=invalid).

      Returns:
          2D array of SNR values matching image shape.



.. py:function:: snr_estimator(fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Create a JIT-ready SNR estimator with pre-computed kernel.

   This factory creates an SNREstimator instance that can be efficiently
   used in JIT-compiled pipelines. The aperture kernel is computed once
   at build time.

   Args:
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges for differentiability.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size for static shapes.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       SNREstimator instance ready for use.

   Example::

       estimator = snr_estimator(fwhm=4.0, fast=True)

       @jax.jit
       def pipeline(images, positions):
           return jax.vmap(lambda img: estimator(img, positions))(images)


.. py:function:: snr(image, positions, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5, validity_map = None)

   Calculate SNR at specific positions using Mawet et al. (2014).

   This is a convenience wrapper around snr_estimator() for simple use cases.
   For iterative pipelines, use snr_estimator() to avoid repeated kernel creation.

   Args:
       image: 2D science image.
       positions: (N, 2) array of (y, x) coordinates.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness for soft apertures.
       fast: Use bilinear interpolation for ~3x speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.
       validity_map: Optional 2D mask (1=valid, 0=invalid) to exclude
           known companions, bad pixels, or edge regions.

   Returns:
       (N,) array of SNR values.

   Reference:
       Mawet et al. (2014), ApJ, 792, 97


.. py:function:: snr_map(image, fwhm, soft = True, sharpness = 10.0, fast = False, max_apertures = 200, exclusion_buffer = 0.5)

   Generate a 2D map of SNR values using Mawet et al. (2014).

   Computes SNR at every pixel position. This is computationally expensive
   O(N²) but useful for generating detection maps.

   Args:
       image: 2D science image.
       fwhm: Full width at half maximum in pixels.
       soft: Use soft aperture edges.
       sharpness: Sigmoid sharpness.
       fast: Use bilinear interpolation for speedup.
       max_apertures: Maximum buffer size.
       exclusion_buffer: Gap between test and first reference aperture in
           units of angular step (default 0.5). Prevents PSF wing leakage.

   Returns:
       2D array of SNR values, same shape as input image.


.. py:function:: _snr_batch_core(image, positions, kernel, fwhm, max_apertures, order, exclusion_buffer = 0.5, validity_map = None)

   JIT-compiled batch SNR calculation (Mawet method).

   Args:
       image: 2D science image.
       positions: (N, 2) array of (y, x) coordinates.
       kernel: Pre-computed aperture kernel.
       fwhm: Full width at half maximum in pixels.
       max_apertures: Maximum buffer size for static shapes.
       order: Interpolation order (1=bilinear, 3=cubic).
       exclusion_buffer: Angular gap between test and first reference.
       validity_map: Optional 2D mask (1=valid, 0=invalid). Off-chip
           locations automatically get 0 via cval boundary handling.


.. py:function:: calculate_ccd_snr(signal, background_noise, read_noise = 0.0, dark_current = 0.0)

   Calculate signal-to-noise ratio using the CCD equation.

   Uses the standard CCD equation for SNR:
       SNR = S / sqrt(S + B + R^2 + D)

   where:
       S = signal (electrons)
       B = background noise (electrons)
       R = read noise (electrons)
       D = dark current (electrons)

   This is distinct from Mawet SNR which uses spatial aperture statistics.

   Args:
       signal: Source signal in electrons.
       background_noise: Background noise in electrons (from sky, zodi, etc.).
       read_noise: Read noise in electrons (per pixel, summed over aperture).
       dark_current: Dark current in electrons (summed over aperture).

   Returns:
       Signal-to-noise ratio.


.. py:function:: exposure_time_for_snr(target_snr, signal_rate, background_rate, read_noise = 0.0, dark_rate = 0.0)

   Calculate exposure time needed to achieve a target SNR.

   Solves the CCD equation for exposure time:
       SNR = S*t / sqrt(S*t + B*t + R^2 + D*t)

   This is a quadratic equation in t.

   Args:
       target_snr: Desired signal-to-noise ratio.
       signal_rate: Source signal rate in electrons/second.
       background_rate: Background rate in electrons/second.
       read_noise: Read noise in electrons (constant, not per second).
       dark_rate: Dark current rate in electrons/second.

   Returns:
       Required exposure time in seconds.


