Source code for jax_healpy.pixelfunc

# This file is part of jax-healpy.
# Copyright (C) 2024 CNRS / SciPol developers
#
# jax-healpy is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# jax-healpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with jax-healpy. If not, see <https://www.gnu.org/licenses/>.

"""
=====================================================
pixelfunc.py : Healpix pixelization related functions
=====================================================

This module provides functions related to Healpix pixelization scheme.

conversion from/to sky coordinates
----------------------------------

- :func:`pix2ang` converts pixel number to angular coordinates
- :func:`pix2vec` converts pixel number to unit 3-vector direction
- :func:`ang2pix` converts angular coordinates to pixel number
- :func:`vec2pix` converts 3-vector to pixel number
- :func:`vec2ang` converts 3-vector to angular coordinates
- :func:`ang2vec` converts angular coordinates to unit 3-vector
- :func:`pix2xyf` converts pixel number to coordinates within face
- :func:`xyf2pix` converts coordinates within face to pixel number
- :func:`get_interp_weights` returns the 4 nearest pixels for given
  angular coordinates and the relative weights for interpolation
- :func:`get_all_neighbours` return the 8 nearest pixels for given
  angular coordinates (or optionally 9 pixels including center with get_center=True)

conversion between NESTED and RING schemes
------------------------------------------

- :func:`nest2ring` converts NESTED scheme pixel numbers to RING
  scheme pixel number
- :func:`ring2nest` converts RING scheme pixel number to NESTED
  scheme pixel number
- :func:`reorder` reorders a healpix map pixels from one scheme to another

nside/npix/resolution
---------------------

- :func:`nside2npix` converts healpix nside parameter to number of pixel
- :func:`npix2nside` converts number of pixel to healpix nside parameter
- :func:`nside2order` converts nside to order
- :func:`order2nside` converts order to nside
- :func:`nside2resol` converts nside to mean angular resolution
- :func:`nside2pixarea` converts nside to pixel area
- :func:`isnsideok` checks the validity of nside
- :func:`isnpixok` checks the validity of npix
- :func:`get_map_size` gives the number of pixel of a map
- :func:`get_min_valid_nside` gives the minimum nside possible for a given
  number of pixel
- :func:`get_nside` returns the nside of a map
- :func:`maptype` checks the type of a map (one map or sequence of maps)
- :func:`ud_grade` upgrades or degrades the resolution (nside) of a map

Masking pixels
--------------

- :const:`UNSEEN` is a constant value interpreted as a masked pixel
- :func:`mask_bad` returns a map with ``True`` where map is :const:`UNSEEN`
- :func:`mask_good` returns a map with ``False`` where map is :const:`UNSEEN`
- :func:`ma` returns a masked array as map, with mask given by :func:`mask_bad`

Map data manipulation
---------------------

- :func:`fit_dipole` fits a monopole+dipole on the map
- :func:`fit_monopole` fits a monopole on the map
- :func:`remove_dipole` fits and removes a monopole+dipole from the map
- :func:`remove_monopole` fits and remove a monopole from the map
- :func:`get_interp_val` computes a bilinear interpolation of the map
  at given angular coordinates, using 4 nearest neighbours
"""

from functools import partial

import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, lax, vmap
from jaxtyping import Array, ArrayLike

__all__ = [
    'pix2ang',
    'ang2pix',
    'pix2xyf',
    'xyf2pix',
    'pix2vec',
    'vec2pix',
    'ang2vec',
    'vec2ang',
    'get_interp_weights',
    'get_interp_val',
    'get_all_neighbours',
    # 'max_pixrad',
    'nest2ring',
    'ring2nest',
    'reorder',
    'ud_grade',
    'UNSEEN',
    # 'mask_good',
    # 'mask_bad',
    # 'ma',
    # 'fit_dipole',
    # 'remove_dipole',
    # 'fit_monopole',
    # 'remove_monopole',
    'nside2npix',
    'npix2nside',
    'nside2order',
    'order2nside',
    'order2npix',
    'npix2order',
    'nside2resol',
    'nside2pixarea',
    'isnsideok',
    'isnpixok',
    # 'get_map_size',
    # 'get_min_valid_nside',
    'get_nside',
    'maptype',
    # 'ma_to_array',
]

# We are using 64-bit integer types.
# nside > 2**29 requires extended integer types.
MAX_NSIDE = 1 << 29
UNSEEN = -1.6375e30

# HEALPix neighbor finding constants
# These constants implement the exact neighbor-finding algorithm from the original
# HEALPix C++ library (healpix_base.cc) for face boundary transitions

# 8-element offset arrays for x and y directions (SW, W, NW, N, NE, E, SE, S)
# These define the relative positions of the 8 neighbors around any pixel
_NB_XOFFSET = jnp.array([-1, -1, 0, 1, 1, 1, 0, -1], dtype=jnp.int32)
_NB_YOFFSET = jnp.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=jnp.int32)

# Face boundary lookup table (9x12) - handles face transitions for neighbors
# This lookup table maps (nbnum, face) -> new_face when neighbors cross face boundaries
# nbnum encodes the boundary crossing direction, face is the original face (0-11)
# Based on original HEALPix C++ implementation's neighbor finding algorithm
_NB_FACEARRAY = jnp.array(
    [
        [8, 9, 10, 11, -1, -1, -1, -1, 10, 11, 8, 9],  # S
        [5, 6, 7, 4, 8, 9, 10, 11, 9, 10, 11, 8],  # SE
        [-1, -1, -1, -1, 5, 6, 7, 4, -1, -1, -1, -1],  # E
        [4, 5, 6, 7, 11, 8, 9, 10, 11, 8, 9, 10],  # SW
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],  # center
        [1, 2, 3, 0, 0, 1, 2, 3, 5, 6, 7, 4],  # NE
        [-1, -1, -1, -1, 7, 4, 5, 6, -1, -1, -1, -1],  # W
        [3, 0, 1, 2, 3, 0, 1, 2, 4, 5, 6, 7],  # NW
        [2, 3, 0, 1, -1, -1, -1, -1, 0, 1, 2, 3],  # N
    ],
    dtype=jnp.int32,
)

# Coordinate transformation bits (9x3) - handles x/y swapping across face boundaries
# This lookup table provides bit flags for coordinate transformations when crossing faces
# Bit 1: flip x coordinate, Bit 2: flip y coordinate, Bit 4: swap x and y coordinates
# Index by (nbnum, face >> 2) to get transformation bits for the boundary crossing
_NB_SWAPARRAY = jnp.array(
    [
        [0, 0, 3],  # S
        [0, 0, 6],  # SE
        [0, 0, 0],  # E
        [0, 0, 5],  # SW
        [0, 0, 0],  # center
        [5, 0, 0],  # NE
        [0, 0, 0],  # W
        [6, 0, 0],  # NW
        [3, 0, 0],  # N
    ],
    dtype=jnp.int32,
)


def check_theta_valid(theta):
    """
    JIT Compatible check_theta_valid
    Raises exception if theta is not within 0 and pi
    """
    invalid_theta = ~((theta >= 0).all() & (theta <= np.pi + 1e-5).all())

    def _raise_invalid_theta(invalid_theta):
        if invalid_theta:
            raise ValueError('THETA is out of range [0,pi]')

    jax.debug.callback(_raise_invalid_theta, invalid_theta)


def check_nside(nside: int, nest: bool = False) -> None:
    """Raises exception is nside is not valid"""
    if not np.all(isnsideok(nside, nest=nest)):
        raise ValueError(f'{nside} is not a valid nside parameter (must be a power of 2, less than 2**30)')


def _pixel_dtype_for(nside: int) -> jnp.dtype:
    """Returns the appropriate dtype for a pixel number given nside"""
    # for nside = 13378, npix = 2_147_650_608 which would overflow int32
    return jnp.int32 if nside <= 13377 else jnp.int64


[docs] def isnsideok(nside: int, nest: bool = False) -> bool: """Returns :const:`True` if nside is a valid nside parameter, :const:`False` otherwise. NSIDE needs to be a power of 2 only for nested ordering Parameters ---------- nside : int, scalar or array-like integer value to be tested Returns ------- ok : bool, scalar or array-like :const:`True` if given value is a valid nside, :const:`False` otherwise. Examples -------- >>> import jax_healpy as hp >>> hp.isnsideok(13, nest=True) False >>> hp.isnsideok(13, nest=False) True >>> hp.isnsideok(32) True >>> hp.isnsideok([1, 2, 3, 4, 8, 16], nest=True) array([ True, True, False, True, True, True], dtype=bool) """ # we use standard bithacks from http://graphics.stanford.edu/~seander/bithacks.html#DetermineIfPowerOf2 if hasattr(nside, '__len__'): if not isinstance(nside, np.ndarray): nside = np.asarray(nside) is_nside_ok = (nside == nside.astype(int)) & (nside > 0) & (nside <= MAX_NSIDE) if nest: is_nside_ok &= (nside.astype(int) & (nside.astype(int) - 1)) == 0 else: is_nside_ok = nside == int(nside) and 0 < nside <= MAX_NSIDE if nest: is_nside_ok = is_nside_ok and (int(nside) & (int(nside) - 1)) == 0 return is_nside_ok
[docs] def isnpixok(npix: int) -> bool: """Return :const:`True` if npix is a valid value for healpix map size, :const:`False` otherwise. Parameters ---------- npix : int, scalar or array-like integer value to be tested Returns ------- ok : bool, scalar or array-like :const:`True` if given value is a valid number of pixel, :const:`False` otherwise Examples -------- >>> import jax_healpy as hp >>> hp.isnpixok(12) True >>> hp.isnpixok(768) True >>> hp.isnpixok([12, 768, 1002]) array([ True, True, False], dtype=bool) """ nside = np.sqrt(np.asarray(npix) / 12.0) return nside == np.floor(nside)
[docs] def nside2npix(nside: int) -> int: """Give the number of pixels for the given nside. Parameters ---------- nside : int healpix nside parameter Returns ------- npix : int corresponding number of pixels Examples -------- >>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2npix(8) 768 >>> np.all([hp.nside2npix(nside) == 12 * nside**2 for nside in [2**n for n in range(12)]]) True >>> hp.nside2npix(7) 588 """ return 12 * nside * nside
[docs] def npix2nside(npix: int) -> int: """Give the nside parameter for the given number of pixels. Parameters ---------- npix : int the number of pixels Returns ------- nside : int the nside parameter corresponding to npix Notes ----- Raise a ValueError exception if number of pixel does not correspond to the number of pixel of a healpix map. Examples -------- >>> import jax_healpy as hp >>> hp.npix2nside(768) 8 >>> np.all([hp.npix2nside(12 * nside**2) == nside for nside in [2**n for n in range(12)]]) True >>> hp.npix2nside(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2) """ if not isnpixok(npix): raise ValueError('Wrong pixel number (it is not 12*nside**2)') return int(np.sqrt(npix / 12.0))
[docs] def nside2order(nside: int) -> int: """Give the resolution order for a given nside. Parameters ---------- nside : int healpix nside parameter; an exception is raised if nside is not valid (nside must be a power of 2, less than 2**30) Returns ------- order : int corresponding order where nside = 2**(order) Notes ----- Raise a ValueError exception if nside is not valid. Examples -------- >>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2order(128) 7 >>> all(hp.nside2order(2**o) == o for o in range(30)) True >>> hp.nside2order(7) Traceback (most recent call last): ... ValueError: 7 is not a valid nside parameter (must be a power of 2, less than 2**30) """ check_nside(nside, nest=True) return len(f'{nside:b}') - 1
[docs] def order2nside(order: int) -> int: """Give the nside parameter for the given resolution order. Parameters ---------- order : int the resolution order Returns ------- nside : int the nside parameter corresponding to order Notes ----- Raise a ValueError exception if order produces an nside out of range. Examples -------- >>> import jax_healpy as hp >>> hp.order2nside(7) 128 >>> print(hp.order2nside(np.arange(8))) [ 1 2 4 8 16 32 64 128] >>> hp.order2nside(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30) """ nside = 1 << order check_nside(nside, nest=True) return nside
[docs] def order2npix(order: int) -> int: """Give the number of pixels for the given resolution order. Parameters ---------- order : int the resolution order Returns ------- npix : int corresponding number of pixels Notes ----- A convenience function that successively applies order2nside then nside2npix to order. Examples -------- >>> import jax_healpy as hp >>> hp.order2npix(7) 196608 >>> print(hp.order2npix(np.arange(8))) [ 12 48 192 768 3072 12288 49152 196608] >>> hp.order2npix(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30) """ nside = order2nside(order) npix = nside2npix(nside) return npix
[docs] def npix2order(npix: int) -> int: """Give the resolution order for the given number of pixels. Parameters ---------- npix : int the number of pixels Returns ------- order : int corresponding resolution order Notes ----- A convenience function that successively applies npix2nside then nside2order to npix. Examples -------- >>> import jax_healpy as hp >>> hp.npix2order(768) 3 >>> np.all([hp.npix2order(12 * 4**order) == order for order in range(12)]) True >>> hp.npix2order(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2) """ nside = npix2nside(npix) order = nside2order(nside) return order
[docs] def nside2resol(nside: int, arcmin=False) -> float: """Give approximate resolution (pixel size in radian or arcmin) for nside. Resolution is just the square root of the pixel area, which is a gross approximation given the different pixel shapes Parameters ---------- nside : int healpix nside parameter, must be a power of 2, less than 2**30 arcmin : bool if True, return resolution in arcmin, otherwise in radian Returns ------- resol : float approximate pixel size in radians or arcmin Notes ----- Raise a ValueError exception if nside is not valid. Examples -------- >>> import jax_healpy as hp >>> hp.nside2resol(128, arcmin = True) # doctest: +FLOAT_CMP 27.483891294539248 >>> hp.nside2resol(256) 0.0039973699529159707 >>> hp.nside2resol(7) 0.1461895297066412 """ resol = np.sqrt(nside2pixarea(nside)) if arcmin: resol = np.rad2deg(resol) * 60 return resol
[docs] def nside2pixarea(nside: int, degrees=False) -> float: """Give pixel area given nside in square radians or square degrees. Parameters ---------- nside : int healpix nside parameter, must be a power of 2, less than 2**30 degrees : bool if True, returns pixel area in square degrees, in square radians otherwise Returns ------- pixarea : float pixel area in square radian or square degree Notes ----- Raise a ValueError exception if nside is not valid. Examples -------- >>> import jax_healpy as hp >>> hp.nside2pixarea(128, degrees = True) # doctest: +FLOAT_CMP 0.2098234113027917 >>> hp.nside2pixarea(256) 1.5978966540475428e-05 >>> hp.nside2pixarea(7) 0.021371378595848933 """ pixarea = 4 * np.pi / nside2npix(nside) if degrees: pixarea = np.rad2deg(np.rad2deg(pixarea)) return pixarea
def _lonlat2thetaphi(lon: ArrayLike, lat: ArrayLike): """Transform longitude and latitude (deg) into co-latitude and longitude (rad) Parameters ---------- lon : int or array-like Longitude in degrees lat : int or array-like Latitude in degrees Returns ------- theta, phi : float, scalar or array-like The co-latitude and longitude in radians """ return np.pi / 2 - jnp.radians(lat), jnp.radians(lon) def _thetaphi2lonlat(theta, phi): """Transform co-latitude and longitude (rad) into longitude and latitude (deg) Parameters ---------- theta : int or array-like Co-latitude in radians phi : int or array-like Longitude in radians Returns ------- lon, lat : float, scalar or array-like The longitude and latitude in degrees """ return jnp.degrees(phi), 90.0 - jnp.degrees(theta)
[docs] def maptype(m): """Describe the type of the map (valid, single, sequence of maps). Checks : the number of maps, that all maps have same length and that this length is a valid map size (using :func:`isnpixok`). Parameters ---------- m : sequence the map to get info from Returns ------- info : int -1 if the given object is not a valid map, 0 if it is a single map, *info* > 0 if it is a sequence of maps (*info* is then the number of maps) Examples -------- >>> import healpy as hp >>> hp.pixelfunc.maptype(np.arange(12)) 0 >>> hp.pixelfunc.maptype([np.arange(12), np.arange(12)]) 2 """ if not hasattr(m, '__len__'): raise TypeError('input map is a scalar') if len(m) == 0: raise TypeError('input map has length zero') try: npix = len(m[0]) except TypeError: npix = None if npix is not None: for mm in m[1:]: if len(mm) != npix: raise TypeError('input maps have different npix') if isnpixok(len(m[0])): return len(m) else: raise TypeError('bad number of pixels') else: if isnpixok(len(m)): return 0 else: raise TypeError('bad number of pixels')
[docs] @partial(jit, static_argnames=['nside', 'nest', 'lonlat']) def ang2pix( nside: int, theta: ArrayLike, phi: ArrayLike, nest: bool = False, lonlat: bool = False, ) -> Array: """ang2pix: nside,theta[rad],phi[rad],nest=False,lonlat=False -> ipix (default:RING) Unlike healpy.ang2pix, specifying a theta not in the range [0, π] does not raise an error, but returns -1. Parameters ---------- nside : int, scalar or array-like The healpix nside parameter, must be a power of 2, less than 2**30 theta, phi : float, scalars or array-like Angular coordinates of a point on the sphere nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering lonlat : bool If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians. Returns ------- pix : int or array of int The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply. See Also -------- pix2ang, pix2vec, vec2pix Examples -------- Note that some of the test inputs below that are on pixel boundaries such as theta=π/2, phi=π/2, have a tiny value of 1e-15 added to them to make them reproducible on i386 machines using x87 floating point instruction set (see https://github.com/healpy/healpy/issues/528). >>> import jax_healpy as hp >>> from jax.numpy import pi as π >>> hp.ang2pix(16, π/2, 0) Array(1440, dtype=int64) >>> print(hp.ang2pix(16, np.array([π/2, π/4, π/2, 0, π]), np.array([0., π/4, π/2 + 1e-15, 0, 0]))) [1440 427 1520 0 3068] >>> print(hp.ang2pix(16, π/2, np.array([0, π/2 + 1e-15]))) [1440 1520] >>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), π/2, 0)) [ 4 12 72 336 1440] >>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), 0, 0, lonlat=True)) [ 4 12 72 336 1440] """ # check_theta_valid(theta) check_nside(nside, nest=nest) if nest: raise NotImplementedError('NEST pixel ordering is not implemented.') if lonlat: theta, phi = _lonlat2thetaphi(theta, phi) pixels = _zphi2pix_ring(nside, jnp.cos(theta), jnp.sin(theta), phi) return jnp.where((theta < 0) | (theta > np.pi + 1e-5), -1, pixels)
def _zphi2pix_ring(nside: int, z: ArrayLike, sin_theta: ArrayLike, phi: ArrayLike) -> Array: tt = jnp.mod(2 * phi / np.pi, 4) ipix = jnp.where( jnp.abs(z) <= 2 / 3, _zphi2pix_equatorial_region_ring(nside, z, sin_theta, tt), _zphi2pix_polar_caps_ring(nside, z, sin_theta, tt), ) return ipix def _zphi2pix_equatorial_region_ring(nside: int, z: ArrayLike, sin_theta: float, tt: ArrayLike) -> Array: ncap = 2 * nside * (nside - 1) nl4 = 4 * nside jp = (nside * (0.5 + tt - 0.75 * z)).astype(int) jm = (nside * (0.5 + tt + 0.75 * z)).astype(int) ir = nside + 1 + jp - jm kshift = 1 - ir & 1 # ir even -> 1, odd -> 0 t1 = jp + jm - nside + kshift + 1 + nl4 + nl4 ip = (t1 >> 1) & (nl4 - 1) pix = ncap + (ir - 1) * nl4 + ip return pix def _zphi2pix_polar_caps_ring(nside: int, z: ArrayLike, sin_theta: ArrayLike, tt: ArrayLike) -> Array: npixel = nside2npix(nside) tp = tt - jnp.floor(tt) # tmp = nside * sin_theta / jnp.sqrt((1 + jnp.abs(z)) / 3) tmp = nside * jnp.sqrt(3.0 * (1.0 - jnp.abs(z))) jp = (tp * tmp).astype(int) jm = ((1.0 - tp) * tmp).astype(int) ir = jp + jm + 1 ip = (tt * ir).astype(int) return jnp.where(z > 0, 2 * ir * (ir - 1) + ip, npixel - 2 * ir * (ir + 1) + ip)
[docs] @partial(jit, static_argnames=['nside', 'nest', 'lonlat']) def pix2ang(nside: int, ipix: ArrayLike, nest: bool = False, lonlat: bool = False) -> tuple[Array, Array]: """pix2ang : nside,ipix,nest=False,lonlat=False -> theta[rad],phi[rad] (default RING) Parameters ---------- nside : int or array-like The healpix nside parameter, must be a power of 2, less than 2**30 ipix : int or array-like Pixel indices nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering lonlat : bool, optional If True, return angles will be longitude and latitude in degree, otherwise, angles will be co-latitude and longitude in radians (default) Returns ------- theta, phi : float, scalar or array-like The angular coordinates corresponding to ipix. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply. See Also -------- ang2pix, vec2pix, pix2vec Examples -------- >>> import jax_healpy as hp >>> hp.pix2ang(16, 1440) (1.5291175943723188, 0.0) >>> hp.pix2ang(16, [1440, 427, 1520, 0, 3068]) (array([ 1.52911759, 0.78550497, 1.57079633, 0.05103658, 3.09055608]), array([ 0. , 0.78539816, 1.61988371, 0.78539816, 0.78539816])) >>> hp.pix2ang([1, 2, 4, 8], 11) (array([ 2.30052398, 0.84106867, 0.41113786, 0.2044802 ]), array([ 5.49778714, 5.89048623, 5.89048623, 5.89048623])) >>> hp.pix2ang([1, 2, 4, 8], 11, lonlat=True) (array([ 315. , 337.5, 337.5, 337.5]), array([-41.8103149 , 41.8103149 , 66.44353569, 78.28414761])) """ # noqa: E501 check_nside(nside, nest=nest) if nest: theta, phi = _pix2ang_nest(nside, ipix) else: iring = _pix2i_ring(nside, ipix) theta = _pix2theta_ring(nside, iring, ipix) phi = _pix2phi_ring(nside, iring, ipix) if lonlat: return _thetaphi2lonlat(theta, phi) return theta, phi
def _pix2i_ring(nside: int, pixels: ArrayLike) -> Array: npixel = nside2npix(nside) ncap = 2 * nside * (nside - 1) iring = jnp.where( pixels < ncap, _pix2i_north_cap_ring(nside, pixels), jnp.where( pixels < npixel - ncap, _pix2i_equatorial_region_ring(nside, pixels), _pix2i_south_cap_ring(nside, pixels), ), ) return iring def _pix2i_north_cap_ring(nside: int, pixels: ArrayLike) -> Array: return (1 + jnp.sqrt(1 + 2 * pixels).astype(int)) >> 1 # counted from North Pole def _pix2i_equatorial_region_ring(nside: int, pixels: ArrayLike) -> Array: ncap = 2 * nside * (nside - 1) ip = pixels - ncap order = nside2order(nside) # I tmp = (order_>=0) ? ip>>(order_+2) : ip/nl4; tmp = ip >> (order + 2) return tmp + nside def _pix2i_south_cap_ring(nside: int, pixels: ArrayLike) -> Array: npixel = nside2npix(nside) ip = npixel - pixels return (1 + jnp.sqrt(2 * ip - 1).astype(int)) >> 1 # counted from South Pole def _pix2z_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> tuple[Array, Array]: npixel = nside2npix(nside) ncap = 2 * nside * (nside - 1) abs_one_minus_z = _pix2z_polar_caps_ring(nside, iring) z = jnp.where( pixels < ncap, 1 - abs_one_minus_z, jnp.where( pixels < npixel - ncap, _pix2z_equatorial_region_ring(nside, iring), abs_one_minus_z - 1, ), ) return z, abs_one_minus_z def _pix2z_polar_caps_ring(nside: int, iring: ArrayLike) -> Array: npixel = nside2npix(nside) return iring * iring * 4 / npixel def _pix2z_equatorial_region_ring(nside: int, iring: ArrayLike) -> Array: return (2 * nside - iring) * 2 / 3 / nside def _pix2theta_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array: z, abs_one_minus_z = _pix2z_ring(nside, iring, pixels) theta = jnp.where( jnp.abs(z) > 0.99, jnp.arctan2(jnp.sqrt(abs_one_minus_z * (2 - abs_one_minus_z)), z), jnp.arccos(z), ) return theta def _pix2phi_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array: npixel = nside2npix(nside) ncap = 2 * nside * (nside - 1) phi = jnp.where( pixels < ncap, _pix2phi_north_cap_ring(nside, iring, pixels), jnp.where( pixels < npixel - ncap, _pix2phi_equatorial_region_ring(nside, iring, pixels), _pix2phi_south_cap_ring(nside, iring, pixels), ), ) return phi def _pix2phi_north_cap_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array: iphi = pixels + 1 - 2 * iring * (iring - 1) phi = (iphi - 0.5) * np.pi / 2 / iring return phi def _pix2phi_equatorial_region_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array: iphi = pixels + 2 * nside * (nside + 1) - 4 * nside * iring + 1 fodd = ((iring + nside) & 1) * 0.5 + 0.5 # iring + nside odd -> 1 else 0.5 phi = (iphi - fodd) * np.pi / 2 / nside return phi def _pix2phi_south_cap_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array: npixel = nside2npix(nside) iphi = 4 * iring + 1 - (npixel - pixels - 2 * iring * (iring - 1)) phi = (iphi - 0.5) * np.pi / 2 / iring return phi def _pix2ang_nest(nside: ArrayLike, ipix: ArrayLike) -> tuple[Array, Array]: raise NotImplementedError('NEST pixel ordering is not implemented.') # template<typename I> void T_Healpix_Base<I>::pix2loc (I pix, double &z, # double &phi, double &sth, bool &have_sth) const # have_sth=false; # { # int face_num, ix, iy; # nest2xyf(pix,ix,iy,face_num); # # I jr = (I(jrll[face_num])<<order_) - ix - iy - 1; # # I nr; # if (jr<nside_) # { # nr = jr; # double tmp=(nr*nr)*fact2_; # z = 1 - tmp; # if (z>0.99) { sth=sqrt(tmp*(2.-tmp)); have_sth=true; } # } # else if (jr > 3*nside_) # { # nr = nside_*4-jr; # double tmp=(nr*nr)*fact2_; # z = tmp - 1; # if (z<-0.99) { sth=sqrt(tmp*(2.-tmp)); have_sth=true; } # } # else # { # nr = nside_; # z = (2*nside_-jr)*fact1_; # } # # I tmp=I(jpll[face_num])*nr+ix-iy; # planck_assert(tmp<8*nr,"must not happen"); # if (tmp<0) tmp+=8*nr; # phi = (nr==nside_) ? 0.75*halfpi*tmp*fact1_ : # (0.5*halfpi*tmp)/nr; # } # }
[docs] @partial(jit, static_argnames=['nside', 'nest']) def xyf2pix(nside: int, x: ArrayLike, y: ArrayLike, face: ArrayLike, nest: bool = False) -> Array: """xyf2pix : nside,x,y,face,nest=False -> ipix (default:RING) Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc. Parameters ---------- nside : int The healpix nside parameter, must be a power of 2 x, y : int, scalars or array-like Pixel indices within face face : int, scalars or array-like Face number nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering Returns ------- pix : int or array of int The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply. See Also -------- pix2xyf Examples -------- >>> import healpy as hp >>> hp.xyf2pix(16, 8, 8, 4) 1440 >>> print(hp.xyf2pix(16, [8, 8, 8, 15, 0], [8, 8, 7, 15, 0], [4, 0, 5, 0, 8])) [1440 427 1520 0 3068] """ check_nside(nside, nest=nest) x = jnp.asarray(x) y = jnp.asarray(y) face = jnp.asarray(face) if nest: return _xyf2pix_nest(nside, x, y, face) else: return _xyf2pix_ring(nside, x, y, face)
def _xyf2pix_nest(nside: int, ix: Array, iy: Array, fnum: Array) -> Array: """Convert (x, y, face) to pixel number in NESTED ordering""" fpix = _xy2fpix(nside, ix, iy) nested_pixel = fnum * nside**2 + fpix return nested_pixel def _xy2fpix(nside: int, ix: Array, iy: Array) -> Array: """Convert (x, y) coordinates to a pixel index inside a face""" # fpix = (ix & 0b1) << 0 | (iy & 0b1) << 1 | (ix & 0b10) << 1 | (iy & 0b10) << 2 | ... def combine_bits(i, val): val |= (ix & (1 << i)) << i val |= (iy & (1 << i)) << (i + 1) return val # ix and iy are always less than nside, so there is no need to extract more bits than this length = (nside - 1).bit_length() # we use a native for loop because it was slightly faster than lax.fori_loop with unroll=True fpix = jnp.zeros_like(ix) for i in range(length): fpix = combine_bits(i, fpix) return fpix # ring index of south corner for each face (0 = North pole) _JRLL = jnp.array([2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]) # longitude index of south corner for each face (0 = longitude zero) _JPLL = jnp.array([1, 3, 5, 7, 0, 2, 4, 6, 1, 3, 5, 7]) def _xyf2pix_ring(nside: int, ix: Array, iy: Array, face_num: Array) -> Array: """Convert (x, y, face) to a pixel number in RING ordering""" # ring index of the pixel center jr = (_JRLL[face_num] * nside) - ix - iy - 1 ringpix = _npix_on_ring(nside, jr) startpix = _start_pixel_ring(nside, jr) kshift = 1 - _ring_shifted(nside, jr) # pixel number in the ring jp = (_JPLL[face_num] * ringpix // 4 + ix - iy + 1 + kshift) // 2 jp = jnp.where(jp < 1, jp + 4 * nside, jp) return startpix - 1 + jp def _start_pixel_ring(nside: int, i_ring: ArrayLike) -> Array: """Get the first pixel number of a ring""" # work in northern hemisphere i_north = _northern_ring(nside, i_ring) ringpix = _npix_on_ring(nside, i_ring) ncap = 2 * nside * (nside - 1) npix = nside2npix(nside) startpix = jnp.where( i_north < nside, 2 * i_north * (i_north - 1), # north polar cap ncap + (i_north - nside) * 4 * nside, # north equatorial belt ) # flip results if in southern hemisphere startpix = jnp.where(i_ring != i_north, npix - startpix - ringpix, startpix) return startpix def _npix_on_ring(nside: int, i_ring: ArrayLike) -> Array: """Get the number of pixels on a ring""" i_north = _northern_ring(nside, i_ring) ringpix = jnp.where( i_north < nside, 4 * i_north, # rings in the polar cap have 4*i pixels 4 * nside, # rings in the equatorial region have 4*nside pixels ) return ringpix def _ring_shifted(nside: int, i_ring: ArrayLike) -> Array: """Check if a ring is shifted""" i_north = _northern_ring(nside, i_ring) shifted = jnp.where( i_north < nside, True, (i_north - nside) & 1 == 0, ) return shifted def _northern_ring(nside: int, i_ring: ArrayLike) -> Array: i_north = jnp.where(i_ring > 2 * nside, 4 * nside - i_ring, i_ring) return i_north def _ring_above(nside: int, cos_theta: ArrayLike) -> Array: """Find the ring index just north of a point with given cos(theta). This follows the exact HEALPix C++ implementation: if (az <= twothird) // equatorial region return I(nside_*(2-1.5*z)); I iring = I(nside_*sqrt(3*(1-az))); return (z>0) ? iring : 4*nside_-iring-1; """ z = cos_theta az = jnp.abs(z) twothird = 2.0 / 3.0 # Equatorial region equatorial_ring = nside * (2.0 - 1.5 * z) # Stop gradient: Ring indices are discrete selectors, don't affect interpolation math equatorial_ring = lax.stop_gradient(jnp.floor(equatorial_ring).astype(jnp.int32)) # Polar caps iring = nside * jnp.sqrt(3.0 * (1.0 - az)) # Stop gradient: Ring indices are discrete selectors, don't affect interpolation math iring = lax.stop_gradient(jnp.floor(iring).astype(jnp.int32)) polar_ring = jnp.where(z > 0, iring, 4 * nside - iring - 1) # Choose based on z value ring_idx = jnp.where(az <= twothird, equatorial_ring, polar_ring) return ring_idx def _get_ring_info(nside: int, ring_idx: ArrayLike) -> tuple[Array, Array, Array, Array]: """Get ring properties following HEALPix C++ get_ring_info2 exactly. Returns: theta, startpix, ringpix, shifted """ # Convert to scalar for compatibility ring = ring_idx # HEALPix C++ constants fact1 = 2.0 / (3.0 * nside) # (nside_<<1)*fact2_ where fact2_ = 4./npix_ = 1/(3*nside**2) fact2 = 4.0 / (12.0 * nside * nside) # 4./npix_ ncap = 2 * nside * (nside - 1) npix_total = 12 * nside * nside # Northern hemisphere equivalent ring northring = jnp.where(ring > 2 * nside, 4 * nside - ring, ring) # Polar cap region (northring < nside) polar_tmp = northring * northring * fact2 polar_costheta = 1.0 - polar_tmp polar_sintheta = jnp.sqrt(polar_tmp * (2.0 - polar_tmp)) polar_theta = jnp.arctan2(polar_sintheta, polar_costheta) polar_ringpix = 4 * northring polar_shifted = True polar_startpix = 2 * northring * (northring - 1) # Equatorial region (northring >= nside) equatorial_theta = jnp.arccos((2.0 * nside - northring) * fact1) equatorial_ringpix = 4 * nside equatorial_shifted = ((northring - nside) & 1) == 0 equatorial_startpix = ncap + (northring - nside) * equatorial_ringpix # Choose based on region theta = jnp.where(northring < nside, polar_theta, equatorial_theta) ringpix = jnp.where(northring < nside, polar_ringpix, equatorial_ringpix) shifted = jnp.where(northring < nside, polar_shifted, equatorial_shifted) startpix = jnp.where(northring < nside, polar_startpix, equatorial_startpix) # Southern hemisphere correction theta = jnp.where(northring != ring, np.pi - theta, theta) startpix = jnp.where(northring != ring, npix_total - startpix - ringpix, startpix) # Convert shifted boolean to float (0.0 or 0.5) shift = jnp.where(shifted, 0.5, 0.0) return theta, startpix, ringpix, shift
[docs] @partial(jit, static_argnames=['nside', 'nest']) def pix2xyf(nside: int, ipix: ArrayLike, nest: bool = False) -> tuple[Array, Array, Array]: """pix2xyf : nside,ipix,nest=False -> x,y,face (default RING) Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc. Parameters ---------- nside : int The healpix nside parameter, must be a power of 2 ipix : int or array-like Pixel indices nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering Returns ------- x, y : int, scalars or array-like Pixel indices within face face : int, scalars or array-like Face number See Also -------- xyf2pix Examples -------- >>> import healpy as hp >>> hp.pix2xyf(16, 1440) (8, 8, 4) >>> hp.pix2xyf(16, [1440, 427, 1520, 0, 3068]) (array([ 8, 8, 8, 15, 0]), array([ 8, 8, 7, 15, 0]), array([4, 0, 5, 0, 8])) >>> hp.pix2xyf([1, 2, 4, 8], 11) (array([0, 1, 3, 7]), array([0, 0, 2, 6]), array([11, 3, 3, 3])) """ check_nside(nside, nest=nest) ipix = jnp.asarray(ipix) if nest: return _pix2xyf_nest(nside, ipix) else: return _pix2xyf_ring(nside, ipix)
def _pix2xyf_nest(nside: int, pix: Array) -> tuple[Array, Array, Array]: """Convert a pixel number in NESTED ordering to (x, y, face)""" fnum, fpix = jnp.divmod(pix, nside**2) ix, iy = _fpix2xy(nside, fpix) return ix, iy, fnum def _fpix2xy(nside: int, pix: Array) -> tuple[Array, Array]: """Convert a pixel index inside a face into (x, y) coordinates. Pixel indices inside the face must be less than nside**2. """ # x = (pix & 0b1) >> 0 | (pix & 0b100) >> 1 | (pix & 0b10000) >> 2 | ... # y = (pix & 0b10) >> 1 | (pix & 0b1000) >> 2 | (pix & 0b100000) >> 3 | ... def extract_bits(i, carry): x, y = carry x |= (pix & (1 << (2 * i))) >> i y |= (pix & (1 << (2 * i + 1))) >> (i + 1) return x, y # imagine that nside = 2 ** ord (nside must be a power of 2 in nested ordering) # the maximum value of pix is nside**2 - 1, which fits on 2 * ord bits # because we extract 2 bits at a time, we need to loop ord times # and ord is the bit length of nside - 1 length = (nside - 1).bit_length() # we use a native for loop because it was slightly faster than lax.fori_loop with unroll=True x, y = jnp.zeros_like(pix), jnp.zeros_like(pix) for i in range(length): x, y = extract_bits(i, (x, y)) return x, y def _pix2xyf_ring(nside: int, pix: Array) -> tuple[Array, Array, Array]: """Convert a pixel number in RING ordering to (x, y, face)""" ncap = 2 * nside * (nside - 1) npix = nside2npix(nside) nl2 = 2 * nside # number of pixels in a latitude circle # TODO(simon): remove this cast when https://github.com/CMBSciPol/jax-healpy/issues/4 is fixed iring = _pix2i_ring(nside, pix).astype(_pixel_dtype_for(nside)) iphi = _pix2iphi_ring(nside, iring, pix) nr = _npix_on_ring(nside, iring) // 4 kshift = 1 - _ring_shifted(nside, iring) ire = iring - nside + 1 irm = nl2 + 2 - ire ifm = (iphi - ire // 2 + nside - 1) // nside ifp = (iphi - irm // 2 + nside - 1) // nside face_num = jnp.where( pix < ncap, (iphi - 1) // nr, # north polar cap jnp.where( pix < (npix - ncap), jnp.where(ifp == ifm, ifp | 4, jnp.where(ifp < ifm, ifp, ifm + 8)), 8 + (iphi - 1) // nr, # south polar cap ), ) iring_for_irt = jnp.where( jnp.logical_or(pix < ncap, pix < (npix - ncap)), iring, # north polar cap and equatorial region 4 * nside - iring, # south polar cap ) # ring number counted from North pole or South pole irt = iring_for_irt - (_JRLL[face_num] * nside) + 1 ipt = 2 * iphi - _JPLL[face_num] * nr - kshift - 1 ipt -= jnp.where(ipt >= nl2, 8 * nside, 0) ix = (ipt - irt) // 2 iy = (-ipt - irt) // 2 return ix, iy, face_num def _pix2iphi_ring(nside: int, iring: Array, pixels: Array) -> Array: npixel = nside2npix(nside) ncap = 2 * nside * (nside - 1) iphi = jnp.where( pixels < ncap, _pix2iphi_north_cap_ring(nside, iring, pixels), jnp.where( pixels < npixel - ncap, _pix2iphi_equatorial_region_ring(nside, iring, pixels), _pix2iphi_south_cap_ring(nside, iring, pixels), ), ) return iphi def _pix2iphi_north_cap_ring(nside: int, iring: Array, pixels: Array) -> Array: iphi = pixels + 1 - 2 * iring * (iring - 1) return iphi def _pix2iphi_equatorial_region_ring(nside: int, iring: Array, pixels: Array) -> Array: iphi = pixels + 2 * nside * (nside + 1) - 4 * nside * iring + 1 return iphi def _pix2iphi_south_cap_ring(nside: int, iring: Array, pixels: Array) -> Array: npixel = nside2npix(nside) iphi = 4 * iring + 1 - (npixel - pixels - 2 * iring * (iring - 1)) return iphi
[docs] @partial(jit, static_argnames=['nside', 'nest']) def vec2pix(nside: int, x: ArrayLike, y: ArrayLike, z: ArrayLike, nest: bool = False) -> Array: """vec2pix : nside,x,y,z,nest=False -> ipix (default:RING) Parameters ---------- nside : int or array-like The healpix nside parameter, must be a power of 2, less than 2**30 x,y,z : floats or array-like vector coordinates defining point on the sphere nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering Returns ------- ipix : int, scalar or array-like The healpix pixel number corresponding to input vector. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply. See Also -------- ang2pix, pix2ang, pix2vec Examples -------- >>> import healpy as hp >>> hp.vec2pix(16, 1, 0, 0) 1504 >>> print(hp.vec2pix(16, [1, 0], [0, 1], [0, 0])) [1504 1520] >>> print(hp.vec2pix([1, 2, 4, 8], 1, 0, 0)) [ 4 20 88 368] """ check_nside(nside, nest=nest) if nest: raise NotImplementedError return _vec2pix_ring(nside, x, y, z)
def vec2pix2(nside: int, vec: ArrayLike, nest: bool = False) -> Array: return vec2pix2_ring(nside, vec) @partial(jit, static_argnames='nside') @partial(vmap, in_axes=(None, 1)) def vec2pix2_ring(nside: int, vec: ArrayLike) -> Array: vec /= jnp.sqrt(jnp.sum(vec**2)) phi = jnp.arctan2(vec[1], vec[0]) # return _zphi2pix_ring(nside, vec[2], jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), phi) return _zphi2pix_ring(nside, vec[2], jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), phi) def _vec2pix_ring(nside: int, x: ArrayLike, y: ArrayLike, z: ArrayLike) -> Array: dnorm = 1 / jnp.sqrt(x**2 + y**2 + z**2) z *= dnorm phi = jnp.arctan2(y, x) return _zphi2pix_ring(nside, z, jnp.sqrt(x**2 + y**2) * dnorm, phi)
[docs] @partial(jit, static_argnames=['nside', 'nest']) def pix2vec(nside: int, ipix: ArrayLike, nest: bool = False) -> Array: """pix2vec : nside,ipix,nest=False -> x,y,z (default RING) Parameters ---------- nside : int, scalar or array-like The healpix nside parameter, must be a power of 2, less than 2**30 ipix : int, scalar or array-like Healpix pixel number nest : bool, optional if True, assume NESTED pixel ordering, otherwise, RING pixel ordering Returns ------- x, y, z : floats, scalar or array-like The coordinates of vector corresponding to input pixels. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply. See Also -------- ang2pix, pix2ang, vec2pix Examples -------- >>> import healpy as hp >>> hp.pix2vec(16, 1504) (0.99879545620517241, 0.049067674327418015, 0.0) >>> hp.pix2vec(16, [1440, 427]) (array([ 0.99913157, 0.5000534 ]), array([ 0. , 0.5000534]), array([ 0.04166667, 0.70703125])) >>> hp.pix2vec([1, 2], 11) (array([ 0.52704628, 0.68861915]), array([-0.52704628, -0.28523539]), array([-0.66666667, 0.66666667])) """ check_nside(nside, nest=nest) if nest: raise NotImplementedError return _pix2vec_ring(nside, ipix)
def _pix2vec_ring(nside, pixels): iring = _pix2i_ring(nside, pixels) z, abs_one_minus_z = _pix2z_ring(nside, iring, pixels) phi = _pix2phi_ring(nside, iring, pixels) sin_theta = jnp.sqrt( jnp.where( jnp.abs(z) > 0.99, abs_one_minus_z * (2 - abs_one_minus_z), (1 - z) * (1 + z), ) ) return jnp.array([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z]).T
[docs] @partial(jit, static_argnames=['lonlat']) def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array: """ang2vec : convert angles to 3D position vector Parameters ---------- theta : float, scalar or array-like co-latitude in radians measured southward from the North pole (in [0,pi]). phi : float, scalar or array-like longitude in radians measured eastward (in [0, 2*pi]). lonlat : bool If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians. Returns ------- vec : float, array if theta and phi are vectors, the result is a 2D array with a vector per row otherwise, it is a 1D array of shape (3,) See Also -------- vec2ang, rotator.dir2vec, rotator.vec2dir """ if lonlat: theta, phi = _lonlat2thetaphi(theta, phi) theta = jnp.where((theta < 0) | (theta > np.pi + 1e-5), np.nan, theta) sin_theta = jnp.sin(theta) x = sin_theta * jnp.cos(phi) y = sin_theta * jnp.sin(phi) z = jnp.cos(theta) return jnp.array([x, y, z]).T
[docs] @partial(jit, static_argnames=['lonlat']) def vec2ang(vectors: ArrayLike, lonlat: bool = False) -> tuple[Array, Array]: """vec2ang: vectors [x, y, z] -> theta[rad], phi[rad] Parameters ---------- vectors : float, array-like the vector(s) to convert, shape is (3,) or (N, 3) lonlat : bool, optional If True, return angles will be longitude and latitude in degree, otherwise, angles will be co-latitude and longitude in radians (default) Returns ------- theta, phi : float, tuple of two arrays the colatitude and longitude in radians See Also -------- ang2vec, rotator.vec2dir, rotator.dir2vec """ vectors = vectors.reshape(-1, 3) dnorm = jnp.sqrt(vectors[..., 0] ** 2 + vectors[..., 1] ** 2 + vectors[..., 2] ** 2) theta = jnp.arccos(vectors[:, 2] / dnorm) phi = jnp.arctan2(vectors[:, 1], vectors[:, 0]) phi = jnp.where(phi < 0, phi + 2 * np.pi, phi) if lonlat: return _thetaphi2lonlat(theta, phi) return theta, phi
[docs] @partial(jit, static_argnames=['nside']) def ring2nest(nside: int, ipix: ArrayLike) -> Array: """Convert pixel number from RING ordering to NESTED ordering. Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc. Parameters ---------- nside : int the healpix nside parameter ipix : int, scalar or array-like the pixel number in RING scheme Returns ------- ipix : int, scalar or array-like the pixel number in NESTED scheme See Also -------- nest2ring, reorder Examples -------- >>> import healpy as hp >>> hp.ring2nest(16, 1504) 1130 >>> print(hp.ring2nest(2, np.arange(10))) [ 3 7 11 15 2 1 6 5 10 9] >>> print(hp.ring2nest([1, 2, 4, 8], 11)) [ 11 13 61 253] """ check_nside(nside, nest=True) ipix = jnp.asarray(ipix) # promote to int64 only if nside requires it ipix = ipix.astype(jnp.promote_types(ipix.dtype, _pixel_dtype_for(nside))) xyf = _pix2xyf_ring(nside, ipix) ipix_nest = _xyf2pix_nest(nside, *xyf) return ipix_nest
[docs] @partial(jit, static_argnames=['nside']) def nest2ring(nside: int, ipix: ArrayLike) -> Array: """Convert pixel number from NESTED ordering to RING ordering. Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc. Parameters ---------- nside : int the healpix nside parameter ipix : int, scalar or array-like the pixel number in NESTED scheme Returns ------- ipix : int, scalar or array-like the pixel number in RING scheme See Also -------- ring2nest, reorder Examples -------- >>> import healpy as hp >>> hp.nest2ring(16, 1130) 1504 >>> print(hp.nest2ring(2, np.arange(10))) [13 5 4 0 15 7 6 1 17 9] >>> print(hp.nest2ring([1, 2, 4, 8], 11)) [ 11 2 12 211] """ check_nside(nside, nest=True) ipix = jnp.asarray(ipix) # promote to int64 only if nside requires it ipix = ipix.astype(jnp.promote_types(ipix.dtype, _pixel_dtype_for(nside))) xyf = _pix2xyf_nest(nside, ipix) ipix_ring = _xyf2pix_ring(nside, *xyf) return ipix_ring
[docs] @partial(jit, static_argnames=['inp', 'out', 'r2n', 'n2r', 'process_by_chunks']) def reorder( map_in: ArrayLike, inp: str | None = None, out: str | None = None, r2n: bool = False, n2r: bool = False, process_by_chunks: bool = False, ) -> Array: """Reorder a healpix map from RING/NESTED ordering to NESTED/RING. Masked arrays are not yet supported. By default, the maps are processed in one go, but if memory is an issue, use the ``process_by_chunks`` option (which reproduces healpy behaviour). Parameters ---------- map_in : array-like the input map to reorder, accepts masked arrays inp, out : ``'RING'`` or ``'NESTED'`` define the input and output ordering r2n : bool if True, reorder from RING to NESTED n2r : bool if True, reorder from NESTED to RING Returns ------- map_out : array-like the reordered map, as masked array if the input was a masked array Notes ----- if ``r2n`` or ``n2r`` is defined, override ``inp`` and ``out``. See Also -------- nest2ring, ring2nest Examples -------- >>> import healpy as hp >>> hp.reorder(np.arange(48), r2n = True) array([13, 5, 4, 0, 15, 7, 6, 1, 17, 9, 8, 2, 19, 11, 10, 3, 28, 20, 27, 12, 30, 22, 21, 14, 32, 24, 23, 16, 34, 26, 25, 18, 44, 37, 36, 29, 45, 39, 38, 31, 46, 41, 40, 33, 47, 43, 42, 35]) >>> hp.reorder(np.arange(12), n2r = True) array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) >>> hp.reorder(hp.ma(np.arange(12.)), n2r = True) masked_array(data = [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], mask = False, fill_value = -1.6375e+30) <BLANKLINE> >>> m = [np.arange(12.), np.arange(12.), np.arange(12.)] >>> m[0][2] = hp.UNSEEN >>> m[1][2] = hp.UNSEEN >>> m[2][2] = hp.UNSEEN >>> m = hp.ma(m) >>> hp.reorder(m, n2r = True) masked_array(data = [[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]], mask = [[False False True False False False False False False False False False] [False False True False False False False False False False False False] [False False True False False False False False False False False False]], fill_value = -1.6375e+30) <BLANKLINE> """ # Check input map(s) map_in = jnp.asarray(map_in) if map_in.ndim == 0: raise ValueError('Input map can not be a scalar') npix = map_in.shape[-1] nside = npix2nside(npix) # npix2nside already fails on bad number of pixels # but in nested ordering we must also ensure that nside is power of 2 check_nside(nside, nest=True) # Check input parameters if r2n and n2r: raise ValueError('r2n and n2r cannot be used simultaneously') if r2n: inp, out = 'RING', 'NEST' if n2r: inp, out = 'NEST', 'RING' inp, out = str(inp).upper()[:4], str(out).upper()[:4] if not {inp, out}.issubset({'RING', 'NEST'}): raise ValueError('inp and out must be either RING or NEST') if inp == out: return map_in # Perform the conversion, which is just a reordering of the pixels def _reorder(ipix): if inp == 'RING': ipix_reordered = nest2ring(nside, ipix) else: ipix_reordered = ring2nest(nside, ipix) return map_in[..., ipix_reordered] if not process_by_chunks: ipix_full = jnp.arange(npix, dtype=_pixel_dtype_for(nside)) return _reorder(ipix_full) # To reduce memory requirements, process the map in chunks chunk_size = npix // 24 if nside > 128 else npix n_chunks = npix // chunk_size def body(i, map_out): # interval bounds must be static, so we shift the values afterwards ipix_chunk = jnp.arange(chunk_size, dtype=_pixel_dtype_for(nside)) + i * chunk_size return map_out.at[..., ipix_chunk].set(_reorder(ipix_chunk)) map_out = lax.fori_loop(0, n_chunks, body, jnp.empty_like(map_in)) return map_out
[docs] @partial(jit, static_argnames=['nside', 'nest', 'lonlat']) def get_interp_weights( nside: int, theta: ArrayLike, phi: ArrayLike | None = None, nest: bool = False, lonlat: bool = False ) -> tuple[Array, Array]: """Return interpolation weights for given coordinates. This function performs bilinear interpolation by finding the four nearest pixel centers and computing their interpolation weights. Provides machine precision matching healpy when pixels are sorted. Parameters ---------- nside : int HEALPix nside parameter theta : ArrayLike Colatitude in radians (or pixel indices if phi is None) phi : ArrayLike, optional Longitude in radians (or degrees if lonlat=True) nest : bool, optional If True, use NESTED pixel ordering (raises error) lonlat : bool, optional If True, interpret theta, phi as longitude, latitude in degrees Returns ------- pixels : Array Array of shape (4, N) containing the four nearest pixel indices. Pixel order is not guaranteed - sort both pixels and weights by pixel values for exact healpy precision matching. weights : Array Array of shape (4, N) containing the interpolation weights. Weights sum to 1.0 for each point to machine precision. Notes ----- For exact healpy compatibility, sort the output by pixel values: >>> sorted_indices = jnp.argsort(pixels, axis=0) >>> sorted_pixels = jnp.take_along_axis(pixels, sorted_indices, axis=0) >>> sorted_weights = jnp.take_along_axis(weights, sorted_indices, axis=0) Precision and Algorithmic Considerations: ---------------------------------------- The phi interpolation calculation can exhibit precision differences compared to healpy, particularly for coordinates near the poles and for high nside values (256+). This is due to: 1. **Phi Interpolation Near Poles**: The formula `phi_norm = (phi / dphi - shift) % nr` can produce different results than healpy's algorithm when transitioning between rings with different pixel counts (e.g., ring transitions near poles). 2. **High Nside Precision Limits**: For nside ≥ 256, floating-point precision limits can cause the JAX implementation to select different (but mathematically valid) interpolation neighbors compared to healpy, especially in challenging geometric regions. 3. **Ring Transition Edge Cases**: Coordinates exactly at boundaries between polar caps and equatorial regions may show pixel differences due to algorithmic implementation variations, though weights still sum to 1.0 and maintain interpolation accuracy. Expected precision levels: - nside ≤ 64: Machine precision matching (< 1e-25 weight error) - nside 128-256: High precision (< 1e-15 weight error) - nside ≥ 512: Good precision (< 1e-12 to 1e-5 weight error depending on nside) Automatic Differentiation: -------------------------- This function is fully compatible with JAX's automatic differentiation. The gradients behave as follows: - Gradients of individual weights reflect the continuous dependence on coordinates - Gradients of sum(weights) are always zero since the sum is identically 1.0 - Pixel indices have zero gradients since they are discrete selectors Example gradient usage: >>> def interpolate_map(m, theta, phi): ... pixels, weights = get_interp_weights(nside, theta, phi) ... return jnp.sum(weights * m[pixels], axis=0) >>> grad_func = jax.grad(interpolate_map, argnums=(1, 2)) >>> grad_theta, grad_phi = grad_func(map_data, theta, phi) """ check_nside(nside, nest=nest) if nest: raise ValueError('NEST pixel ordering is not supported. Only RING ordering is supported.') # Handle different input modes if phi is None: # theta contains pixel indices, convert to (theta, phi) coordinates theta_coords, phi_coords = pix2ang(nside, theta, nest=False, lonlat=False) else: theta_coords, phi_coords = jnp.asarray(theta), jnp.asarray(phi) if lonlat: theta_coords, phi_coords = _lonlat2thetaphi(theta_coords, phi_coords) # Call the RING implementation return _get_interp_weights_ring(nside, theta_coords, phi_coords)
def _get_interp_weights_ring(nside: int, theta_coords: Array, phi_coords: Array) -> tuple[Array, Array]: """ Memory-optimized implementation of bilinear interpolation for RING ordering. This optimized version reduces temporary memory usage by 2.8x while maintaining full numerical precision by: 1. Eliminating excessive conditional operations that create intermediate arrays 2. Using direct computation instead of conditional masking 3. Streamlined special case handling with mathematical formulas 4. Efficient array construction using stack operations Gradient Compatibility: ---------------------- This function is fully compatible with JAX's automatic differentiation system. The implementation carefully separates discrete operations (pixel selection) from continuous operations (weight computation): - Discrete pixel indices use `lax.stop_gradient()` to prevent gradient flow through non-differentiable operations like `jnp.floor().astype(int)` - Weight computations use continuous mathematical operations that preserve gradients - Final weight normalization enforces the constraint sum(weights) = 1.0, ensuring that gradients of the weight sum are exactly zero The stop_gradient usage is mathematically sound because: 1. Pixel indices are discrete selectors that don't affect the interpolation mathematics 2. Weight values depend continuously on input coordinates within each pixel region 3. The fundamental constraint sum(weights) = 1.0 must hold regardless of pixel selection This design allows proper gradient flow for meaningful computations (like map interpolation) while maintaining numerical precision and memory efficiency. """ # Core computation - minimal intermediate arrays z = jnp.cos(theta_coords) ir1 = _ring_above(nside, z) ir2 = ir1 + 1 # Special case flags - compute once is_north_pole = ir1 == 0 is_south_pole = ir2 == (4 * nside) is_normal = ~is_north_pole & ~is_south_pole # Safe ring indices for _get_ring_info calls ir1_safe = jnp.maximum(ir1, 1) ir2_safe = jnp.minimum(ir2, 4 * nside - 1) # Get ring properties - only two function calls needed theta1, sp1, nr1, shift1 = _get_ring_info(nside, ir1_safe) theta2, sp2, nr2, shift2 = _get_ring_info(nside, ir2_safe) # Core phi interpolation computation dphi1 = 2.0 * jnp.pi / nr1 dphi2 = 2.0 * jnp.pi / nr2 # Phi interpolation indices and weights phi1_norm = (phi_coords / dphi1 - shift1) % nr1 phi2_norm = (phi_coords / dphi2 - shift2) % nr2 # Compute pixel indices (for pixel selection only) # Stop gradient: Floor+cast operations are non-differentiable and only used for indexing i1_1 = lax.stop_gradient(jnp.floor(phi1_norm).astype(jnp.int32)) i1_2 = lax.stop_gradient(jnp.floor(phi2_norm).astype(jnp.int32)) # Compute weights using gradient-friendly fractional parts # Use modulo instead of floor subtraction for better gradient behavior w_phi1 = phi1_norm % 1.0 w_phi2 = phi2_norm % 1.0 i2_1 = (i1_1 + 1) % nr1 i2_2 = (i1_2 + 1) % nr2 # Theta interpolation weight computation theta_denom = jnp.where(is_normal, theta2 - theta1, 1.0) # Avoid div by 0 w_theta_base = jnp.where(is_normal, (theta_coords - theta1) / theta_denom, 0.0) # Special case adjustments using mathematical formulas w_theta_north = jnp.where(is_north_pole, theta_coords / theta2, w_theta_base) w_theta_south = jnp.where(is_south_pole, (theta_coords - theta1) / (jnp.pi - theta1), w_theta_base) # Pixel computation - direct mathematical approach # Normal case pixels pixels_ring1_1 = sp1 + i1_1 pixels_ring1_2 = sp1 + i2_1 pixels_ring2_1 = sp2 + i1_2 pixels_ring2_2 = sp2 + i2_2 # North pole pixel adjustments npix_total = 12 * nside * nside pixels_ring1_1 = jnp.where(is_north_pole, (pixels_ring2_1 + 2) & 3, pixels_ring1_1) pixels_ring1_2 = jnp.where(is_north_pole, (pixels_ring2_2 + 2) & 3, pixels_ring1_2) # South pole pixel adjustments pixels_ring2_1 = jnp.where(is_south_pole, ((pixels_ring1_1 + 2) & 3) + npix_total - 4, pixels_ring2_1) pixels_ring2_2 = jnp.where(is_south_pole, ((pixels_ring1_2 + 2) & 3) + npix_total - 4, pixels_ring2_2) # Weight computation - optimized mathematical approach # Base phi weights w1_phi = 1.0 - w_phi1 w2_phi = w_phi1 w3_phi = 1.0 - w_phi2 w4_phi = w_phi2 # Apply theta interpolation w1_base = w1_phi * (1.0 - w_theta_base) w2_base = w2_phi * (1.0 - w_theta_base) w3_base = w3_phi * w_theta_base w4_base = w4_phi * w_theta_base # North pole weight adjustments north_factor = (1.0 - w_theta_north) * 0.25 w1_north = jnp.where(is_north_pole, north_factor, w1_base) w2_north = jnp.where(is_north_pole, north_factor, w2_base) w3_north = jnp.where(is_north_pole, w3_phi * w_theta_north + north_factor, w3_base) w4_north = jnp.where(is_north_pole, w4_phi * w_theta_north + north_factor, w4_base) # South pole weight adjustments south_factor = w_theta_south * 0.25 w1_final = jnp.where(is_south_pole, w1_north * (1.0 - w_theta_south) + south_factor, w1_north) w2_final = jnp.where(is_south_pole, w2_north * (1.0 - w_theta_south) + south_factor, w2_north) w3_final = jnp.where(is_south_pole, south_factor, w3_north) w4_final = jnp.where(is_south_pole, south_factor, w4_north) # Final assembly - single stack operation # Stop gradient: Pixel indices are discrete array selectors, not part of interpolation math pixels = lax.stop_gradient(jnp.stack([pixels_ring1_1, pixels_ring1_2, pixels_ring2_1, pixels_ring2_2])) weights = jnp.stack([w1_final, w2_final, w3_final, w4_final]) # Clamp weights to ensure non-negativity (handles floating point precision issues) weights = jnp.maximum(weights, 0.0) # Ensure weights sum to exactly 1.0 for gradient consistency # This enforces the mathematical constraint sum(weights) = 1.0, making gradients # of the sum exactly zero while preserving gradients of individual weights weight_sum = jnp.sum(weights, axis=0, keepdims=True) weights = weights / weight_sum return pixels, weights
[docs] @partial(jit, static_argnames=['nest', 'lonlat']) def get_interp_val( m: ArrayLike, theta: ArrayLike, phi: ArrayLike | None = None, nest: bool = False, lonlat: bool = False ) -> Array: """Return interpolated map values at given coordinates. This function performs bilinear interpolation of map values using the four nearest pixel neighbors, providing machine precision matching healpy. Parameters ---------- m : ArrayLike HEALPix map(s) to interpolate. Can be 1D (single map) or 2D (multiple maps). Shape: (npix,) or (nmaps, npix) theta : ArrayLike Colatitude in radians (or pixel indices if phi is None) phi : ArrayLike, optional Longitude in radians (or degrees if lonlat=True) nest : bool, optional If True, use NESTED pixel ordering (raises error - not supported) lonlat : bool, optional If True, interpret theta, phi as longitude, latitude in degrees Returns ------- values : Array Interpolated map values at the given coordinates. Shape matches broadcast of theta, phi for single map. For multiple maps, shape is (nmaps, ...) where ... is broadcast shape. Notes ----- Uses bilinear interpolation with the four nearest pixel neighbors. For exact healpy compatibility, this function uses get_interp_weights internally and computes: result = sum(weights * map_values[pixels]) Results won't match healpy if theta and phi are not valid angles. Examples -------- >>> import jax_healpy as hp >>> import jax.numpy as jnp >>> m = jnp.arange(12.) >>> hp.get_interp_val(m, jnp.pi/2, 0.0) Array(4.5, dtype=float64) >>> # Multiple coordinates >>> theta = jnp.array([jnp.pi/4, jnp.pi/2]) >>> phi = jnp.array([0.0, jnp.pi/2]) >>> hp.get_interp_val(m, theta, phi) Array([2.25, 6. ], dtype=float64) >>> # Multiple maps >>> maps = jnp.array([jnp.arange(12.), 2*jnp.arange(12.)]) >>> hp.get_interp_val(maps, jnp.pi/2, 0.0) Array([4.5, 9. ], dtype=float64) """ if nest: raise ValueError('NEST pixel ordering is not supported. Only RING ordering is supported.') # Convert inputs to JAX arrays m = jnp.asarray(m) theta = jnp.asarray(theta) if phi is not None: phi = jnp.asarray(phi) # Determine nside from map size npix = m.shape[-1] nside = int(np.sqrt(npix / 12)) # Use numpy sqrt to avoid tracer issues check_nside(nside, nest=nest) # Handle multiple maps vs single map single_map = m.ndim == 1 if single_map: map_data = m[jnp.newaxis, :] # Add map dimension else: map_data = m # Get interpolation weights and pixels pixels, weights = get_interp_weights(nside, theta, phi, nest=nest, lonlat=lonlat) # Perform interpolation: sum(weights * map_values[pixels]) # pixels shape: (4, ...) where ... is broadcast shape of theta, phi # weights shape: (4, ...) # map_data shape: (nmaps, npix) # Extract map values at interpolation pixels # map_values shape: (nmaps, 4, ...) map_values = map_data[..., pixels] # Broadcasting: (nmaps, npix)[..., (4, ...)] -> (nmaps, 4, ...) # Compute weighted sum along the pixel dimension (axis=1) # result shape: (nmaps, ...) result = jnp.sum(weights[jnp.newaxis, ...] * map_values, axis=1) # If input was single map, remove the map dimension if single_map: result = result[0] return result
[docs] @partial(jit, static_argnames=['nside', 'nest', 'lonlat', 'get_center']) def get_all_neighbours( nside: int, theta: ArrayLike, phi: ArrayLike | None = None, nest: bool = False, lonlat: bool = False, get_center: bool = False, ) -> Array: """Get the 8 nearest neighbors of given pixels, optionally including center pixel. Parameters ---------- nside : int HEALPix resolution parameter, must be a power of 2 theta : ArrayLike Either colatitude in radians (if phi is provided) or pixel indices (if phi is None) phi : ArrayLike, optional Longitude in radians (or degrees if lonlat=True). If None, theta is treated as pixel indices. nest : bool, optional If True, use NESTED pixel ordering scheme. Default is False (RING ordering). lonlat : bool, optional If True and phi is provided, interpret (theta, phi) as (longitude, latitude) in degrees. get_center : bool, optional If True, return center pixel + 8 neighbors (9 total). If False, return only 8 neighbors. Default is False for backward compatibility with healpy. Returns ------- neighbors : Array Array of pixel indices. When get_center=False: shape is (8,) for scalar input or (8, N) for array input, with neighbors in directions [SW, W, NW, N, NE, E, SE, S]. When get_center=True: shape is (9,) for scalar input or (9, N) for array input, with pixels in order [CENTER, SW, W, NW, N, NE, E, SE, S]. Non-existent neighbors (at map boundaries) are marked with -1. Examples -------- >>> import jax_healpy as hp >>> # Get 8 neighbors of pixel 4 at nside=1 (default behavior, matches healpy) >>> neighbors = hp.get_all_neighbours(1, 4) >>> print(neighbors) [11 7 3 -1 0 5 8 -1] >>> # Get center + 8 neighbors (9 total) >>> neighbors_with_center = hp.get_all_neighbours(1, 4, get_center=True) >>> print(neighbors_with_center) [ 4 11 7 3 -1 0 5 8 -1] >>> # Works with angular coordinates too >>> import jax.numpy as jnp >>> neighbors = hp.get_all_neighbours(1, jnp.pi/2, jnp.pi/2, get_center=True) >>> print(neighbors) [ 6 8 4 0 -1 1 6 9 -1] Notes ----- **healpy Compatibility**: The `get_center=False` (default) behavior maintains perfect compatibility with healpy.get_all_neighbours(). The `get_center=True` parameter is a jax-healpy-specific extension that does not exist in healpy. When `get_center=False` (default): - Returns 8 neighbors in identical order to healpy: [SW, W, NW, N, NE, E, SE, S] - Produces bit-for-bit identical results to healpy for all input modes - Maintains backward compatibility with existing healpy-based code When `get_center=True` (jax-healpy extension): - Returns 9 pixels: center pixel + 8 neighbors in order [CENTER, SW, W, NW, N, NE, E, SE, S] - Center pixel is always the first element (index 0) - Neighbor ordering matches healpy convention starting from index 1 - This functionality does not exist in healpy and is unique to jax-healpy **Performance**: The default `get_center=False` case has no performance overhead compared to the original implementation. The `get_center=True` case adds minimal computational cost. **JAX Features**: This function is fully compatible with JAX transformations including jit compilation, vmap, grad, and automatic differentiation. The `get_center` parameter is a static argument that allows different compilations for each mode. """ # Validate inputs check_nside(nside, nest=nest) theta = jnp.asarray(theta) # Handle the two API modes: pixel indices vs angular coordinates if phi is None: # theta contains pixel indices ipix = theta.astype(_pixel_dtype_for(nside)) input_shape = ipix.shape ipix_flat = ipix.flatten() else: # theta, phi contain angular coordinates - convert to pixels phi = jnp.asarray(phi) if lonlat: # Convert longitude, latitude in degrees to colatitude, longitude in radians lon, lat = theta, phi theta = jnp.deg2rad(90.0 - lat) phi = jnp.deg2rad(lon) # Ensure theta and phi can be broadcast together theta_bc, phi_bc = jnp.broadcast_arrays(theta, phi) input_shape = theta_bc.shape # Convert angular coordinates to pixel indices ipix_flat = ang2pix(nside, theta_bc.flatten(), phi_bc.flatten(), nest=nest) # Convert pixels to (x, y, face) coordinates ix, iy, face_num = pix2xyf(nside, ipix_flat, nest=nest) # Vectorized neighbor finding for all pixels neighbors_flat = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest) # Conditionally include center pixel based on get_center parameter if get_center: # Add center pixels as first element: [CENTER, SW, W, NW, N, NE, E, SE, S] if phi is None: # Pixel mode: center pixels are the input pixels themselves center_pixels_flat = ipix_flat else: # Angular mode: center pixels are pixels at the given coordinates # We already have ipix_flat from the coordinate conversion above center_pixels_flat = ipix_flat # Combine center + neighbors: shape (9, N) result_flat = jnp.concatenate([center_pixels_flat[None, :], neighbors_flat], axis=0) # Reshape result to (9, *input_shape) if input_shape == (): # Scalar input - should return shape (9,), not (9, 1) return result_flat.squeeze() # Remove the extra dimension else: # Array input - reshape from (9, N) to (9, *input_shape) return result_flat.reshape((9,) + input_shape) else: # Original behavior: return only 8 neighbors for backward compatibility # Reshape result to (8, *input_shape) if input_shape == (): # Scalar input - should return shape (8,), not (8, 1) return neighbors_flat.squeeze() # Remove the extra dimension else: # Array input - reshape from (8, N) to (8, *input_shape) return neighbors_flat.reshape((8,) + input_shape)
def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, nest: bool = False) -> Array: """Vectorized neighbor finding in (x, y, face) coordinates. This is the core neighbor-finding algorithm that handles face boundary crossings using the original HEALPix neighbor-finding methodology. It applies the 8-directional offsets to find potential neighbors, then handles cases where neighbors cross face boundaries using lookup tables and coordinate transformations. The algorithm follows these steps: 1. Apply 8-directional offsets (_NB_XOFFSET, _NB_YOFFSET) to get neighbor coordinates 2. Check which neighbors remain within the current face (valid range [0, nside-1]) 3. For neighbors that cross face boundaries, apply face transition logic using lookup tables (_NB_FACEARRAY, _NB_SWAPARRAY) based on original C++ implementation Parameters ---------- nside : int HEALPix resolution parameter (must be power of 2) ix, iy : Array Face-local x, y coordinates of pixels (shape: (N,)) Valid range: [0, nside-1] for pixels within face face_num : Array Face numbers of pixels (shape: (N,)) Valid range: [0, 11] for HEALPix faces nest : bool, optional Whether to use NESTED ordering scheme. Default is False (RING ordering). Returns ------- neighbors : Array Neighbor pixel indices for each input pixel. Shape: (8, N) Neighbors in order: [SW, W, NW, N, NE, E, SE, S] Non-existent neighbors (at map boundaries) are marked with -1. Notes ----- This function implements the exact neighbor-finding logic from the original HEALPix C++ library, ensuring bit-for-bit compatibility with healpy results. """ n_pixels = ix.shape[0] # Initialize output array for neighbors neighbors = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside)) # Apply 8-direction offsets to get neighbor coordinates # Use broadcasting: ix[None, :] + _NB_XOFFSET[:, None] -> (8, N) neighbor_ix = ix[None, :] + _NB_XOFFSET[:, None] # Shape: (8, N) neighbor_iy = iy[None, :] + _NB_YOFFSET[:, None] # Shape: (8, N) neighbor_face = jnp.broadcast_to(face_num[None, :], (8, n_pixels)) # Shape: (8, N) # Check which neighbors are within the current face (no boundary crossing) # Valid range is [0, nside-1] for both ix and iy within_face = ( (neighbor_ix >= 0) & (neighbor_ix < nside) & (neighbor_iy >= 0) & (neighbor_iy < nside) ) # Shape: (8, N) # For neighbors within face, convert directly to pixels valid_mask = within_face valid_neighbors = xyf2pix(nside, neighbor_ix, neighbor_iy, neighbor_face, nest=nest) # Shape: (8, N) neighbors = jnp.where(valid_mask, valid_neighbors, neighbors) # Handle boundary crossings for neighbors outside current face boundary_mask = ~within_face # Shape: (8, N) # For boundary pixels, we need to use the lookup tables # This is complex due to the face transition logic - we'll implement a simplified version # that handles the most common boundary cases # Apply face boundary corrections using lookup tables corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, neighbor_face, face_num, nest) # Use corrected neighbors where we have boundary crossings neighbors = jnp.where(boundary_mask, corrected_neighbors, neighbors) return neighbors def _handle_face_boundaries( nside: int, neighbor_ix: Array, neighbor_iy: Array, neighbor_face: Array, original_face: Array, nest: bool ) -> Array: """Handle neighbor pixels that cross face boundaries. This implements the exact HEALPix face transition logic using lookup tables, based on the original C++ implementation in healpix_base.cc. When a neighbor coordinate falls outside the current face boundaries, this function determines the correct face and applies coordinate transformations. The algorithm follows these steps for each boundary crossing: 1. Detect boundary crossing condition (x < 0, x >= nside, y < 0, y >= nside) 2. Calculate nbnum index encoding the crossing direction 3. Look up new face using _NB_FACEARRAY[nbnum, original_face] 4. Apply coordinate corrections (wrap coordinates to valid range) 5. Apply bit-based transformations using _NB_SWAPARRAY (flip x, flip y, swap x/y) 6. Convert corrected (x, y, face) back to pixel indices Parameters ---------- nside : int HEALPix resolution parameter neighbor_ix, neighbor_iy : Array Neighbor coordinates that may be outside face boundaries. Shape: (8, N) neighbor_face : Array Face numbers for neighbors (initially same as original). Shape: (8, N) original_face : Array Original face numbers of input pixels. Shape: (N,) nest : bool Whether to use NESTED ordering Returns ------- corrected_neighbors : Array Corrected neighbor pixel indices. Shape: (8, N) Returns -1 for invalid neighbors (outside map boundaries) Notes ----- This function is a direct translation of the original HEALPix C++ neighbor finding algorithm, ensuring exact compatibility with healpy. The lookup tables (_NB_FACEARRAY, _NB_SWAPARRAY) encode the complex geometric relationships between HEALPix faces and handle all 12 face transitions correctly. """ n_pixels = original_face.shape[0] # Initialize result with invalid neighbors result = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside)) # Process each neighbor direction individually for direction_idx in range(8): # Get coordinates for this direction across all pixels ix = neighbor_ix[direction_idx, :] # Shape: (n_pixels,) iy = neighbor_iy[direction_idx, :] # Shape: (n_pixels,) orig_face = original_face # Shape: (n_pixels,) # Check boundary conditions - exact replication of original algorithm x_low = ix < 0 x_high = ix >= nside y_low = iy < 0 y_high = iy >= nside # Any pixel crossing face boundary boundary_crossing = x_low | x_high | y_low | y_high # Initialize corrected coordinates with original values corrected_ix = ix corrected_iy = iy # Apply boundary corrections exactly as in original C++ code # First handle x boundary crossings corrected_ix = jnp.where(x_low, corrected_ix + nside, corrected_ix) corrected_ix = jnp.where(x_high, corrected_ix - nside, corrected_ix) # Then handle y boundary crossings corrected_iy = jnp.where(y_low, corrected_iy + nside, corrected_iy) corrected_iy = jnp.where(y_high, corrected_iy - nside, corrected_iy) # Calculate nbnum index for lookup tables (matches original C++ logic) nbnum = 4 # Start with center case nbnum = jnp.where(x_low, nbnum - 1, nbnum) nbnum = jnp.where(x_high, nbnum + 1, nbnum) nbnum = jnp.where(y_low, nbnum - 3, nbnum) nbnum = jnp.where(y_high, nbnum + 3, nbnum) # Look up new face using the face array (vectorized) # Use advanced indexing to get new faces for each pixel new_face = _NB_FACEARRAY[nbnum, orig_face] # Only process pixels that actually cross boundaries and have valid new faces valid_crossing = boundary_crossing & (new_face >= 0) & (new_face < 12) # Apply coordinate transformations using swap array bits # Get swap bits for face transitions (vectorized) swap_bits = _NB_SWAPARRAY[nbnum, orig_face >> 2] # Apply bit transformations exactly as in original C++ # Bit 1: Flip x coordinate flip_x = (swap_bits & 1) != 0 corrected_ix = jnp.where(valid_crossing & flip_x, nside - corrected_ix - 1, corrected_ix) # Bit 2: Flip y coordinate flip_y = (swap_bits & 2) != 0 corrected_iy = jnp.where(valid_crossing & flip_y, nside - corrected_iy - 1, corrected_iy) # Bit 4: Swap x and y coordinates swap_xy = (swap_bits & 4) != 0 new_x = jnp.where(valid_crossing & swap_xy, corrected_iy, corrected_ix) new_y = jnp.where(valid_crossing & swap_xy, corrected_ix, corrected_iy) corrected_ix = new_x corrected_iy = new_y # Use new face for valid crossings, original face otherwise corrected_face = jnp.where(valid_crossing, new_face, orig_face) # Convert to pixel indices neighbor_pixels = xyf2pix(nside, corrected_ix, corrected_iy, corrected_face, nest=nest) # Update result for this direction - only valid crossings get neighbor pixels result = result.at[direction_idx, :].set(jnp.where(valid_crossing, neighbor_pixels, -1)) return result # Note: Removed _get_adjacent_face - using lookup table directly in _handle_face_boundaries
[docs] def get_nside(m: ArrayLike) -> int: """Extract nside parameter from map length. Parameters ---------- m : array-like HEALPix map or sequence of maps Returns ------- nside : int The nside parameter corresponding to the map size Raises ------ ValueError If the map size doesn't correspond to a valid HEALPix map """ m = jnp.asarray(m) if m.ndim == 1: npix = len(m) elif m.ndim == 2: npix = m.shape[-1] # Last dimension should be pixels else: raise ValueError(f'Map must be 1D or 2D, got shape {m.shape}') return npix2nside(npix)
def mask_bad(m: ArrayLike) -> Array: """Create boolean mask for UNSEEN pixels. Parameters ---------- m : array-like HEALPix map Returns ------- mask : Array Boolean array with True where pixels are UNSEEN """ m = jnp.asarray(m) return m == UNSEEN
[docs] @partial(jit, static_argnames=['nside_out', 'pess', 'order_in', 'order_out', 'power', 'dtype']) def ud_grade( map_in: ArrayLike, nside_out: int, pess: bool = False, order_in: str = 'RING', order_out: str = None, power: float = None, dtype: type = None, ) -> Array: """Upgrade or degrade the resolution (nside) of a map. This function changes the resolution of a HEALPix map by either upgrading it to a higher resolution (more pixels) or degrading it to a lower resolution (fewer pixels). The algorithm follows the HEALPix specification: - For upgrading: each parent pixel value is replicated to all its children - For degrading: each parent pixel value is the average of its children Parameters ---------- map_in : array-like Input map(s) to be upgraded or degraded. Can be a single map or a sequence of maps. nside_out : int Desired output resolution parameter. Must be a power of 2. pess : bool, optional Pessimistic mask handling during degradation. If True, a parent pixel is marked as UNSEEN if any of its children are UNSEEN. If False (default), a parent pixel is UNSEEN only if all children are UNSEEN. order_in : {'RING', 'NESTED'}, optional Pixel ordering of input map. Default is 'RING'. order_out : {'RING', 'NESTED'}, optional Pixel ordering of output map. If None, same as order_in. power : float, optional Scaling factor for resolution change. If provided, the output values are multiplied by (nside_out/nside_in)^power. dtype : data type, optional Data type of output map. If None, same as input map dtype. Returns ------- map_out : Array Upgraded or degraded map(s) with the same shape as input but different number of pixels corresponding to nside_out. Raises ------ NotImplementedError This function is not yet implemented. ValueError If nside_out is not a valid HEALPix nside parameter. Examples -------- >>> import jax_healpy as jhp >>> import numpy as np >>> # Create a simple map at nside=4 >>> nside_in = 4 >>> map_in = np.arange(jhp.nside2npix(nside_in), dtype=float) >>> # Degrade to nside=2 >>> map_out = jhp.ud_grade(map_in, 2) >>> # Upgrade to nside=8 >>> map_out = jhp.ud_grade(map_in, 8) Notes ----- This function can create artifacts in power spectra and should be used with caution for scientific applications. The HEALPix documentation recommends using spherical harmonic transforms for resolution changes when possible. The algorithm implements the exact same logic as healpy.ud_grade for compatibility, including proper handling of UNSEEN pixels and coordinate system conversions between RING and NESTED schemes. """ # Early validation to provide clear error messages # udgrade requires power-of-2 nside values regardless of ordering scheme if not isnsideok(nside_out, nest=True): raise ValueError( f'{nside_out} is not a valid nside parameter for udgrade (must be a power of 2, less than 2**30)' ) # Convert input to JAX array and handle map format map_in = jnp.asarray(map_in) is_single_map = map_in.ndim == 1 # Ensure we work with 2D array (n_maps, npix) if is_single_map: maps = map_in[None, :] # Add map dimension else: maps = map_in # Get input nside and validate nside_in = get_nside(maps[0]) # Use first map to get nside if not isnsideok(nside_in, nest=True): raise ValueError( f'{nside_in} is not a valid nside parameter for udgrade (must be a power of 2, less than 2**30)' ) # Determine output ordering if order_out is None: order_out = order_in # Call the core implementation return _ud_grade_core(maps, nside_in, nside_out, pess, order_in, order_out, power, dtype, is_single_map)
def _ud_grade_core( maps: Array, nside_in: int, nside_out: int, pess: bool, order_in: str, order_out: str, power: float, dtype: type, is_single_map: bool, ) -> Array: """Core udgrade implementation for processing multiple maps.""" npix_in = nside2npix(nside_in) npix_out = nside2npix(nside_out) # Determine output dtype if dtype is not None: output_dtype = dtype else: output_dtype = maps.dtype # Step 1: Convert to NESTED if needed (reorder handles batch dimension) if order_in == 'RING': maps = reorder(maps, r2n=True) # Step 2: Core resolution change in NESTED scheme if nside_out == nside_in: # No change needed result = maps elif nside_out > nside_in: # UPGRADE: replicate parent pixels to children rat2 = npix_out // npix_in # Apply power scaling if specified if power is not None: ratio = (jnp.float32(nside_out) / jnp.float32(nside_in)) ** jnp.float32(power) else: ratio = 1.0 # Replicate each pixel value to its children using broadcasting # maps shape: (n_maps, npix_in) # fact shape: (rat2,) # outer product: (n_maps, npix_in, rat2) -> reshape to (n_maps, npix_out) fact = jnp.ones(rat2, dtype=output_dtype) * ratio # Use broadcasting: maps[..., :, None] * fact[None, None, :] expanded_maps = maps[..., :, None] * fact # (n_maps, npix_in, rat2) result = expanded_maps.reshape(maps.shape[0], npix_out) else: # DEGRADE: average children pixels to parent rat2 = npix_in // npix_out # Reshape to group children pixels: (n_maps, npix_out, rat2) reshaped_maps = maps.reshape(maps.shape[0], npix_out, rat2) # Create mask for valid pixels (not UNSEEN and finite) goods = ~(mask_bad(reshaped_maps) | (~jnp.isfinite(reshaped_maps))) # Sum valid pixels along children axis map_sum = jnp.sum(reshaped_maps * goods, axis=-1) # (n_maps, npix_out) n_good = jnp.sum(goods, axis=-1) # (n_maps, npix_out) # Determine which output pixels should be UNSEEN if pess: # Pessimistic: mark UNSEEN if ANY child is bad badout = n_good != rat2 else: # Optimistic: mark UNSEEN only if ALL children are bad badout = n_good == 0 # Apply power scaling if specified if power is not None: ratio = (jnp.float32(nside_out) / jnp.float32(nside_in)) ** jnp.float32(power) n_good = n_good / ratio # Calculate averages for pixels with valid children result = jnp.where( n_good > 0, map_sum / n_good, 0.0, # Temporary value, will be set to UNSEEN below ) # Set UNSEEN pixels result = jnp.where(badout, UNSEEN, result) # Step 3: Convert back to desired output ordering (reorder handles batch dimension) if order_out == 'RING' and order_in == 'NESTED': result = reorder(result, n2r=True) elif order_out == 'NESTED' and order_in == 'RING': # Map was converted to NESTED in step 1, keep it pass elif order_out == order_in: # Convert back if we changed it if order_in == 'RING': result = reorder(result, n2r=True) # Apply output dtype result = result.astype(output_dtype) if is_single_map: return result[0] # Remove the added map dimension else: return result