# This file is part of jax-healpy.
# Copyright (C) 2024 CNRS / SciPol developers
#
# jax-healpy is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# jax-healpy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with jax-healpy. If not, see <https://www.gnu.org/licenses/>.
from functools import partial
import jax
import numpy as np
from jax import numpy as jnp
from jaxtyping import Array, PRNGKeyArray
import jax_healpy as jhp
from ..pixelfunc import UNSEEN
from ._kmeans import kmeans_sample
def call_back_check(n_regions: Array, max_centroids: None) -> None:
"""Check if the number of regions exceeds the maximum centroids.
Args:
n_regions (Array): Number of regions requested.
max_centroids (None): Maximum allowed centroids.
Raises:
RuntimeError: If n_regions exceeds max_centroids.
"""
if max_centroids is not None:
if n_regions > max_centroids:
raise RuntimeError("""
In function [get_clusters] in the comp_sep module:
Number of regions (n_regions) is greater than max_centroids.
Either:
- Increase max_centroids.
- Set max_centroids to None, but n_regions will have
to be static and can no longer be a tracer.
""")
[docs]
@partial(jax.jit, static_argnums=(2))
def get_cutout_from_mask(ful_map: Array, indices: Array, axis: int = 0) -> Array:
"""Extract a cutout from a full map using given indices.
Args:
ful_map (Array): The full HEALPix map.
indices (Array): Indices for the cutout.
axis (int, optional): Axis along which to apply the cutout. Defaults to 0.
Returns:
Array: The cutout map.
Example:
>>> mask = np.load("GAL20.npy")
>>> indices, = jnp.where(mask == 1)
>>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),))
>>> cutout = get_cutout_from_mask(full_map, indices)
>>> print(cutout.shape)
"""
return jax.tree.map(lambda x: jnp.take(x, indices, axis=axis), ful_map)
[docs]
@partial(jax.jit, static_argnums=(2, 3))
def combine_masks(cutouts: list[Array], indices: list[Array], nside: int, axis: int = 0) -> Array:
if len(cutouts) != len(indices):
raise ValueError(' The number of cutouts and indices must match.')
structure = jax.tree.structure(cutouts[0])
for cutout in cutouts[1:]:
if jax.tree.structure(cutout) != structure:
raise ValueError('All cutouts must have the same structure.')
npix = 12 * nside**2
full_shape = list(jax.tree.leaves(cutouts)[0].shape)
full_shape[axis] = npix
map_ids = jax.tree.map(lambda x: jnp.full(full_shape, UNSEEN), cutouts[0])
for cutout, indices in zip(cutouts, indices):
patch_slice = [slice(None)] * len(jax.tree.leaves(cutout)[0].shape)
patch_slice[axis] = indices
patch_slice = tuple(patch_slice)
map_ids = jax.tree.map(lambda maps, lbl: maps.at[patch_slice].set(lbl), map_ids, cutout)
return map_ids
[docs]
@partial(jax.jit, static_argnums=(2, 3))
def get_fullmap_from_cutout(labels: Array, indices: Array, nside: int, axis: int = 0) -> Array:
"""
Reconstruct the full map from a cutout by inserting values along a specified axis.
Args:
labels (Array): The cutout array, shape [..., npatch, ...].
indices (Array): The pixel indices for the cutout.
nside (int): HEALPix NSIDE.
axis (int): The axis in `labels` that corresponds to the patch dimension (to be expanded to npix).
Returns:
Array: Full map array with shape like `labels`, but with `npatch` → `npix` along the specified axis.
Example:
>>> mask = np.load("GAL20.npy")
>>> indices, = jnp.where(mask == 1)
>>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),))
>>> cutout = get_cutout_from_mask(full_map, indices)
>>> reconstructed = get_fullmap_from_cutout(cutout, indices, nside=64)
>>> print(jnp.array_equal(reconstructed, full_map))
"""
npix = 12 * nside**2
def insert_fn(lbl):
full_shape = list(lbl.shape)
full_shape[axis] = npix
base = jnp.full(full_shape, UNSEEN)
slicing = [slice(None)] * lbl.ndim
slicing[axis] = indices
slicing = tuple(slicing)
return base.at[slicing].set(lbl)
return jax.tree.map(insert_fn, labels)
[docs]
def find_kmeans_clusters(
mask: Array,
indices: Array,
n_regions: int,
key: PRNGKeyArray,
max_centroids: int | None = None,
unassigned: float = UNSEEN,
initial_sample_size: int = 3,
) -> Array:
"""Cluster pixels of a HEALPix map into regions using KMeans.
Args:
mask (Array): HEALPix mask.
indices (Array): Indices of valid pixels.
n_regions (int): Number of regions to cluster into.
key (PRNGKeyArray): JAX random key.
max_centroids (int | None, optional): Maximum allowed centroids. Defaults to None.
unassigned (float, optional): Value for unassigned pixels. Defaults to jhp.UNSEEN.
initial_sample_size (int, optional): Initial sample size for KMeans. Defaults to 3.
It is used to initialize the centroids.
The sample size is initial_sample_size * n_regions.
Returns:
Array: Map with clustered region labels.
Raises:
RuntimeError: If n_regions exceeds max_centroids when provided.
TracerBoolConversionError: If n_regions is a tracer and max_centroids is None.
Example:
>>> import numpy as np
>>> from jax import numpy as jnp, random
>>> import jax_healpy as jhp
# Load mask and identify valid pixels
>>> mask = np.load("GAL20.npy")
>>> indices, = jnp.where(mask == 1)
>>> key = random.key(0)
# Perform clustering
>>> clustered_map = find_kmeans_clusters(mask, indices, n_regions=5, key=key)
>>> print(jnp.unique(clustered_map))
[0 1 2 3 4]
# Error example when max_centroids constraint is violated
>>> try:
... clustered_map = find_kmeans_clusters(mask, indices, n_regions=15, key=key, max_centroids=10)
... except RuntimeError as e:
... print(e)
"""
jax.debug.callback(call_back_check, n_regions, max_centroids)
npix = mask.size
nside = jhp.npix2nside(npix)
ipix = jnp.arange(npix)
ra, dec = jhp.pix2ang(nside, ipix, lonlat=True)
ra_dec = jnp.stack([ra[indices], dec[indices]], axis=-1)
km = kmeans_sample(
key,
ra_dec,
n_regions,
max_centroids=max_centroids,
maxiter=100,
tol=1.0e-5,
initial_sample_size=initial_sample_size,
)
map_ids = jnp.full(npix, unassigned)
return map_ids.at[ipix[indices]].set(km.labels)
[docs]
@partial(jax.jit, static_argnums=(2,))
def normalize_by_first_occurrence(arr: Array, n_regions: int, max_centroids: int) -> Array:
"""
Normalize an array by mapping each unique value to the index of its first occurrence,
preserving order up to `n_regions` values.
Any value not among the first `n_regions` unique elements (determined by order of
appearance) is clipped to fit within `[0, n_regions - 1]`, or set to `UNSEEN` if
originally marked as such.
This is useful after clustering or segmentation tasks to ensure label indices are
contiguous, compact, and order-consistent for downstream processing.
Args:
arr: Integer array (1D or ND) containing raw labels, including possible `UNSEEN` markers.
n_regions: Maximum number of regions (unique labels) to preserve. Others are clipped.
max_centroids: Maximum number of unique labels expected (must be static for JIT).
Returns:
An array of same shape as `arr`, where each label is replaced by its first-seen index,
or `UNSEEN` if it was originally marked or beyond `n_regions`.
Example:
>>> arr = jnp.array([UNSEEN, UNSEEN, 5, 5, 5, 2, 3, 3, 8])
>>> normalize_by_first_occurrence(arr, 4, 10)
Array([UNSEEN, UNSEEN, 0, 0, 0, 1, 2, 2, 3])
"""
arr_unseen = jnp.concatenate([jnp.array([UNSEEN]), arr])
unique_vals, first_idxs = jnp.unique(arr_unseen, size=max_centroids + 1, return_index=True)
order = jnp.argsort(first_idxs)
unique_by_first = unique_vals[order]
matches = arr_unseen[..., None] == unique_by_first
idxs = jnp.argmax(matches, axis=-1)
no_match = ~jnp.any(matches, axis=-1)
normalized = jnp.where(no_match, UNSEEN, idxs)
normalized = normalized[1:]
normalized = jnp.where(
arr == UNSEEN, UNSEEN, jnp.clip(normalized - (max_centroids - n_regions) - 1, 0, n_regions - 1)
)
return normalized
def shuffle_labels(arr: Array) -> Array:
"""
Randomly reassigns label indices using a NumPy-based permutation.
Assumes that input labels are normalized integers in [0, N), with possible `hp.UNSEEN`
entries. The function produces a random bijective mapping of present labels, preserving
shape and leaving `hp.UNSEEN` values unchanged.
This is intended for visualization purposes — shuffling label indices can reduce
misleading visual patterns (e.g., color clumping) in plots such as `mollview`, making
class structure easier to interpret.
Args:
arr: Integer array of label indices, e.g., [0, 0, 1, 2, hp.UNSEEN].
Returns:
A NumPy array of the same shape as `arr`, with valid labels randomly permuted.
`hp.UNSEEN` entries are left unchanged.
Example:
>>> arr = np.array([0, 0, 1, 1, 2, hp.UNSEEN])
>>> shuffle_labels(arr)
array([2, 2, 0, 0, 1, hp.UNSEEN]) # result will vary
"""
unique_vals = np.unique(arr[arr != UNSEEN])
shuffled_vals = np.random.permutation(unique_vals)
mapping = dict(zip(unique_vals, shuffled_vals))
shuffled_arr = np.vectorize(lambda x: mapping.get(x, UNSEEN))(arr)
return shuffled_arr.astype(arr.dtype)