Source code for jax_healpy._query_disc

# This file is part of jax-healpy.
# Copyright (C) 2024 CNRS / SciPol developers
#
# jax-healpy is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# jax-healpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with jax-healpy. If not, see <https://www.gnu.org/licenses/>.

"""
Query disc implementation for HEALPix spherical disc queries.
"""

from functools import partial
from typing import Optional

import jax
import jax.numpy as jnp
from jax import jit, lax
from jaxtyping import Array, ArrayLike

from .pixelfunc import _get_ring_info, _ring_above, nside2resol, pix2vec

__all__ = ['query_disc', 'estimate_disc_pixel_count', 'estimate_disc_radius', '_query_disc_bruteforce']


def _ring2z(nside: int, ring_idx: ArrayLike) -> Array:
    """Convert ring index to z-coordinate following HEALPix C++ implementation.

    This follows the exact logic from the C++ healpix_base.cc ring2z function.

    Parameters
    ----------
    nside : int
        HEALPix nside parameter
    ring_idx : ArrayLike
        Ring index (1 to 4*nside-1)

    Returns
    -------
    z : Array
        z-coordinate (cos(theta)) for the ring
    """
    ring = jnp.asarray(ring_idx, dtype=jnp.int32)

    # Convert to northern hemisphere ring number
    northring = jnp.where(ring > 2 * nside, 4 * nside - ring, ring)

    # HEALPix constants
    fact2 = 4.0 / (12.0 * nside * nside)  # 4/npix

    # Polar cap region (northring < nside)
    polar_z = 1.0 - (northring * northring) * fact2

    # Equatorial region (northring >= nside)
    equatorial_z = (2.0 * nside - northring) * 2.0 / (3.0 * nside)

    # Select based on region
    z = jnp.where(northring < nside, polar_z, equatorial_z)

    # Handle southern hemisphere (original ring > 2*nside)
    z = jnp.where(ring > 2 * nside, -z, z)

    return z


def _max_pixrad(nside: int) -> float:
    """Calculate maximum pixel radius for the given nside.

    This approximates the maximum angular distance from a pixel center
    to any point within the pixel. Based on HEALPix C++ max_pixrad.

    Parameters
    ----------
    nside : int
        HEALPix nside parameter

    Returns
    -------
    max_radius : float
        Maximum pixel radius in radians
    """
    # Approximate maximum pixel radius
    # This is a conservative estimate - in reality varies by pixel location
    resol = jnp.sqrt(4.0 * jnp.pi / (12.0 * nside * nside))
    return resol * 0.6  # Conservative factor from HEALPix


[docs] def estimate_disc_pixel_count(nside: int, radius: float) -> int: """Estimate number of pixels in a disc of given radius. Uses the analytical approximation: n_approx = 6 * nside^2 * (1 - cos(radius)) This approximation is exact at radius = π (full sphere) and loses precision for smaller radii. It provides a good estimate for setting the max_length parameter in query_disc functions. Parameters ---------- nside : int HEALPix nside parameter radius : float Disc radius in radians Returns ------- pixel_count : int Estimated number of pixels in the disc Notes ----- The approximation assumes uniform pixel density across the sphere, which is accurate for large discs but less precise for small discs where pixel shape variations matter more. Examples -------- >>> import jax_healpy as hp >>> nside = 64 >>> radius = 0.1 # ~5.7 degrees >>> estimated_count = hp.estimate_disc_pixel_count(nside, radius) >>> # Use as max_length for query_disc >>> actual_disc = hp.query_disc(nside, [1, 0, 0], radius, max_length=estimated_count) """ # Clip radius to valid range radius = jnp.clip(radius, 0.0, jnp.pi) # Analytical approximation: n = 6 * nside^2 * (1 - cos(radius)) pixel_count = 6 * nside**2 * (1 - jnp.cos(radius)) # For full sphere, ensure we return exactly npix npix_total = 12 * nside**2 if radius >= jnp.pi: return npix_total return int(jnp.ceil(pixel_count))
[docs] def estimate_disc_radius(nside: int, pixel_count: int) -> float: """Estimate radius needed for a disc containing given pixel count. Inverse of estimate_disc_pixel_count. Uses the analytical relationship: radius = arccos(1 - pixel_count / (6 * nside^2)) Parameters ---------- nside : int HEALPix nside parameter pixel_count : int Desired number of pixels in the disc Returns ------- radius : float Estimated disc radius in radians Notes ----- This is the inverse of estimate_disc_pixel_count and has the same accuracy characteristics: exact at full sphere, less precise for small discs. Examples -------- >>> import jax_healpy as hp >>> nside = 64 >>> target_pixels = 1000 >>> estimated_radius = hp.estimate_disc_radius(nside, target_pixels) >>> # Verify the relationship >>> back_pixels = hp.estimate_disc_pixel_count(nside, estimated_radius) >>> abs(back_pixels - target_pixels) < 10 # Should be close True """ npix_total = 12 * nside**2 if pixel_count >= npix_total: return jnp.pi # Full sphere if pixel_count <= 0: return 0.0 # Inverse formula: radius = arccos(1 - pixel_count / (6 * nside^2)) cos_term = 1 - pixel_count / (6 * nside**2) # Clamp to valid range for arccos cos_term = jnp.clip(cos_term, -1.0, 1.0) return jnp.arccos(cos_term)
def _query_disc_ring_single( nside: int, vec: Array, radius: float, inclusive: bool = False, fact: int = 4, max_length: int = None ) -> Array: """True geometric single-disc query for RING scheme following HEALPix C++ algorithm. This implements the exact geometry-based algorithm from HEALPix C++ that processes only candidate rings and generates pixels directly from geometric intersections. NO brute-force operations are performed. The algorithm is memory-optimized using lax.fori_loop instead of lax.scan to avoid large intermediate arrays. Algorithm Steps: 1. Setup geometric bounds and normalize input vector 2. Calculate candidate ring range based on disc geometry 3. Initialize pixel mask for accumulation 4. Add polar region pixels if disc intersects poles 5. Process rings using fixed-size loop for JAX compatibility 6. Extract valid pixels using memory-optimized collection method Parameters ---------- nside : int HEALPix nside parameter (must be a power of 2) vec : Array Unit vector (3,) defining disc center. Will be normalized if not unit length. radius : float Disc radius in radians, will be clipped to [0, π] inclusive : bool, optional If True, include pixels that overlap the disc boundary (default: False) fact : int, optional Oversampling factor for inclusive mode (default: 4) max_length : int, optional Maximum number of pixels to return. If None, defaults to npix. Returns ------- pixels : Array Array of pixel indices in the disc, shape (max_length,). Pixels outside the disc are marked as npix (sentinel value). Results are padded with npix for unused entries. Notes ----- This function achieves significant memory optimization compared to brute-force approaches by avoiding creation of large intermediate arrays and using geometric ring processing following the HEALPix C++ reference implementation. """ npix = 12 * nside * nside if max_length is None: max_length = npix # #step1: Setup and Geometric Bounds # Normalize vector and handle edge cases vec_norm = jnp.linalg.norm(vec) safe_vec = jnp.where(vec_norm > 1e-10, vec / vec_norm, jnp.array([1.0, 0.0, 0.0])) # Clip radius to valid range [0, π] radius = jnp.clip(radius, 0.0, jnp.pi) # Calculate inclusive mode radii based on C++ reference if inclusive: # Use finer grid and original grid pixel radii for inclusive bounds finer_pixrad = _max_pixrad(fact * nside) # More precise pixel radius coarse_pixrad = _max_pixrad(nside) # Original pixel radius rsmall = radius + finer_pixrad rbig = radius + coarse_pixrad else: rsmall = rbig = radius # Handle full-sphere case full_sphere = rsmall >= jnp.pi rbig = jnp.minimum(jnp.pi, rbig) # Pre-compute trigonometric values cosrbig = jnp.cos(rbig) # #step2: Calculate Disc Center Coordinates and Ring Range # Disc center coordinates z0 = safe_vec[2] # cos(theta) phi0 = jnp.arctan2(safe_vec[1], safe_vec[0]) # Handle polar singularity where sin(theta) = 0 sin_theta_sq = (1.0 - z0) * (1.0 + z0) # sin²(theta) = 1 - cos²(theta) xa = jnp.where( sin_theta_sq > 1e-10, 1.0 / jnp.sqrt(sin_theta_sq), # Normal case: 1/sin(theta) 1e10, # Polar case: very large value (effectively infinity) ) # Calculate candidate ring range # Note: z0 = cos(theta), so theta = arccos(z0) theta0 = jnp.arccos(jnp.clip(z0, -1.0, 1.0)) # Clip to handle numerical precision rlat1 = theta0 - rsmall # theta - rsmall rlat2 = theta0 + rsmall # theta + rsmall zmax = jnp.cos(jnp.maximum(0.0, rlat1)) irmin = _ring_above(nside, zmax) + 1 zmin = jnp.cos(jnp.minimum(jnp.pi, rlat2)) irmax = _ring_above(nside, zmin) # For inclusive mode, expand ring range slightly irmin = jnp.where(inclusive & (rlat1 > 0), jnp.maximum(1, irmin - 1), irmin) irmax = jnp.where(inclusive & (rlat2 < jnp.pi), jnp.minimum(4 * nside - 1, irmax + 1), irmax) # Handle polar regions (following C++ logic exactly) north_pole_in_disc = (rlat1 <= 0.0) & (irmin > 1) south_pole_in_disc = (rlat2 >= jnp.pi) & (irmax + 1 < 4 * nside) # #step3: Initialize pixel mask for accumulation # Use boolean mask to track valid pixels - JAX compatible approach pixel_mask = jnp.zeros(npix, dtype=bool) # #step4: Add polar region pixels if needed def add_north_pole_pixels(mask): # If north pole is in disc, add pixels from rings 1 to irmin-1 # NOTE: When irmin = 1, we still need to potentially add pixels from the north cap def add_north_pixels(): # Determine which ring to use as the boundary boundary_ring = jnp.maximum(1, irmin - 1) # Get total pixels in north cap up to boundary_ring ring_info = _get_ring_info(nside, boundary_ring) north_cap_pixels = ring_info[1] + ring_info[2] # startpix + ringpix of boundary ring # Create mask for pixels 0 to north_cap_pixels-1 north_indices = jnp.arange(npix) north_mask = north_indices < north_cap_pixels return mask | north_mask def no_north_pixels(): return mask return lax.cond(north_pole_in_disc, add_north_pixels, no_north_pixels) def add_south_pole_pixels(mask): # C++ logic: if (rlat2>=pi) && (irmax+1<4*nside_) # Add pixels from startpix of ring (irmax+1) to npix-1 def add_south_pixels(): # Get start pixel of ring irmax+1 (which is guaranteed to exist by the condition) ring_info = _get_ring_info(nside, irmax + 1) south_start_pixel = ring_info[1] # startpix # Create mask for pixels from south_start_pixel to npix-1 south_indices = jnp.arange(npix) south_mask = south_indices >= south_start_pixel return mask | south_mask def no_south_pixels(): return mask return lax.cond(south_pole_in_disc, add_south_pixels, no_south_pixels) # Add polar pixels pixel_mask = add_north_pole_pixels(pixel_mask) pixel_mask = add_south_pole_pixels(pixel_mask) # #step5: Process rings using fixed-size loop def process_ring_iteration(ring_idx, mask): """Process a single ring and update the pixel mask.""" # Only process if ring is in our candidate range in_range = (ring_idx >= irmin) & (ring_idx <= irmax) & (ring_idx >= 1) & (ring_idx < 4 * nside) def process_valid_ring(): # Get ring properties z = _ring2z(nside, ring_idx) ring_info = _get_ring_info(nside, ring_idx) ipix1 = ring_info[1] # Start pixel index for this ring nr = ring_info[2] # Number of pixels in ring shifted = ring_info[3] # Whether ring is shifted # Calculate intersection geometry x = (cosrbig - z * z0) * xa ysq = 1.0 - z * z - x * x def calculate_ring_pixels(): """Calculate which pixels in this ring are in the disc.""" # Following C++ logic: handle ysq <= 0 case # When ysq <= 0, no normal intersection exists - ring is either # completely inside or completely outside the disc def handle_no_intersection(): # When ysq <= 0, ring is either completely inside or outside the disc # Check if ring center is inside the disc to determine which case # Get a representative point on the ring (any longitude will do) ring_phi = 0.0 # Use phi=0 as representative point ring_vec = jnp.array( [ jnp.sqrt(1 - z * z) * jnp.cos(ring_phi), # x = sin(theta) * cos(phi) jnp.sqrt(1 - z * z) * jnp.sin(ring_phi), # y = sin(theta) * sin(phi) z, # z = cos(theta) ] ) # Check if this ring point is inside the disc ring_dot = jnp.dot(ring_vec, safe_vec) ring_inside_disc = ring_dot >= cosrbig # If ring is inside disc, include all pixels (dphi = pi) # If ring is outside disc, include no pixels (dphi = 0) dphi = jnp.where(ring_inside_disc, jnp.pi - 1e-15, 0.0) return dphi def handle_normal_intersection(): # Normal case: calculate intersection half-angle dphi = jnp.arctan2(jnp.sqrt(ysq), x) return dphi # Calculate dphi based on whether we have a geometric intersection dphi = lax.cond(ysq <= 0, handle_no_intersection, handle_normal_intersection) # If dphi <= 0, no pixels in this ring def no_pixels(): return jnp.zeros(npix, dtype=bool) def calculate_pixels(): # Convert longitude range to pixel indices within ring shift = jnp.where(shifted, 0.5, 0.0) inv_twopi = 1.0 / (2.0 * jnp.pi) # Calculate pixel range in ring coordinates (following C++ logic exactly) ip_lo = jnp.floor(nr * inv_twopi * (phi0 - dphi) - shift).astype(jnp.int32) + 1 ip_hi = jnp.floor(nr * inv_twopi * (phi0 + dphi) - shift).astype(jnp.int32) # Handle fullcircle case (when dphi ≈ π, we want nearly the entire ring) fullcircle = dphi >= (jnp.pi - 1e-10) # Close to full circle def adjust_for_fullcircle(): # C++ logic: if (ip_hi-ip_lo<nr-1) expand the range adj_ip_lo = ip_lo adj_ip_hi = ip_hi needs_expansion = (adj_ip_hi - adj_ip_lo) < (nr - 1) def expand_range(): # if (ip_lo>0) --ip_lo; else ++ip_hi; new_ip_lo = jnp.where(adj_ip_lo > 0, adj_ip_lo - 1, adj_ip_lo) new_ip_hi = jnp.where(adj_ip_lo > 0, adj_ip_hi, adj_ip_hi + 1) return new_ip_lo, new_ip_hi def keep_range(): return adj_ip_lo, adj_ip_hi return lax.cond(needs_expansion, expand_range, keep_range) def keep_original(): return ip_lo, ip_hi ip_lo, ip_hi = lax.cond(fullcircle, adjust_for_fullcircle, keep_original) # DO NOT clip here - wraparound is detected by ip_lo > ip_hi or out-of-bounds values # Create mask for this ring's pixels ring_pixel_indices = jnp.arange(npix) # Handle the C++ wraparound logic exactly def simple_range(): # Standard case: ip_lo <= ip_hi and both in bounds ring_start = ipix1 + ip_lo ring_end = ipix1 + ip_hi + 1 ring_mask = (ring_pixel_indices >= ring_start) & (ring_pixel_indices < ring_end) return ring_mask def handle_wraparound(): # Handle out-of-bounds cases according to C++ logic # Adjust indices for wraparound adj_ip_lo = ip_lo adj_ip_hi = ip_hi # Handle ip_hi >= nr case adj_ip_lo = jnp.where(ip_hi >= nr, adj_ip_lo - nr, adj_ip_lo) adj_ip_hi = jnp.where(ip_hi >= nr, adj_ip_hi - nr, adj_ip_hi) # Handle ip_lo < 0 case (wraparound) def wraparound_case(): # Two segments: [ipix1, ipix1+ip_hi+1) and [ipix1+ip_lo+nr, ipix1+nr) # Following C++ logic: append(ipix1, ipix1+ip_hi+1) and append(ipix1+ip_lo+nr, ipix2+1) mask1 = (ring_pixel_indices >= ipix1) & (ring_pixel_indices < ipix1 + adj_ip_hi + 1) mask2 = (ring_pixel_indices >= ipix1 + adj_ip_lo + nr) & (ring_pixel_indices < ipix1 + nr) return mask1 | mask2 def normal_case(): # Single segment: [ipix1+adj_ip_lo, ipix1+adj_ip_hi] ring_start = ipix1 + adj_ip_lo ring_end = ipix1 + adj_ip_hi + 1 ring_mask = (ring_pixel_indices >= ring_start) & (ring_pixel_indices < ring_end) return ring_mask return lax.cond(adj_ip_lo < 0, wraparound_case, normal_case) # Check if we need special handling needs_special_handling = (ip_lo > ip_hi) | (ip_hi >= nr) | (ip_lo < 0) return lax.cond(needs_special_handling, handle_wraparound, simple_range) # Return appropriate result based on dphi return lax.cond(dphi <= 0, no_pixels, calculate_pixels) def no_ring_pixels(): """No intersection - return empty mask.""" return jnp.zeros(npix, dtype=bool) # Always try to calculate ring pixels - the dphi calculation handles the ysq <= 0 case ring_mask = calculate_ring_pixels() return mask | ring_mask def skip_ring(): """Ring not in range - return unchanged mask.""" return mask return lax.cond(in_range, process_valid_ring, skip_ring) # Process all rings using fixed-size loop (static bounds) max_rings = 4 * nside pixel_mask = lax.fori_loop(1, max_rings, process_ring_iteration, pixel_mask) # #step6: Extract valid pixels using memory-optimized method def extract_pixels_from_mask(mask): """Extract pixel indices from boolean mask without expensive argsort. This memory-optimized approach uses lax.fori_loop instead of lax.scan to avoid creating large intermediate arrays with jnp.arange(npix). """ # #step6a: Use fori_loop to collect valid pixels sequentially def fori_body(i, carry): result_array, count = carry pixel_is_valid = mask[i] # Add pixel to result if valid and we have space should_add = pixel_is_valid & (count < max_length) new_result = jnp.where(should_add, result_array.at[count].set(i), result_array) new_count = jnp.where(should_add, count + 1, count) return (new_result, new_count) # #step6b: Initialize result array filled with sentinel values init_result = jnp.full(max_length, npix, dtype=jnp.int32) init_carry = (init_result, 0) # #step6c: Scan through all pixels to collect valid ones (final_result, final_count) = lax.fori_loop(0, npix, fori_body, init_carry) return final_result # Handle full sphere case def get_all_pixels(): return jnp.arange(max_length, dtype=jnp.int32) def get_geometric_pixels(): return extract_pixels_from_mask(pixel_mask) return lax.cond(full_sphere, get_all_pixels, get_geometric_pixels)
[docs] @partial(jit, static_argnames=['nside', 'inclusive', 'fact', 'nest', 'max_length']) def query_disc( nside: int, vec: ArrayLike, radius: float, inclusive: bool = False, fact: int = 4, nest: bool = False, max_length: Optional[int] = None, ) -> Array: """Find pixels within a disc on the sphere. This function supports both single and batched queries. It is fully JIT-compatible and differentiable with respect to vec and radius parameters. Parameters ---------- nside : int The resolution parameter of the HEALPix map vec : array-like Either a single three-component unit vector (3,) defining the center of the disc, or a batch of vectors (B, 3) defining B disc centers radius : float The radius of the disc in radians inclusive : bool, optional If False (default), return pixels whose centers lie within the disc. Results are guaranteed to match healpy exactly for inclusive=False. If True, return all pixels that overlap with the disc. Note: inclusive=True may produce slightly different results compared to healpy due to algorithm differences in determining pixel overlap. fact : int, optional For inclusive queries, the pixelization factor (default: 4) nest : bool, optional If True, assume NESTED pixel ordering, otherwise RING ordering (default: False) max_length : int, optional Maximum number of pixels to return per disc. If None, defaults to npix. For batched inputs, this limits memory usage by returning (max_length, B) instead of (npix, B). Returns ------- ipix : array For single vector input (3,): returns (max_length,) or (npix,) if max_length is None For batch vector input (B, 3): returns (max_length, B) Pixels outside the disc are marked as npix (sentinel value). This allows direct indexing like: map.at[disc].set(value) Raises ------ NotImplementedError If nest=True (nested ordering not yet supported) Notes ----- This function currently only supports RING ordering. The function is JIT-compatible and differentiable when compiled with static_argnums=(0,) for the nside parameter. The returned array has fixed size for JIT compatibility - pixels outside the disc have value npix. When indexing a JAX array, pixels outside the disc should be ignored. If indexing a numpy array, this will raise an out-of-bounds error. Examples -------- Single disc query: >>> import jax_healpy as hp >>> import jax.numpy as jnp >>> import jax >>> nside = 16 >>> vec = jnp.array([1.0, 0.0, 0.0]) # Point on equator >>> radius = 0.1 # ~5.7 degrees >>> disc = hp.query_disc(nside, vec, radius) >>> # Use directly for indexing >>> map = jnp.zeros(hp.nside2npix(nside)) >>> map = map.at[disc].set(1.0) Batch disc query: >>> vecs = jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]) # (2, 3) - two centers >>> discs = hp.query_disc(nside, vecs, radius, max_length=1000) # (1000, 2) >>> # Each column contains pixels for one disc >>> map1 = map.at[discs[:, 0]].set(1.0) # First disc >>> map2 = map.at[discs[:, 1]].set(2.0) # Second disc For JIT compilation: >>> jit_query_disc = jax.jit( ... lambda n, v, r: hp.query_disc(n, v, r) ... ) >>> disc_jit = jit_query_disc(nside, vec, radius) """ # Raise error for nested ordering if nest: raise NotImplementedError('Nested ordering not yet supported') return _query_disc_ring(nside, vec, radius, inclusive, fact, max_length)
def _query_disc_ring( nside: int, vec: ArrayLike, radius: float, inclusive: bool, fact: int, max_length: Optional[int] ) -> Array: """Efficient RING scheme query with batching support using geometric algorithm. This function handles both single and batched disc queries by using jax.vmap to vectorize the single-disc geometric algorithm. Parameters ---------- nside : int HEALPix nside parameter vec : ArrayLike Either single vector (3,) or batch of vectors (B, 3) radius : float Disc radius in radians inclusive : bool If True, include pixels that overlap the disc boundary fact : int Oversampling factor for inclusive mode max_length : Optional[int] Maximum number of pixels to return per disc Returns ------- Array For single vector: shape (max_length,) For batch: shape (B, max_length) Pixels outside discs are marked as npix (sentinel value) """ # Convert to JAX arrays vec = jnp.asarray(vec, dtype=jnp.float64) radius = jnp.asarray(radius, dtype=jnp.float64) original_is_single = vec.ndim == 1 if original_is_single: vec = vec[None, :] # (3,) → (1, 3) npix = 12 * nside * nside if max_length is None: max_length = npix # Process each vector in the batch def process_single_vec(single_vec): return _query_disc_ring_single(nside, single_vec, radius, inclusive, fact, max_length) # Use vmap to handle batching result = jax.vmap(process_single_vec)(vec) # (batch_dims, max_length) # Squeeze for single vector input if original_is_single: result = jnp.squeeze(result, axis=0) # (1, max_length) → (max_length,) return result def _query_disc_bruteforce( nside: int, vec: ArrayLike, radius: float, inclusive: bool, fact: int, max_length: Optional[int] ) -> Array: """DEPRECATED: Brute-force disc query with O(batch_size × npix) complexity. ⚠️ **WARNING: NOT RECOMMENDED FOR PRODUCTION USE** ⚠️ This function has poor computational and memory scaling characteristics: - **Complexity**: O(batch_size × npix) where npix = 12 × nside² - **Memory**: Creates large intermediate arrays of size (npix × batch_size) - **Performance**: Much slower than the geometric algorithm for large nside values **RECOMMENDATION**: Use the default `query_disc()` function instead, which uses an efficient geometric algorithm with much better scaling properties. This brute-force implementation is kept only for reference and testing purposes. It computes dot products with ALL pixels on the sphere, making it inefficient for typical use cases. Algorithm Overview: 1. Standardize input to (batch_dims, 3) format and set defaults 2. Normalize input vectors and clip radius to valid range 3. Calculate the cosine threshold for the dot product test 4. Generate ALL pixel vectors and compute broadcast dot products (EXPENSIVE!) 5. Create mask for pixels within the disc(s) (large intermediate arrays) 6. Select top max_length pixels per disc with sentinel padding 7. Apply JAX-compatible warning system for truncation 8. Squeeze output for single vector compatibility Performance Comparison: - For nside=512: ~3M pixels → creates 3M × batch_size arrays - For nside=1024: ~12M pixels → creates 12M × batch_size arrays - Geometric algorithm processes only candidate rings (~10-100× fewer operations) """ # Step 1: Input standardization to (batch_dims, 3) format vec = jnp.asarray(vec, dtype=jnp.float64) original_is_single = vec.ndim == 1 if original_is_single: vec = vec[None, :] # (3,) → (1, 3) batch_dims = vec.shape[0] npix = 12 * nside * nside radius = jnp.asarray(radius, dtype=jnp.float64) # Default max_length to npix if not provided if max_length is None: max_length = npix # Step 2: Normalize center vectors (handle zero vector case) vec_norms = jnp.linalg.norm(vec, axis=1) # (batch_dims,) # Create default direction - broadcasts to (batch_dims, 3) default_dir = jnp.array([1.0, 0.0, 0.0])[None, :] # Normalize each vector individually safe_vecs = jnp.where(vec_norms[:, None] > 1e-10, vec / vec_norms[:, None], default_dir) # Clip radius to valid range [0, π] radius = jnp.clip(radius, 0.0, jnp.pi) # Step 3: Calculate cosine threshold for dot product comparison cos_radius = jnp.cos(radius) # For inclusive mode, expand the radius by pixel resolution divided by fact if inclusive: expanded_radius = radius + nside2resol(nside) / fact cos_expanded_radius = jnp.cos(jnp.clip(expanded_radius, 0, jnp.pi)) else: cos_expanded_radius = cos_radius # Step 4: Generate all pixel vectors and compute broadcast dot products all_pixels = jnp.arange(npix, dtype=jnp.int32) pixel_vecs = pix2vec(nside, all_pixels, nest=False) # (npix, 3) # Broadcast dot products: (npix, 3) @ (batch_dims, 3).T → (npix, batch_dims) dot_products = jnp.dot(pixel_vecs, safe_vecs.T) # Step 5: Create mask for pixels within the disc(s) # Use small tolerance to handle floating point precision issues tolerance = 1e-6 mask = dot_products >= (cos_expanded_radius - tolerance) # (npix, batch_dims) # Step 6: Select top max_length pixels per disc # Create sort keys: valid pixels keep dot product, invalid get -inf sort_keys = jnp.where(mask, dot_products, -jnp.inf) # (npix, batch_dims) # Sort indices by dot product (best pixels last) sorted_indices = jnp.argsort(sort_keys, axis=0) # (npix, batch_dims) # Select top max_length pixels per batch top_indices = sorted_indices[-max_length:] # (max_length, batch_dims) result = top_indices # (max_length, batch_dims) - keep pixels as leading axis # Replace invalid entries (where sort key was -inf) with npix selected_scores = sort_keys[top_indices, jnp.arange(batch_dims)] # (max_length, batch_dims) invalid_mask = selected_scores == -jnp.inf # (max_length, batch_dims) result = jnp.where(invalid_mask, npix, result) # Step 7: JAX-compatible warning system for truncation if max_length < npix: # Only check for truncation if limiting # Count valid pixels per batch valid_counts = jnp.sum(mask.astype(jnp.int32), axis=0) # (batch_dims,) exceeded_count = jnp.sum(valid_counts > max_length) # scalar # Use lax.cond for JAX-compatible conditional warning lax.cond( exceeded_count > 0, lambda: jax.debug.print('Warning: {} valid pixels exceeded max_length={}', valid_counts.max(), max_length), lambda: None, ) # Step 8: Squeeze output for single vector compatibility if original_is_single: result = jnp.squeeze(result, axis=1) # (max_length, 1) → (max_length,) return result # This file is part of jax-healpy.