Source code for coronalyze.pipelines.yield_pipeline

"""Yield simulation pipelines for fast SNR calculation.

Provides high-level workflows for yield estimation:

- calculate_yield_snr: End-to-end subtraction + SNR calculation
- klip_subtract: PCA/KLIP PSF subtraction

For subtraction primitives, see coronalyze.core.modeling:
- subtract_star, subtract_disk

All functions are JIT-compiled and differentiable.
"""

import functools

import jax
import jax.numpy as jnp

from coronalyze.core.modeling import subtract_disk, subtract_star
from coronalyze.core.pca import get_pca_basis, pca_subtract
from coronalyze.core.snr import snr


[docs] @functools.partial(jax.jit, static_argnames=["n_modes"]) def klip_subtract( science: jnp.ndarray, reference_cube: jnp.ndarray, n_modes: int = 5, ) -> jnp.ndarray: """KLIP/PCA PSF subtraction. Uses PCA to build a stellar PSF model from a reference cube (e.g., images at different roll angles) and subtracts it. This is the most physically realistic stellar subtraction mode but also the slowest. Args: science: Science observation (electrons). reference_cube: Reference library of shape (n_frames, ny, nx). Typically star-only images at different roll angles. n_modes: Number of PCA modes to use for subtraction. Returns: Residual image after KLIP subtraction. Example:: # ADI with KLIP residual = klip_subtract(science, roll_cube, n_modes=10) # If you also need disk subtraction, do it after: residual = subtract_disk(residual, disk_model) """ basis, mean_ref = get_pca_basis(reference_cube, n_modes) return pca_subtract(science, basis, mean_ref)
[docs] def calculate_yield_snr( science: jnp.ndarray, planet_positions: jnp.ndarray, fwhm: float, star_model: jnp.ndarray = None, disk_model: jnp.ndarray = None, reference_cube: jnp.ndarray = None, n_modes: int = 5, method: str = "star", star_scale: float = 1.0, disk_scale: float = 1.0, exclusion_buffer: float = 0.5, validity_map: jnp.ndarray = None, ) -> jnp.ndarray: """End-to-end yield SNR calculation. Convenience function that performs subtraction and SNR calculation in one call. Selects the appropriate subtraction method based on the `method` argument. Args: science: Observed image (electrons). planet_positions: Planet positions as (N, 2) array of (y, x) coords. fwhm: PSF FWHM in pixels. star_model: Noiseless star expectation (required for 'star' method). disk_model: Optional noiseless disk expectation. reference_cube: Reference library for KLIP (required for 'klip'/'rdi'). n_modes: Number of PCA modes (for klip method only). method: Subtraction method - "star", "rdi", or "klip". star_scale: Scaling factor for star model (default 1.0). disk_scale: Scaling factor for disk model (default 1.0). exclusion_buffer: Gap between test and reference apertures in units of angular step (default 0.5). Prevents PSF wing leakage. validity_map: Optional 2D mask (1=valid, 0=invalid) to exclude known companions, bad pixels, or edge regions. Returns: SNR values for each planet position. Example:: # Fast yield calculation with static PSF snrs = calculate_yield_snr( science_image, planet_positions, fwhm=4.5, star_model=star_expectation, disk_model=disk_expectation, method="star" ) """ if method == "star": if star_model is None: raise ValueError("star_model required for 'star' method") residual = subtract_star(science, star_model, star_scale) elif method == "rdi": if reference_cube is None: raise ValueError("reference_cube required for 'rdi' method") # Handle both 2D (single reference) and 3D (cube) inputs if reference_cube.ndim == 2: reference = reference_cube else: reference = reference_cube[0] # Use subtract_star - it's the same operation (science - scale * ref) residual = subtract_star(science, reference, star_scale) elif method == "klip": if reference_cube is None: raise ValueError("reference_cube required for 'klip' method") residual = klip_subtract(science, reference_cube, n_modes) else: raise ValueError(f"Unknown method: {method}. Use 'star', 'rdi', or 'klip'") # Apply disk subtraction if provided if disk_model is not None: residual = subtract_disk(residual, disk_model, disk_scale) return snr( residual, planet_positions, fwhm, exclusion_buffer=exclusion_buffer, validity_map=validity_map, )