Skip to content

Xarray

convert_ds_dtypes(ds, verbose=False)

Convert all float variables to float32 and all int variables to int32 in an xarray Dataset.

Parameters:

Name Type Description Default
ds Dataset

Input xarray Dataset.

required
verbose bool

Whether to print out variables converted

False

Returns:

Name Type Description
ds_out Dataset

Dataset with all float variables converted to float32 and all int variables to int32.

Source code in isca_tools/utils/xarray.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def convert_ds_dtypes(ds: xr.Dataset, verbose: bool = False) -> xr.Dataset:
    """
    Convert all float variables to float32 and all int variables to int32 in an xarray Dataset.

    Args:
        ds: Input xarray Dataset.
        verbose: Whether to print out variables converted

    Returns:
        ds_out: Dataset with all float variables converted to float32 and all int variables to int32.
    """
    converted = {}
    float_conv = []
    int_conv = []
    for var_name, da in ds.data_vars.items():
        if np.issubdtype(da.dtype, np.floating) and da.dtype != np.float32:
            converted[var_name] = da.astype(np.float32)
            float_conv.append(var_name)
        elif np.issubdtype(da.dtype, np.integer) and da.dtype != np.int32:
            converted[var_name] = da.astype(np.int32)
            int_conv.append(var_name)
        else:
            converted[var_name] = da
    if verbose:
        if len(float_conv) > 0:
            print(f"Converted the following float variables:\n{float_conv}")
        if len(int_conv) > 0:
            print(f"Converted the following integer variables:\n{int_conv}")
    return ds.assign(**converted)

flatten_to_numpy(var, keep_dim=None)

Flattens var to a numpy array with at most 2 dimensions.

Examples:

If var has dims=(lat, lon, lev) and keep_dim=lev, it will return a numpy array of size [n_lat*n_lon, n_lev].

If var has dims=(lat, lon) and keep_dim=None, it will return a numpy array of size [n_lat*n_lon].

Parameters:

Name Type Description Default
var DataArray

Variable to flatten.

required
keep_dim Optional[str]

Dimension along which not to flatten.

None

Returns:

Name Type Description
var_flatten ndarray

Numpy array with flattened dimension first, and keep_dim dimension second.

Source code in isca_tools/utils/xarray.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def flatten_to_numpy(var: xr.DataArray, keep_dim: Optional[str] = None) -> np.ndarray:
    """
    Flattens `var` to a numpy array with at most 2 dimensions.

    Examples:
        If `var` has `dims=(lat, lon, lev)` and `keep_dim=lev`, it will return a numpy array of
            size `[n_lat*n_lon, n_lev]`.

        If `var` has `dims=(lat, lon)` and `keep_dim=None`, it will return a numpy array of
            size `[n_lat*n_lon]`.

    Args:
        var: Variable to flatten.
        keep_dim: Dimension along which not to flatten.

    Returns:
        var_flatten: Numpy array with flattened dimension first, and `keep_dim` dimension second.
    """
    if (keep_dim is not None) and (keep_dim not in var.dims):
        raise ValueError(f"var must have a '{keep_dim}' dimension")

    # dims except vertical
    flatten_dims = [d for d in var.dims if d != keep_dim]

    # stack all flatten_dims into a single "points" axis
    stacked = var.stack(points=flatten_dims)  # dims (..., lev_name) -> (points, lev_name) after transpose
    if keep_dim is not None:
        stacked = stacked.transpose("points", keep_dim)  # ensure order is (points, lev)
    return stacked.values

print_ds_var_list(ds, phrase=None)

Prints all variables in ds which contain phrase in the variable name or variable long_name.

Parameters:

Name Type Description Default
ds Dataset

Dataset to investigate variables of.

required
phrase Optional[str]

Key phrase to search for in variable info.

None
Source code in isca_tools/utils/xarray.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def print_ds_var_list(ds: xr.Dataset, phrase: Optional[str] = None) -> None:
    """
    Prints all variables in `ds` which contain `phrase` in the variable name or variable `long_name`.

    Args:
        ds: Dataset to investigate variables of.
        phrase: Key phrase to search for in variable info.

    """
    # All the exceptions to deal with case when var does not have a long_name
    var_list = list(ds.keys())
    if phrase is None:
        for var in var_list:
            try:
                print(f'{var}: {ds[var].long_name}')
            except AttributeError:
                print(f'{var}')
    else:
        for var in var_list:
            if phrase.lower() in var.lower():
                try:
                    print(f'{var}: {ds[var].long_name}')
                except AttributeError:
                    print(f'{var}')
                continue
            try:
                if phrase.lower() in ds[var].long_name.lower():
                    print(f'{var}: {ds[var].long_name}')
                    continue
            except AttributeError:
                continue
    return None

set_attrs(var, overwrite=True, **kwargs)

Set attributes of a given variable.

Examples:

set_attrs(ds.plev, long_name='pressure', units='Pa')

Parameters:

Name Type Description Default
var DataArray

Variable to set attributes of.

required
overwrite bool

If True, overwrite existing attributes, otherwise leave unchanged.

True
**kwargs str

Attributes to set. Common ones include long_name and units

{}

Returns:

Type Description
DataArray

var with attributes set.

Source code in isca_tools/utils/xarray.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def set_attrs(var: xr.DataArray, overwrite: bool = True, **kwargs: str) -> xr.DataArray:
    """
    Set attributes of a given variable.

    Examples:
        `set_attrs(ds.plev, long_name='pressure', units='Pa')`

    Args:
        var: Variable to set attributes of.
        overwrite: If `True`, overwrite existing attributes, otherwise leave unchanged.
        **kwargs: Attributes to set. Common ones include `long_name` and `units`

    Returns:
        `var` with attributes set.
    """
    # Function to set main attributes of given variable
    for key in kwargs:
        if (key in var.attrs) and not overwrite:
            continue
        var.attrs[key] = kwargs[key]
    return var

unflatten_from_numpy(arr, var, keep_dim=None)

Reconstructs an xarray.DataArray from a flattened NumPy array created by flatten_to_numpy.

Examples:

If var had dims=(lat, lon, lev) and keep_dim='lev', and arr has shape (n_lat*n_lon, n_lev), this will return a DataArray with dims (lat, lon, lev).

If var had dims=(lat, lon)andkeep_dim=None, andarr` has shape (n_lat*n_lon), this will return a DataArray with dims (lat, lon).

Parameters:

Name Type Description Default
arr ndarray

Flattened NumPy array from flatten_to_numpy.

required
var DataArray

The original DataArray used to determine dimension order, shape, and coordinates.

required
keep_dim Optional[str]

Dimension that was kept unflattened in flatten_to_numpy.

None

Returns:

Type Description
DataArray

xr.DataArray: DataArray with the original dimensions and coordinates restored.

Source code in isca_tools/utils/xarray.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def unflatten_from_numpy(arr: np.ndarray, var: xr.DataArray, keep_dim: Optional[str] = None) -> xr.DataArray:
    """
    Reconstructs an xarray.DataArray from a flattened NumPy array created by `flatten_to_numpy`.

    Examples:
        If `var` had dims=(lat, lon, lev) and `keep_dim='lev'`, and `arr` has shape (n_lat*n_lon, n_lev),
        this will return a DataArray with dims (lat, lon, lev).

        If `var` had dims=(lat, lon)` and `keep_dim=None`, and `arr` has shape (n_lat*n_lon),
        this will return a DataArray with dims (lat, lon).

    Args:
        arr: Flattened NumPy array from `flatten_to_numpy`.
        var: The original DataArray used to determine dimension order, shape, and coordinates.
        keep_dim: Dimension that was kept unflattened in `flatten_to_numpy`.

    Returns:
        xr.DataArray: DataArray with the original dimensions and coordinates restored.
    """
    # Validate keep_dim
    if (keep_dim is not None) and (keep_dim not in var.dims):
        raise ValueError(f"var must have a '{keep_dim}' dimension")

    # Identify flattened dims
    flatten_dims = [d for d in var.dims if d != keep_dim]

    # Compute target shape
    target_shape = [var.sizes[d] for d in flatten_dims]
    if keep_dim is not None:
        target_shape.append(var.sizes[keep_dim])

    # Reshape numpy array
    reshaped = arr.reshape(target_shape)

    # Reconstruct DataArray with original dimension order
    if keep_dim is not None:
        dims = flatten_dims + [keep_dim]
    else:
        dims = flatten_dims

    # Unstack the flattened dims
    da_flat = xr.DataArray(reshaped, dims=dims, coords={d: var.coords[d] for d in dims if d in var.coords},
                           attrs=var.attrs)

    # Reverse the stacking by unstacking the combined "points" dimension
    if keep_dim is not None:
        da_flat = da_flat.transpose(*var.dims)
    else:
        da_flat = da_flat.transpose(*var.dims)

    return da_flat