Source code for coronalyze.core.geometry

"""Geometric utilities for coronagraphic image analysis.

Functions for coordinate transforms, distance calculations, and generating
aperture positions with static array shapes for JAX compatibility.
"""

import jax.numpy as jnp


[docs] def get_center(shape: tuple[int, int]) -> tuple[float, float]: """Get the center coordinates of an image (0-indexed geometric center). For an image of size N, the geometric center is at (N-1)/2. This matches the convention used in modeling.py for planet injection. Args: shape: Image shape (ny, nx). Returns: Center coordinates (cy, cx). """ return ((shape[0] - 1) / 2.0, (shape[1] - 1) / 2.0)
[docs] def radial_distance( shape: tuple[int, int], center: tuple[float, float] | None = None, ) -> jnp.ndarray: """Calculate radial distance from center for each pixel. Args: shape: Image shape (ny, nx). center: Center coordinates (cy, cx). If None, uses image center. Returns: 2D array of radial distances in pixels. """ ny, nx = shape if center is None: center = get_center(shape) cy, cx = center y, x = jnp.ogrid[:ny, :nx] return jnp.sqrt((y - cy) ** 2 + (x - cx) ** 2)
[docs] def calculate_n_apertures( radius: float, fwhm: float, exclusion_buffer: float = 0.5, ) -> int: """Calculate the number of reference apertures at a given radius. Uses the Mawet et al. (2014) formula with exclusion buffer correction to ensure no overlap with the planet aperture on either side. This function provides the canonical calculation used by the SNR module, ensuring consistency between visualization and computation. Args: radius: Radial distance from center in pixels. fwhm: Full width at half maximum in pixels. exclusion_buffer: Gap between planet and first/last reference aperture in units of angular step (default 0.5). Creates a gap on both sides. Returns: Number of valid reference apertures. Example:: from coronalyze.core.geometry import calculate_n_apertures n = calculate_n_apertures(radius=20, fwhm=5.0) print(f"{n} reference apertures at r=20px") """ half_angle = jnp.arcsin(jnp.minimum(fwhm / 2.0 / jnp.maximum(radius, 0.1), 1.0)) d_theta = 2.0 * half_angle n_theoretical = jnp.floor(2 * jnp.pi / jnp.maximum(d_theta, 0.01)) # Subtract: 1 for planet position + 2*buffer for gap on each side n_actual = max(int(n_theoretical - 1 - 2 * exclusion_buffer), 1) return n_actual
[docs] def generate_aperture_coords( center: tuple[float, float], radius: float, planet_angle: float, n_apertures: int, max_apertures: int = 200, fwhm: float | None = None, exclusion_buffer: float = 0.5, ) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Generate coordinates for reference apertures at a given radius. Uses a fixed-size array with masking for JAX compatibility (static shapes). Apertures are distributed evenly around the annulus, excluding the planet position. Matches VIP's clockwise rotation and angle formula from planet position. Args: center: Image center (cy, cx) in pixels. radius: Radial distance from center in pixels. planet_angle: Angle of the planet position in radians. n_apertures: Actual number of valid apertures to use. max_apertures: Maximum buffer size for static array shape. fwhm: Full width half maximum for VIP-style angle calculation. If None, uses uniform distribution. exclusion_buffer: Gap between test and first reference aperture in units of angular step (default 0.0). Prevents PSF wing leakage. Returns: Tuple of (y_coords, x_coords, mask) where: - y_coords: Y coordinates of aperture centers (size max_apertures) - x_coords: X coordinates of aperture centers (size max_apertures) - mask: Boolean mask indicating valid apertures """ cy, cx = center idx_grid = jnp.arange(max_apertures) # Angular step calculation # VIP formula: angle = 2 * arcsin(fwhm/2/radius) if fwhm is not None: half_angle = jnp.arcsin(jnp.minimum(fwhm / 2.0 / jnp.maximum(radius, 0.1), 1.0)) d_theta = 2.0 * half_angle else: # Fallback: uniform distribution d_theta = 2 * jnp.pi / jnp.maximum(n_apertures + 1, 1) # Angles starting from one step past the planet (with optional buffer) # Use NEGATIVE rotation (clockwise) to match VIP's convention angles = planet_angle - (idx_grid + 1 + exclusion_buffer) * d_theta # Compute coordinates y_coords = cy + radius * jnp.sin(angles) x_coords = cx + radius * jnp.cos(angles) # Mask: valid if index < n_apertures mask = idx_grid < n_apertures return y_coords, x_coords, mask