Skip to content

Decomposition

best_score_excluding_atom(norm_reduction, combinations, atom)

For each sample, find the maximum norm_reduction value among all combinations that do NOT contain atom[i].

Parameters:

Name Type Description Default
norm_reduction ndarray

(n_sample, n_comb) Score or reduction value for each sample–combination pair.

required
combinations ndarray

(n_comb, n_atom_select) Atom indices used in each combination.

required
atom ndarray

(n_sample,) Atom index to exclude for each sample.

required

Returns:

Name Type Description
best_score_excl ndarray

(n_sample,) Max norm_reduction for each sample excluding combinations that contain atom[i].

Source code in isca_tools/utils/decomposition.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
def best_score_excluding_atom(norm_reduction: np.ndarray,
                              combinations: np.ndarray,
                              atom: np.ndarray) -> np.ndarray:
    """
    For each sample, find the maximum norm_reduction value
    among all combinations that do NOT contain atom[i].

    Args:
        norm_reduction: (n_sample, n_comb)
            Score or reduction value for each sample–combination pair.
        combinations: (n_comb, n_atom_select)
            Atom indices used in each combination.
        atom: (n_sample,)
            Atom index to exclude for each sample.

    Returns:
        best_score_excl: (n_sample,)
            Max norm_reduction for each sample excluding combinations that contain atom[i].
    """
    # (n_sample, n_comb): True if this combo contains that sample’s excluded atom
    contains_atom = np.any(combinations[None, :, :] == atom[:, None, None], axis=2)

    # Mask those out
    masked_scores = np.where(~contains_atom, norm_reduction, -np.inf)

    # Take max over combinations
    best_score_excl = masked_scores.max(axis=1)

    return best_score_excl

pca_on_xarray(data, n_modes=4, standardize=True, valid=None, feature_dim_name='lev', reference_mean=True)

Perform PCA (via SVD) on xarray dataset. The PCA is fit only on samples where valid is True. The components found are then fit to all samples in data.

Parameters:

Name Type Description Default
data

DataArray with dims (..., feature_dim_name) (e.g. (co2, lat, lon, lev) with feature_dim_name=lev).

required
n_modes int

Number of PCA modes to keep.

4
standardize bool

If True, divide each feature by its std (computed from valid samples) before SVD so that features with different variances are equalized. If False, SVD is performed on raw deviations from reference_mean.

True
valid Optional[DataArray]

Boolean mask with the same non-feature dims as data (e.g. (co2, lat, lon)). True indicates the grid cell is used to compute the PCA basis. If None, all grid cells with finite values across lev are considered valid.

None
feature_dim_name str

Name of the dimension containing features of interest in data.

'lev'
reference_mean Union[bool, DataArray]

1-D DataArray (dim feature_dim_name) to subtract before SVD. If False, a zero reference mean is used (i.e. PCA on deviations from zero). If True, a reference mean will be computed from all valid samples.

True

Returns:

Name Type Description
components DataArray

EOFs (modes) with dims (mode, feature_dim_name).

scores DataArray

PC coefficients with same dims as data but mode replacing feature_dim_name.

mean_profile DataArray

The reference_mean actually used (dim feature_dim_name).

std_profile DataArray

Std used for scaling (dim feature_dim_name). Ones if standardize=False.

Notes
  • This function uses np.linalg.svd directly so there is NO automatic re-centering: the reference_mean you supply (or zero) is the baseline from which deviations are computed.
  • If standardize=True, std_profile is computed from the valid set and used both for the SVD input and for projecting all profiles.
Source code in isca_tools/utils/decomposition.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def pca_on_xarray(data: xr.DataArray, n_modes: int = 4, standardize: bool = True,
                  valid: Optional[xr.DataArray] = None, feature_dim_name: str = "lev",
                  reference_mean: Union[bool, xr.DataArray] = True,
                  ) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray, xr.DataArray, np.ndarray, np.ndarray]:
    """
    Perform PCA (via SVD) on xarray dataset. The PCA is fit only on samples where `valid` is True. The
    components found are then fit to all samples in `data`.

    Args:
        data : DataArray with dims (..., feature_dim_name) (e.g. (co2, lat, lon, lev) with `feature_dim_name=lev`).
        n_modes (int): Number of PCA modes to keep.
        standardize (bool): If True, divide each feature by its std (computed from valid samples)
            *before* SVD so that features with different variances are equalized.
            If False, SVD is performed on raw deviations from `reference_mean`.
        valid: Boolean mask with the same non-feature dims as `data` (e.g. (co2, lat, lon)).
            True indicates the grid cell is used to compute the PCA basis. If None, all
            grid cells with finite values across lev are considered valid.
        feature_dim_name: Name of the dimension containing features of interest in `data`.
        reference_mean: 1-D DataArray (dim `feature_dim_name`) to subtract before SVD.
            If False, a zero reference mean is used (i.e. PCA on deviations from zero).
            If True, a reference mean will be computed from all `valid` samples.

    Returns:
        components: EOFs (modes) with dims (mode, feature_dim_name).
        scores: PC coefficients with same dims as `data` but `mode` replacing `feature_dim_name`.
        mean_profile: The reference_mean actually used (dim `feature_dim_name`).
        std_profile: Std used for scaling (dim `feature_dim_name`). Ones if `standardize=False`.

    Notes:
        - This function uses np.linalg.svd directly so there is NO automatic re-centering:
          the `reference_mean` you supply (or zero) is the baseline from which deviations
          are computed.
        - If standardize=True, std_profile is computed from the valid set and used both
          for the SVD input and for projecting all profiles.
    """
    if feature_dim_name not in data.dims:
        raise ValueError(f"X must have a '{feature_dim_name}' dimension")

    non_feature_dims = [d for d in data.dims if d != feature_dim_name]
    n_feature = data.sizes[feature_dim_name]

    # prepare reference mean (1d array length n_feature)
    if reference_mean == False:
        reference_mean_vals = np.zeros(n_feature)
    elif reference_mean == True:
        if valid is None:
            reference_mean = data.mean(dim=non_feature_dims)
        else:
            reference_mean = data.where(valid).mean(dim=non_feature_dims)
        reference_mean_vals = reference_mean.values
    else:
        if feature_dim_name not in reference_mean.dims:
            raise ValueError(f"reference_mean must have dimension named {feature_dim_name}")
        # align and extract numeric array in feature order of data
        reference_mean_vals = reference_mean.reindex({feature_dim_name: data[feature_dim_name]}).values

    X_all = flatten_to_numpy(data, feature_dim_name)
    if valid is None:
        X_valid = X_all
    else:
        if list(valid.dims) != non_feature_dims:
            raise ValueError(f"Valid has dims {list(valid.dims)}\nShould have dims {non_feature_dims}\nOrder important too.")
        X_valid = X_all[flatten_to_numpy(valid)]
    n_valid = X_valid.shape[0]
    if n_valid < (n_modes + 1):
        raise ValueError("Too few valid samples for PCA; reduce n_modes or check coverage.")

    # subtract reference mean
    Xc_valid = X_valid - reference_mean_vals[None, :]

    # compute std_profile from valid subset if requested
    if standardize:
        std_profile_vals = Xc_valid.std(axis=0, ddof=1)
        # avoid zeros
        std_profile_vals = np.where(std_profile_vals == 0, 1.0, std_profile_vals)
        Xc_valid = Xc_valid / std_profile_vals[None, :]
    else:
        std_profile_vals = np.ones(n_feature)

    # --- SVD on the prepared valid data (no further centering) ---
    # Xc_valid shape: (n_valid, n_feature). compute thin SVD
    U, S, Vt = np.linalg.svd(Xc_valid, full_matrices=False)
    # components (EOFs) are rows of Vt; keep first n_modes
    components_vals = Vt[:n_modes, :]  # (n_modes, n_feature)

    # --- project ALL profiles using same transform ---
    # subtract reference mean and divide by std_profile (if standardize)
    Xc_all = X_all - reference_mean_vals[None, :]
    if standardize:
        Xc_all = Xc_all / std_profile_vals[None, :]

    # scores_all: (n_samples, n_modes)
    scores_all = Xc_all @ components_vals.T

    # reshape back to original non-feature dims + mode
    out_shape = [data.sizes[d] for d in non_feature_dims] + [n_modes]
    scores_da = xr.DataArray(
        scores_all.reshape(*out_shape),
        dims=non_feature_dims + ["mode"],
        coords={**{d: data[d] for d in non_feature_dims}, "mode": np.arange(n_modes)}
    )

    components_da = xr.DataArray(
        components_vals,
        dims=("mode", feature_dim_name),
        coords={"mode": np.arange(n_modes), feature_dim_name: data[feature_dim_name]}
    )

    mean_profile_da = xr.DataArray(reference_mean_vals, dims=(feature_dim_name,),
                                   coords={feature_dim_name: data[feature_dim_name]})
    std_profile_da = xr.DataArray(std_profile_vals, dims=(feature_dim_name,),
                                  coords={feature_dim_name: data[feature_dim_name]})

    # Variance explained by each mode
    var_explained = (S ** 2) / (Xc_valid.shape[0] - 1)

    # Fractional variance explained
    frac_var_explained = var_explained / var_explained.sum()

    return components_da, scores_da, mean_profile_da, std_profile_da, var_explained[:n_modes], frac_var_explained[
        :n_modes]

scaled_k_means(x, initial_cluster_mean, valid=None, n_atom_select=1, norm_thresh=0, score_thresh=0.5, score_diff_thresh=0.1, score_diff_thresh_test_converge=0.05, score_thresh_multi_atom=0.05, min_cluster_size=10, n_iter=100, remove_perm=None, atom_ind_no_update=None, use_norm=False)

Perform scaled k-means clustering with optional multi-atom combinations.

This algorithm generalizes k-means by allowing each data point to be represented as a scaled combination of a small subset of cluster "atoms" (mean vectors), optionally including a zero vector (to allow sparse fits). At each iteration, coefficients for all possible atom combinations are computed to minimize residual norm, and cluster means are updated as the dominant direction of assigned samples’ residuals.

Parameters:

Name Type Description Default
x ndarray

Input data of shape (n_sample, n_feature).

required
initial_cluster_mean ndarray

Initial cluster centroids of shape (n_cluster, n_feature).

required
valid Optional[ndarray]

Boolean mask (n_sample,) specifying valid samples for updates.

None
n_atom_select int

Number of atoms combined to represent each sample. Defaults to 1.

1
norm_thresh float

Threshold for treating samples as small-norm (ignored in fitting). Defaults to 0.

0
score_thresh float

Minimum improvement (norm reduction) for a sample to influence cluster update. Defaults to 0.5.

0.5
score_diff_thresh float

Minimum difference in score between best and next-best atom to be considered distinct. Defaults to 0.1.

0.1
score_diff_thresh_test_converge float

Tolerance for convergence test (difference between old and new best scores). Defaults to 0.05.

0.05
score_thresh_multi_atom float

Threshold for assigning multi-atom fits when residual difference is small. Defaults to 0.05.

0.05
min_cluster_size int

Minimum number of samples required to update a cluster. Defaults to 10.

10
n_iter int

Maximum number of iterations. Defaults to 100.

100
remove_perm Optional[ndarray]

List of atom combinations (indices) to exclude. Defaults to None.

None
atom_ind_no_update Optional[ndarray]

Atom indices that should not be updated. Defaults to None.

None
use_norm bool

Whether to normalize each residual before updating atoms. Defaults to False.

False

Returns:

Name Type Description
norm_cluster_mean ndarray

Updated normalized cluster mean vectors (atoms).

cluster_eig_val ndarray

Leading eigenvalues for each cluster.

cluster_ind ndarray

Cluster/combination index assigned to each sample.

top_score ndarray

Norm reduction score of the assigned combination for each sample.

coef_best ndarray

Coefficients of best-fitting atom combination per sample.

atom_perm ndarray

Array of atom index combinations considered.

Notes
  • The algorithm can handle multi-atom fits by enumerating all valid atom combinations.
  • A zero vector is appended as an additional atom to allow sparse representations.
  • Clusters with fewer than min_cluster_size assigned samples are deactivated.
Source code in isca_tools/utils/decomposition.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def scaled_k_means(
    x: np.ndarray,
    initial_cluster_mean: np.ndarray,
    valid: Optional[np.ndarray] = None,
    n_atom_select: int = 1,
    norm_thresh: float = 0,
    score_thresh: float = 0.5,
    score_diff_thresh: float = 0.1,
    score_diff_thresh_test_converge: float = 0.05,
    score_thresh_multi_atom: float = 0.05,
    min_cluster_size: int = 10,
    n_iter: int = 100,
    remove_perm: Optional[np.ndarray] = None,
    atom_ind_no_update: Optional[np.ndarray] = None,
    use_norm: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Perform scaled k-means clustering with optional multi-atom combinations.

    This algorithm generalizes k-means by allowing each data point to be represented
    as a *scaled combination* of a small subset of cluster "atoms" (mean vectors),
    optionally including a zero vector (to allow sparse fits). At each iteration,
    coefficients for all possible atom combinations are computed to minimize
    residual norm, and cluster means are updated as the dominant direction
    of assigned samples’ residuals.

    Args:
        x:
            Input data of shape (n_sample, n_feature).
        initial_cluster_mean:
            Initial cluster centroids of shape (n_cluster, n_feature).
        valid:
            Boolean mask (n_sample,) specifying valid samples for updates.
        n_atom_select:
            Number of atoms combined to represent each sample. Defaults to 1.
        norm_thresh:
            Threshold for treating samples as small-norm (ignored in fitting). Defaults to 0.
        score_thresh:
            Minimum improvement (norm reduction) for a sample to influence cluster update. Defaults to 0.5.
        score_diff_thresh:
            Minimum difference in score between best and next-best atom to be considered distinct. Defaults to 0.1.
        score_diff_thresh_test_converge:
            Tolerance for convergence test (difference between old and new best scores). Defaults to 0.05.
        score_thresh_multi_atom:
            Threshold for assigning multi-atom fits when residual difference is small. Defaults to 0.05.
        min_cluster_size:
            Minimum number of samples required to update a cluster. Defaults to 10.
        n_iter:
            Maximum number of iterations. Defaults to 100.
        remove_perm:
            List of atom combinations (indices) to exclude. Defaults to None.
        atom_ind_no_update:
            Atom indices that should not be updated. Defaults to None.
        use_norm:
            Whether to normalize each residual before updating atoms. Defaults to False.

    Returns:
        norm_cluster_mean: Updated normalized cluster mean vectors (atoms).
        cluster_eig_val: Leading eigenvalues for each cluster.
        cluster_ind: Cluster/combination index assigned to each sample.
        top_score: Norm reduction score of the assigned combination for each sample.
        coef_best: Coefficients of best-fitting atom combination per sample.
        atom_perm: Array of atom index combinations considered.

    Notes:
        - The algorithm can handle multi-atom fits by enumerating all valid atom combinations.
        - A zero vector is appended as an additional atom to allow sparse representations.
        - Clusters with fewer than `min_cluster_size` assigned samples are deactivated.
    """

    n_sample, n_feature = x.shape

    # Normalize initial cluster means (atoms)
    norm_cluster_mean = initial_cluster_mean / np.linalg.norm(initial_cluster_mean, axis=1).reshape(-1, 1)

    # Append a zero vector atom to allow sparse/no-fit representations
    norm_cluster_mean = np.vstack([norm_cluster_mean, np.zeros(n_feature)])
    n_atom = norm_cluster_mean.shape[0]

    # Initialize containers
    cluster_eig_val = np.zeros(n_atom)
    cluster_ind = np.full(x.shape[0], -20, dtype=int)
    x_norm = np.linalg.norm(x, axis=1)

    # Identify samples with very small norms — skip coefficient computation for them
    small_norm = x_norm <= norm_thresh

    # Generate all possible atom index combinations (permutations of n_atom_select atoms)
    atom_perm = np.array(list(itertools.combinations(range(n_atom), n_atom_select)))
    atom_perm = np.sort(atom_perm, axis=1)  # Sort to ensure zero atom appears last consistently

    # Optionally remove forbidden combinations
    if remove_perm is not None:
        remove_perm = np.sort(remove_perm, axis=1)
        mask = ~np.isin(
            atom_perm.view([('', atom_perm.dtype)] * atom_perm.shape[1]),
            remove_perm.view([('', atom_perm.dtype)] * remove_perm.shape[1])
        ).squeeze()
        if (~mask).sum() > 0:
            print(f"Removing the following atom permutations:\n{atom_perm[~mask]}")
        atom_perm = atom_perm[mask]

    if atom_ind_no_update is None:
        atom_ind_no_update = np.zeros(0, dtype=int)

    n_perm = len(atom_perm)

    # Identify all permutations that include the zero atom
    perm_zero_ind = np.where([n_atom - 1 in atom_perm[i] for i in range(n_perm)])[0].squeeze()

    # Track permutations to ignore (e.g., if corresponding atoms become inactive)
    ignore_perm = np.zeros(n_perm, dtype=bool)

    for i in range(np.clip(n_iter, 1, 1000)):
        coef = np.zeros((n_sample, n_perm, n_atom_select))  # coefficients for each permutation

        # --- Step 1: Compute coefficients for all permutations ---
        for j in range(n_perm):
            if ignore_perm[j]:
                continue

            if j in perm_zero_ind:
                if n_atom_select > 1:
                    # Compute coefficients for non-zero atoms in combinations including zero
                    A = norm_cluster_mean[atom_perm[j][:-1]]
                    AAT_inv = np.linalg.inv(A @ A.T)
                    coef[~small_norm, j, :-1] = (AAT_inv @ A @ x[~small_norm].T).T
            else:
                # Compute coefficients for full atom combinations
                A = norm_cluster_mean[atom_perm[j]]
                AAT_inv = np.linalg.inv(A @ A.T)
                coef[~small_norm, j] = (AAT_inv @ A @ x[~small_norm].T).T

        cluster_ind_old = cluster_ind.copy()

        # --- Step 2: Compute residuals and assign each sample to the best combination ---
        x_residual = x[:, None] - (coef[..., None] * norm_cluster_mean[atom_perm][None]).sum(axis=-2)
        x_residual_norm = np.linalg.norm(x_residual, axis=-1)

        # Compute fractional norm reduction
        norm_reduction = (x_norm[:, None] - x_residual_norm) / (x_norm[:, None] + 1e-20)

        # Choose combination with smallest residual
        cluster_ind = x_residual_norm.argmin(axis=1)

        # If multi-atom case, prefer those with near-zero residuals that include zero atom
        if n_atom_select > 1:
            good_with_zero = x_residual_norm[:, perm_zero_ind].min(axis=1) <= norm_thresh
            good_with_zero |= (
                norm_reduction.max(axis=1) - norm_reduction[:, perm_zero_ind].max(axis=1) < score_thresh_multi_atom
            )
            cluster_ind[good_with_zero] = perm_zero_ind[
                x_residual_norm[good_with_zero][:, perm_zero_ind].argmin(axis=1)
            ]

        # Assign -1 for samples below norm threshold
        cluster_ind[small_norm] = -1

        # Top score per sample (how much norm was reduced)
        top_score = norm_reduction[np.arange(n_sample), cluster_ind]
        top_score[x_norm <= norm_thresh] = 0

        if n_iter == 0:
            print('n_iter=0 so not updating atoms')
            break

        # --- Step 3: Identify strong assignments to guide cluster updates ---
        score_exclude_atom = [
            best_score_excluding_atom(norm_reduction, atom_perm, atom_perm[cluster_ind][:, k])
            for k in range(n_atom_select)
        ]
        high_score = [
            (top_score > score_thresh) & (top_score - score_exclude_atom[k] > score_diff_thresh)
            for k in range(n_atom_select)
        ]

        # Convergence test: low score difference means cluster assignment has stabilized
        low_score = [
            top_score - score_exclude_atom[k] < score_diff_thresh_test_converge
            for k in range(n_atom_select)
        ]
        low_score = np.any(low_score, axis=0)

        if valid is not None:
            # Restrict updates to valid samples
            high_score = [high_score[k] & valid for k in range(n_atom_select)]
            low_score = low_score | ~valid

        # --- Step 4: Update cluster means (atoms) ---
        for c in range(n_atom - 1):  # skip zero atom
            if c in atom_ind_no_update:
                continue

            my_points = np.zeros((0, n_feature))

            # Collect residuals corresponding to this atom
            for k in range(n_atom_select):
                samples_use = (cluster_ind >= 0) & (atom_perm[cluster_ind, k] == c) & high_score[k]
                if samples_use.sum() > 0:
                    x_use_fit = coef[samples_use, cluster_ind[samples_use], :, None] * norm_cluster_mean[
                        atom_perm[cluster_ind[samples_use]]
                    ]
                    x_use_fit = np.delete(x_use_fit, k, axis=1)
                    x_use_fit = x_use_fit.sum(axis=1)
                    my_points = np.append(my_points, x[samples_use] - x_use_fit, axis=0)

            n_my_points = my_points.shape[0]

            # Deactivate cluster if too few points assigned
            if n_my_points < min_cluster_size:
                norm_cluster_mean[c] = 0
                ignore_perm[np.where([c in atom_perm[k] for k in range(n_perm)])[0].squeeze()] = True
                continue

            if use_norm:
                # Normalize residuals to equalize influence
                my_points = my_points / (np.linalg.norm(my_points, axis=1)[:, None] + 1e-20)

            # Update atom as the leading eigenvector of covariance matrix
            eig_vals, eigs = np.linalg.eig(my_points.T @ my_points / n_my_points)
            best_eig_ind = np.argmax(eig_vals)
            norm_cluster_mean[c] = eigs[:, best_eig_ind] * np.sign(eigs[:, best_eig_ind].mean())
            cluster_eig_val[c] = eig_vals[best_eig_ind]

        # Print number of reassignments to monitor convergence
        print(i + 1, (cluster_ind[~low_score] != cluster_ind_old[~low_score]).sum())

        # Stop if cluster assignments have stabilized
        if (cluster_ind[~low_score] == cluster_ind_old[~low_score]).all():
            print(f"Done after {i + 1} iter")
            break

    # Return best-fit coefficients for each sample
    coef_best = coef[np.arange(x.shape[0]), cluster_ind]

    return norm_cluster_mean, cluster_eig_val, cluster_ind, top_score, coef_best, atom_perm