Source code for coronalyze.core.modeling

"""Forward modeling tools for injecting synthetic sources.

Provides utilities for fake planet injection and simple disk generation
for throughput estimation and contrast curve calculations.
"""

import functools

import jax
import jax.numpy as jnp
from hwoutils.transforms import shift_image


[docs] @jax.jit def inject_planet( image: jnp.ndarray, psf_template: jnp.ndarray, flux: float, pos: tuple[float, float], order: int = 3, ) -> jnp.ndarray: """Inject a fake planet into an image using cubic spline shifts. The PSF template is shifted from the image center to the target position with sub-pixel precision, scaled by the flux, and added to the image. Args: image: 2D image to inject into, shape (ny, nx). psf_template: 2D PSF template centered at image center, same shape as image. flux: Flux scaling factor for the injected planet. pos: Target position (y, x) in pixels. order: Interpolation order for sub-pixel shifting (default: 3 = cubic). Returns: Image with injected planet, same shape as input. """ ny, nx = image.shape # Use (N-1)/2 for correct 0-indexed geometric center cy, cx = (ny - 1) / 2.0, (nx - 1) / 2.0 # Shift from center to target position dy = pos[0] - cy dx = pos[1] - cx planet_signal = shift_image(psf_template, dy, dx, order=order) return image + planet_signal * flux
[docs] @functools.partial(jax.jit, static_argnames=["shape"]) def make_simple_disk( shape: tuple[int, int], radius: float, inclination_deg: float, width: float, flux: float = 1.0, pa_deg: float = 0.0, ) -> jnp.ndarray: """Generate a simple, optically thin Gaussian ring/disk. Analytically projects the disk to avoid interpolation artifacts. Flux is normalized so the total integrated flux equals the specified value. Args: shape: Output image shape (ny, nx). radius: Ring radius in pixels. inclination_deg: Disk inclination (0 = face-on, 90 = edge-on). width: Gaussian width (sigma) of the ring in pixels. flux: Total integrated flux of the disk. Default 1.0. pa_deg: Position angle of major axis, measured East of North (degrees). Returns: 2D disk image with total flux normalized to the specified value. """ ny, nx = shape # Use (N-1)/2 for correct 0-indexed geometric center cy, cx = (ny - 1) / 2.0, (nx - 1) / 2.0 y, x = jnp.mgrid[:ny, :nx] # Shift to center y_c = y - cy x_c = x - cx # Rotate by position angle # Add 90 degrees so PA=0 aligns with North (Y-axis) per astronomical convention pa_rad = jnp.deg2rad(pa_deg + 90.0) cos_pa = jnp.cos(pa_rad) sin_pa = jnp.sin(pa_rad) y_rot = y_c * cos_pa + x_c * sin_pa x_rot = -y_c * sin_pa + x_c * cos_pa # De-project inclination (stretch along minor axis) inc_rad = jnp.deg2rad(inclination_deg) cos_inc = jnp.cos(inc_rad) # Avoid division by zero for edge-on disks cos_inc = jnp.maximum(cos_inc, 1e-6) y_deproj = y_rot / cos_inc # Radial distance in disk plane r_deproj = jnp.sqrt(y_deproj**2 + x_rot**2) # Gaussian profile profile = jnp.exp(-0.5 * ((r_deproj - radius) / width) ** 2) # Normalize to specified flux total = jnp.sum(profile) return jnp.where(total > 0, profile / total * flux, profile)
# ============================================================================= # Model Subtraction (inverse of injection) # =============================================================================
[docs] @jax.jit def subtract_star( science: jnp.ndarray, star_model: jnp.ndarray, scale: float = 1.0, ) -> jnp.ndarray: """Subtract stellar PSF model from observation. This is the fundamental operation for "perfect" RDI when you have a noiseless stellar PSF expectation (e.g., from coronagraphoto). Args: science: Observed image (electrons). star_model: Noiseless stellar PSF expectation (electrons). scale: Multiplicative scaling factor for the model before subtraction. Use values != 1.0 when the reference brightness differs from the science image (e.g., different exposure times or stellar flux). Returns: Residual image containing noise + planet signal. Example:: residual = subtract_star(observation, star_expectation) # With scaling for brightness mismatch: residual = subtract_star(observation, star_expectation, scale=0.95) """ return science - scale * star_model
[docs] @jax.jit def subtract_disk( residual: jnp.ndarray, disk_model: jnp.ndarray, scale: float = 1.0, ) -> jnp.ndarray: """Subtract disk model from residual image. Disk subtraction is typically a separate modeling task from stellar speckle subtraction. Call this after subtract_star when analyzing systems with circumstellar disks. Args: residual: Image after stellar subtraction (from subtract_star). disk_model: Disk model expectation (electrons). scale: Multiplicative scaling factor for the disk model. Adjust when disk model brightness doesn't match observation. Returns: Residual image with disk contribution removed. Example:: # Two-step subtraction residual = subtract_star(observation, star_model) residual = subtract_disk(residual, disk_model) """ return residual - scale * disk_model