Pixel Functions (pixelfunc)#
This module provides functions related to HEALPix pixelization scheme, including coordinate conversions, interpolation, and map manipulation functions.
- jax_healpy.pixelfunc.pix2ang(nside, ipix, nest=False, lonlat=False)[source]#
pix2ang : nside,ipix,nest=False,lonlat=False -> theta[rad],phi[rad] (default RING)
- Parameters:
nside (int or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
ipix (int or array-like) – Pixel indices
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat (bool, optional) – If True, return angles will be longitude and latitude in degree, otherwise, angles will be co-latitude and longitude in radians (default)
- Returns:
theta, phi – The angular coordinates corresponding to ipix. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
float, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.pix2ang(16, 1440) (1.5291175943723188, 0.0)
>>> hp.pix2ang(16, [1440, 427, 1520, 0, 3068]) (array([ 1.52911759, 0.78550497, 1.57079633, 0.05103658, 3.09055608]), array([ 0. , 0.78539816, 1.61988371, 0.78539816, 0.78539816]))
>>> hp.pix2ang([1, 2, 4, 8], 11) (array([ 2.30052398, 0.84106867, 0.41113786, 0.2044802 ]), array([ 5.49778714, 5.89048623, 5.89048623, 5.89048623]))
>>> hp.pix2ang([1, 2, 4, 8], 11, lonlat=True) (array([ 315. , 337.5, 337.5, 337.5]), array([-41.8103149 , 41.8103149 , 66.44353569, 78.28414761]))
- jax_healpy.pixelfunc.ang2pix(nside, theta, phi, nest=False, lonlat=False)[source]#
ang2pix: nside,theta[rad],phi[rad],nest=False,lonlat=False -> ipix (default:RING)
Unlike healpy.ang2pix, specifying a theta not in the range [0, π] does not raise an error, but returns -1.
- Parameters:
nside (int, scalar or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
theta (float, scalars or array-like) – Angular coordinates of a point on the sphere
phi (float, scalars or array-like) – Angular coordinates of a point on the sphere
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat (bool) – If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians.
- Returns:
pix – The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
Examples
Note that some of the test inputs below that are on pixel boundaries such as theta=π/2, phi=π/2, have a tiny value of 1e-15 added to them to make them reproducible on i386 machines using x87 floating point instruction set (see healpy/healpy#528).
>>> import jax_healpy as hp >>> from jax.numpy import pi as π >>> hp.ang2pix(16, π/2, 0) Array(1440, dtype=int64)
>>> print(hp.ang2pix(16, np.array([π/2, π/4, π/2, 0, π]), np.array([0., π/4, π/2 + 1e-15, 0, 0]))) [1440 427 1520 0 3068]
>>> print(hp.ang2pix(16, π/2, np.array([0, π/2 + 1e-15]))) [1440 1520]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), π/2, 0)) [ 4 12 72 336 1440]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), 0, 0, lonlat=True)) [ 4 12 72 336 1440]
- jax_healpy.pixelfunc.pix2xyf(nside, ipix, nest=False)[source]#
pix2xyf : nside,ipix,nest=False -> x,y,face (default RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
x, y (int, scalars or array-like) – Pixel indices within face
face (int, scalars or array-like) – Face number
- Return type:
See also
Examples
>>> import healpy as hp >>> hp.pix2xyf(16, 1440) (8, 8, 4)
>>> hp.pix2xyf(16, [1440, 427, 1520, 0, 3068]) (array([ 8, 8, 8, 15, 0]), array([ 8, 8, 7, 15, 0]), array([4, 0, 5, 0, 8]))
>>> hp.pix2xyf([1, 2, 4, 8], 11) (array([0, 1, 3, 7]), array([0, 0, 2, 6]), array([11, 3, 3, 3]))
- jax_healpy.pixelfunc.xyf2pix(nside, x, y, face, nest=False)[source]#
xyf2pix : nside,x,y,face,nest=False -> ipix (default:RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
nside (int) – The healpix nside parameter, must be a power of 2
x (int, scalars or array-like) – Pixel indices within face
y (int, scalars or array-like) – Pixel indices within face
face (int, scalars or array-like) – Face number
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
- Returns:
pix – The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
See also
Examples
>>> import healpy as hp >>> hp.xyf2pix(16, 8, 8, 4) 1440
>>> print(hp.xyf2pix(16, [8, 8, 8, 15, 0], [8, 8, 7, 15, 0], [4, 0, 5, 0, 8])) [1440 427 1520 0 3068]
- jax_healpy.pixelfunc.pix2vec(nside, ipix, nest=False)[source]#
pix2vec : nside,ipix,nest=False -> x,y,z (default RING)
- Parameters:
- Returns:
x, y, z – The coordinates of vector corresponding to input pixels. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
floats, scalar or array-like
Examples
>>> import healpy as hp >>> hp.pix2vec(16, 1504) (0.99879545620517241, 0.049067674327418015, 0.0)
>>> hp.pix2vec(16, [1440, 427]) (array([ 0.99913157, 0.5000534 ]), array([ 0. , 0.5000534]), array([ 0.04166667, 0.70703125]))
>>> hp.pix2vec([1, 2], 11) (array([ 0.52704628, 0.68861915]), array([-0.52704628, -0.28523539]), array([-0.66666667, 0.66666667]))
- jax_healpy.pixelfunc.vec2pix(nside, x, y, z, nest=False)[source]#
vec2pix : nside,x,y,z,nest=False -> ipix (default:RING)
- Parameters:
nside (int or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
x (floats or array-like) – vector coordinates defining point on the sphere
y (floats or array-like) – vector coordinates defining point on the sphere
z (floats or array-like) – vector coordinates defining point on the sphere
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
- Returns:
ipix – The healpix pixel number corresponding to input vector. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.vec2pix(16, 1, 0, 0) 1504
>>> print(hp.vec2pix(16, [1, 0], [0, 1], [0, 0])) [1504 1520]
>>> print(hp.vec2pix([1, 2, 4, 8], 1, 0, 0)) [ 4 20 88 368]
- jax_healpy.pixelfunc.ang2vec(theta, phi, lonlat=False)[source]#
ang2vec : convert angles to 3D position vector
- Parameters:
theta (float, scalar or array-like) – co-latitude in radians measured southward from the North pole (in [0,pi]).
phi (float, scalar or array-like) – longitude in radians measured eastward (in [0, 2*pi]).
lonlat (bool) – If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians.
- Returns:
vec – if theta and phi are vectors, the result is a 2D array with a vector per row otherwise, it is a 1D array of shape (3,)
- Return type:
float, array
See also
vec2ang,rotator.dir2vec,rotator.vec2dir
- jax_healpy.pixelfunc.vec2ang(vectors, lonlat=False)[source]#
vec2ang: vectors [x, y, z] -> theta[rad], phi[rad]
- Parameters:
- Returns:
theta, phi – the colatitude and longitude in radians
- Return type:
See also
ang2vec,rotator.vec2dir,rotator.dir2vec
- jax_healpy.pixelfunc.get_interp_weights(nside, theta, phi=None, nest=False, lonlat=False)[source]#
Return interpolation weights for given coordinates.
This function performs bilinear interpolation by finding the four nearest pixel centers and computing their interpolation weights. Provides machine precision matching healpy when pixels are sorted.
- Parameters:
nside (int) – HEALPix nside parameter
theta (ArrayLike) – Colatitude in radians (or pixel indices if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True)
nest (bool, optional) – If True, use NESTED pixel ordering (raises error)
lonlat (bool, optional) – If True, interpret theta, phi as longitude, latitude in degrees
- Returns:
pixels (Array) – Array of shape (4, N) containing the four nearest pixel indices. Pixel order is not guaranteed - sort both pixels and weights by pixel values for exact healpy precision matching.
weights (Array) – Array of shape (4, N) containing the interpolation weights. Weights sum to 1.0 for each point to machine precision.
- Return type:
Notes
For exact healpy compatibility, sort the output by pixel values:
>>> sorted_indices = jnp.argsort(pixels, axis=0) >>> sorted_pixels = jnp.take_along_axis(pixels, sorted_indices, axis=0) >>> sorted_weights = jnp.take_along_axis(weights, sorted_indices, axis=0)
Precision and Algorithmic Considerations:#
The phi interpolation calculation can exhibit precision differences compared to healpy, particularly for coordinates near the poles and for high nside values (256+). This is due to:
Phi Interpolation Near Poles: The formula phi_norm = (phi / dphi - shift) % nr can produce different results than healpy’s algorithm when transitioning between rings with different pixel counts (e.g., ring transitions near poles).
High Nside Precision Limits: For nside ≥ 256, floating-point precision limits can cause the JAX implementation to select different (but mathematically valid) interpolation neighbors compared to healpy, especially in challenging geometric regions.
Ring Transition Edge Cases: Coordinates exactly at boundaries between polar caps and equatorial regions may show pixel differences due to algorithmic implementation variations, though weights still sum to 1.0 and maintain interpolation accuracy.
Expected precision levels: - nside ≤ 64: Machine precision matching (< 1e-25 weight error) - nside 128-256: High precision (< 1e-15 weight error) - nside ≥ 512: Good precision (< 1e-12 to 1e-5 weight error depending on nside)
Automatic Differentiation:#
This function is fully compatible with JAX’s automatic differentiation. The gradients behave as follows:
Gradients of individual weights reflect the continuous dependence on coordinates
Gradients of sum(weights) are always zero since the sum is identically 1.0
Pixel indices have zero gradients since they are discrete selectors
Example gradient usage:
>>> def interpolate_map(m, theta, phi): ... pixels, weights = get_interp_weights(nside, theta, phi) ... return jnp.sum(weights * m[pixels], axis=0) >>> grad_func = jax.grad(interpolate_map, argnums=(1, 2)) >>> grad_theta, grad_phi = grad_func(map_data, theta, phi)
- jax_healpy.pixelfunc.get_interp_val(m, theta, phi=None, nest=False, lonlat=False)[source]#
Return interpolated map values at given coordinates.
This function performs bilinear interpolation of map values using the four nearest pixel neighbors, providing machine precision matching healpy.
- Parameters:
m (ArrayLike) – HEALPix map(s) to interpolate. Can be 1D (single map) or 2D (multiple maps). Shape: (npix,) or (nmaps, npix)
theta (ArrayLike) – Colatitude in radians (or pixel indices if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True)
nest (bool, optional) – If True, use NESTED pixel ordering (raises error - not supported)
lonlat (bool, optional) – If True, interpret theta, phi as longitude, latitude in degrees
- Returns:
values – Interpolated map values at the given coordinates. Shape matches broadcast of theta, phi for single map. For multiple maps, shape is (nmaps, …) where … is broadcast shape.
- Return type:
Array
Notes
Uses bilinear interpolation with the four nearest pixel neighbors. For exact healpy compatibility, this function uses get_interp_weights internally and computes: result = sum(weights * map_values[pixels]) Results won’t match healpy if theta and phi are not valid angles.
Examples
>>> import jax_healpy as hp >>> import jax.numpy as jnp >>> m = jnp.arange(12.) >>> hp.get_interp_val(m, jnp.pi/2, 0.0) Array(4.5, dtype=float64)
>>> # Multiple coordinates >>> theta = jnp.array([jnp.pi/4, jnp.pi/2]) >>> phi = jnp.array([0.0, jnp.pi/2]) >>> hp.get_interp_val(m, theta, phi) Array([2.25, 6. ], dtype=float64)
>>> # Multiple maps >>> maps = jnp.array([jnp.arange(12.), 2*jnp.arange(12.)]) >>> hp.get_interp_val(maps, jnp.pi/2, 0.0) Array([4.5, 9. ], dtype=float64)
- jax_healpy.pixelfunc.get_all_neighbours(nside, theta, phi=None, nest=False, lonlat=False, get_center=False)[source]#
Get the 8 nearest neighbors of given pixels, optionally including center pixel.
- Parameters:
nside (int) – HEALPix resolution parameter, must be a power of 2
theta (ArrayLike) – Either colatitude in radians (if phi is provided) or pixel indices (if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True). If None, theta is treated as pixel indices.
nest (bool, optional) – If True, use NESTED pixel ordering scheme. Default is False (RING ordering).
lonlat (bool, optional) – If True and phi is provided, interpret (theta, phi) as (longitude, latitude) in degrees.
get_center (bool, optional) – If True, return center pixel + 8 neighbors (9 total). If False, return only 8 neighbors. Default is False for backward compatibility with healpy.
- Returns:
neighbors – Array of pixel indices. When get_center=False: shape is (8,) for scalar input or (8, N) for array input, with neighbors in directions [SW, W, NW, N, NE, E, SE, S]. When get_center=True: shape is (9,) for scalar input or (9, N) for array input, with pixels in order [CENTER, SW, W, NW, N, NE, E, SE, S]. Non-existent neighbors (at map boundaries) are marked with -1.
- Return type:
Array
Examples
>>> import jax_healpy as hp >>> # Get 8 neighbors of pixel 4 at nside=1 (default behavior, matches healpy) >>> neighbors = hp.get_all_neighbours(1, 4) >>> print(neighbors) [11 7 3 -1 0 5 8 -1]
>>> # Get center + 8 neighbors (9 total) >>> neighbors_with_center = hp.get_all_neighbours(1, 4, get_center=True) >>> print(neighbors_with_center) [ 4 11 7 3 -1 0 5 8 -1]
>>> # Works with angular coordinates too >>> import jax.numpy as jnp >>> neighbors = hp.get_all_neighbours(1, jnp.pi/2, jnp.pi/2, get_center=True) >>> print(neighbors) [ 6 8 4 0 -1 1 6 9 -1]
Notes
healpy Compatibility: The get_center=False (default) behavior maintains perfect compatibility with healpy.get_all_neighbours(). The get_center=True parameter is a jax-healpy-specific extension that does not exist in healpy.
When get_center=False (default): - Returns 8 neighbors in identical order to healpy: [SW, W, NW, N, NE, E, SE, S] - Produces bit-for-bit identical results to healpy for all input modes - Maintains backward compatibility with existing healpy-based code
When get_center=True (jax-healpy extension): - Returns 9 pixels: center pixel + 8 neighbors in order [CENTER, SW, W, NW, N, NE, E, SE, S] - Center pixel is always the first element (index 0) - Neighbor ordering matches healpy convention starting from index 1 - This functionality does not exist in healpy and is unique to jax-healpy
Performance: The default get_center=False case has no performance overhead compared to the original implementation. The get_center=True case adds minimal computational cost.
JAX Features: This function is fully compatible with JAX transformations including jit compilation, vmap, grad, and automatic differentiation. The get_center parameter is a static argument that allows different compilations for each mode.
- jax_healpy.pixelfunc.nest2ring(nside, ipix)[source]#
Convert pixel number from NESTED ordering to RING ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
ipix – the pixel number in RING scheme
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.nest2ring(16, 1130) 1504
>>> print(hp.nest2ring(2, np.arange(10))) [13 5 4 0 15 7 6 1 17 9]
>>> print(hp.nest2ring([1, 2, 4, 8], 11)) [ 11 2 12 211]
- jax_healpy.pixelfunc.ring2nest(nside, ipix)[source]#
Convert pixel number from RING ordering to NESTED ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
ipix – the pixel number in NESTED scheme
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.ring2nest(16, 1504) 1130
>>> print(hp.ring2nest(2, np.arange(10))) [ 3 7 11 15 2 1 6 5 10 9]
>>> print(hp.ring2nest([1, 2, 4, 8], 11)) [ 11 13 61 253]
- jax_healpy.pixelfunc.reorder(map_in, inp=None, out=None, r2n=False, n2r=False, process_by_chunks=False)[source]#
Reorder a healpix map from RING/NESTED ordering to NESTED/RING.
Masked arrays are not yet supported.
By default, the maps are processed in one go, but if memory is an issue, use the
process_by_chunksoption (which reproduces healpy behaviour).- Parameters:
map_in (array-like) – the input map to reorder, accepts masked arrays
inp (
'RING'or'NESTED') – define the input and output orderingout (
'RING'or'NESTED') – define the input and output orderingr2n (bool) – if True, reorder from RING to NESTED
n2r (bool) – if True, reorder from NESTED to RING
process_by_chunks (bool)
- Returns:
map_out – the reordered map, as masked array if the input was a masked array
- Return type:
array-like
Notes
if
r2norn2ris defined, overrideinpandout.Examples
>>> import healpy as hp >>> hp.reorder(np.arange(48), r2n = True) array([13, 5, 4, 0, 15, 7, 6, 1, 17, 9, 8, 2, 19, 11, 10, 3, 28, 20, 27, 12, 30, 22, 21, 14, 32, 24, 23, 16, 34, 26, 25, 18, 44, 37, 36, 29, 45, 39, 38, 31, 46, 41, 40, 33, 47, 43, 42, 35]) >>> hp.reorder(np.arange(12), n2r = True) array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) >>> hp.reorder(hp.ma(np.arange(12.)), n2r = True) masked_array(data = [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], mask = False, fill_value = -1.6375e+30) >>> m = [np.arange(12.), np.arange(12.), np.arange(12.)] >>> m[0][2] = hp.UNSEEN >>> m[1][2] = hp.UNSEEN >>> m[2][2] = hp.UNSEEN >>> m = hp.ma(m) >>> hp.reorder(m, n2r = True) masked_array(data = [[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]], mask = [[False False True False False False False False False False False False] [False False True False False False False False False False False False] [False False True False False False False False False False False False]], fill_value = -1.6375e+30)
- jax_healpy.pixelfunc.ud_grade(map_in, nside_out, pess=False, order_in='RING', order_out=None, power=None, dtype=None)[source]#
Upgrade or degrade the resolution (nside) of a map.
This function changes the resolution of a HEALPix map by either upgrading it to a higher resolution (more pixels) or degrading it to a lower resolution (fewer pixels). The algorithm follows the HEALPix specification:
For upgrading: each parent pixel value is replicated to all its children
For degrading: each parent pixel value is the average of its children
- Parameters:
map_in (array-like) – Input map(s) to be upgraded or degraded. Can be a single map or a sequence of maps.
nside_out (int) – Desired output resolution parameter. Must be a power of 2.
pess (bool, optional) – Pessimistic mask handling during degradation. If True, a parent pixel is marked as UNSEEN if any of its children are UNSEEN. If False (default), a parent pixel is UNSEEN only if all children are UNSEEN.
order_in ({'RING', 'NESTED'}, optional) – Pixel ordering of input map. Default is ‘RING’.
order_out ({'RING', 'NESTED'}, optional) – Pixel ordering of output map. If None, same as order_in.
power (float, optional) – Scaling factor for resolution change. If provided, the output values are multiplied by (nside_out/nside_in)^power.
dtype (data type, optional) – Data type of output map. If None, same as input map dtype.
- Returns:
map_out – Upgraded or degraded map(s) with the same shape as input but different number of pixels corresponding to nside_out.
- Return type:
Array
- Raises:
NotImplementedError – This function is not yet implemented.
ValueError – If nside_out is not a valid HEALPix nside parameter.
Examples
>>> import jax_healpy as jhp >>> import numpy as np >>> # Create a simple map at nside=4 >>> nside_in = 4 >>> map_in = np.arange(jhp.nside2npix(nside_in), dtype=float) >>> # Degrade to nside=2 >>> map_out = jhp.ud_grade(map_in, 2) >>> # Upgrade to nside=8 >>> map_out = jhp.ud_grade(map_in, 8)
Notes
This function can create artifacts in power spectra and should be used with caution for scientific applications. The HEALPix documentation recommends using spherical harmonic transforms for resolution changes when possible.
The algorithm implements the exact same logic as healpy.ud_grade for compatibility, including proper handling of UNSEEN pixels and coordinate system conversions between RING and NESTED schemes.
- jax_healpy.pixelfunc.nside2npix(nside)[source]#
Give the number of pixels for the given nside.
- Parameters:
nside (int) – healpix nside parameter
- Returns:
npix – corresponding number of pixels
- Return type:
Examples
>>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2npix(8) 768
>>> np.all([hp.nside2npix(nside) == 12 * nside**2 for nside in [2**n for n in range(12)]]) True
>>> hp.nside2npix(7) 588
- jax_healpy.pixelfunc.npix2nside(npix)[source]#
Give the nside parameter for the given number of pixels.
- Parameters:
npix (int) – the number of pixels
- Returns:
nside – the nside parameter corresponding to npix
- Return type:
Notes
Raise a ValueError exception if number of pixel does not correspond to the number of pixel of a healpix map.
Examples
>>> import jax_healpy as hp >>> hp.npix2nside(768) 8
>>> np.all([hp.npix2nside(12 * nside**2) == nside for nside in [2**n for n in range(12)]]) True
>>> hp.npix2nside(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2)
- jax_healpy.pixelfunc.nside2order(nside)[source]#
Give the resolution order for a given nside.
- Parameters:
nside (int) – healpix nside parameter; an exception is raised if nside is not valid (nside must be a power of 2, less than 2**30)
- Returns:
order – corresponding order where nside = 2**(order)
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2order(128) 7
>>> all(hp.nside2order(2**o) == o for o in range(30)) True
>>> hp.nside2order(7) Traceback (most recent call last): ... ValueError: 7 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.pixelfunc.order2nside(order)[source]#
Give the nside parameter for the given resolution order.
- Parameters:
order (int) – the resolution order
- Returns:
nside – the nside parameter corresponding to order
- Return type:
Notes
Raise a ValueError exception if order produces an nside out of range.
Examples
>>> import jax_healpy as hp >>> hp.order2nside(7) 128
>>> print(hp.order2nside(np.arange(8))) [ 1 2 4 8 16 32 64 128]
>>> hp.order2nside(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.pixelfunc.order2npix(order)[source]#
Give the number of pixels for the given resolution order.
- Parameters:
order (int) – the resolution order
- Returns:
npix – corresponding number of pixels
- Return type:
Notes
A convenience function that successively applies order2nside then nside2npix to order.
Examples
>>> import jax_healpy as hp >>> hp.order2npix(7) 196608
>>> print(hp.order2npix(np.arange(8))) [ 12 48 192 768 3072 12288 49152 196608]
>>> hp.order2npix(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.pixelfunc.npix2order(npix)[source]#
Give the resolution order for the given number of pixels.
- Parameters:
npix (int) – the number of pixels
- Returns:
order – corresponding resolution order
- Return type:
Notes
A convenience function that successively applies npix2nside then nside2order to npix.
Examples
>>> import jax_healpy as hp >>> hp.npix2order(768) 3
>>> np.all([hp.npix2order(12 * 4**order) == order for order in range(12)]) True
>>> hp.npix2order(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2)
- jax_healpy.pixelfunc.nside2resol(nside, arcmin=False)[source]#
Give approximate resolution (pixel size in radian or arcmin) for nside.
Resolution is just the square root of the pixel area, which is a gross approximation given the different pixel shapes
- Parameters:
- Returns:
resol – approximate pixel size in radians or arcmin
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> hp.nside2resol(128, arcmin = True) 27.483891294539248
>>> hp.nside2resol(256) 0.0039973699529159707
>>> hp.nside2resol(7) 0.1461895297066412
- jax_healpy.pixelfunc.nside2pixarea(nside, degrees=False)[source]#
Give pixel area given nside in square radians or square degrees.
- Parameters:
- Returns:
pixarea – pixel area in square radian or square degree
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> hp.nside2pixarea(128, degrees = True) 0.2098234113027917
>>> hp.nside2pixarea(256) 1.5978966540475428e-05
>>> hp.nside2pixarea(7) 0.021371378595848933
- jax_healpy.pixelfunc.isnsideok(nside, nest=False)[source]#
Returns
Trueif nside is a valid nside parameter,Falseotherwise.NSIDE needs to be a power of 2 only for nested ordering
- Parameters:
- Returns:
ok –
Trueif given value is a valid nside,Falseotherwise.- Return type:
bool, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.isnsideok(13, nest=True) False
>>> hp.isnsideok(13, nest=False) True
>>> hp.isnsideok(32) True
>>> hp.isnsideok([1, 2, 3, 4, 8, 16], nest=True) array([ True, True, False, True, True, True], dtype=bool)
- jax_healpy.pixelfunc.isnpixok(npix)[source]#
Return
Trueif npix is a valid value for healpix map size,Falseotherwise.- Parameters:
npix (int, scalar or array-like) – integer value to be tested
- Returns:
ok –
Trueif given value is a valid number of pixel,Falseotherwise- Return type:
bool, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.isnpixok(12) True
>>> hp.isnpixok(768) True
>>> hp.isnpixok([12, 768, 1002]) array([ True, True, False], dtype=bool)
- jax_healpy.pixelfunc.get_nside(m)[source]#
Extract nside parameter from map length.
- Parameters:
m (array-like) – HEALPix map or sequence of maps
- Returns:
nside – The nside parameter corresponding to the map size
- Return type:
- Raises:
ValueError – If the map size doesn’t correspond to a valid HEALPix map
- jax_healpy.pixelfunc.maptype(m)[source]#
Describe the type of the map (valid, single, sequence of maps). Checks : the number of maps, that all maps have same length and that this length is a valid map size (using
isnpixok()).- Parameters:
m (sequence) – the map to get info from
- Returns:
info – -1 if the given object is not a valid map, 0 if it is a single map, info > 0 if it is a sequence of maps (info is then the number of maps)
- Return type:
Examples
>>> import healpy as hp >>> hp.pixelfunc.maptype(np.arange(12)) 0 >>> hp.pixelfunc.maptype([np.arange(12), np.arange(12)]) 2
Coordinate Conversions#
Functions for converting between different coordinate representations:
- jax_healpy.pix2ang(nside, ipix, nest=False, lonlat=False)[source]#
pix2ang : nside,ipix,nest=False,lonlat=False -> theta[rad],phi[rad] (default RING)
- Parameters:
nside (int or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
ipix (int or array-like) – Pixel indices
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat (bool, optional) – If True, return angles will be longitude and latitude in degree, otherwise, angles will be co-latitude and longitude in radians (default)
- Returns:
theta, phi – The angular coordinates corresponding to ipix. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
float, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.pix2ang(16, 1440) (1.5291175943723188, 0.0)
>>> hp.pix2ang(16, [1440, 427, 1520, 0, 3068]) (array([ 1.52911759, 0.78550497, 1.57079633, 0.05103658, 3.09055608]), array([ 0. , 0.78539816, 1.61988371, 0.78539816, 0.78539816]))
>>> hp.pix2ang([1, 2, 4, 8], 11) (array([ 2.30052398, 0.84106867, 0.41113786, 0.2044802 ]), array([ 5.49778714, 5.89048623, 5.89048623, 5.89048623]))
>>> hp.pix2ang([1, 2, 4, 8], 11, lonlat=True) (array([ 315. , 337.5, 337.5, 337.5]), array([-41.8103149 , 41.8103149 , 66.44353569, 78.28414761]))
- jax_healpy.ang2pix(nside, theta, phi, nest=False, lonlat=False)[source]#
ang2pix: nside,theta[rad],phi[rad],nest=False,lonlat=False -> ipix (default:RING)
Unlike healpy.ang2pix, specifying a theta not in the range [0, π] does not raise an error, but returns -1.
- Parameters:
nside (int, scalar or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
theta (float, scalars or array-like) – Angular coordinates of a point on the sphere
phi (float, scalars or array-like) – Angular coordinates of a point on the sphere
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
lonlat (bool) – If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians.
- Returns:
pix – The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
Examples
Note that some of the test inputs below that are on pixel boundaries such as theta=π/2, phi=π/2, have a tiny value of 1e-15 added to them to make them reproducible on i386 machines using x87 floating point instruction set (see healpy/healpy#528).
>>> import jax_healpy as hp >>> from jax.numpy import pi as π >>> hp.ang2pix(16, π/2, 0) Array(1440, dtype=int64)
>>> print(hp.ang2pix(16, np.array([π/2, π/4, π/2, 0, π]), np.array([0., π/4, π/2 + 1e-15, 0, 0]))) [1440 427 1520 0 3068]
>>> print(hp.ang2pix(16, π/2, np.array([0, π/2 + 1e-15]))) [1440 1520]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), π/2, 0)) [ 4 12 72 336 1440]
>>> print(hp.ang2pix(np.array([1, 2, 4, 8, 16]), 0, 0, lonlat=True)) [ 4 12 72 336 1440]
- jax_healpy.pix2vec(nside, ipix, nest=False)[source]#
pix2vec : nside,ipix,nest=False -> x,y,z (default RING)
- Parameters:
- Returns:
x, y, z – The coordinates of vector corresponding to input pixels. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
floats, scalar or array-like
Examples
>>> import healpy as hp >>> hp.pix2vec(16, 1504) (0.99879545620517241, 0.049067674327418015, 0.0)
>>> hp.pix2vec(16, [1440, 427]) (array([ 0.99913157, 0.5000534 ]), array([ 0. , 0.5000534]), array([ 0.04166667, 0.70703125]))
>>> hp.pix2vec([1, 2], 11) (array([ 0.52704628, 0.68861915]), array([-0.52704628, -0.28523539]), array([-0.66666667, 0.66666667]))
- jax_healpy.vec2pix(nside, x, y, z, nest=False)[source]#
vec2pix : nside,x,y,z,nest=False -> ipix (default:RING)
- Parameters:
nside (int or array-like) – The healpix nside parameter, must be a power of 2, less than 2**30
x (floats or array-like) – vector coordinates defining point on the sphere
y (floats or array-like) – vector coordinates defining point on the sphere
z (floats or array-like) – vector coordinates defining point on the sphere
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
- Returns:
ipix – The healpix pixel number corresponding to input vector. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.vec2pix(16, 1, 0, 0) 1504
>>> print(hp.vec2pix(16, [1, 0], [0, 1], [0, 0])) [1504 1520]
>>> print(hp.vec2pix([1, 2, 4, 8], 1, 0, 0)) [ 4 20 88 368]
- jax_healpy.ang2vec(theta, phi, lonlat=False)[source]#
ang2vec : convert angles to 3D position vector
- Parameters:
theta (float, scalar or array-like) – co-latitude in radians measured southward from the North pole (in [0,pi]).
phi (float, scalar or array-like) – longitude in radians measured eastward (in [0, 2*pi]).
lonlat (bool) – If True, input angles are assumed to be longitude and latitude in degree, otherwise, they are co-latitude and longitude in radians.
- Returns:
vec – if theta and phi are vectors, the result is a 2D array with a vector per row otherwise, it is a 1D array of shape (3,)
- Return type:
float, array
See also
vec2ang,rotator.dir2vec,rotator.vec2dir
Pixel Coordinates#
Functions for working with pixel coordinates within HEALPix faces:
- jax_healpy.pix2xyf(nside, ipix, nest=False)[source]#
pix2xyf : nside,ipix,nest=False -> x,y,face (default RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
x, y (int, scalars or array-like) – Pixel indices within face
face (int, scalars or array-like) – Face number
- Return type:
See also
Examples
>>> import healpy as hp >>> hp.pix2xyf(16, 1440) (8, 8, 4)
>>> hp.pix2xyf(16, [1440, 427, 1520, 0, 3068]) (array([ 8, 8, 8, 15, 0]), array([ 8, 8, 7, 15, 0]), array([4, 0, 5, 0, 8]))
>>> hp.pix2xyf([1, 2, 4, 8], 11) (array([0, 1, 3, 7]), array([0, 0, 2, 6]), array([11, 3, 3, 3]))
- jax_healpy.xyf2pix(nside, x, y, face, nest=False)[source]#
xyf2pix : nside,x,y,face,nest=False -> ipix (default:RING)
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
nside (int) – The healpix nside parameter, must be a power of 2
x (int, scalars or array-like) – Pixel indices within face
y (int, scalars or array-like) – Pixel indices within face
face (int, scalars or array-like) – Face number
nest (bool, optional) – if True, assume NESTED pixel ordering, otherwise, RING pixel ordering
- Returns:
pix – The healpix pixel numbers. Scalar if all input are scalar, array otherwise. Usual numpy broadcasting rules apply.
- Return type:
See also
Examples
>>> import healpy as hp >>> hp.xyf2pix(16, 8, 8, 4) 1440
>>> print(hp.xyf2pix(16, [8, 8, 8, 15, 0], [8, 8, 7, 15, 0], [4, 0, 5, 0, 8])) [1440 427 1520 0 3068]
Scheme Conversions#
Functions for converting between RING and NESTED pixelization schemes:
- jax_healpy.nest2ring(nside, ipix)[source]#
Convert pixel number from NESTED ordering to RING ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
ipix – the pixel number in RING scheme
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.nest2ring(16, 1130) 1504
>>> print(hp.nest2ring(2, np.arange(10))) [13 5 4 0 15 7 6 1 17 9]
>>> print(hp.nest2ring([1, 2, 4, 8], 11)) [ 11 2 12 211]
- jax_healpy.ring2nest(nside, ipix)[source]#
Convert pixel number from RING ordering to NESTED ordering.
Contrary to healpy, nside must be an int. It cannot be a list, array, tuple, etc.
- Parameters:
- Returns:
ipix – the pixel number in NESTED scheme
- Return type:
int, scalar or array-like
Examples
>>> import healpy as hp >>> hp.ring2nest(16, 1504) 1130
>>> print(hp.ring2nest(2, np.arange(10))) [ 3 7 11 15 2 1 6 5 10 9]
>>> print(hp.ring2nest([1, 2, 4, 8], 11)) [ 11 13 61 253]
- jax_healpy.reorder(map_in, inp=None, out=None, r2n=False, n2r=False, process_by_chunks=False)[source]#
Reorder a healpix map from RING/NESTED ordering to NESTED/RING.
Masked arrays are not yet supported.
By default, the maps are processed in one go, but if memory is an issue, use the
process_by_chunksoption (which reproduces healpy behaviour).- Parameters:
map_in (array-like) – the input map to reorder, accepts masked arrays
inp (
'RING'or'NESTED') – define the input and output orderingout (
'RING'or'NESTED') – define the input and output orderingr2n (bool) – if True, reorder from RING to NESTED
n2r (bool) – if True, reorder from NESTED to RING
process_by_chunks (bool)
- Returns:
map_out – the reordered map, as masked array if the input was a masked array
- Return type:
array-like
Notes
if
r2norn2ris defined, overrideinpandout.Examples
>>> import healpy as hp >>> hp.reorder(np.arange(48), r2n = True) array([13, 5, 4, 0, 15, 7, 6, 1, 17, 9, 8, 2, 19, 11, 10, 3, 28, 20, 27, 12, 30, 22, 21, 14, 32, 24, 23, 16, 34, 26, 25, 18, 44, 37, 36, 29, 45, 39, 38, 31, 46, 41, 40, 33, 47, 43, 42, 35]) >>> hp.reorder(np.arange(12), n2r = True) array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) >>> hp.reorder(hp.ma(np.arange(12.)), n2r = True) masked_array(data = [ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.], mask = False, fill_value = -1.6375e+30) >>> m = [np.arange(12.), np.arange(12.), np.arange(12.)] >>> m[0][2] = hp.UNSEEN >>> m[1][2] = hp.UNSEEN >>> m[2][2] = hp.UNSEEN >>> m = hp.ma(m) >>> hp.reorder(m, n2r = True) masked_array(data = [[0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0] [0.0 1.0 -- 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0]], mask = [[False False True False False False False False False False False False] [False False True False False False False False False False False False] [False False True False False False False False False False False False]], fill_value = -1.6375e+30)
Map Resolution Functions#
Functions for changing map resolution:
Interpolation#
Functions for interpolating values on the sphere:
- jax_healpy.get_interp_weights(nside, theta, phi=None, nest=False, lonlat=False)[source]#
Return interpolation weights for given coordinates.
This function performs bilinear interpolation by finding the four nearest pixel centers and computing their interpolation weights. Provides machine precision matching healpy when pixels are sorted.
- Parameters:
nside (int) – HEALPix nside parameter
theta (ArrayLike) – Colatitude in radians (or pixel indices if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True)
nest (bool, optional) – If True, use NESTED pixel ordering (raises error)
lonlat (bool, optional) – If True, interpret theta, phi as longitude, latitude in degrees
- Returns:
pixels (Array) – Array of shape (4, N) containing the four nearest pixel indices. Pixel order is not guaranteed - sort both pixels and weights by pixel values for exact healpy precision matching.
weights (Array) – Array of shape (4, N) containing the interpolation weights. Weights sum to 1.0 for each point to machine precision.
- Return type:
Notes
For exact healpy compatibility, sort the output by pixel values:
>>> sorted_indices = jnp.argsort(pixels, axis=0) >>> sorted_pixels = jnp.take_along_axis(pixels, sorted_indices, axis=0) >>> sorted_weights = jnp.take_along_axis(weights, sorted_indices, axis=0)
Precision and Algorithmic Considerations:#
The phi interpolation calculation can exhibit precision differences compared to healpy, particularly for coordinates near the poles and for high nside values (256+). This is due to:
Phi Interpolation Near Poles: The formula phi_norm = (phi / dphi - shift) % nr can produce different results than healpy’s algorithm when transitioning between rings with different pixel counts (e.g., ring transitions near poles).
High Nside Precision Limits: For nside ≥ 256, floating-point precision limits can cause the JAX implementation to select different (but mathematically valid) interpolation neighbors compared to healpy, especially in challenging geometric regions.
Ring Transition Edge Cases: Coordinates exactly at boundaries between polar caps and equatorial regions may show pixel differences due to algorithmic implementation variations, though weights still sum to 1.0 and maintain interpolation accuracy.
Expected precision levels: - nside ≤ 64: Machine precision matching (< 1e-25 weight error) - nside 128-256: High precision (< 1e-15 weight error) - nside ≥ 512: Good precision (< 1e-12 to 1e-5 weight error depending on nside)
Automatic Differentiation:#
This function is fully compatible with JAX’s automatic differentiation. The gradients behave as follows:
Gradients of individual weights reflect the continuous dependence on coordinates
Gradients of sum(weights) are always zero since the sum is identically 1.0
Pixel indices have zero gradients since they are discrete selectors
Example gradient usage:
>>> def interpolate_map(m, theta, phi): ... pixels, weights = get_interp_weights(nside, theta, phi) ... return jnp.sum(weights * m[pixels], axis=0) >>> grad_func = jax.grad(interpolate_map, argnums=(1, 2)) >>> grad_theta, grad_phi = grad_func(map_data, theta, phi)
- jax_healpy.get_interp_val(m, theta, phi=None, nest=False, lonlat=False)[source]#
Return interpolated map values at given coordinates.
This function performs bilinear interpolation of map values using the four nearest pixel neighbors, providing machine precision matching healpy.
- Parameters:
m (ArrayLike) – HEALPix map(s) to interpolate. Can be 1D (single map) or 2D (multiple maps). Shape: (npix,) or (nmaps, npix)
theta (ArrayLike) – Colatitude in radians (or pixel indices if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True)
nest (bool, optional) – If True, use NESTED pixel ordering (raises error - not supported)
lonlat (bool, optional) – If True, interpret theta, phi as longitude, latitude in degrees
- Returns:
values – Interpolated map values at the given coordinates. Shape matches broadcast of theta, phi for single map. For multiple maps, shape is (nmaps, …) where … is broadcast shape.
- Return type:
Array
Notes
Uses bilinear interpolation with the four nearest pixel neighbors. For exact healpy compatibility, this function uses get_interp_weights internally and computes: result = sum(weights * map_values[pixels]) Results won’t match healpy if theta and phi are not valid angles.
Examples
>>> import jax_healpy as hp >>> import jax.numpy as jnp >>> m = jnp.arange(12.) >>> hp.get_interp_val(m, jnp.pi/2, 0.0) Array(4.5, dtype=float64)
>>> # Multiple coordinates >>> theta = jnp.array([jnp.pi/4, jnp.pi/2]) >>> phi = jnp.array([0.0, jnp.pi/2]) >>> hp.get_interp_val(m, theta, phi) Array([2.25, 6. ], dtype=float64)
>>> # Multiple maps >>> maps = jnp.array([jnp.arange(12.), 2*jnp.arange(12.)]) >>> hp.get_interp_val(maps, jnp.pi/2, 0.0) Array([4.5, 9. ], dtype=float64)
Neighbor Functions#
Functions for finding neighboring pixels:
- jax_healpy.get_all_neighbours(nside, theta, phi=None, nest=False, lonlat=False, get_center=False)[source]#
Get the 8 nearest neighbors of given pixels, optionally including center pixel.
- Parameters:
nside (int) – HEALPix resolution parameter, must be a power of 2
theta (ArrayLike) – Either colatitude in radians (if phi is provided) or pixel indices (if phi is None)
phi (ArrayLike, optional) – Longitude in radians (or degrees if lonlat=True). If None, theta is treated as pixel indices.
nest (bool, optional) – If True, use NESTED pixel ordering scheme. Default is False (RING ordering).
lonlat (bool, optional) – If True and phi is provided, interpret (theta, phi) as (longitude, latitude) in degrees.
get_center (bool, optional) – If True, return center pixel + 8 neighbors (9 total). If False, return only 8 neighbors. Default is False for backward compatibility with healpy.
- Returns:
neighbors – Array of pixel indices. When get_center=False: shape is (8,) for scalar input or (8, N) for array input, with neighbors in directions [SW, W, NW, N, NE, E, SE, S]. When get_center=True: shape is (9,) for scalar input or (9, N) for array input, with pixels in order [CENTER, SW, W, NW, N, NE, E, SE, S]. Non-existent neighbors (at map boundaries) are marked with -1.
- Return type:
Array
Examples
>>> import jax_healpy as hp >>> # Get 8 neighbors of pixel 4 at nside=1 (default behavior, matches healpy) >>> neighbors = hp.get_all_neighbours(1, 4) >>> print(neighbors) [11 7 3 -1 0 5 8 -1]
>>> # Get center + 8 neighbors (9 total) >>> neighbors_with_center = hp.get_all_neighbours(1, 4, get_center=True) >>> print(neighbors_with_center) [ 4 11 7 3 -1 0 5 8 -1]
>>> # Works with angular coordinates too >>> import jax.numpy as jnp >>> neighbors = hp.get_all_neighbours(1, jnp.pi/2, jnp.pi/2, get_center=True) >>> print(neighbors) [ 6 8 4 0 -1 1 6 9 -1]
Notes
healpy Compatibility: The get_center=False (default) behavior maintains perfect compatibility with healpy.get_all_neighbours(). The get_center=True parameter is a jax-healpy-specific extension that does not exist in healpy.
When get_center=False (default): - Returns 8 neighbors in identical order to healpy: [SW, W, NW, N, NE, E, SE, S] - Produces bit-for-bit identical results to healpy for all input modes - Maintains backward compatibility with existing healpy-based code
When get_center=True (jax-healpy extension): - Returns 9 pixels: center pixel + 8 neighbors in order [CENTER, SW, W, NW, N, NE, E, SE, S] - Center pixel is always the first element (index 0) - Neighbor ordering matches healpy convention starting from index 1 - This functionality does not exist in healpy and is unique to jax-healpy
Performance: The default get_center=False case has no performance overhead compared to the original implementation. The get_center=True case adds minimal computational cost.
JAX Features: This function is fully compatible with JAX transformations including jit compilation, vmap, grad, and automatic differentiation. The get_center parameter is a static argument that allows different compilations for each mode.
HEALPix Parameters#
Functions for working with HEALPix resolution parameters:
- jax_healpy.nside2npix(nside)[source]#
Give the number of pixels for the given nside.
- Parameters:
nside (int) – healpix nside parameter
- Returns:
npix – corresponding number of pixels
- Return type:
Examples
>>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2npix(8) 768
>>> np.all([hp.nside2npix(nside) == 12 * nside**2 for nside in [2**n for n in range(12)]]) True
>>> hp.nside2npix(7) 588
- jax_healpy.npix2nside(npix)[source]#
Give the nside parameter for the given number of pixels.
- Parameters:
npix (int) – the number of pixels
- Returns:
nside – the nside parameter corresponding to npix
- Return type:
Notes
Raise a ValueError exception if number of pixel does not correspond to the number of pixel of a healpix map.
Examples
>>> import jax_healpy as hp >>> hp.npix2nside(768) 8
>>> np.all([hp.npix2nside(12 * nside**2) == nside for nside in [2**n for n in range(12)]]) True
>>> hp.npix2nside(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2)
- jax_healpy.nside2order(nside)[source]#
Give the resolution order for a given nside.
- Parameters:
nside (int) – healpix nside parameter; an exception is raised if nside is not valid (nside must be a power of 2, less than 2**30)
- Returns:
order – corresponding order where nside = 2**(order)
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> import numpy as np >>> hp.nside2order(128) 7
>>> all(hp.nside2order(2**o) == o for o in range(30)) True
>>> hp.nside2order(7) Traceback (most recent call last): ... ValueError: 7 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.order2nside(order)[source]#
Give the nside parameter for the given resolution order.
- Parameters:
order (int) – the resolution order
- Returns:
nside – the nside parameter corresponding to order
- Return type:
Notes
Raise a ValueError exception if order produces an nside out of range.
Examples
>>> import jax_healpy as hp >>> hp.order2nside(7) 128
>>> print(hp.order2nside(np.arange(8))) [ 1 2 4 8 16 32 64 128]
>>> hp.order2nside(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.order2npix(order)[source]#
Give the number of pixels for the given resolution order.
- Parameters:
order (int) – the resolution order
- Returns:
npix – corresponding number of pixels
- Return type:
Notes
A convenience function that successively applies order2nside then nside2npix to order.
Examples
>>> import jax_healpy as hp >>> hp.order2npix(7) 196608
>>> print(hp.order2npix(np.arange(8))) [ 12 48 192 768 3072 12288 49152 196608]
>>> hp.order2npix(31) Traceback (most recent call last): ... ValueError: 2147483648 is not a valid nside parameter (must be a power of 2, less than 2**30)
- jax_healpy.npix2order(npix)[source]#
Give the resolution order for the given number of pixels.
- Parameters:
npix (int) – the number of pixels
- Returns:
order – corresponding resolution order
- Return type:
Notes
A convenience function that successively applies npix2nside then nside2order to npix.
Examples
>>> import jax_healpy as hp >>> hp.npix2order(768) 3
>>> np.all([hp.npix2order(12 * 4**order) == order for order in range(12)]) True
>>> hp.npix2order(1000) Traceback (most recent call last): ... ValueError: Wrong pixel number (it is not 12*nside**2)
- jax_healpy.nside2resol(nside, arcmin=False)[source]#
Give approximate resolution (pixel size in radian or arcmin) for nside.
Resolution is just the square root of the pixel area, which is a gross approximation given the different pixel shapes
- Parameters:
- Returns:
resol – approximate pixel size in radians or arcmin
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> hp.nside2resol(128, arcmin = True) 27.483891294539248
>>> hp.nside2resol(256) 0.0039973699529159707
>>> hp.nside2resol(7) 0.1461895297066412
- jax_healpy.nside2pixarea(nside, degrees=False)[source]#
Give pixel area given nside in square radians or square degrees.
- Parameters:
- Returns:
pixarea – pixel area in square radian or square degree
- Return type:
Notes
Raise a ValueError exception if nside is not valid.
Examples
>>> import jax_healpy as hp >>> hp.nside2pixarea(128, degrees = True) 0.2098234113027917
>>> hp.nside2pixarea(256) 1.5978966540475428e-05
>>> hp.nside2pixarea(7) 0.021371378595848933
Utility Functions#
Helper functions for validation and map properties:
- jax_healpy.isnsideok(nside, nest=False)[source]#
Returns
Trueif nside is a valid nside parameter,Falseotherwise.NSIDE needs to be a power of 2 only for nested ordering
- Parameters:
- Returns:
ok –
Trueif given value is a valid nside,Falseotherwise.- Return type:
bool, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.isnsideok(13, nest=True) False
>>> hp.isnsideok(13, nest=False) True
>>> hp.isnsideok(32) True
>>> hp.isnsideok([1, 2, 3, 4, 8, 16], nest=True) array([ True, True, False, True, True, True], dtype=bool)
- jax_healpy.isnpixok(npix)[source]#
Return
Trueif npix is a valid value for healpix map size,Falseotherwise.- Parameters:
npix (int, scalar or array-like) – integer value to be tested
- Returns:
ok –
Trueif given value is a valid number of pixel,Falseotherwise- Return type:
bool, scalar or array-like
Examples
>>> import jax_healpy as hp >>> hp.isnpixok(12) True
>>> hp.isnpixok(768) True
>>> hp.isnpixok([12, 768, 1002]) array([ True, True, False], dtype=bool)
- jax_healpy.maptype(m)[source]#
Describe the type of the map (valid, single, sequence of maps). Checks : the number of maps, that all maps have same length and that this length is a valid map size (using
isnpixok()).- Parameters:
m (sequence) – the map to get info from
- Returns:
info – -1 if the given object is not a valid map, 0 if it is a single map, info > 0 if it is a sequence of maps (info is then the number of maps)
- Return type:
Examples
>>> import healpy as hp >>> hp.pixelfunc.maptype(np.arange(12)) 0 >>> hp.pixelfunc.maptype([np.arange(12), np.arange(12)]) 2
Constants#
- jax_healpy.UNSEEN = 1.6375e+30#
Convert a string or number to a floating point number, if possible.
Sentinel value used to mark invalid or missing pixels in HEALPix maps.