# This file is part of jax-healpy.
# Copyright (C) 2024 CNRS / SciPol developers
#
# jax-healpy is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# jax-healpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with jax-healpy. If not, see <https://www.gnu.org/licenses/>.
"""
=====================================================
pixelfunc.py : Healpix pixelization related functions
=====================================================
This module provides functions related to Healpix pixelization scheme.
conversion from/to sky coordinates
----------------------------------
- :func:`pix2ang` converts pixel number to angular coordinates
- :func:`pix2vec` converts pixel number to unit 3-vector direction
- :func:`ang2pix` converts angular coordinates to pixel number
- :func:`vec2pix` converts 3-vector to pixel number
- :func:`vec2ang` converts 3-vector to angular coordinates
- :func:`ang2vec` converts angular coordinates to unit 3-vector
- :func:`pix2xyf` converts pixel number to coordinates within face
- :func:`xyf2pix` converts coordinates within face to pixel number
- :func:`get_interp_weights` returns the 4 nearest pixels for given
angular coordinates and the relative weights for interpolation
- :func:`get_all_neighbours` return the 8 nearest pixels for given
angular coordinates (or optionally 9 pixels including center with get_center=True)
conversion between NESTED and RING schemes
------------------------------------------
- :func:`nest2ring` converts NESTED scheme pixel numbers to RING
scheme pixel number
- :func:`ring2nest` converts RING scheme pixel number to NESTED
scheme pixel number
- :func:`reorder` reorders a healpix map pixels from one scheme to another
nside/npix/resolution
---------------------
- :func:`nside2npix` converts healpix nside parameter to number of pixel
- :func:`npix2nside` converts number of pixel to healpix nside parameter
- :func:`nside2order` converts nside to order
- :func:`order2nside` converts order to nside
- :func:`nside2resol` converts nside to mean angular resolution
- :func:`nside2pixarea` converts nside to pixel area
- :func:`isnsideok` checks the validity of nside
- :func:`isnpixok` checks the validity of npix
- :func:`get_map_size` gives the number of pixel of a map
- :func:`get_min_valid_nside` gives the minimum nside possible for a given
number of pixel
- :func:`get_nside` returns the nside of a map
- :func:`maptype` checks the type of a map (one map or sequence of maps)
- :func:`ud_grade` upgrades or degrades the resolution (nside) of a map
Masking pixels
--------------
- :const:`UNSEEN` is a constant value interpreted as a masked pixel
- :func:`mask_bad` returns a map with ``True`` where map is :const:`UNSEEN`
- :func:`mask_good` returns a map with ``False`` where map is :const:`UNSEEN`
- :func:`ma` returns a masked array as map, with mask given by :func:`mask_bad`
Map data manipulation
---------------------
- :func:`fit_dipole` fits a monopole+dipole on the map
- :func:`fit_monopole` fits a monopole on the map
- :func:`remove_dipole` fits and removes a monopole+dipole from the map
- :func:`remove_monopole` fits and remove a monopole from the map
- :func:`get_interp_val` computes a bilinear interpolation of the map
at given angular coordinates, using 4 nearest neighbours
"""
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, lax, vmap
from jaxtyping import Array, ArrayLike
__all__ = [
'pix2ang',
'ang2pix',
'pix2xyf',
'xyf2pix',
'pix2vec',
'vec2pix',
'ang2vec',
'vec2ang',
'get_interp_weights',
'get_interp_val',
'get_all_neighbours',
# 'max_pixrad',
'nest2ring',
'ring2nest',
'reorder',
'ud_grade',
'UNSEEN',
# 'mask_good',
# 'mask_bad',
# 'ma',
# 'fit_dipole',
# 'remove_dipole',
# 'fit_monopole',
# 'remove_monopole',
'nside2npix',
'npix2nside',
'nside2order',
'order2nside',
'order2npix',
'npix2order',
'nside2resol',
'nside2pixarea',
'isnsideok',
'isnpixok',
# 'get_map_size',
# 'get_min_valid_nside',
'get_nside',
'maptype',
# 'ma_to_array',
]
# We are using 64-bit integer types.
# nside > 2**29 requires extended integer types.
MAX_NSIDE = 1 << 29
UNSEEN = -1.6375e30
# HEALPix neighbor finding constants
# These constants implement the exact neighbor-finding algorithm from the original
# HEALPix C++ library (healpix_base.cc) for face boundary transitions
# 8-element offset arrays for x and y directions (SW, W, NW, N, NE, E, SE, S)
# These define the relative positions of the 8 neighbors around any pixel
_NB_XOFFSET = jnp.array([-1, -1, 0, 1, 1, 1, 0, -1], dtype=jnp.int32)
_NB_YOFFSET = jnp.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=jnp.int32)
# Face boundary lookup table (9x12) - handles face transitions for neighbors
# This lookup table maps (nbnum, face) -> new_face when neighbors cross face boundaries
# nbnum encodes the boundary crossing direction, face is the original face (0-11)
# Based on original HEALPix C++ implementation's neighbor finding algorithm
_NB_FACEARRAY = jnp.array(
[
[8, 9, 10, 11, -1, -1, -1, -1, 10, 11, 8, 9], # S
[5, 6, 7, 4, 8, 9, 10, 11, 9, 10, 11, 8], # SE
[-1, -1, -1, -1, 5, 6, 7, 4, -1, -1, -1, -1], # E
[4, 5, 6, 7, 11, 8, 9, 10, 11, 8, 9, 10], # SW
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # center
[1, 2, 3, 0, 0, 1, 2, 3, 5, 6, 7, 4], # NE
[-1, -1, -1, -1, 7, 4, 5, 6, -1, -1, -1, -1], # W
[3, 0, 1, 2, 3, 0, 1, 2, 4, 5, 6, 7], # NW
[2, 3, 0, 1, -1, -1, -1, -1, 0, 1, 2, 3], # N
],
dtype=jnp.int32,
)
# Coordinate transformation bits (9x3) - handles x/y swapping across face boundaries
# This lookup table provides bit flags for coordinate transformations when crossing faces
# Bit 1: flip x coordinate, Bit 2: flip y coordinate, Bit 4: swap x and y coordinates
# Index by (nbnum, face >> 2) to get transformation bits for the boundary crossing
_NB_SWAPARRAY = jnp.array(
[
[0, 0, 3], # S
[0, 0, 6], # SE
[0, 0, 0], # E
[0, 0, 5], # SW
[0, 0, 0], # center
[5, 0, 0], # NE
[0, 0, 0], # W
[6, 0, 0], # NW
[3, 0, 0], # N
],
dtype=jnp.int32,
)
def check_theta_valid(theta):
"""
JIT Compatible check_theta_valid
Raises exception if theta is not within 0 and pi
"""
invalid_theta = ~((theta >= 0).all() & (theta <= np.pi + 1e-5).all())
def _raise_invalid_theta(invalid_theta):
if invalid_theta:
raise ValueError('THETA is out of range [0,pi]')
jax.debug.callback(_raise_invalid_theta, invalid_theta)
def check_nside(nside: int, nest: bool = False) -> None:
"""Raises exception is nside is not valid"""
if not np.all(isnsideok(nside, nest=nest)):
raise ValueError(f'{nside} is not a valid nside parameter (must be a power of 2, less than 2**30)')
def _pixel_dtype_for(nside: int) -> jnp.dtype:
"""Returns the appropriate dtype for a pixel number given nside"""
# for nside = 13378, npix = 2_147_650_608 which would overflow int32
return jnp.int32 if nside <= 13377 else jnp.int64
[docs]
def isnsideok(nside: int, nest: bool = False) -> bool:
"""Returns :const:`True` if nside is a valid nside parameter, :const:`False` otherwise.
NSIDE needs to be a power of 2 only for nested ordering
Parameters
----------
nside : int, scalar or array-like
integer value to be tested
Returns
-------
ok : bool, scalar or array-like
:const:`True` if given value is a valid nside, :const:`False` otherwise.
Examples
--------
>>> import jax_healpy as hp
>>> hp.isnsideok(13, nest=True)
False
>>> hp.isnsideok(13, nest=False)
True
>>> hp.isnsideok(32)
True
>>> hp.isnsideok([1, 2, 3, 4, 8, 16], nest=True)
array([ True, True, False, True, True, True], dtype=bool)
"""
# we use standard bithacks from http://graphics.stanford.edu/~seander/bithacks.html#DetermineIfPowerOf2
if hasattr(nside, '__len__'):
if not isinstance(nside, np.ndarray):
nside = np.asarray(nside)
is_nside_ok = (nside == nside.astype(int)) & (nside > 0) & (nside <= MAX_NSIDE)
if nest:
is_nside_ok &= (nside.astype(int) & (nside.astype(int) - 1)) == 0
else:
is_nside_ok = nside == int(nside) and 0 < nside <= MAX_NSIDE
if nest:
is_nside_ok = is_nside_ok and (int(nside) & (int(nside) - 1)) == 0
return is_nside_ok
[docs]
def isnpixok(npix: int) -> bool:
"""Return :const:`True` if npix is a valid value for healpix map size, :const:`False` otherwise.
Parameters
----------
npix : int, scalar or array-like
integer value to be tested
Returns
-------
ok : bool, scalar or array-like
:const:`True` if given value is a valid number of pixel, :const:`False` otherwise
Examples
--------
>>> import jax_healpy as hp
>>> hp.isnpixok(12)
True
>>> hp.isnpixok(768)
True
>>> hp.isnpixok([12, 768, 1002])
array([ True, True, False], dtype=bool)
"""
nside = np.sqrt(np.asarray(npix) / 12.0)
return nside == np.floor(nside)
[docs]
def nside2npix(nside: int) -> int:
"""Give the number of pixels for the given nside.
Parameters
----------
nside : int
healpix nside parameter
Returns
-------
npix : int
corresponding number of pixels
Examples
--------
>>> import jax_healpy as hp
>>> import numpy as np
>>> hp.nside2npix(8)
768
>>> np.all([hp.nside2npix(nside) == 12 * nside**2 for nside in [2**n for n in range(12)]])
True
>>> hp.nside2npix(7)
588
"""
return 12 * nside * nside
[docs]
def npix2nside(npix: int) -> int:
"""Give the nside parameter for the given number of pixels.
Parameters
----------
npix : int
the number of pixels
Returns
-------
nside : int
the nside parameter corresponding to npix
Notes
-----
Raise a ValueError exception if number of pixel does not correspond to
the number of pixel of a healpix map.
Examples
--------
>>> import jax_healpy as hp
>>> hp.npix2nside(768)
8
>>> np.all([hp.npix2nside(12 * nside**2) == nside for nside in [2**n for n in range(12)]])
True
>>> hp.npix2nside(1000)
Traceback (most recent call last):
...
ValueError: Wrong pixel number (it is not 12*nside**2)
"""
if not isnpixok(npix):
raise ValueError('Wrong pixel number (it is not 12*nside**2)')
return int(np.sqrt(npix / 12.0))
[docs]
def nside2order(nside: int) -> int:
"""Give the resolution order for a given nside.
Parameters
----------
nside : int
healpix nside parameter; an exception is raised if nside is not valid
(nside must be a power of 2, less than 2**30)
Returns
-------
order : int
corresponding order where nside = 2**(order)
Notes
-----
Raise a ValueError exception if nside is not valid.
Examples
--------
>>> import jax_healpy as hp
>>> import numpy as np
>>> hp.nside2order(128)
7
>>> all(hp.nside2order(2**o) == o for o in range(30))
True
>>> hp.nside2order(7)
Traceback (most recent call last):
...
ValueError: 7 is not a valid nside parameter (must be a power of 2, less than 2**30)
"""
check_nside(nside, nest=True)
return len(f'{nside:b}') - 1
[docs]
def order2nside(order: int) -> int:
"""Give the nside parameter for the given resolution order.
Parameters
----------
order : int
the resolution order
Returns
-------
nside : int
the nside parameter corresponding to order
Notes
-----
Raise a ValueError exception if order produces an nside out of range.
Examples
--------
>>> import jax_healpy as hp
>>> hp.order2nside(7)
128
>>> print(hp.order2nside(np.arange(8)))
[ 1 2 4 8 16 32 64 128]
>>> hp.order2nside(31)
Traceback (most recent call last):
...
ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
"""
nside = 1 << order
check_nside(nside, nest=True)
return nside
[docs]
def order2npix(order: int) -> int:
"""Give the number of pixels for the given resolution order.
Parameters
----------
order : int
the resolution order
Returns
-------
npix : int
corresponding number of pixels
Notes
-----
A convenience function that successively applies order2nside then nside2npix to order.
Examples
--------
>>> import jax_healpy as hp
>>> hp.order2npix(7)
196608
>>> print(hp.order2npix(np.arange(8)))
[ 12 48 192 768 3072 12288 49152 196608]
>>> hp.order2npix(31)
Traceback (most recent call last):
...
ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
"""
nside = order2nside(order)
npix = nside2npix(nside)
return npix
[docs]
def npix2order(npix: int) -> int:
"""Give the resolution order for the given number of pixels.
Parameters
----------
npix : int
the number of pixels
Returns
-------
order : int
corresponding resolution order
Notes
-----
A convenience function that successively applies npix2nside then nside2order to npix.
Examples
--------
>>> import jax_healpy as hp
>>> hp.npix2order(768)
3
>>> np.all([hp.npix2order(12 * 4**order) == order for order in range(12)])
True
>>> hp.npix2order(1000)
Traceback (most recent call last):
...
ValueError: Wrong pixel number (it is not 12*nside**2)
"""
nside = npix2nside(npix)
order = nside2order(nside)
return order
[docs]
def nside2resol(nside: int, arcmin=False) -> float:
"""Give approximate resolution (pixel size in radian or arcmin) for nside.
Resolution is just the square root of the pixel area, which is a gross
approximation given the different pixel shapes
Parameters
----------
nside : int
healpix nside parameter, must be a power of 2, less than 2**30
arcmin : bool
if True, return resolution in arcmin, otherwise in radian
Returns
-------
resol : float
approximate pixel size in radians or arcmin
Notes
-----
Raise a ValueError exception if nside is not valid.
Examples
--------
>>> import jax_healpy as hp
>>> hp.nside2resol(128, arcmin = True) # doctest: +FLOAT_CMP
27.483891294539248
>>> hp.nside2resol(256)
0.0039973699529159707
>>> hp.nside2resol(7)
0.1461895297066412
"""
resol = np.sqrt(nside2pixarea(nside))
if arcmin:
resol = np.rad2deg(resol) * 60
return resol
[docs]
def nside2pixarea(nside: int, degrees=False) -> float:
"""Give pixel area given nside in square radians or square degrees.
Parameters
----------
nside : int
healpix nside parameter, must be a power of 2, less than 2**30
degrees : bool
if True, returns pixel area in square degrees, in square radians otherwise
Returns
-------
pixarea : float
pixel area in square radian or square degree
Notes
-----
Raise a ValueError exception if nside is not valid.
Examples
--------
>>> import jax_healpy as hp
>>> hp.nside2pixarea(128, degrees = True) # doctest: +FLOAT_CMP
0.2098234113027917
>>> hp.nside2pixarea(256)
1.5978966540475428e-05
>>> hp.nside2pixarea(7)
0.021371378595848933
"""
pixarea = 4 * np.pi / nside2npix(nside)
if degrees:
pixarea = np.rad2deg(np.rad2deg(pixarea))
return pixarea
def _lonlat2thetaphi(lon: ArrayLike, lat: ArrayLike):
"""Transform longitude and latitude (deg) into co-latitude and longitude (rad)
Parameters
----------
lon : int or array-like
Longitude in degrees
lat : int or array-like
Latitude in degrees
Returns
-------
theta, phi : float, scalar or array-like
The co-latitude and longitude in radians
"""
return np.pi / 2 - jnp.radians(lat), jnp.radians(lon)
def _thetaphi2lonlat(theta, phi):
"""Transform co-latitude and longitude (rad) into longitude and latitude (deg)
Parameters
----------
theta : int or array-like
Co-latitude in radians
phi : int or array-like
Longitude in radians
Returns
-------
lon, lat : float, scalar or array-like
The longitude and latitude in degrees
"""
return jnp.degrees(phi), 90.0 - jnp.degrees(theta)
[docs]
def maptype(m):
"""Describe the type of the map (valid, single, sequence of maps).
Checks : the number of maps, that all maps have same length and that this
length is a valid map size (using :func:`isnpixok`).
Parameters
----------
m : sequence
the map to get info from
Returns
-------
info : int
-1 if the given object is not a valid map, 0 if it is a single map,
*info* > 0 if it is a sequence of maps (*info* is then the number of
maps)
Examples
--------
>>> import healpy as hp
>>> hp.pixelfunc.maptype(np.arange(12))
0
>>> hp.pixelfunc.maptype([np.arange(12), np.arange(12)])
2
"""
if not hasattr(m, '__len__'):
raise TypeError('input map is a scalar')
if len(m) == 0:
raise TypeError('input map has length zero')
try:
npix = len(m[0])
except TypeError:
npix = None
if npix is not None:
for mm in m[1:]:
if len(mm) != npix:
raise TypeError('input maps have different npix')
if isnpixok(len(m[0])):
return len(m)
else:
raise TypeError('bad number of pixels')
else:
if isnpixok(len(m)):
return 0
else:
raise TypeError('bad number of pixels')
[docs]
@partial(jit, static_argnames=['nside', 'nest', 'lonlat'])
def ang2pix(
nside: int,
theta: ArrayLike,
phi: ArrayLike,
nest: bool = False,
lonlat: bool = False,
) -> Array:
"""ang2pix: nside,theta[rad],phi[rad],nest=False,lonlat=False -> ipix (default:RING)
Unlike healpy.ang2pix, specifying a theta not in the range [0, π] does
not raise an error, but returns -1.
Parameters
----------
nside : int, scalar or array-like
The healpix nside parameter, must be a power of 2, less than 2**30
theta, phi : float, scalars or array-like
Angular coordinates of a point on the sphere
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat : bool
If True, input angles are assumed to be longitude and latitude in degree,
otherwise, they are co-latitude and longitude in radians.
Returns
-------
pix : int or array of int
The healpix pixel numbers. Scalar if all input are scalar, array otherwise.
Usual numpy broadcasting rules apply.
See Also
--------
pix2ang, pix2vec, vec2pix
Examples
--------
Note that some of the test inputs below that are on pixel boundaries
such as theta=π/2, phi=π/2, have a tiny value of 1e-15 added to them
to make them reproducible on i386 machines using x87 floating point
instruction set (see https://github.com/healpy/healpy/issues/528).
>>> import jax_healpy as hp
>>> from jax.numpy import pi as π
>>> hp.ang2pix(16, π/2, 0)
Array(1440, dtype=int64)
>>> print(hp.ang2pix(16, np.array([π/2, π/4, π/2, 0, π]), np.array([0., π/4, π/2 + 1e-15, 0, 0])))
[1440 427 1520 0 3068]
>>> print(hp.ang2pix(16, π/2, np.array([0, π/2 + 1e-15])))
[1440 1520]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), π/2, 0))
[ 4 12 72 336 1440]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), 0, 0, lonlat=True))
[ 4 12 72 336 1440]
"""
# check_theta_valid(theta)
check_nside(nside, nest=nest)
if nest:
raise NotImplementedError('NEST pixel ordering is not implemented.')
if lonlat:
theta, phi = _lonlat2thetaphi(theta, phi)
pixels = _zphi2pix_ring(nside, jnp.cos(theta), jnp.sin(theta), phi)
return jnp.where((theta < 0) | (theta > np.pi + 1e-5), -1, pixels)
def _zphi2pix_ring(nside: int, z: ArrayLike, sin_theta: ArrayLike, phi: ArrayLike) -> Array:
tt = jnp.mod(2 * phi / np.pi, 4)
ipix = jnp.where(
jnp.abs(z) <= 2 / 3,
_zphi2pix_equatorial_region_ring(nside, z, sin_theta, tt),
_zphi2pix_polar_caps_ring(nside, z, sin_theta, tt),
)
return ipix
def _zphi2pix_equatorial_region_ring(nside: int, z: ArrayLike, sin_theta: float, tt: ArrayLike) -> Array:
ncap = 2 * nside * (nside - 1)
nl4 = 4 * nside
jp = (nside * (0.5 + tt - 0.75 * z)).astype(int)
jm = (nside * (0.5 + tt + 0.75 * z)).astype(int)
ir = nside + 1 + jp - jm
kshift = 1 - ir & 1 # ir even -> 1, odd -> 0
t1 = jp + jm - nside + kshift + 1 + nl4 + nl4
ip = (t1 >> 1) & (nl4 - 1)
pix = ncap + (ir - 1) * nl4 + ip
return pix
def _zphi2pix_polar_caps_ring(nside: int, z: ArrayLike, sin_theta: ArrayLike, tt: ArrayLike) -> Array:
npixel = nside2npix(nside)
tp = tt - jnp.floor(tt)
# tmp = nside * sin_theta / jnp.sqrt((1 + jnp.abs(z)) / 3)
tmp = nside * jnp.sqrt(3.0 * (1.0 - jnp.abs(z)))
jp = (tp * tmp).astype(int)
jm = ((1.0 - tp) * tmp).astype(int)
ir = jp + jm + 1
ip = (tt * ir).astype(int)
return jnp.where(z > 0, 2 * ir * (ir - 1) + ip, npixel - 2 * ir * (ir + 1) + ip)
[docs]
@partial(jit, static_argnames=['nside', 'nest', 'lonlat'])
def pix2ang(nside: int, ipix: ArrayLike, nest: bool = False, lonlat: bool = False) -> tuple[Array, Array]:
"""pix2ang : nside,ipix,nest=False,lonlat=False -> theta[rad],phi[rad] (default RING)
Parameters
----------
nside : int or array-like
The healpix nside parameter, must be a power of 2, less than 2**30
ipix : int or array-like
Pixel indices
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat : bool, optional
If True, return angles will be longitude and latitude in degree,
otherwise, angles will be co-latitude and longitude in radians (default)
Returns
-------
theta, phi : float, scalar or array-like
The angular coordinates corresponding to ipix. Scalar if all input
are scalar, array otherwise. Usual numpy broadcasting rules apply.
See Also
--------
ang2pix, vec2pix, pix2vec
Examples
--------
>>> import jax_healpy as hp
>>> hp.pix2ang(16, 1440)
(1.5291175943723188, 0.0)
>>> hp.pix2ang(16, [1440, 427, 1520, 0, 3068])
(array([ 1.52911759, 0.78550497, 1.57079633, 0.05103658, 3.09055608]), array([ 0. , 0.78539816, 1.61988371, 0.78539816, 0.78539816]))
>>> hp.pix2ang([1, 2, 4, 8], 11)
(array([ 2.30052398, 0.84106867, 0.41113786, 0.2044802 ]), array([ 5.49778714, 5.89048623, 5.89048623, 5.89048623]))
>>> hp.pix2ang([1, 2, 4, 8], 11, lonlat=True)
(array([ 315. , 337.5, 337.5, 337.5]), array([-41.8103149 , 41.8103149 , 66.44353569, 78.28414761]))
""" # noqa: E501
check_nside(nside, nest=nest)
if nest:
theta, phi = _pix2ang_nest(nside, ipix)
else:
iring = _pix2i_ring(nside, ipix)
theta = _pix2theta_ring(nside, iring, ipix)
phi = _pix2phi_ring(nside, iring, ipix)
if lonlat:
return _thetaphi2lonlat(theta, phi)
return theta, phi
def _pix2i_ring(nside: int, pixels: ArrayLike) -> Array:
npixel = nside2npix(nside)
ncap = 2 * nside * (nside - 1)
iring = jnp.where(
pixels < ncap,
_pix2i_north_cap_ring(nside, pixels),
jnp.where(
pixels < npixel - ncap,
_pix2i_equatorial_region_ring(nside, pixels),
_pix2i_south_cap_ring(nside, pixels),
),
)
return iring
def _pix2i_north_cap_ring(nside: int, pixels: ArrayLike) -> Array:
return (1 + jnp.sqrt(1 + 2 * pixels).astype(int)) >> 1 # counted from North Pole
def _pix2i_equatorial_region_ring(nside: int, pixels: ArrayLike) -> Array:
ncap = 2 * nside * (nside - 1)
ip = pixels - ncap
order = nside2order(nside)
# I tmp = (order_>=0) ? ip>>(order_+2) : ip/nl4;
tmp = ip >> (order + 2)
return tmp + nside
def _pix2i_south_cap_ring(nside: int, pixels: ArrayLike) -> Array:
npixel = nside2npix(nside)
ip = npixel - pixels
return (1 + jnp.sqrt(2 * ip - 1).astype(int)) >> 1 # counted from South Pole
def _pix2z_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> tuple[Array, Array]:
npixel = nside2npix(nside)
ncap = 2 * nside * (nside - 1)
abs_one_minus_z = _pix2z_polar_caps_ring(nside, iring)
z = jnp.where(
pixels < ncap,
1 - abs_one_minus_z,
jnp.where(
pixels < npixel - ncap,
_pix2z_equatorial_region_ring(nside, iring),
abs_one_minus_z - 1,
),
)
return z, abs_one_minus_z
def _pix2z_polar_caps_ring(nside: int, iring: ArrayLike) -> Array:
npixel = nside2npix(nside)
return iring * iring * 4 / npixel
def _pix2z_equatorial_region_ring(nside: int, iring: ArrayLike) -> Array:
return (2 * nside - iring) * 2 / 3 / nside
def _pix2theta_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array:
z, abs_one_minus_z = _pix2z_ring(nside, iring, pixels)
theta = jnp.where(
jnp.abs(z) > 0.99,
jnp.arctan2(jnp.sqrt(abs_one_minus_z * (2 - abs_one_minus_z)), z),
jnp.arccos(z),
)
return theta
def _pix2phi_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array:
npixel = nside2npix(nside)
ncap = 2 * nside * (nside - 1)
phi = jnp.where(
pixels < ncap,
_pix2phi_north_cap_ring(nside, iring, pixels),
jnp.where(
pixels < npixel - ncap,
_pix2phi_equatorial_region_ring(nside, iring, pixels),
_pix2phi_south_cap_ring(nside, iring, pixels),
),
)
return phi
def _pix2phi_north_cap_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array:
iphi = pixels + 1 - 2 * iring * (iring - 1)
phi = (iphi - 0.5) * np.pi / 2 / iring
return phi
def _pix2phi_equatorial_region_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array:
iphi = pixels + 2 * nside * (nside + 1) - 4 * nside * iring + 1
fodd = ((iring + nside) & 1) * 0.5 + 0.5 # iring + nside odd -> 1 else 0.5
phi = (iphi - fodd) * np.pi / 2 / nside
return phi
def _pix2phi_south_cap_ring(nside: int, iring: ArrayLike, pixels: ArrayLike) -> Array:
npixel = nside2npix(nside)
iphi = 4 * iring + 1 - (npixel - pixels - 2 * iring * (iring - 1))
phi = (iphi - 0.5) * np.pi / 2 / iring
return phi
def _pix2ang_nest(nside: ArrayLike, ipix: ArrayLike) -> tuple[Array, Array]:
raise NotImplementedError('NEST pixel ordering is not implemented.')
# template<typename I> void T_Healpix_Base<I>::pix2loc (I pix, double &z,
# double &phi, double &sth, bool &have_sth) const
# have_sth=false;
# {
# int face_num, ix, iy;
# nest2xyf(pix,ix,iy,face_num);
#
# I jr = (I(jrll[face_num])<<order_) - ix - iy - 1;
#
# I nr;
# if (jr<nside_)
# {
# nr = jr;
# double tmp=(nr*nr)*fact2_;
# z = 1 - tmp;
# if (z>0.99) { sth=sqrt(tmp*(2.-tmp)); have_sth=true; }
# }
# else if (jr > 3*nside_)
# {
# nr = nside_*4-jr;
# double tmp=(nr*nr)*fact2_;
# z = tmp - 1;
# if (z<-0.99) { sth=sqrt(tmp*(2.-tmp)); have_sth=true; }
# }
# else
# {
# nr = nside_;
# z = (2*nside_-jr)*fact1_;
# }
#
# I tmp=I(jpll[face_num])*nr+ix-iy;
# planck_assert(tmp<8*nr,"must not happen");
# if (tmp<0) tmp+=8*nr;
# phi = (nr==nside_) ? 0.75*halfpi*tmp*fact1_ :
# (0.5*halfpi*tmp)/nr;
# }
# }
[docs]
@partial(jit, static_argnames=['nside', 'nest'])
def xyf2pix(nside: int, x: ArrayLike, y: ArrayLike, face: ArrayLike, nest: bool = False) -> Array:
"""xyf2pix : nside,x,y,face,nest=False -> ipix (default:RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
Parameters
----------
nside : int
The healpix nside parameter, must be a power of 2
x, y : int, scalars or array-like
Pixel indices within face
face : int, scalars or array-like
Face number
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
Returns
-------
pix : int or array of int
The healpix pixel numbers. Scalar if all input are scalar, array otherwise.
Usual numpy broadcasting rules apply.
See Also
--------
pix2xyf
Examples
--------
>>> import healpy as hp
>>> hp.xyf2pix(16, 8, 8, 4)
1440
>>> print(hp.xyf2pix(16, [8, 8, 8, 15, 0], [8, 8, 7, 15, 0], [4, 0, 5, 0, 8]))
[1440 427 1520 0 3068]
"""
check_nside(nside, nest=nest)
x = jnp.asarray(x)
y = jnp.asarray(y)
face = jnp.asarray(face)
if nest:
return _xyf2pix_nest(nside, x, y, face)
else:
return _xyf2pix_ring(nside, x, y, face)
def _xyf2pix_nest(nside: int, ix: Array, iy: Array, fnum: Array) -> Array:
"""Convert (x, y, face) to pixel number in NESTED ordering"""
fpix = _xy2fpix(nside, ix, iy)
nested_pixel = fnum * nside**2 + fpix
return nested_pixel
def _xy2fpix(nside: int, ix: Array, iy: Array) -> Array:
"""Convert (x, y) coordinates to a pixel index inside a face"""
# fpix = (ix & 0b1) << 0 | (iy & 0b1) << 1 | (ix & 0b10) << 1 | (iy & 0b10) << 2 | ...
def combine_bits(i, val):
val |= (ix & (1 << i)) << i
val |= (iy & (1 << i)) << (i + 1)
return val
# ix and iy are always less than nside, so there is no need to extract more bits than this
length = (nside - 1).bit_length()
# we use a native for loop because it was slightly faster than lax.fori_loop with unroll=True
fpix = jnp.zeros_like(ix)
for i in range(length):
fpix = combine_bits(i, fpix)
return fpix
# ring index of south corner for each face (0 = North pole)
_JRLL = jnp.array([2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4])
# longitude index of south corner for each face (0 = longitude zero)
_JPLL = jnp.array([1, 3, 5, 7, 0, 2, 4, 6, 1, 3, 5, 7])
def _xyf2pix_ring(nside: int, ix: Array, iy: Array, face_num: Array) -> Array:
"""Convert (x, y, face) to a pixel number in RING ordering"""
# ring index of the pixel center
jr = (_JRLL[face_num] * nside) - ix - iy - 1
ringpix = _npix_on_ring(nside, jr)
startpix = _start_pixel_ring(nside, jr)
kshift = 1 - _ring_shifted(nside, jr)
# pixel number in the ring
jp = (_JPLL[face_num] * ringpix // 4 + ix - iy + 1 + kshift) // 2
jp = jnp.where(jp < 1, jp + 4 * nside, jp)
return startpix - 1 + jp
def _start_pixel_ring(nside: int, i_ring: ArrayLike) -> Array:
"""Get the first pixel number of a ring"""
# work in northern hemisphere
i_north = _northern_ring(nside, i_ring)
ringpix = _npix_on_ring(nside, i_ring)
ncap = 2 * nside * (nside - 1)
npix = nside2npix(nside)
startpix = jnp.where(
i_north < nside,
2 * i_north * (i_north - 1), # north polar cap
ncap + (i_north - nside) * 4 * nside, # north equatorial belt
)
# flip results if in southern hemisphere
startpix = jnp.where(i_ring != i_north, npix - startpix - ringpix, startpix)
return startpix
def _npix_on_ring(nside: int, i_ring: ArrayLike) -> Array:
"""Get the number of pixels on a ring"""
i_north = _northern_ring(nside, i_ring)
ringpix = jnp.where(
i_north < nside,
4 * i_north, # rings in the polar cap have 4*i pixels
4 * nside, # rings in the equatorial region have 4*nside pixels
)
return ringpix
def _ring_shifted(nside: int, i_ring: ArrayLike) -> Array:
"""Check if a ring is shifted"""
i_north = _northern_ring(nside, i_ring)
shifted = jnp.where(
i_north < nside,
True,
(i_north - nside) & 1 == 0,
)
return shifted
def _northern_ring(nside: int, i_ring: ArrayLike) -> Array:
i_north = jnp.where(i_ring > 2 * nside, 4 * nside - i_ring, i_ring)
return i_north
def _ring_above(nside: int, cos_theta: ArrayLike) -> Array:
"""Find the ring index just north of a point with given cos(theta).
This follows the exact HEALPix C++ implementation:
if (az <= twothird) // equatorial region
return I(nside_*(2-1.5*z));
I iring = I(nside_*sqrt(3*(1-az)));
return (z>0) ? iring : 4*nside_-iring-1;
"""
z = cos_theta
az = jnp.abs(z)
twothird = 2.0 / 3.0
# Equatorial region
equatorial_ring = nside * (2.0 - 1.5 * z)
# Stop gradient: Ring indices are discrete selectors, don't affect interpolation math
equatorial_ring = lax.stop_gradient(jnp.floor(equatorial_ring).astype(jnp.int32))
# Polar caps
iring = nside * jnp.sqrt(3.0 * (1.0 - az))
# Stop gradient: Ring indices are discrete selectors, don't affect interpolation math
iring = lax.stop_gradient(jnp.floor(iring).astype(jnp.int32))
polar_ring = jnp.where(z > 0, iring, 4 * nside - iring - 1)
# Choose based on z value
ring_idx = jnp.where(az <= twothird, equatorial_ring, polar_ring)
return ring_idx
def _get_ring_info(nside: int, ring_idx: ArrayLike) -> tuple[Array, Array, Array, Array]:
"""Get ring properties following HEALPix C++ get_ring_info2 exactly.
Returns: theta, startpix, ringpix, shifted
"""
# Convert to scalar for compatibility
ring = ring_idx
# HEALPix C++ constants
fact1 = 2.0 / (3.0 * nside) # (nside_<<1)*fact2_ where fact2_ = 4./npix_ = 1/(3*nside**2)
fact2 = 4.0 / (12.0 * nside * nside) # 4./npix_
ncap = 2 * nside * (nside - 1)
npix_total = 12 * nside * nside
# Northern hemisphere equivalent ring
northring = jnp.where(ring > 2 * nside, 4 * nside - ring, ring)
# Polar cap region (northring < nside)
polar_tmp = northring * northring * fact2
polar_costheta = 1.0 - polar_tmp
polar_sintheta = jnp.sqrt(polar_tmp * (2.0 - polar_tmp))
polar_theta = jnp.arctan2(polar_sintheta, polar_costheta)
polar_ringpix = 4 * northring
polar_shifted = True
polar_startpix = 2 * northring * (northring - 1)
# Equatorial region (northring >= nside)
equatorial_theta = jnp.arccos((2.0 * nside - northring) * fact1)
equatorial_ringpix = 4 * nside
equatorial_shifted = ((northring - nside) & 1) == 0
equatorial_startpix = ncap + (northring - nside) * equatorial_ringpix
# Choose based on region
theta = jnp.where(northring < nside, polar_theta, equatorial_theta)
ringpix = jnp.where(northring < nside, polar_ringpix, equatorial_ringpix)
shifted = jnp.where(northring < nside, polar_shifted, equatorial_shifted)
startpix = jnp.where(northring < nside, polar_startpix, equatorial_startpix)
# Southern hemisphere correction
theta = jnp.where(northring != ring, np.pi - theta, theta)
startpix = jnp.where(northring != ring, npix_total - startpix - ringpix, startpix)
# Convert shifted boolean to float (0.0 or 0.5)
shift = jnp.where(shifted, 0.5, 0.0)
return theta, startpix, ringpix, shift
[docs]
@partial(jit, static_argnames=['nside', 'nest'])
def pix2xyf(nside: int, ipix: ArrayLike, nest: bool = False) -> tuple[Array, Array, Array]:
"""pix2xyf : nside,ipix,nest=False -> x,y,face (default RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
Parameters
----------
nside : int
The healpix nside parameter, must be a power of 2
ipix : int or array-like
Pixel indices
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
Returns
-------
x, y : int, scalars or array-like
Pixel indices within face
face : int, scalars or array-like
Face number
See Also
--------
xyf2pix
Examples
--------
>>> import healpy as hp
>>> hp.pix2xyf(16, 1440)
(8, 8, 4)
>>> hp.pix2xyf(16, [1440, 427, 1520, 0, 3068])
(array([ 8, 8, 8, 15, 0]), array([ 8, 8, 7, 15, 0]), array([4, 0, 5, 0, 8]))
>>> hp.pix2xyf([1, 2, 4, 8], 11)
(array([0, 1, 3, 7]), array([0, 0, 2, 6]), array([11, 3, 3, 3]))
"""
check_nside(nside, nest=nest)
ipix = jnp.asarray(ipix)
if nest:
return _pix2xyf_nest(nside, ipix)
else:
return _pix2xyf_ring(nside, ipix)
def _pix2xyf_nest(nside: int, pix: Array) -> tuple[Array, Array, Array]:
"""Convert a pixel number in NESTED ordering to (x, y, face)"""
fnum, fpix = jnp.divmod(pix, nside**2)
ix, iy = _fpix2xy(nside, fpix)
return ix, iy, fnum
def _fpix2xy(nside: int, pix: Array) -> tuple[Array, Array]:
"""Convert a pixel index inside a face into (x, y) coordinates.
Pixel indices inside the face must be less than nside**2.
"""
# x = (pix & 0b1) >> 0 | (pix & 0b100) >> 1 | (pix & 0b10000) >> 2 | ...
# y = (pix & 0b10) >> 1 | (pix & 0b1000) >> 2 | (pix & 0b100000) >> 3 | ...
def extract_bits(i, carry):
x, y = carry
x |= (pix & (1 << (2 * i))) >> i
y |= (pix & (1 << (2 * i + 1))) >> (i + 1)
return x, y
# imagine that nside = 2 ** ord (nside must be a power of 2 in nested ordering)
# the maximum value of pix is nside**2 - 1, which fits on 2 * ord bits
# because we extract 2 bits at a time, we need to loop ord times
# and ord is the bit length of nside - 1
length = (nside - 1).bit_length()
# we use a native for loop because it was slightly faster than lax.fori_loop with unroll=True
x, y = jnp.zeros_like(pix), jnp.zeros_like(pix)
for i in range(length):
x, y = extract_bits(i, (x, y))
return x, y
def _pix2xyf_ring(nside: int, pix: Array) -> tuple[Array, Array, Array]:
"""Convert a pixel number in RING ordering to (x, y, face)"""
ncap = 2 * nside * (nside - 1)
npix = nside2npix(nside)
nl2 = 2 * nside # number of pixels in a latitude circle
# TODO(simon): remove this cast when https://github.com/CMBSciPol/jax-healpy/issues/4 is fixed
iring = _pix2i_ring(nside, pix).astype(_pixel_dtype_for(nside))
iphi = _pix2iphi_ring(nside, iring, pix)
nr = _npix_on_ring(nside, iring) // 4
kshift = 1 - _ring_shifted(nside, iring)
ire = iring - nside + 1
irm = nl2 + 2 - ire
ifm = (iphi - ire // 2 + nside - 1) // nside
ifp = (iphi - irm // 2 + nside - 1) // nside
face_num = jnp.where(
pix < ncap,
(iphi - 1) // nr, # north polar cap
jnp.where(
pix < (npix - ncap),
jnp.where(ifp == ifm, ifp | 4, jnp.where(ifp < ifm, ifp, ifm + 8)),
8 + (iphi - 1) // nr, # south polar cap
),
)
iring_for_irt = jnp.where(
jnp.logical_or(pix < ncap, pix < (npix - ncap)),
iring, # north polar cap and equatorial region
4 * nside - iring, # south polar cap
) # ring number counted from North pole or South pole
irt = iring_for_irt - (_JRLL[face_num] * nside) + 1
ipt = 2 * iphi - _JPLL[face_num] * nr - kshift - 1
ipt -= jnp.where(ipt >= nl2, 8 * nside, 0)
ix = (ipt - irt) // 2
iy = (-ipt - irt) // 2
return ix, iy, face_num
def _pix2iphi_ring(nside: int, iring: Array, pixels: Array) -> Array:
npixel = nside2npix(nside)
ncap = 2 * nside * (nside - 1)
iphi = jnp.where(
pixels < ncap,
_pix2iphi_north_cap_ring(nside, iring, pixels),
jnp.where(
pixels < npixel - ncap,
_pix2iphi_equatorial_region_ring(nside, iring, pixels),
_pix2iphi_south_cap_ring(nside, iring, pixels),
),
)
return iphi
def _pix2iphi_north_cap_ring(nside: int, iring: Array, pixels: Array) -> Array:
iphi = pixels + 1 - 2 * iring * (iring - 1)
return iphi
def _pix2iphi_equatorial_region_ring(nside: int, iring: Array, pixels: Array) -> Array:
iphi = pixels + 2 * nside * (nside + 1) - 4 * nside * iring + 1
return iphi
def _pix2iphi_south_cap_ring(nside: int, iring: Array, pixels: Array) -> Array:
npixel = nside2npix(nside)
iphi = 4 * iring + 1 - (npixel - pixels - 2 * iring * (iring - 1))
return iphi
[docs]
@partial(jit, static_argnames=['nside', 'nest'])
def vec2pix(nside: int, x: ArrayLike, y: ArrayLike, z: ArrayLike, nest: bool = False) -> Array:
"""vec2pix : nside,x,y,z,nest=False -> ipix (default:RING)
Parameters
----------
nside : int or array-like
The healpix nside parameter, must be a power of 2, less than 2**30
x,y,z : floats or array-like
vector coordinates defining point on the sphere
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
Returns
-------
ipix : int, scalar or array-like
The healpix pixel number corresponding to input vector. Scalar if all input
are scalar, array otherwise. Usual numpy broadcasting rules apply.
See Also
--------
ang2pix, pix2ang, pix2vec
Examples
--------
>>> import healpy as hp
>>> hp.vec2pix(16, 1, 0, 0)
1504
>>> print(hp.vec2pix(16, [1, 0], [0, 1], [0, 0]))
[1504 1520]
>>> print(hp.vec2pix([1, 2, 4, 8], 1, 0, 0))
[ 4 20 88 368]
"""
check_nside(nside, nest=nest)
if nest:
raise NotImplementedError
return _vec2pix_ring(nside, x, y, z)
def vec2pix2(nside: int, vec: ArrayLike, nest: bool = False) -> Array:
return vec2pix2_ring(nside, vec)
@partial(jit, static_argnames='nside')
@partial(vmap, in_axes=(None, 1))
def vec2pix2_ring(nside: int, vec: ArrayLike) -> Array:
vec /= jnp.sqrt(jnp.sum(vec**2))
phi = jnp.arctan2(vec[1], vec[0])
# return _zphi2pix_ring(nside, vec[2], jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), phi)
return _zphi2pix_ring(nside, vec[2], jnp.sqrt(vec[0] ** 2 + vec[1] ** 2), phi)
def _vec2pix_ring(nside: int, x: ArrayLike, y: ArrayLike, z: ArrayLike) -> Array:
dnorm = 1 / jnp.sqrt(x**2 + y**2 + z**2)
z *= dnorm
phi = jnp.arctan2(y, x)
return _zphi2pix_ring(nside, z, jnp.sqrt(x**2 + y**2) * dnorm, phi)
[docs]
@partial(jit, static_argnames=['nside', 'nest'])
def pix2vec(nside: int, ipix: ArrayLike, nest: bool = False) -> Array:
"""pix2vec : nside,ipix,nest=False -> x,y,z (default RING)
Parameters
----------
nside : int, scalar or array-like
The healpix nside parameter, must be a power of 2, less than 2**30
ipix : int, scalar or array-like
Healpix pixel number
nest : bool, optional
if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
Returns
-------
x, y, z : floats, scalar or array-like
The coordinates of vector corresponding to input pixels. Scalar if all input
are scalar, array otherwise. Usual numpy broadcasting rules apply.
See Also
--------
ang2pix, pix2ang, vec2pix
Examples
--------
>>> import healpy as hp
>>> hp.pix2vec(16, 1504)
(0.99879545620517241, 0.049067674327418015, 0.0)
>>> hp.pix2vec(16, [1440, 427])
(array([ 0.99913157, 0.5000534 ]), array([ 0. , 0.5000534]), array([ 0.04166667, 0.70703125]))
>>> hp.pix2vec([1, 2], 11)
(array([ 0.52704628, 0.68861915]), array([-0.52704628, -0.28523539]), array([-0.66666667, 0.66666667]))
"""
check_nside(nside, nest=nest)
if nest:
raise NotImplementedError
return _pix2vec_ring(nside, ipix)
def _pix2vec_ring(nside, pixels):
iring = _pix2i_ring(nside, pixels)
z, abs_one_minus_z = _pix2z_ring(nside, iring, pixels)
phi = _pix2phi_ring(nside, iring, pixels)
sin_theta = jnp.sqrt(
jnp.where(
jnp.abs(z) > 0.99,
abs_one_minus_z * (2 - abs_one_minus_z),
(1 - z) * (1 + z),
)
)
return jnp.array([sin_theta * jnp.cos(phi), sin_theta * jnp.sin(phi), z]).T
[docs]
@partial(jit, static_argnames=['lonlat'])
def ang2vec(theta: ArrayLike, phi: ArrayLike, lonlat: bool = False) -> Array:
"""ang2vec : convert angles to 3D position vector
Parameters
----------
theta : float, scalar or array-like
co-latitude in radians measured southward from the North pole (in [0,pi]).
phi : float, scalar or array-like
longitude in radians measured eastward (in [0, 2*pi]).
lonlat : bool
If True, input angles are assumed to be longitude and latitude in degree,
otherwise, they are co-latitude and longitude in radians.
Returns
-------
vec : float, array
if theta and phi are vectors, the result is a 2D array with a vector per row
otherwise, it is a 1D array of shape (3,)
See Also
--------
vec2ang, rotator.dir2vec, rotator.vec2dir
"""
if lonlat:
theta, phi = _lonlat2thetaphi(theta, phi)
theta = jnp.where((theta < 0) | (theta > np.pi + 1e-5), np.nan, theta)
sin_theta = jnp.sin(theta)
x = sin_theta * jnp.cos(phi)
y = sin_theta * jnp.sin(phi)
z = jnp.cos(theta)
return jnp.array([x, y, z]).T
[docs]
@partial(jit, static_argnames=['lonlat'])
def vec2ang(vectors: ArrayLike, lonlat: bool = False) -> tuple[Array, Array]:
"""vec2ang: vectors [x, y, z] -> theta[rad], phi[rad]
Parameters
----------
vectors : float, array-like
the vector(s) to convert, shape is (3,) or (N, 3)
lonlat : bool, optional
If True, return angles will be longitude and latitude in degree,
otherwise, angles will be co-latitude and longitude in radians (default)
Returns
-------
theta, phi : float, tuple of two arrays
the colatitude and longitude in radians
See Also
--------
ang2vec, rotator.vec2dir, rotator.dir2vec
"""
vectors = vectors.reshape(-1, 3)
dnorm = jnp.sqrt(vectors[..., 0] ** 2 + vectors[..., 1] ** 2 + vectors[..., 2] ** 2)
theta = jnp.arccos(vectors[:, 2] / dnorm)
phi = jnp.arctan2(vectors[:, 1], vectors[:, 0])
phi = jnp.where(phi < 0, phi + 2 * np.pi, phi)
if lonlat:
return _thetaphi2lonlat(theta, phi)
return theta, phi
[docs]
@partial(jit, static_argnames=['nside'])
def ring2nest(nside: int, ipix: ArrayLike) -> Array:
"""Convert pixel number from RING ordering to NESTED ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
Parameters
----------
nside : int
the healpix nside parameter
ipix : int, scalar or array-like
the pixel number in RING scheme
Returns
-------
ipix : int, scalar or array-like
the pixel number in NESTED scheme
See Also
--------
nest2ring, reorder
Examples
--------
>>> import healpy as hp
>>> hp.ring2nest(16, 1504)
1130
>>> print(hp.ring2nest(2, np.arange(10)))
[ 3 7 11 15 2 1 6 5 10 9]
>>> print(hp.ring2nest([1, 2, 4, 8], 11))
[ 11 13 61 253]
"""
check_nside(nside, nest=True)
ipix = jnp.asarray(ipix)
# promote to int64 only if nside requires it
ipix = ipix.astype(jnp.promote_types(ipix.dtype, _pixel_dtype_for(nside)))
xyf = _pix2xyf_ring(nside, ipix)
ipix_nest = _xyf2pix_nest(nside, *xyf)
return ipix_nest
[docs]
@partial(jit, static_argnames=['nside'])
def nest2ring(nside: int, ipix: ArrayLike) -> Array:
"""Convert pixel number from NESTED ordering to RING ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
Parameters
----------
nside : int
the healpix nside parameter
ipix : int, scalar or array-like
the pixel number in NESTED scheme
Returns
-------
ipix : int, scalar or array-like
the pixel number in RING scheme
See Also
--------
ring2nest, reorder
Examples
--------
>>> import healpy as hp
>>> hp.nest2ring(16, 1130)
1504
>>> print(hp.nest2ring(2, np.arange(10)))
[13 5 4 0 15 7 6 1 17 9]
>>> print(hp.nest2ring([1, 2, 4, 8], 11))
[ 11 2 12 211]
"""
check_nside(nside, nest=True)
ipix = jnp.asarray(ipix)
# promote to int64 only if nside requires it
ipix = ipix.astype(jnp.promote_types(ipix.dtype, _pixel_dtype_for(nside)))
xyf = _pix2xyf_nest(nside, ipix)
ipix_ring = _xyf2pix_ring(nside, *xyf)
return ipix_ring
[docs]
@partial(jit, static_argnames=['inp', 'out', 'r2n', 'n2r', 'process_by_chunks'])
def reorder(
map_in: ArrayLike,
inp: str | None = None,
out: str | None = None,
r2n: bool = False,
n2r: bool = False,
process_by_chunks: bool = False,
) -> Array:
"""Reorder a healpix map from RING/NESTED ordering to NESTED/RING.
Masked arrays are not yet supported.
By default, the maps are processed in one go, but if memory is an issue,
use the ``process_by_chunks`` option (which reproduces healpy behaviour).
Parameters
----------
map_in : array-like
the input map to reorder, accepts masked arrays
inp, out : ``'RING'`` or ``'NESTED'``
define the input and output ordering
r2n : bool
if True, reorder from RING to NESTED
n2r : bool
if True, reorder from NESTED to RING
Returns
-------
map_out : array-like
the reordered map, as masked array if the input was a
masked array
Notes
-----
if ``r2n`` or ``n2r`` is defined, override ``inp`` and ``out``.
See Also
--------
nest2ring, ring2nest
Examples
--------
>>> import healpy as hp
>>> hp.reorder(np.arange(48), r2n = True)
array([13, 5, 4, 0, 15, 7, 6, 1, 17, 9, 8, 2, 19, 11, 10, 3, 28,
20, 27, 12, 30, 22, 21, 14, 32, 24, 23, 16, 34, 26, 25, 18, 44, 37,
36, 29, 45, 39, 38, 31, 46, 41, 40, 33, 47, 43, 42, 35])
>>> hp.reorder(np.arange(12), n2r = True)
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
>>> hp.reorder(hp.ma(np.arange(12.)), n2r = True)
masked_array(data = [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.],
mask = False,
fill_value = -1.6375e+30)
<BLANKLINE>
>>> m = [np.arange(12.), np.arange(12.), np.arange(12.)]
>>> m[0][2] = hp.UNSEEN
>>> m[1][2] = hp.UNSEEN
>>> m[2][2] = hp.UNSEEN
>>> m = hp.ma(m)
>>> hp.reorder(m, n2r = True)
masked_array(data =
[[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]
[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]
[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]],
mask =
[[False False True False False False False False False False False False]
[False False True False False False False False False False False False]
[False False True False False False False False False False False False]],
fill_value = -1.6375e+30)
<BLANKLINE>
"""
# Check input map(s)
map_in = jnp.asarray(map_in)
if map_in.ndim == 0:
raise ValueError('Input map can not be a scalar')
npix = map_in.shape[-1]
nside = npix2nside(npix)
# npix2nside already fails on bad number of pixels
# but in nested ordering we must also ensure that nside is power of 2
check_nside(nside, nest=True)
# Check input parameters
if r2n and n2r:
raise ValueError('r2n and n2r cannot be used simultaneously')
if r2n:
inp, out = 'RING', 'NEST'
if n2r:
inp, out = 'NEST', 'RING'
inp, out = str(inp).upper()[:4], str(out).upper()[:4]
if not {inp, out}.issubset({'RING', 'NEST'}):
raise ValueError('inp and out must be either RING or NEST')
if inp == out:
return map_in
# Perform the conversion, which is just a reordering of the pixels
def _reorder(ipix):
if inp == 'RING':
ipix_reordered = nest2ring(nside, ipix)
else:
ipix_reordered = ring2nest(nside, ipix)
return map_in[..., ipix_reordered]
if not process_by_chunks:
ipix_full = jnp.arange(npix, dtype=_pixel_dtype_for(nside))
return _reorder(ipix_full)
# To reduce memory requirements, process the map in chunks
chunk_size = npix // 24 if nside > 128 else npix
n_chunks = npix // chunk_size
def body(i, map_out):
# interval bounds must be static, so we shift the values afterwards
ipix_chunk = jnp.arange(chunk_size, dtype=_pixel_dtype_for(nside)) + i * chunk_size
return map_out.at[..., ipix_chunk].set(_reorder(ipix_chunk))
map_out = lax.fori_loop(0, n_chunks, body, jnp.empty_like(map_in))
return map_out
[docs]
@partial(jit, static_argnames=['nside', 'nest', 'lonlat'])
def get_interp_weights(
nside: int, theta: ArrayLike, phi: ArrayLike | None = None, nest: bool = False, lonlat: bool = False
) -> tuple[Array, Array]:
"""Return interpolation weights for given coordinates.
This function performs bilinear interpolation by finding the four
nearest pixel centers and computing their interpolation weights.
Provides machine precision matching healpy when pixels are sorted.
Parameters
----------
nside : int
HEALPix nside parameter
theta : ArrayLike
Colatitude in radians (or pixel indices if phi is None)
phi : ArrayLike, optional
Longitude in radians (or degrees if lonlat=True)
nest : bool, optional
If True, use NESTED pixel ordering (raises error)
lonlat : bool, optional
If True, interpret theta, phi as longitude, latitude in degrees
Returns
-------
pixels : Array
Array of shape (4, N) containing the four nearest pixel indices.
Pixel order is not guaranteed - sort both pixels and weights by
pixel values for exact healpy precision matching.
weights : Array
Array of shape (4, N) containing the interpolation weights.
Weights sum to 1.0 for each point to machine precision.
Notes
-----
For exact healpy compatibility, sort the output by pixel values:
>>> sorted_indices = jnp.argsort(pixels, axis=0)
>>> sorted_pixels = jnp.take_along_axis(pixels, sorted_indices, axis=0)
>>> sorted_weights = jnp.take_along_axis(weights, sorted_indices, axis=0)
Precision and Algorithmic Considerations:
----------------------------------------
The phi interpolation calculation can exhibit precision differences compared to healpy,
particularly for coordinates near the poles and for high nside values (256+). This is due to:
1. **Phi Interpolation Near Poles**: The formula `phi_norm = (phi / dphi - shift) % nr`
can produce different results than healpy's algorithm when transitioning between rings
with different pixel counts (e.g., ring transitions near poles).
2. **High Nside Precision Limits**: For nside ≥ 256, floating-point precision limits
can cause the JAX implementation to select different (but mathematically valid)
interpolation neighbors compared to healpy, especially in challenging geometric regions.
3. **Ring Transition Edge Cases**: Coordinates exactly at boundaries between polar caps
and equatorial regions may show pixel differences due to algorithmic implementation
variations, though weights still sum to 1.0 and maintain interpolation accuracy.
Expected precision levels:
- nside ≤ 64: Machine precision matching (< 1e-25 weight error)
- nside 128-256: High precision (< 1e-15 weight error)
- nside ≥ 512: Good precision (< 1e-12 to 1e-5 weight error depending on nside)
Automatic Differentiation:
--------------------------
This function is fully compatible with JAX's automatic differentiation.
The gradients behave as follows:
- Gradients of individual weights reflect the continuous dependence on coordinates
- Gradients of sum(weights) are always zero since the sum is identically 1.0
- Pixel indices have zero gradients since they are discrete selectors
Example gradient usage:
>>> def interpolate_map(m, theta, phi):
... pixels, weights = get_interp_weights(nside, theta, phi)
... return jnp.sum(weights * m[pixels], axis=0)
>>> grad_func = jax.grad(interpolate_map, argnums=(1, 2))
>>> grad_theta, grad_phi = grad_func(map_data, theta, phi)
"""
check_nside(nside, nest=nest)
if nest:
raise ValueError('NEST pixel ordering is not supported. Only RING ordering is supported.')
# Handle different input modes
if phi is None:
# theta contains pixel indices, convert to (theta, phi) coordinates
theta_coords, phi_coords = pix2ang(nside, theta, nest=False, lonlat=False)
else:
theta_coords, phi_coords = jnp.asarray(theta), jnp.asarray(phi)
if lonlat:
theta_coords, phi_coords = _lonlat2thetaphi(theta_coords, phi_coords)
# Call the RING implementation
return _get_interp_weights_ring(nside, theta_coords, phi_coords)
def _get_interp_weights_ring(nside: int, theta_coords: Array, phi_coords: Array) -> tuple[Array, Array]:
"""
Memory-optimized implementation of bilinear interpolation for RING ordering.
This optimized version reduces temporary memory usage by 2.8x while maintaining
full numerical precision by:
1. Eliminating excessive conditional operations that create intermediate arrays
2. Using direct computation instead of conditional masking
3. Streamlined special case handling with mathematical formulas
4. Efficient array construction using stack operations
Gradient Compatibility:
----------------------
This function is fully compatible with JAX's automatic differentiation system.
The implementation carefully separates discrete operations (pixel selection) from
continuous operations (weight computation):
- Discrete pixel indices use `lax.stop_gradient()` to prevent gradient flow
through non-differentiable operations like `jnp.floor().astype(int)`
- Weight computations use continuous mathematical operations that preserve gradients
- Final weight normalization enforces the constraint sum(weights) = 1.0, ensuring
that gradients of the weight sum are exactly zero
The stop_gradient usage is mathematically sound because:
1. Pixel indices are discrete selectors that don't affect the interpolation mathematics
2. Weight values depend continuously on input coordinates within each pixel region
3. The fundamental constraint sum(weights) = 1.0 must hold regardless of pixel selection
This design allows proper gradient flow for meaningful computations (like map
interpolation) while maintaining numerical precision and memory efficiency.
"""
# Core computation - minimal intermediate arrays
z = jnp.cos(theta_coords)
ir1 = _ring_above(nside, z)
ir2 = ir1 + 1
# Special case flags - compute once
is_north_pole = ir1 == 0
is_south_pole = ir2 == (4 * nside)
is_normal = ~is_north_pole & ~is_south_pole
# Safe ring indices for _get_ring_info calls
ir1_safe = jnp.maximum(ir1, 1)
ir2_safe = jnp.minimum(ir2, 4 * nside - 1)
# Get ring properties - only two function calls needed
theta1, sp1, nr1, shift1 = _get_ring_info(nside, ir1_safe)
theta2, sp2, nr2, shift2 = _get_ring_info(nside, ir2_safe)
# Core phi interpolation computation
dphi1 = 2.0 * jnp.pi / nr1
dphi2 = 2.0 * jnp.pi / nr2
# Phi interpolation indices and weights
phi1_norm = (phi_coords / dphi1 - shift1) % nr1
phi2_norm = (phi_coords / dphi2 - shift2) % nr2
# Compute pixel indices (for pixel selection only)
# Stop gradient: Floor+cast operations are non-differentiable and only used for indexing
i1_1 = lax.stop_gradient(jnp.floor(phi1_norm).astype(jnp.int32))
i1_2 = lax.stop_gradient(jnp.floor(phi2_norm).astype(jnp.int32))
# Compute weights using gradient-friendly fractional parts
# Use modulo instead of floor subtraction for better gradient behavior
w_phi1 = phi1_norm % 1.0
w_phi2 = phi2_norm % 1.0
i2_1 = (i1_1 + 1) % nr1
i2_2 = (i1_2 + 1) % nr2
# Theta interpolation weight computation
theta_denom = jnp.where(is_normal, theta2 - theta1, 1.0) # Avoid div by 0
w_theta_base = jnp.where(is_normal, (theta_coords - theta1) / theta_denom, 0.0)
# Special case adjustments using mathematical formulas
w_theta_north = jnp.where(is_north_pole, theta_coords / theta2, w_theta_base)
w_theta_south = jnp.where(is_south_pole, (theta_coords - theta1) / (jnp.pi - theta1), w_theta_base)
# Pixel computation - direct mathematical approach
# Normal case pixels
pixels_ring1_1 = sp1 + i1_1
pixels_ring1_2 = sp1 + i2_1
pixels_ring2_1 = sp2 + i1_2
pixels_ring2_2 = sp2 + i2_2
# North pole pixel adjustments
npix_total = 12 * nside * nside
pixels_ring1_1 = jnp.where(is_north_pole, (pixels_ring2_1 + 2) & 3, pixels_ring1_1)
pixels_ring1_2 = jnp.where(is_north_pole, (pixels_ring2_2 + 2) & 3, pixels_ring1_2)
# South pole pixel adjustments
pixels_ring2_1 = jnp.where(is_south_pole, ((pixels_ring1_1 + 2) & 3) + npix_total - 4, pixels_ring2_1)
pixels_ring2_2 = jnp.where(is_south_pole, ((pixels_ring1_2 + 2) & 3) + npix_total - 4, pixels_ring2_2)
# Weight computation - optimized mathematical approach
# Base phi weights
w1_phi = 1.0 - w_phi1
w2_phi = w_phi1
w3_phi = 1.0 - w_phi2
w4_phi = w_phi2
# Apply theta interpolation
w1_base = w1_phi * (1.0 - w_theta_base)
w2_base = w2_phi * (1.0 - w_theta_base)
w3_base = w3_phi * w_theta_base
w4_base = w4_phi * w_theta_base
# North pole weight adjustments
north_factor = (1.0 - w_theta_north) * 0.25
w1_north = jnp.where(is_north_pole, north_factor, w1_base)
w2_north = jnp.where(is_north_pole, north_factor, w2_base)
w3_north = jnp.where(is_north_pole, w3_phi * w_theta_north + north_factor, w3_base)
w4_north = jnp.where(is_north_pole, w4_phi * w_theta_north + north_factor, w4_base)
# South pole weight adjustments
south_factor = w_theta_south * 0.25
w1_final = jnp.where(is_south_pole, w1_north * (1.0 - w_theta_south) + south_factor, w1_north)
w2_final = jnp.where(is_south_pole, w2_north * (1.0 - w_theta_south) + south_factor, w2_north)
w3_final = jnp.where(is_south_pole, south_factor, w3_north)
w4_final = jnp.where(is_south_pole, south_factor, w4_north)
# Final assembly - single stack operation
# Stop gradient: Pixel indices are discrete array selectors, not part of interpolation math
pixels = lax.stop_gradient(jnp.stack([pixels_ring1_1, pixels_ring1_2, pixels_ring2_1, pixels_ring2_2]))
weights = jnp.stack([w1_final, w2_final, w3_final, w4_final])
# Clamp weights to ensure non-negativity (handles floating point precision issues)
weights = jnp.maximum(weights, 0.0)
# Ensure weights sum to exactly 1.0 for gradient consistency
# This enforces the mathematical constraint sum(weights) = 1.0, making gradients
# of the sum exactly zero while preserving gradients of individual weights
weight_sum = jnp.sum(weights, axis=0, keepdims=True)
weights = weights / weight_sum
return pixels, weights
[docs]
@partial(jit, static_argnames=['nest', 'lonlat'])
def get_interp_val(
m: ArrayLike, theta: ArrayLike, phi: ArrayLike | None = None, nest: bool = False, lonlat: bool = False
) -> Array:
"""Return interpolated map values at given coordinates.
This function performs bilinear interpolation of map values using the four
nearest pixel neighbors, providing machine precision matching healpy.
Parameters
----------
m : ArrayLike
HEALPix map(s) to interpolate. Can be 1D (single map) or 2D (multiple maps).
Shape: (npix,) or (nmaps, npix)
theta : ArrayLike
Colatitude in radians (or pixel indices if phi is None)
phi : ArrayLike, optional
Longitude in radians (or degrees if lonlat=True)
nest : bool, optional
If True, use NESTED pixel ordering (raises error - not supported)
lonlat : bool, optional
If True, interpret theta, phi as longitude, latitude in degrees
Returns
-------
values : Array
Interpolated map values at the given coordinates.
Shape matches broadcast of theta, phi for single map.
For multiple maps, shape is (nmaps, ...) where ... is broadcast shape.
Notes
-----
Uses bilinear interpolation with the four nearest pixel neighbors.
For exact healpy compatibility, this function uses get_interp_weights
internally and computes: result = sum(weights * map_values[pixels])
Results won't match healpy if theta and phi are not valid angles.
Examples
--------
>>> import jax_healpy as hp
>>> import jax.numpy as jnp
>>> m = jnp.arange(12.)
>>> hp.get_interp_val(m, jnp.pi/2, 0.0)
Array(4.5, dtype=float64)
>>> # Multiple coordinates
>>> theta = jnp.array([jnp.pi/4, jnp.pi/2])
>>> phi = jnp.array([0.0, jnp.pi/2])
>>> hp.get_interp_val(m, theta, phi)
Array([2.25, 6. ], dtype=float64)
>>> # Multiple maps
>>> maps = jnp.array([jnp.arange(12.), 2*jnp.arange(12.)])
>>> hp.get_interp_val(maps, jnp.pi/2, 0.0)
Array([4.5, 9. ], dtype=float64)
"""
if nest:
raise ValueError('NEST pixel ordering is not supported. Only RING ordering is supported.')
# Convert inputs to JAX arrays
m = jnp.asarray(m)
theta = jnp.asarray(theta)
if phi is not None:
phi = jnp.asarray(phi)
# Determine nside from map size
npix = m.shape[-1]
nside = int(np.sqrt(npix / 12)) # Use numpy sqrt to avoid tracer issues
check_nside(nside, nest=nest)
# Handle multiple maps vs single map
single_map = m.ndim == 1
if single_map:
map_data = m[jnp.newaxis, :] # Add map dimension
else:
map_data = m
# Get interpolation weights and pixels
pixels, weights = get_interp_weights(nside, theta, phi, nest=nest, lonlat=lonlat)
# Perform interpolation: sum(weights * map_values[pixels])
# pixels shape: (4, ...) where ... is broadcast shape of theta, phi
# weights shape: (4, ...)
# map_data shape: (nmaps, npix)
# Extract map values at interpolation pixels
# map_values shape: (nmaps, 4, ...)
map_values = map_data[..., pixels] # Broadcasting: (nmaps, npix)[..., (4, ...)] -> (nmaps, 4, ...)
# Compute weighted sum along the pixel dimension (axis=1)
# result shape: (nmaps, ...)
result = jnp.sum(weights[jnp.newaxis, ...] * map_values, axis=1)
# If input was single map, remove the map dimension
if single_map:
result = result[0]
return result
[docs]
@partial(jit, static_argnames=['nside', 'nest', 'lonlat', 'get_center'])
def get_all_neighbours(
nside: int,
theta: ArrayLike,
phi: ArrayLike | None = None,
nest: bool = False,
lonlat: bool = False,
get_center: bool = False,
) -> Array:
"""Get the 8 nearest neighbors of given pixels, optionally including center pixel.
Parameters
----------
nside : int
HEALPix resolution parameter, must be a power of 2
theta : ArrayLike
Either colatitude in radians (if phi is provided) or pixel indices (if phi is None)
phi : ArrayLike, optional
Longitude in radians (or degrees if lonlat=True). If None, theta is treated as pixel indices.
nest : bool, optional
If True, use NESTED pixel ordering scheme. Default is False (RING ordering).
lonlat : bool, optional
If True and phi is provided, interpret (theta, phi) as (longitude, latitude) in degrees.
get_center : bool, optional
If True, return center pixel + 8 neighbors (9 total). If False, return only 8 neighbors.
Default is False for backward compatibility with healpy.
Returns
-------
neighbors : Array
Array of pixel indices. When get_center=False: shape is (8,) for scalar input or
(8, N) for array input, with neighbors in directions [SW, W, NW, N, NE, E, SE, S].
When get_center=True: shape is (9,) for scalar input or (9, N) for array input,
with pixels in order [CENTER, SW, W, NW, N, NE, E, SE, S].
Non-existent neighbors (at map boundaries) are marked with -1.
Examples
--------
>>> import jax_healpy as hp
>>> # Get 8 neighbors of pixel 4 at nside=1 (default behavior, matches healpy)
>>> neighbors = hp.get_all_neighbours(1, 4)
>>> print(neighbors)
[11 7 3 -1 0 5 8 -1]
>>> # Get center + 8 neighbors (9 total)
>>> neighbors_with_center = hp.get_all_neighbours(1, 4, get_center=True)
>>> print(neighbors_with_center)
[ 4 11 7 3 -1 0 5 8 -1]
>>> # Works with angular coordinates too
>>> import jax.numpy as jnp
>>> neighbors = hp.get_all_neighbours(1, jnp.pi/2, jnp.pi/2, get_center=True)
>>> print(neighbors)
[ 6 8 4 0 -1 1 6 9 -1]
Notes
-----
**healpy Compatibility**: The `get_center=False` (default) behavior maintains perfect
compatibility with healpy.get_all_neighbours(). The `get_center=True` parameter is a
jax-healpy-specific extension that does not exist in healpy.
When `get_center=False` (default):
- Returns 8 neighbors in identical order to healpy: [SW, W, NW, N, NE, E, SE, S]
- Produces bit-for-bit identical results to healpy for all input modes
- Maintains backward compatibility with existing healpy-based code
When `get_center=True` (jax-healpy extension):
- Returns 9 pixels: center pixel + 8 neighbors in order [CENTER, SW, W, NW, N, NE, E, SE, S]
- Center pixel is always the first element (index 0)
- Neighbor ordering matches healpy convention starting from index 1
- This functionality does not exist in healpy and is unique to jax-healpy
**Performance**: The default `get_center=False` case has no performance overhead compared
to the original implementation. The `get_center=True` case adds minimal computational cost.
**JAX Features**: This function is fully compatible with JAX transformations including
jit compilation, vmap, grad, and automatic differentiation. The `get_center` parameter
is a static argument that allows different compilations for each mode.
"""
# Validate inputs
check_nside(nside, nest=nest)
theta = jnp.asarray(theta)
# Handle the two API modes: pixel indices vs angular coordinates
if phi is None:
# theta contains pixel indices
ipix = theta.astype(_pixel_dtype_for(nside))
input_shape = ipix.shape
ipix_flat = ipix.flatten()
else:
# theta, phi contain angular coordinates - convert to pixels
phi = jnp.asarray(phi)
if lonlat:
# Convert longitude, latitude in degrees to colatitude, longitude in radians
lon, lat = theta, phi
theta = jnp.deg2rad(90.0 - lat)
phi = jnp.deg2rad(lon)
# Ensure theta and phi can be broadcast together
theta_bc, phi_bc = jnp.broadcast_arrays(theta, phi)
input_shape = theta_bc.shape
# Convert angular coordinates to pixel indices
ipix_flat = ang2pix(nside, theta_bc.flatten(), phi_bc.flatten(), nest=nest)
# Convert pixels to (x, y, face) coordinates
ix, iy, face_num = pix2xyf(nside, ipix_flat, nest=nest)
# Vectorized neighbor finding for all pixels
neighbors_flat = _get_all_neighbors_xyf(nside, ix, iy, face_num, nest=nest)
# Conditionally include center pixel based on get_center parameter
if get_center:
# Add center pixels as first element: [CENTER, SW, W, NW, N, NE, E, SE, S]
if phi is None:
# Pixel mode: center pixels are the input pixels themselves
center_pixels_flat = ipix_flat
else:
# Angular mode: center pixels are pixels at the given coordinates
# We already have ipix_flat from the coordinate conversion above
center_pixels_flat = ipix_flat
# Combine center + neighbors: shape (9, N)
result_flat = jnp.concatenate([center_pixels_flat[None, :], neighbors_flat], axis=0)
# Reshape result to (9, *input_shape)
if input_shape == ():
# Scalar input - should return shape (9,), not (9, 1)
return result_flat.squeeze() # Remove the extra dimension
else:
# Array input - reshape from (9, N) to (9, *input_shape)
return result_flat.reshape((9,) + input_shape)
else:
# Original behavior: return only 8 neighbors for backward compatibility
# Reshape result to (8, *input_shape)
if input_shape == ():
# Scalar input - should return shape (8,), not (8, 1)
return neighbors_flat.squeeze() # Remove the extra dimension
else:
# Array input - reshape from (8, N) to (8, *input_shape)
return neighbors_flat.reshape((8,) + input_shape)
def _get_all_neighbors_xyf(nside: int, ix: Array, iy: Array, face_num: Array, nest: bool = False) -> Array:
"""Vectorized neighbor finding in (x, y, face) coordinates.
This is the core neighbor-finding algorithm that handles face boundary crossings
using the original HEALPix neighbor-finding methodology. It applies the 8-directional
offsets to find potential neighbors, then handles cases where neighbors cross face
boundaries using lookup tables and coordinate transformations.
The algorithm follows these steps:
1. Apply 8-directional offsets (_NB_XOFFSET, _NB_YOFFSET) to get neighbor coordinates
2. Check which neighbors remain within the current face (valid range [0, nside-1])
3. For neighbors that cross face boundaries, apply face transition logic using
lookup tables (_NB_FACEARRAY, _NB_SWAPARRAY) based on original C++ implementation
Parameters
----------
nside : int
HEALPix resolution parameter (must be power of 2)
ix, iy : Array
Face-local x, y coordinates of pixels (shape: (N,))
Valid range: [0, nside-1] for pixels within face
face_num : Array
Face numbers of pixels (shape: (N,))
Valid range: [0, 11] for HEALPix faces
nest : bool, optional
Whether to use NESTED ordering scheme. Default is False (RING ordering).
Returns
-------
neighbors : Array
Neighbor pixel indices for each input pixel. Shape: (8, N)
Neighbors in order: [SW, W, NW, N, NE, E, SE, S]
Non-existent neighbors (at map boundaries) are marked with -1.
Notes
-----
This function implements the exact neighbor-finding logic from the original
HEALPix C++ library, ensuring bit-for-bit compatibility with healpy results.
"""
n_pixels = ix.shape[0]
# Initialize output array for neighbors
neighbors = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside))
# Apply 8-direction offsets to get neighbor coordinates
# Use broadcasting: ix[None, :] + _NB_XOFFSET[:, None] -> (8, N)
neighbor_ix = ix[None, :] + _NB_XOFFSET[:, None] # Shape: (8, N)
neighbor_iy = iy[None, :] + _NB_YOFFSET[:, None] # Shape: (8, N)
neighbor_face = jnp.broadcast_to(face_num[None, :], (8, n_pixels)) # Shape: (8, N)
# Check which neighbors are within the current face (no boundary crossing)
# Valid range is [0, nside-1] for both ix and iy
within_face = (
(neighbor_ix >= 0) & (neighbor_ix < nside) & (neighbor_iy >= 0) & (neighbor_iy < nside)
) # Shape: (8, N)
# For neighbors within face, convert directly to pixels
valid_mask = within_face
valid_neighbors = xyf2pix(nside, neighbor_ix, neighbor_iy, neighbor_face, nest=nest) # Shape: (8, N)
neighbors = jnp.where(valid_mask, valid_neighbors, neighbors)
# Handle boundary crossings for neighbors outside current face
boundary_mask = ~within_face # Shape: (8, N)
# For boundary pixels, we need to use the lookup tables
# This is complex due to the face transition logic - we'll implement a simplified version
# that handles the most common boundary cases
# Apply face boundary corrections using lookup tables
corrected_neighbors = _handle_face_boundaries(nside, neighbor_ix, neighbor_iy, neighbor_face, face_num, nest)
# Use corrected neighbors where we have boundary crossings
neighbors = jnp.where(boundary_mask, corrected_neighbors, neighbors)
return neighbors
def _handle_face_boundaries(
nside: int, neighbor_ix: Array, neighbor_iy: Array, neighbor_face: Array, original_face: Array, nest: bool
) -> Array:
"""Handle neighbor pixels that cross face boundaries.
This implements the exact HEALPix face transition logic using lookup tables,
based on the original C++ implementation in healpix_base.cc. When a neighbor
coordinate falls outside the current face boundaries, this function determines
the correct face and applies coordinate transformations.
The algorithm follows these steps for each boundary crossing:
1. Detect boundary crossing condition (x < 0, x >= nside, y < 0, y >= nside)
2. Calculate nbnum index encoding the crossing direction
3. Look up new face using _NB_FACEARRAY[nbnum, original_face]
4. Apply coordinate corrections (wrap coordinates to valid range)
5. Apply bit-based transformations using _NB_SWAPARRAY (flip x, flip y, swap x/y)
6. Convert corrected (x, y, face) back to pixel indices
Parameters
----------
nside : int
HEALPix resolution parameter
neighbor_ix, neighbor_iy : Array
Neighbor coordinates that may be outside face boundaries. Shape: (8, N)
neighbor_face : Array
Face numbers for neighbors (initially same as original). Shape: (8, N)
original_face : Array
Original face numbers of input pixels. Shape: (N,)
nest : bool
Whether to use NESTED ordering
Returns
-------
corrected_neighbors : Array
Corrected neighbor pixel indices. Shape: (8, N)
Returns -1 for invalid neighbors (outside map boundaries)
Notes
-----
This function is a direct translation of the original HEALPix C++ neighbor
finding algorithm, ensuring exact compatibility with healpy. The lookup tables
(_NB_FACEARRAY, _NB_SWAPARRAY) encode the complex geometric relationships
between HEALPix faces and handle all 12 face transitions correctly.
"""
n_pixels = original_face.shape[0]
# Initialize result with invalid neighbors
result = jnp.full((8, n_pixels), -1, dtype=_pixel_dtype_for(nside))
# Process each neighbor direction individually
for direction_idx in range(8):
# Get coordinates for this direction across all pixels
ix = neighbor_ix[direction_idx, :] # Shape: (n_pixels,)
iy = neighbor_iy[direction_idx, :] # Shape: (n_pixels,)
orig_face = original_face # Shape: (n_pixels,)
# Check boundary conditions - exact replication of original algorithm
x_low = ix < 0
x_high = ix >= nside
y_low = iy < 0
y_high = iy >= nside
# Any pixel crossing face boundary
boundary_crossing = x_low | x_high | y_low | y_high
# Initialize corrected coordinates with original values
corrected_ix = ix
corrected_iy = iy
# Apply boundary corrections exactly as in original C++ code
# First handle x boundary crossings
corrected_ix = jnp.where(x_low, corrected_ix + nside, corrected_ix)
corrected_ix = jnp.where(x_high, corrected_ix - nside, corrected_ix)
# Then handle y boundary crossings
corrected_iy = jnp.where(y_low, corrected_iy + nside, corrected_iy)
corrected_iy = jnp.where(y_high, corrected_iy - nside, corrected_iy)
# Calculate nbnum index for lookup tables (matches original C++ logic)
nbnum = 4 # Start with center case
nbnum = jnp.where(x_low, nbnum - 1, nbnum)
nbnum = jnp.where(x_high, nbnum + 1, nbnum)
nbnum = jnp.where(y_low, nbnum - 3, nbnum)
nbnum = jnp.where(y_high, nbnum + 3, nbnum)
# Look up new face using the face array (vectorized)
# Use advanced indexing to get new faces for each pixel
new_face = _NB_FACEARRAY[nbnum, orig_face]
# Only process pixels that actually cross boundaries and have valid new faces
valid_crossing = boundary_crossing & (new_face >= 0) & (new_face < 12)
# Apply coordinate transformations using swap array bits
# Get swap bits for face transitions (vectorized)
swap_bits = _NB_SWAPARRAY[nbnum, orig_face >> 2]
# Apply bit transformations exactly as in original C++
# Bit 1: Flip x coordinate
flip_x = (swap_bits & 1) != 0
corrected_ix = jnp.where(valid_crossing & flip_x, nside - corrected_ix - 1, corrected_ix)
# Bit 2: Flip y coordinate
flip_y = (swap_bits & 2) != 0
corrected_iy = jnp.where(valid_crossing & flip_y, nside - corrected_iy - 1, corrected_iy)
# Bit 4: Swap x and y coordinates
swap_xy = (swap_bits & 4) != 0
new_x = jnp.where(valid_crossing & swap_xy, corrected_iy, corrected_ix)
new_y = jnp.where(valid_crossing & swap_xy, corrected_ix, corrected_iy)
corrected_ix = new_x
corrected_iy = new_y
# Use new face for valid crossings, original face otherwise
corrected_face = jnp.where(valid_crossing, new_face, orig_face)
# Convert to pixel indices
neighbor_pixels = xyf2pix(nside, corrected_ix, corrected_iy, corrected_face, nest=nest)
# Update result for this direction - only valid crossings get neighbor pixels
result = result.at[direction_idx, :].set(jnp.where(valid_crossing, neighbor_pixels, -1))
return result
# Note: Removed _get_adjacent_face - using lookup table directly in _handle_face_boundaries
[docs]
def get_nside(m: ArrayLike) -> int:
"""Extract nside parameter from map length.
Parameters
----------
m : array-like
HEALPix map or sequence of maps
Returns
-------
nside : int
The nside parameter corresponding to the map size
Raises
------
ValueError
If the map size doesn't correspond to a valid HEALPix map
"""
m = jnp.asarray(m)
if m.ndim == 1:
npix = len(m)
elif m.ndim == 2:
npix = m.shape[-1] # Last dimension should be pixels
else:
raise ValueError(f'Map must be 1D or 2D, got shape {m.shape}')
return npix2nside(npix)
def mask_bad(m: ArrayLike) -> Array:
"""Create boolean mask for UNSEEN pixels.
Parameters
----------
m : array-like
HEALPix map
Returns
-------
mask : Array
Boolean array with True where pixels are UNSEEN
"""
m = jnp.asarray(m)
return m == UNSEEN
[docs]
@partial(jit, static_argnames=['nside_out', 'pess', 'order_in', 'order_out', 'power', 'dtype'])
def ud_grade(
map_in: ArrayLike,
nside_out: int,
pess: bool = False,
order_in: str = 'RING',
order_out: str = None,
power: float = None,
dtype: type = None,
) -> Array:
"""Upgrade or degrade the resolution (nside) of a map.
This function changes the resolution of a HEALPix map by either upgrading
it to a higher resolution (more pixels) or degrading it to a lower resolution
(fewer pixels). The algorithm follows the HEALPix specification:
- For upgrading: each parent pixel value is replicated to all its children
- For degrading: each parent pixel value is the average of its children
Parameters
----------
map_in : array-like
Input map(s) to be upgraded or degraded. Can be a single map or a sequence of maps.
nside_out : int
Desired output resolution parameter. Must be a power of 2.
pess : bool, optional
Pessimistic mask handling during degradation. If True, a parent pixel is
marked as UNSEEN if any of its children are UNSEEN. If False (default),
a parent pixel is UNSEEN only if all children are UNSEEN.
order_in : {'RING', 'NESTED'}, optional
Pixel ordering of input map. Default is 'RING'.
order_out : {'RING', 'NESTED'}, optional
Pixel ordering of output map. If None, same as order_in.
power : float, optional
Scaling factor for resolution change. If provided, the output values
are multiplied by (nside_out/nside_in)^power.
dtype : data type, optional
Data type of output map. If None, same as input map dtype.
Returns
-------
map_out : Array
Upgraded or degraded map(s) with the same shape as input but different
number of pixels corresponding to nside_out.
Raises
------
NotImplementedError
This function is not yet implemented.
ValueError
If nside_out is not a valid HEALPix nside parameter.
Examples
--------
>>> import jax_healpy as jhp
>>> import numpy as np
>>> # Create a simple map at nside=4
>>> nside_in = 4
>>> map_in = np.arange(jhp.nside2npix(nside_in), dtype=float)
>>> # Degrade to nside=2
>>> map_out = jhp.ud_grade(map_in, 2)
>>> # Upgrade to nside=8
>>> map_out = jhp.ud_grade(map_in, 8)
Notes
-----
This function can create artifacts in power spectra and should be used with
caution for scientific applications. The HEALPix documentation recommends
using spherical harmonic transforms for resolution changes when possible.
The algorithm implements the exact same logic as healpy.ud_grade for
compatibility, including proper handling of UNSEEN pixels and coordinate
system conversions between RING and NESTED schemes.
"""
# Early validation to provide clear error messages
# udgrade requires power-of-2 nside values regardless of ordering scheme
if not isnsideok(nside_out, nest=True):
raise ValueError(
f'{nside_out} is not a valid nside parameter for udgrade (must be a power of 2, less than 2**30)'
)
# Convert input to JAX array and handle map format
map_in = jnp.asarray(map_in)
is_single_map = map_in.ndim == 1
# Ensure we work with 2D array (n_maps, npix)
if is_single_map:
maps = map_in[None, :] # Add map dimension
else:
maps = map_in
# Get input nside and validate
nside_in = get_nside(maps[0]) # Use first map to get nside
if not isnsideok(nside_in, nest=True):
raise ValueError(
f'{nside_in} is not a valid nside parameter for udgrade (must be a power of 2, less than 2**30)'
)
# Determine output ordering
if order_out is None:
order_out = order_in
# Call the core implementation
return _ud_grade_core(maps, nside_in, nside_out, pess, order_in, order_out, power, dtype, is_single_map)
def _ud_grade_core(
maps: Array,
nside_in: int,
nside_out: int,
pess: bool,
order_in: str,
order_out: str,
power: float,
dtype: type,
is_single_map: bool,
) -> Array:
"""Core udgrade implementation for processing multiple maps."""
npix_in = nside2npix(nside_in)
npix_out = nside2npix(nside_out)
# Determine output dtype
if dtype is not None:
output_dtype = dtype
else:
output_dtype = maps.dtype
# Step 1: Convert to NESTED if needed (reorder handles batch dimension)
if order_in == 'RING':
maps = reorder(maps, r2n=True)
# Step 2: Core resolution change in NESTED scheme
if nside_out == nside_in:
# No change needed
result = maps
elif nside_out > nside_in:
# UPGRADE: replicate parent pixels to children
rat2 = npix_out // npix_in
# Apply power scaling if specified
if power is not None:
ratio = (jnp.float32(nside_out) / jnp.float32(nside_in)) ** jnp.float32(power)
else:
ratio = 1.0
# Replicate each pixel value to its children using broadcasting
# maps shape: (n_maps, npix_in)
# fact shape: (rat2,)
# outer product: (n_maps, npix_in, rat2) -> reshape to (n_maps, npix_out)
fact = jnp.ones(rat2, dtype=output_dtype) * ratio
# Use broadcasting: maps[..., :, None] * fact[None, None, :]
expanded_maps = maps[..., :, None] * fact # (n_maps, npix_in, rat2)
result = expanded_maps.reshape(maps.shape[0], npix_out)
else:
# DEGRADE: average children pixels to parent
rat2 = npix_in // npix_out
# Reshape to group children pixels: (n_maps, npix_out, rat2)
reshaped_maps = maps.reshape(maps.shape[0], npix_out, rat2)
# Create mask for valid pixels (not UNSEEN and finite)
goods = ~(mask_bad(reshaped_maps) | (~jnp.isfinite(reshaped_maps)))
# Sum valid pixels along children axis
map_sum = jnp.sum(reshaped_maps * goods, axis=-1) # (n_maps, npix_out)
n_good = jnp.sum(goods, axis=-1) # (n_maps, npix_out)
# Determine which output pixels should be UNSEEN
if pess:
# Pessimistic: mark UNSEEN if ANY child is bad
badout = n_good != rat2
else:
# Optimistic: mark UNSEEN only if ALL children are bad
badout = n_good == 0
# Apply power scaling if specified
if power is not None:
ratio = (jnp.float32(nside_out) / jnp.float32(nside_in)) ** jnp.float32(power)
n_good = n_good / ratio
# Calculate averages for pixels with valid children
result = jnp.where(
n_good > 0,
map_sum / n_good,
0.0, # Temporary value, will be set to UNSEEN below
)
# Set UNSEEN pixels
result = jnp.where(badout, UNSEEN, result)
# Step 3: Convert back to desired output ordering (reorder handles batch dimension)
if order_out == 'RING' and order_in == 'NESTED':
result = reorder(result, n2r=True)
elif order_out == 'NESTED' and order_in == 'RING':
# Map was converted to NESTED in step 1, keep it
pass
elif order_out == order_in:
# Convert back if we changed it
if order_in == 'RING':
result = reorder(result, n2r=True)
# Apply output dtype
result = result.astype(output_dtype)
if is_single_map:
return result[0] # Remove the added map dimension
else:
return result