Source code for coronalyze.core.photometry

"""Convolution-based aperture photometry for coronagraphic images.

Provides differentiable aperture kernels (soft and hard) and convolution-based
flux map generation for efficient photometry calculations.
"""

import functools

import jax
import jax.numpy as jnp
from jax.scipy.signal import convolve2d


[docs] @functools.lru_cache(maxsize=32) def make_aperture_kernel( radius: float, soft: bool = False, sharpness: float = 10.0, ) -> jnp.ndarray: """Create a circular aperture kernel for convolution-based photometry. This function is cached (up to 32 configurations) to avoid repeated kernel creation overhead when called with the same parameters. Args: radius: Aperture radius in pixels. soft: If True, use sigmoid-based soft edge for differentiability. If False, use hard binary mask. sharpness: Steepness of sigmoid transition (only used if soft=True). Returns: 2D aperture kernel. Not normalized (sum of counts, not average). """ # Kernel size with padding size = int(2 * radius + 3) half = size // 2 # Create coordinate grid centered on kernel y, x = jnp.ogrid[-half : half + 1, -half : half + 1] dist = jnp.sqrt(x**2 + y**2) if soft: # Sigmoid for differentiable edge: 1 / (1 + e^(sharpness*(dist-radius))) kernel = jax.nn.sigmoid(sharpness * (radius - dist)) else: # Hard binary mask kernel = (dist <= radius).astype(jnp.float32) return kernel
[docs] @jax.jit def flux_map(image: jnp.ndarray, kernel: jnp.ndarray) -> jnp.ndarray: """Generate a flux map via convolution with an aperture kernel. Each pixel in the output represents the integrated flux that would be measured if an aperture were centered at that position. Args: image: 2D image array. kernel: Aperture kernel from make_aperture_kernel. Returns: 2D flux map with same shape as input image. """ return convolve2d(image, kernel, mode="same")
# ============================================================================== # Aperture Mask Functions (consolidated from top-level photometry.py) # ==============================================================================
[docs] def circular_aperture_mask( shape: tuple[int, int], center: tuple[float, float], radius: float, ) -> jnp.ndarray: """Create a circular aperture mask. Args: shape: Image shape (ny, nx). center: Center of aperture (y, x) in pixels. radius: Radius of aperture in pixels. Returns: Boolean mask array with True inside the aperture. """ ny, nx = shape y, x = jnp.ogrid[:ny, :nx] cy, cx = center distance = jnp.sqrt((y - cy) ** 2 + (x - cx) ** 2) return distance <= radius
[docs] def soft_aperture_mask( shape: tuple[int, int], center: tuple[float, float], radius: float, sharpness: float = 10.0, ) -> jnp.ndarray: """Create a soft (differentiable) circular aperture mask. Uses a sigmoid function to create a smooth transition at the aperture edge, enabling gradient-based optimization through the mask. Args: shape: Image shape (ny, nx). center: Center of aperture (y, x) in pixels. radius: Radius of aperture in pixels. sharpness: Steepness of the sigmoid transition. Returns: Soft mask array with values in [0, 1]. """ ny, nx = shape y, x = jnp.ogrid[:ny, :nx] cy, cx = center distance = jnp.sqrt((y - cy) ** 2 + (x - cx) ** 2) return jax.nn.sigmoid(sharpness * (radius - distance))
[docs] @jax.jit def aperture_photometry( image: jnp.ndarray, center: tuple[float, float], radius: float, ) -> float: """Perform circular aperture photometry on an image. Args: image: 2D image array. center: Center of aperture (y, x) in pixels. radius: Radius of aperture in pixels. Returns: Total flux within the aperture. """ mask = circular_aperture_mask(image.shape, center, radius) return jnp.sum(image * mask)
[docs] def aperture_solid_angle( radius_pixels: float, pixel_scale_arcsec: float, ) -> float: """Calculate the solid angle of a circular aperture. Args: radius_pixels: Aperture radius in pixels. pixel_scale_arcsec: Pixel scale in arcseconds per pixel. Returns: Solid angle in arcsec^2. """ radius_arcsec = radius_pixels * pixel_scale_arcsec return jnp.pi * radius_arcsec**2