Clustering Functions#
This module provides advanced clustering algorithms and utilities for astronomical data analysis, including K-means clustering and mask manipulation functions.
- jax_healpy.clustering.find_kmeans_clusters(mask, indices, n_regions, key, max_centroids=None, unassigned=-1.6375e+30, initial_sample_size=3)[source]#
Cluster pixels of a HEALPix map into regions using KMeans.
- Parameters:
mask (Array) – HEALPix mask.
indices (Array) – Indices of valid pixels.
n_regions (int) – Number of regions to cluster into.
key (PRNGKeyArray) – JAX random key.
max_centroids (int | None, optional) – Maximum allowed centroids. Defaults to None.
unassigned (float, optional) – Value for unassigned pixels. Defaults to jhp.UNSEEN.
initial_sample_size (int, optional) – Initial sample size for KMeans. Defaults to 3. It is used to initialize the centroids. The sample size is initial_sample_size * n_regions.
- Returns:
Map with clustered region labels.
- Return type:
Array
- Raises:
RuntimeError – If n_regions exceeds max_centroids when provided.
TracerBoolConversionError – If n_regions is a tracer and max_centroids is None.
Example
>>> import numpy as np >>> from jax import numpy as jnp, random >>> import jax_healpy as jhp
# Load mask and identify valid pixels >>> mask = np.load(“GAL20.npy”) >>> indices, = jnp.where(mask == 1) >>> key = random.key(0)
# Perform clustering >>> clustered_map = find_kmeans_clusters(mask, indices, n_regions=5, key=key) >>> print(jnp.unique(clustered_map)) [0 1 2 3 4]
# Error example when max_centroids constraint is violated >>> try: … clustered_map = find_kmeans_clusters(mask, indices, n_regions=15, key=key, max_centroids=10) … except RuntimeError as e: … print(e)
- jax_healpy.clustering.get_cutout_from_mask(ful_map, indices, axis=0)[source]#
Extract a cutout from a full map using given indices.
- Parameters:
ful_map (Array) – The full HEALPix map.
indices (Array) – Indices for the cutout.
axis (int, optional) – Axis along which to apply the cutout. Defaults to 0.
- Returns:
The cutout map.
- Return type:
Array
Example
>>> mask = np.load("GAL20.npy") >>> indices, = jnp.where(mask == 1) >>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),)) >>> cutout = get_cutout_from_mask(full_map, indices) >>> print(cutout.shape)
- jax_healpy.clustering.get_fullmap_from_cutout(labels, indices, nside, axis=0)[source]#
Reconstruct the full map from a cutout by inserting values along a specified axis.
- Parameters:
- Returns:
Full map array with shape like labels, but with npatch → npix along the specified axis.
- Return type:
Array
Example
>>> mask = np.load("GAL20.npy") >>> indices, = jnp.where(mask == 1) >>> full_map = random.normal(random.key(0), shape=(jhp.nside2npix(64),)) >>> cutout = get_cutout_from_mask(full_map, indices) >>> reconstructed = get_fullmap_from_cutout(cutout, indices, nside=64) >>> print(jnp.array_equal(reconstructed, full_map))
- jax_healpy.clustering.normalize_by_first_occurrence(arr, n_regions, max_centroids)[source]#
Normalize an array by mapping each unique value to the index of its first occurrence, preserving order up to n_regions values.
Any value not among the first n_regions unique elements (determined by order of appearance) is clipped to fit within [0, n_regions - 1], or set to UNSEEN if originally marked as such.
This is useful after clustering or segmentation tasks to ensure label indices are contiguous, compact, and order-consistent for downstream processing.
- Parameters:
- Returns:
An array of same shape as arr, where each label is replaced by its first-seen index, or UNSEEN if it was originally marked or beyond n_regions.
- Return type:
Example
>>> arr = jnp.array([UNSEEN, UNSEEN, 5, 5, 5, 2, 3, 3, 8]) >>> normalize_by_first_occurrence(arr, 4, 10) Array([UNSEEN, UNSEEN, 0, 0, 0, 1, 2, 2, 3])
K-means Clustering#
Mask and Map Utilities#
Functions for manipulating masks and extracting map regions:
Label Utilities#
Functions for manipulating cluster labels:
Examples#
Basic K-means clustering:
import jax.numpy as jnp
import jax_healpy as hp
# Generate sample data
data = jnp.random.normal(0, 1, (1000, 3))
# Perform K-means clustering
centroids, labels, inertia = hp.kmeans_sample(data, n_clusters=5)
print(f"Final inertia: {inertia}")
print(f"Cluster sizes: {jnp.bincount(labels)}")
Using the KMeans class:
# Initialize K-means object
kmeans = hp.KMeans(n_clusters=3, max_iter=100, tol=1e-4)
# Fit the model
kmeans.fit(data)
# Get predictions for new data
new_data = jnp.random.normal(0, 1, (100, 3))
predictions = kmeans.predict(new_data)
Working with HEALPix masks:
# Create a test mask
nside = 64
npix = hp.nside2npix(nside)
# Generate random clusters
mask = jnp.random.random(npix) > 0.8
# Find connected clusters
cluster_labels, n_clusters = hp.get_clusters(mask, min_size=10)
print(f"Found {n_clusters} clusters")
Extracting map cutouts:
# Create test map and mask
test_map = jnp.random.normal(0, 1, npix)
region_mask = cluster_labels == 1 # Focus on cluster 1
# Extract cutout
cutout, cutout_mask, indices = hp.get_cutout_from_mask(
test_map, region_mask, buffer_size=5
)
# Process cutout data
processed_cutout = cutout * 2.0 # Example processing
# Insert back into full map
result_map = hp.from_cutout_to_fullmap(
processed_cutout, indices, nside, fill_value=hp.UNSEEN
)
Performance Tips#
K-means clustering benefits significantly from GPU acceleration
Use JIT compilation for repeated clustering operations:
@jax.jit
def fast_kmeans(data, n_clusters):
return hp.kmeans_sample(data, n_clusters)
For large datasets, consider using mini-batch K-means or data sampling
Mask operations are vectorized and GPU-accelerated automatically