Clustering Functions#

This module provides advanced clustering algorithms and utilities for astronomical data analysis, including K-means clustering and mask manipulation functions.

jax_healpy.clustering.combine_masks(cutouts, indices, nside, axis=0)[source]#
Parameters:
Return type:

Array

jax_healpy.clustering.find_kmeans_clusters(mask, indices, n_regions, key, max_centroids=None, unassigned=-1.6375e+30, initial_sample_size=3)[source]#

Cluster pixels of a HEALPix map into regions using KMeans.

Parameters:
  • mask (Array) – HEALPix mask.

  • indices (Array) – Indices of valid pixels.

  • n_regions (int) – Number of regions to cluster into.

  • key (PRNGKeyArray) – JAX random key.

  • max_centroids (int | None, optional) – Maximum allowed centroids. Defaults to None.

  • unassigned (float, optional) – Value for unassigned pixels. Defaults to jhp.UNSEEN.

  • initial_sample_size (int, optional) – Initial sample size for KMeans. Defaults to 3. It is used to initialize the centroids. The sample size is initial_sample_size * n_regions.

Returns:

Map with clustered region labels.

Return type:

Array

Raises:
  • RuntimeError – If n_regions exceeds max_centroids when provided.

  • TracerBoolConversionError – If n_regions is a tracer and max_centroids is None.

Example

>>> import numpy as np
>>> from jax import numpy as jnp, random
>>> import jax_healpy as jhp

# Load mask and identify valid pixels >>> mask = np.load(“GAL20.npy”) >>> indices, = jnp.where(mask == 1) >>> key = random.key(0)

# Perform clustering >>> clustered_map = find_kmeans_clusters(mask, indices, n_regions=5, key=key) >>> print(jnp.unique(clustered_map)) [0 1 2 3 4]

# Error example when max_centroids constraint is violated >>> try: … clustered_map = find_kmeans_clusters(mask, indices, n_regions=15, key=key, max_centroids=10) … except RuntimeError as e: … print(e)

jax_healpy.clustering.get_cutout_from_mask(ful_map, indices, axis=0)[source]#

Extract a cutout from a full map using given indices.

Parameters:
  • ful_map (Array) – The full HEALPix map.

  • indices (Array) – Indices for the cutout.

  • axis (int, optional) – Axis along which to apply the cutout. Defaults to 0.

Returns:

The cutout map.

Return type:

Array

Example

>>> mask = np.load("GAL20.npy")
>>> indices, = jnp.where(mask == 1)
>>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),))
>>> cutout = get_cutout_from_mask(full_map, indices)
>>> print(cutout.shape)
jax_healpy.clustering.get_fullmap_from_cutout(labels, indices, nside, axis=0)[source]#

Reconstruct the full map from a cutout by inserting values along a specified axis.

Parameters:
  • labels (Array) – The cutout array, shape […, npatch, …].

  • indices (Array) – The pixel indices for the cutout.

  • nside (int) – HEALPix NSIDE.

  • axis (int) – The axis in labels that corresponds to the patch dimension (to be expanded to npix).

Returns:

Full map array with shape like labels, but with npatchnpix along the specified axis.

Return type:

Array

Example

>>> mask = np.load("GAL20.npy")
>>> indices, = jnp.where(mask == 1)
>>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),))
>>> cutout = get_cutout_from_mask(full_map, indices)
>>> reconstructed = get_fullmap_from_cutout(cutout, indices, nside=64)
>>> print(jnp.array_equal(reconstructed, full_map))
jax_healpy.clustering.normalize_by_first_occurrence(arr, n_regions, max_centroids)[source]#

Normalize an array by mapping each unique value to the index of its first occurrence, preserving order up to n_regions values.

Any value not among the first n_regions unique elements (determined by order of appearance) is clipped to fit within [0, n_regions - 1], or set to UNSEEN if originally marked as such.

This is useful after clustering or segmentation tasks to ensure label indices are contiguous, compact, and order-consistent for downstream processing.

Parameters:
  • arr (Array) – Integer array (1D or ND) containing raw labels, including possible UNSEEN markers.

  • n_regions (int) – Maximum number of regions (unique labels) to preserve. Others are clipped.

  • max_centroids (int) – Maximum number of unique labels expected (must be static for JIT).

Returns:

An array of same shape as arr, where each label is replaced by its first-seen index, or UNSEEN if it was originally marked or beyond n_regions.

Return type:

Array

Example

>>> arr = jnp.array([UNSEEN, UNSEEN, 5, 5, 5, 2, 3, 3, 8])
>>> normalize_by_first_occurrence(arr, 4, 10)
Array([UNSEEN, UNSEEN, 0, 0, 0, 1, 2, 2, 3])

K-means Clustering#

Mask and Map Utilities#

Functions for manipulating masks and extracting map regions:

Label Utilities#

Functions for manipulating cluster labels:

Examples#

Basic K-means clustering:

import jax.numpy as jnp
import jax_healpy as hp

# Generate sample data
data = jnp.random.normal(0, 1, (1000, 3))

# Perform K-means clustering
centroids, labels, inertia = hp.kmeans_sample(data, n_clusters=5)

print(f"Final inertia: {inertia}")
print(f"Cluster sizes: {jnp.bincount(labels)}")

Using the KMeans class:

# Initialize K-means object
kmeans = hp.KMeans(n_clusters=3, max_iter=100, tol=1e-4)

# Fit the model
kmeans.fit(data)

# Get predictions for new data
new_data = jnp.random.normal(0, 1, (100, 3))
predictions = kmeans.predict(new_data)

Working with HEALPix masks:

# Create a test mask
nside = 64
npix = hp.nside2npix(nside)

# Generate random clusters
mask = jnp.random.random(npix) > 0.8

# Find connected clusters
cluster_labels, n_clusters = hp.get_clusters(mask, min_size=10)

print(f"Found {n_clusters} clusters")

Extracting map cutouts:

# Create test map and mask
test_map = jnp.random.normal(0, 1, npix)
region_mask = cluster_labels == 1  # Focus on cluster 1

# Extract cutout
cutout, cutout_mask, indices = hp.get_cutout_from_mask(
    test_map, region_mask, buffer_size=5
)

# Process cutout data
processed_cutout = cutout * 2.0  # Example processing

# Insert back into full map
result_map = hp.from_cutout_to_fullmap(
    processed_cutout, indices, nside, fill_value=hp.UNSEEN
)

Performance Tips#

  • K-means clustering benefits significantly from GPU acceleration

  • Use JIT compilation for repeated clustering operations:

@jax.jit
def fast_kmeans(data, n_clusters):
    return hp.kmeans_sample(data, n_clusters)
  • For large datasets, consider using mini-batch K-means or data sampling

  • Mask operations are vectorized and GPU-accelerated automatically