Skip to content

Definitions

Setup and basic functions

This file contains commonly used definitions of paths and compute options, gotten from the file $HOME/.jetutils.ini if it exists, otherwise guessed.

It also contains all the constants to do physics, the common timeranges, the full names of jet variables as well as their units, default values and LaTeX symbols.

Finally, it contains a few functions that are useful all over.

Timer dataclass

This is stolen from a gist somewhere I don't remember. Nice context manager timer.

Raises:

Type Description
TimerError

Examples:

>>> with Timer():
...    do_something_long()
"elapsed time: 5.3s"
```
Source code in jetutils/definitions.py
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
@dataclass
class Timer:
    """
    This is stolen from a gist somewhere I don't remember. Nice context manager timer.

    Raises
    ------
    TimerError

    Examples
    --------
    >>> with Timer():
    ...    do_something_long()
    "elapsed time: 5.3s"
    ```
    """
    timers: ClassVar[Dict[str, float]] = {}
    name: Optional[str] = None
    text: str = "Elapsed time: {:0.4f} seconds"
    logger: Optional[Callable[[str], None]] = print
    _start_time: Optional[float] = field(default=None, init=False, repr=False)

    def __post_init__(self) -> None:
        """Add timer to dict of timers after initialization"""
        if self.name is not None:
            self.timers.setdefault(self.name, 0)

    def start(self) -> None:
        """Start a new timer"""
        if self._start_time is not None:
            raise TimerError("Timer is running. Use .stop() to stop it")

        self._start_time = time.perf_counter()

    def stop(self) -> float:
        """Stop the timer, and report the elapsed time"""
        if self._start_time is None:
            raise TimerError("Timer is not running. Use .start() to start it")

        # Calculate elapsed time
        elapsed_time = time.perf_counter() - self._start_time
        self._start_time = None

        # Report elapsed time
        if self.logger:
            self.logger(self.text.format(elapsed_time))
        if self.name:
            self.timers[self.name] += elapsed_time

        return elapsed_time

    def __enter__(self):
        """Start a new timer as a context manager"""
        self.start()
        return self

    def __exit__(self, *exc_info):
        """Stop the context manager timer"""
        self.stop()

__enter__()

Start a new timer as a context manager

Source code in jetutils/definitions.py
1397
1398
1399
1400
def __enter__(self):
    """Start a new timer as a context manager"""
    self.start()
    return self

__exit__(*exc_info)

Stop the context manager timer

Source code in jetutils/definitions.py
1402
1403
1404
def __exit__(self, *exc_info):
    """Stop the context manager timer"""
    self.stop()

__post_init__()

Add timer to dict of timers after initialization

Source code in jetutils/definitions.py
1368
1369
1370
1371
def __post_init__(self) -> None:
    """Add timer to dict of timers after initialization"""
    if self.name is not None:
        self.timers.setdefault(self.name, 0)

start()

Start a new timer

Source code in jetutils/definitions.py
1373
1374
1375
1376
1377
1378
def start(self) -> None:
    """Start a new timer"""
    if self._start_time is not None:
        raise TimerError("Timer is running. Use .stop() to stop it")

    self._start_time = time.perf_counter()

stop()

Stop the timer, and report the elapsed time

Source code in jetutils/definitions.py
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
def stop(self) -> float:
    """Stop the timer, and report the elapsed time"""
    if self._start_time is None:
        raise TimerError("Timer is not running. Use .start() to start it")

    # Calculate elapsed time
    elapsed_time = time.perf_counter() - self._start_time
    self._start_time = None

    # Report elapsed time
    if self.logger:
        self.logger(self.text.format(elapsed_time))
    if self.name:
        self.timers[self.name] += elapsed_time

    return elapsed_time

TimerError

Bases: Exception

A custom exception used to report errors in use of Timer class

Source code in jetutils/definitions.py
1342
1343
class TimerError(Exception): 
    """A custom exception used to report errors in use of Timer class"""

case_insensitive_equal(str1, str2)

Returns whether two strings are equal if all letters are lowercased.

Examples:

>>> case_insensitive_equal("AbC", "aBc")
True
Source code in jetutils/definitions.py
711
712
713
714
715
716
717
718
719
720
def case_insensitive_equal(str1: str, str2: str) -> bool:
    """
    Returns whether two strings are equal if all letters are lowercased.

    Examples
    --------
    >>> case_insensitive_equal("AbC", "aBc")
    True
    """
    return str1.casefold() == str2.casefold()

compute(obj, progress_flag=False, **kwargs)

Computes a Dask object. If a dask client named client exists in the globals, uses it.

Parameters:

Name Type Description Default
obj Any

Dask object to force compute

required
progress_flag bool

Whether to show a progress bar, by default False

False
kwargs

Keyword arguments passed to obj.compute() if no client exists

{}

Returns:

Name Type Description
obj Any

Computed object.

Source code in jetutils/definitions.py
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
def compute(obj, progress_flag: bool = False, **kwargs):
    """
    Computes a Dask object. If a dask client named `client` exists in the globals, uses it.

    Parameters
    ----------
    obj : Any
        Dask object to force compute
    progress_flag : bool, optional
        Whether to show a progress bar, by default False
    kwargs
        Keyword arguments passed to `obj.compute()` if no client exists

    Returns
    -------
    obj : Any
        Computed object. 
    """
    kwargs = COMPUTE_KWARGS | kwargs
    try:
        client  # in globals # type: ignore # noqa: F821
    except NameError:
        try:
            if progress_flag:
                with ProgressBar():
                    return obj.compute(**kwargs)
            else:
                return obj.compute(**kwargs)
        except AttributeError:
            return obj
    try:
        if progress_flag:
            obj = client.gather(client.persist(obj))  # type: ignore # noqa: F821
            progress(obj, notebook=False)
            return obj
        else:
            return client.compute(obj)  # type: ignore # noqa: F821
    except AttributeError:
        return obj

degcos(x)

Cosine of an angle expressed in degrees

Parameters:

Name Type Description Default
x float

Angle in degrees

required

Returns:

Type Description
float

Cosine result

Source code in jetutils/definitions.py
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
def degcos(x: float | np.ndarray) -> float:
    """
    Cosine of an angle expressed in degrees

    Parameters
    ----------
    x : float
        Angle in degrees

    Returns
    -------
    float
        Cosine result
    """
    return np.cos(np.radians(x))

degsin(x)

Sine of an angle expressed in degrees

Parameters:

Name Type Description Default
x float

Angle in degrees

required

Returns:

Type Description
float

Sine results

Source code in jetutils/definitions.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
def degsin(x: float) -> float:
    """
    Sine of an angle expressed in degrees

    Parameters
    ----------
    x : float
        Angle in degrees

    Returns
    -------
    float
        Sine results
    """
    return np.sin(np.radians(x))

do_rle_fill_hole(df, condition_expr, group_by=None, hole_size=4, unwrap=False)

Wraps around polars' pl.Expr.rle() to find runs of identical values, potentially interrupted by a different value, as long as this interruption is shorter than hole_size.

It can do it for the whose DataFrame or in groups specified by group_by.

Parameters:

Name Type Description Default
df DataFrame

Input DataFrame

required
condition_expr Expr

Expression that evaluates to True or False from one or several columns of df

required
group_by Sequence[str] | Sequence[Expr] | str | Expr

Columns to group by, by default None

None
hole_size int | timedelta

Maximum authorized size of holes than can be in a run without interrupting it, by default 4

4
unwrap bool

If False, returns the whole data as a modified run length encoded DataFrame. If True, returns the True runs exploded. By default False

False

Returns:

Type Description
DataFrame

Modified-run-length-encoded input, or exploded True runs.

Raises:

Type Description
ValueError

If hole_size is specified as a datetime.timedelta but there is no "time", or "time" is in group_by.

Source code in jetutils/definitions.py
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
def do_rle_fill_hole(
    df: pl.DataFrame,
    condition_expr: pl.Expr,
    group_by: Sequence[str] | Sequence[pl.Expr] | str | pl.Expr | None = None,
    hole_size: int | datetime.timedelta = 4,
    unwrap: bool = False,
) -> pl.DataFrame:
    """
    Wraps around polars' `pl.Expr.rle()` to find runs of identical values, potentially interrupted by a different value, as long as this interruption is shorter than `hole_size`. 

    It can do it for the whose DataFrame or in groups specified by `group_by`.

    Parameters
    ----------
    df : pl.DataFrame
        Input DataFrame
    condition_expr : pl.Expr
        Expression that evaluates to True or False from one or several columns of `df`
    group_by : Sequence[str] | Sequence[pl.Expr] | str | pl.Expr, optional
        Columns to group by, by default None
    hole_size : int | datetime.timedelta, optional
        Maximum authorized size of holes than can be in a run without interrupting it, by default 4
    unwrap : bool, optional
        If False, returns the whole data as a modified run length encoded DataFrame. If True, returns the True runs exploded. By default False

    Returns
    -------
    pl.DataFrame
        Modified-run-length-encoded input, or exploded True runs. 

    Raises
    ------
    ValueError
        If `hole_size` is specified as a `datetime.timedelta` but there is no `"time"`, or `"time"` is in `group_by`.
    """
    if isinstance(group_by, str | pl.Expr):
        group_by = [group_by]
    to_drop: list[str] = []
    if not group_by:
        group_by = []
        group_by.extend(get_index_columns(df, ["member", "cluster"]))

    if not group_by:
        df = df.with_columns(dummy=1)
        group_by.append("dummy")
        to_drop.append("dummy")
    df = df.with_columns(
        index=pl.int_range(0, pl.col(group_by[0]).len()).cast(pl.UInt32).over(group_by)
    )
    df1: pl.DataFrame = df.with_columns(condition=condition_expr.not_().over(group_by))
    df = df.with_columns(condition=condition_expr.over(group_by))
    if isinstance(hole_size, datetime.timedelta):
        if "time" not in df.columns or (group_by is not None and "time" in group_by):
            raise ValueError
        times = df["time"].unique().bottom_k(2).sort()
        dt = times[1] - times[0]
        hole_size = int(hole_size / dt)
        no_time_jump_expr = (pl.col("time").diff() <= dt).fill_null(True)
        df = df.with_columns(condition=pl.col("condition") & no_time_jump_expr.over(group_by))
        df1 = df1.with_columns(condition=pl.col("condition") & no_time_jump_expr.over(group_by))

    holes_to_fill = do_rle(df1, group_by=group_by)
    holes_to_fill = holes_to_fill.filter(
        pl.col("len") <= hole_size, pl.col("value"), pl.col("start") > 0
    )
    holes_to_fill = (
        explode_rle(holes_to_fill)
        .with_columns(condition=pl.lit(True))
        .drop("len", "start", "value")
    )
    df = df.join(holes_to_fill, on=[*group_by, "index"], how="left")
    df = df.with_columns(
        condition=pl.when(pl.col("condition_right").is_not_null())
        .then(pl.col("condition_right"))
        .otherwise(pl.col("condition"))
    ).drop("condition_right", "index")
    df = do_rle(df, group_by=group_by)

    if not unwrap:
        return df.drop(*to_drop)

    df = df.filter("value")
    to_drop.extend(["len", "start", "value"])
    df = explode_rle(df)
    return df.drop(to_drop)

extract_season_from_df(df, season=None)

Subsets a DataFrame containing a "time" column to a given season.

Source code in jetutils/definitions.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
def extract_season_from_df(
    df: pl.DataFrame,
    season: list | str | tuple | int | None = None,
) -> pl.DataFrame:
    """
    Subsets a DataFrame containing a `"time"` column to a given season.
    """
    if season is None:
        return df
    if isinstance(season, str):
        season = SEASONS[season]
    if isinstance(season, int):
        season = [season]
    return df.filter(pl.col("time").dt.month().is_in(season))

first_elements(arr, n_elements, sort=False)

Get the smallest n_elements of arr, along the last axis.

Parameters:

Name Type Description Default
arr ndarray

Any array

required
n_elements int

Number of elements to return along each axis

required
sort bool

Sort the output, only valid for 1D arr, by default False

False

Returns:

Type Description
ndarray

Raises:

Type Description
RuntimeWarning

If sort=True and arr.ndim > 1 because it's ambiguous.

Source code in jetutils/definitions.py
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
def first_elements(arr: np.ndarray, n_elements: int, sort: bool = False) -> np.ndarray:
    """
    Get the smallest `n_elements` of `arr`, along the last axis.

    Parameters
    ----------
    arr : np.ndarray
        Any array
    n_elements : int
        Number of elements to return along each axis
    sort : bool, optional
        Sort the output, only valid for 1D `arr`, by default False

    Returns
    -------
    np.ndarray

    Raises
    ------
    RuntimeWarning
        If `sort=True` and `arr.ndim > 1` because it's ambiguous.
    """
    ndim = arr.ndim
    if ndim > 1 and sort:
        print("sorting output not supported for arrays with ndim > 1")
        sort = False
        raise RuntimeWarning
    idxs = np.argpartition(arr.ravel(), n_elements)[:n_elements]
    if ndim > 1:
        return np.unravel_index(idxs, arr.shape)
    if sort:
        return idxs[np.argsort(arr[idxs])]
    return idxs

get_index_columns(df, potentials=None)

Finds columns in df that represent an index imformation more than a data information in the context of this package.

Parameters:

Name Type Description Default
df DataFrame

Any DataFrame

required
potentials tuple

Potential names of column indices, by default ( "member", "time", "cluster", "jet ID", "spell", "relative_index", "relative_time", "sample_index", "inside_index", )

None

Returns:

Type Description
list

list of columns in potentials that are columns in df.

Source code in jetutils/definitions.py
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def get_index_columns(
    df,
    potentials: list[str] | tuple[str] | None = None,
) -> list[str]:
    """
    Finds columns in `df` that represent an index imformation more than a data information in the context of this package.

    Parameters
    ----------
    df : pl.DataFrame
        Any DataFrame
    potentials : tuple, optional
        Potential names of column indices, by default ( "member", "time", "cluster", "jet ID", "spell", "relative_index", "relative_time", "sample_index", "inside_index", )

    Returns
    -------
    list
        list of columns in `potentials` that are columns in `df`.
    """
    if potentials is None:
        potentials: tuple[str] = default_index_columns
    index_columns: list[str] = [ic for ic in potentials if ic in df.collect_schema().names()]
    return index_columns

get_region(da)

Extracts the lon-lat region spanned by an xarray object containing the "lon" and "lat" dimensions.

Parameters:

Name Type Description Default
da DataArray | Dataset

Xarray object

required

Returns:

Name Type Description
minlon float

minimum longitude

maxlon float

maximum longitude

minlat float

minimum latitude

maxlat float

maximum latitude

Source code in jetutils/definitions.py
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def get_region(da: xr.DataArray | xr.Dataset) -> tuple:
    """
    Extracts the lon-lat region spanned by an xarray object containing the `"lon"` and `"lat"` dimensions.

    Parameters
    ----------
    da : xr.DataArray | xr.Dataset
        Xarray object

    Returns
    -------
    minlon: float
        minimum longitude

    maxlon: float
        maximum longitude

    minlat: float
        minimum latitude

    maxlat: float
        maximum latitude
    """
    lon = np.sort(da.lon.values)
    minlon: float = lon[0].item()
    maxlon: float = lon[-1].item()
    dlon = np.abs(lon[1] - lon[0]).item()
    if np.all(np.diff(lon) < 2 * dlon):
        return (
            minlon,
            maxlon,
            da.lat.min().item(),
            da.lat.max().item(),
        )
    lon = np.roll(lon, - np.diff(lon).argmax() - 1)
    return (
        lon[0].item(),
        lon[-1].item(),
        da.lat.min().item(),
        da.lat.max().item(),
    )

get_runs(mask, cyclic=True)

Obsolete basic implementaion of the Run Length Encoding algorithm using itertools.groupby.

With the cyclic argument on, runs are allowed to wrap around the end of the list to its start. For instance, list [True, True, False, ..., False, True, True] will have a True run going from indices -2 to 1 included if cyclic=True.

Source code in jetutils/definitions.py
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
def get_runs(mask, cyclic: bool = True): 
    """
    Obsolete basic implementaion of the Run Length Encoding algorithm using `itertools.groupby`. 

    With the `cyclic` argument on, runs are allowed to wrap around the end of the list to its start. For instance, list `[True, True, False, ..., False, True, True]` will have a `True` run going from indices `-2` to `1` included if `cyclic=True`.
    """
    start = 0
    runs = []
    if cyclic:
        for key, run in groupby(np.tile(mask, 2)):
            if start >= len(mask):
                break
            length = sum(1 for _ in run)
            runs.append((key, start, start + length - 1))
            start += length
        return runs
    for key, run in groupby(mask):
        length = sum(1 for _ in run)
        runs.append((key, start, start + length - 1))
        start += length
    return runs

get_runs_fill_holes(mask, cyclic=True, hole_size=8)

Obsolete algorithm to get potentially interrupted runs of True values. The runs can be uninterrupted like the basic algorithm, or interrupted by False values if the run of False values within the run of True values is shorter than hole_size.

The algorithm first performs RLE using get_runs, then fills the short False runs with True and applies get_runs a second time on the modified input.

Source code in jetutils/definitions.py
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
def get_runs_fill_holes(mask, cyclic: bool = True, hole_size: int = 8):
    """
    Obsolete algorithm to get potentially interrupted runs of `True` values. The runs can be uninterrupted like the basic algorithm, or interrupted by `False` values if the run of `False` values within the run of `True` values is shorter than `hole_size`.

    The algorithm first performs RLE using `get_runs`, then fills the short `False` runs with `True` and applies `get_runs` a second time on the modified input.
    """
    runs = get_runs(mask, cyclic=cyclic)
    for run in runs:
        key, start, end = run
        leng = end - start + 1
        if key or leng > hole_size:  # I want negative short spans
            continue
        if start == 0 and (not mask[-1] or not cyclic):
            continue
        if end == len(mask) - 1 and (not mask[0] or not cyclic):
            continue
        end_ = min(len(mask), end + 1)
        mask[start:end_] = ~mask[start:end_]
    runs = get_runs(mask, cyclic=cyclic)
    indices = []
    for run in runs:
        key, start, end = run
        leng = end - start + 1
        if leng > 10 and key:
            indices.append(np.arange(start, end + 1) % len(mask))
    if len(indices) == 0:
        _, start, end = max(runs, key=lambda x: (x[2] - x[1]) * int(x[0]))
        indices.append(np.arange(start, end + 1) % len(mask))
    return indices

infer_direction(to_plot)

Infers the direction of an arbitrary array.

Parameters:

Name Type Description Default
to_plot Any

Array or list of arrays

required

Returns:

Type Description
int

-1 if the data is mostly negative, +1 if it is mostly positive and 0 if the data is symmetric

Source code in jetutils/definitions.py
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
def infer_direction(to_plot: Any) -> int:
    """
    Infers the direction of an arbitrary array.

    Parameters
    ----------
    to_plot : Any
        Array or list of arrays

    Returns
    -------
    int
        -1 if the data is mostly negative, +1 if it is mostly positive and 0 if the data is symmetric
    """
    max_: float = max([np.nanmax(tplt) for tplt in to_plot])
    min_: float = min([np.nanmin(tplt) for tplt in to_plot])
    try:
        max_ = max_.item()
        min_ = min_.item()
    except AttributeError:
        pass
    sym: bool = np.sign(max_) == -np.sign(min_)
    sym = sym and np.abs(np.log10(np.abs(max_)) - np.log10(np.abs(min_))) <= 2
    if sym:
        return 0
    return 1 if np.abs(max_) > np.abs(min_) else -1

iterate_over_year_maybe_member(df=None, da=None, several_years=1, several_members=1)

Constructs iterators over time and member, for up to a polars DataFrame and a xarray DataArray that have the same indices.

Source code in jetutils/definitions.py
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
def iterate_over_year_maybe_member(
    df: pl.DataFrame | None = None,
    da: xr.DataArray | xr.Dataset | None = None,
    several_years: int = 1,
    several_members: int = 1,
):
    """
    Constructs iterators over time and member, for up to a polars DataFrame and a xarray DataArray that have the same indices.
    """
    if df is None and da is None:
        return 0
    if da is None and df is not None:
        years = df["time"].dt.year().unique(maintain_order=True).to_numpy()
        try:
            year_lists = np.array_split(years, len(years) // several_years)
        except ValueError:
            year_lists = [years]
        indexer_polars = (
            pl.col("time").dt.year().is_in(year_list) for year_list in year_lists
        )
        if "member" not in df.columns:
            return zip(indexer_polars)
        members = df["member"].unique(maintain_order=True).to_numpy()
        member_lists = np.array_split(members, len(members) // several_members)
        indexer_polars_2 = (
            pl.col("member").is_in(member_list) for member_list in member_lists
        )
        indexer_polars = product(indexer_polars, indexer_polars_2)
        return indexer_polars
    elif da is not None and df is None:
        years = np.unique(da["time"].dt.year.values)
        try:
            year_lists = np.array_split(years, len(years) // several_years)
        except ValueError:
            # ValueError when too few years. Then a one list should suffice
            year_lists = [years]
        indexer_xarray = (
            {"time": np.isin(da["time"].dt.year.values, year_list)}
            for year_list in year_lists
        )
        if "member" not in da.dims:
            return indexer_xarray
        members = np.unique(da["member"].values)
        member_lists = np.array_split(members, len(members) // several_members)
        indexer_xarray_2 = (
            {"member": np.isin(da["member"].values, member_list)}
            for member_list in member_lists
        )
        indexer_xarray = product(indexer_xarray, indexer_xarray_2)
        indexer_xarray = (indexer[0] | indexer[1] for indexer in indexer_xarray)
        return indexer_xarray
    years = df["time"].dt.year().unique(maintain_order=True).to_numpy()
    years_ = np.unique(da["time"].dt.year.values)
    years = np.intersect1d(years, years_)
    year_lists = np.array_split(years, len(years) // several_years)
    indexer_polars = (
        pl.col("time").dt.year().is_in(year_list) for year_list in year_lists
    )
    indexer_xarray = (
        {"time": np.isin(da["time"].dt.year.values, year_list)}
        for year_list in year_lists
    )
    if "member" not in df.columns:
        return zip(zip(indexer_polars), indexer_xarray)
    """
        weird inner zip: don't worry lol. I want to always be able call::

            for idx in indexer: df.filter(*idx)

        so I need to put it in zip by itself if it's not out of product, so it's always a tuple.
    """
    members = df["member"].unique(maintain_order=True).to_numpy()
    member_lists = np.array_split(members, len(members) // several_members)
    indexer_polars_2 = (
        pl.col("member").is_in(member_list) for member_list in member_lists
    )
    indexer_polars = product(indexer_polars, indexer_polars_2)
    indexer_xarray_2 = (
        {"member": np.isin(da["member"].values, member_list)}
        for member_list in member_lists
    )
    indexer_xarray = product(indexer_xarray, indexer_xarray_2)
    indexer_xarray = (indexer[0] | indexer[1] for indexer in indexer_xarray)
    return zip(indexer_polars, indexer_xarray)

labels_to_mask(labels, as_da=False)

Turns an array of labels into a mask

Parameters:

Name Type Description Default
labels DataArray | ndarray

Array of labels.

required
as_da bool

If labels is a DataArray and as_da is True, then turns the output into a DataArray, by default False

False

Returns:

Type Description
xr.DataArray | np.ndarray of shape (*labels.shape, n_unique_labels)

Boolean mask, of the same shape as labels plus one dimension / axis at position 0. If turned into a DataArray, that new dimension is named "cluster".

Examples:

>>> labels_to_mask([1, 3, 2, 1])
array([[True, False, False],
   [False, False, True],
   [False, True, False],
   [True, False, False]])
Source code in jetutils/definitions.py
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
def labels_to_mask(labels: xr.DataArray | np.ndarray, as_da: bool = False) -> np.ndarray:
    """
    Turns an array of labels into a mask

    Parameters
    ----------
    labels : xr.DataArray | np.ndarray
        Array of labels.
    as_da : bool, optional
        If `labels` is a DataArray and `as_da` is True, then turns the output into a DataArray, by default False

    Returns
    -------
    xr.DataArray | np.ndarray of shape (*labels.shape, n_unique_labels)
        Boolean mask, of the same shape as labels plus one dimension / axis at position 0. If turned into a DataArray, that new dimension is named "cluster".

    Examples
    --------
    >>> labels_to_mask([1, 3, 2, 1])
    array([[True, False, False],
       [False, False, True],
       [False, True, False],
       [True, False, False]])
    """
    if isinstance(labels, np.ndarray):
        as_da = False
    else:
        coords = labels.coords.copy()
        labels = labels.values
    unique_labels = np.unique(labels)
    mask = labels[..., None] == unique_labels[None, :]
    if not as_da:
        return mask
    coords = coords.assign({"cluster": unique_labels})
    mask = xr.DataArray(mask, coords=coords)
    return mask

last_elements(arr, n_elements, sort=False)

Get the largest n_elements of arr, along the last axis.

Parameters:

Name Type Description Default
arr ndarray

Any array

required
n_elements int

Number of elements to return along each axis

required
sort bool

Sort the output, only valid for 1D arr, by default False

False

Returns:

Type Description
ndarray

Raises:

Type Description
RuntimeWarning

If sort=True and arr.ndim > 1 because it's ambiguous.

Source code in jetutils/definitions.py
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def last_elements(arr: np.ndarray, n_elements: int, sort: bool = False) -> np.ndarray:
    """
    Get the largest `n_elements` of `arr`, along the last axis.

    Parameters
    ----------
    arr : np.ndarray
        Any array
    n_elements : int
        Number of elements to return along each axis
    sort : bool, optional
        Sort the output, only valid for 1D `arr`, by default False

    Returns
    -------
    np.ndarray

    Raises
    ------
    RuntimeWarning
        If `sort=True` and `arr.ndim > 1` because it's ambiguous.
    """
    arr = np.nan_to_num(arr, posinf=0)
    ndim = arr.ndim
    if ndim > 1 and sort:
        print("sorting output not supported for arrays with ndim > 1")
        sort = False
        raise RuntimeWarning
    idxs = np.argpartition(arr.ravel(), -n_elements)[-n_elements:]
    if ndim > 1:
        return np.unravel_index(idxs, arr.shape)
    if sort:
        return idxs[np.argsort(arr[idxs])]
    return idxs

load_pickle(filename)

Save a pickleable object to file

Parameters:

Name Type Description Default
filename str | Path

path, it's better if it ends in .pkl

required

Returns:

Type Description
Any

Pickled object

Source code in jetutils/definitions.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def load_pickle(filename: str | Path) -> Any:
    """
    Save a pickleable object to file

    Parameters
    ----------
    filename : str | Path
        path, it's better if it ends in `.pkl`

    Returns
    -------
    Any
        Pickled object
    """
    with open(filename, "rb") as handle:
        to_ret = pkl.load(handle)
    return to_ret

map_maybe_parallel(iterator, func, len_, processes=N_WORKERS, chunksize=None, progress=True, pool_kwargs=None, ctx=None)

Maps a function on the components of an Iterable. Can be parallel if processes is greater than one. In this case the other arguments are used to create a multiprocessing.Pool. In most cases, I recommend using ctx = get_context("spawn") instead of the default (on linux) fork.

Parameters:

Name Type Description Default
iterator Iterable

Data

required
func Callable

Function to apply to each element of iterator

required
len_ int

len of the iterator, so we can display a progress bar.

required
processes int

Number of parallel processes, will not create a Pool if 1, by default N_WORKERS

N_WORKERS
chunksize int

How many elements to send to a worker at once, by default 100

None
progress bool

Show a progress bar using tqdm, by default True

True
pool_kwargs dict | None

Keyword arguments passed to multiprocessing.Pool, by default None

None
ctx optional

Multiporcessing context, created using multiprocessing.get_context(), by default None, will be spawn on windowd and mac, and fork on linux at time of writing, but it should change in python 3.15.

None

Returns:

Type Description
list

result of the map coerced into a list.

Source code in jetutils/definitions.py
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
def map_maybe_parallel(
    iterator: Iterable,
    func: Callable,
    len_: int,
    processes: int = N_WORKERS,
    chunksize: int | None = None,
    progress: bool = True,
    pool_kwargs: dict | None = None,
    ctx=None,
) -> list:
    """
    Maps a function on the components of an Iterable. Can be parallel if processes is greater than one. In this case the other arguments are used to create a `multiprocessing.Pool`. In most cases, I recommend using `ctx = get_context("spawn")` instead of the default (on linux) `fork`.

    Parameters
    ----------
    iterator : Iterable
        Data
    func : Callable
        Function to apply to each element of `iterator`
    len_ : int
        len of the `iterator`, so we can display a progress bar.
    processes : int, optional
        Number of parallel processes, will not create a `Pool` if 1, by default N_WORKERS
    chunksize : int, optional
        How many elements to send to a worker at once, by default 100
    progress : bool, optional
        Show a progress bar using `tqdm`, by default True
    pool_kwargs : dict | None, optional
        Keyword arguments passed to `multiprocessing.Pool`, by default None
    ctx : optional
        Multiporcessing context, created using `multiprocessing.get_context()`, by default None, will be `spawn` on windowd and mac, and `fork` on linux at time of writing, but it should change in python 3.15.

    Returns
    -------
    list
        result of the map coerced into a list.
    """
    processes = min(processes, len_)
    if processes == 1 and progress:
        return list(tqdm(map(func, iterator), total=len_))
    if processes == 1:
        return list(map(func, iterator))
    if pool_kwargs is None:
        pool_kwargs = {}
    if chunksize is None:
        chunksize = min(int(len_ // processes), 200)
    pool_func = Pool if ctx is None else ctx.Pool
    if not progress:
        with pool_func(processes=processes, **pool_kwargs) as pool:
            to_ret = pool.imap(func, iterator, chunksize=chunksize)
            return list(to_ret)
    with pool_func(processes=processes, **pool_kwargs) as pool:
        to_ret = tqdm(
            pool.imap(func, iterator, chunksize=chunksize),
            total=len_,
        )
        return list(to_ret)

maybe_circular_mean(x)

Circular mean of a number already converted to radians

Parameters:

Name Type Description Default
x float

Angle in degrees

required

Returns:

Type Description
float

Circular mean

Source code in jetutils/definitions.py
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
def maybe_circular_mean(x: float) -> float:
    """
    Circular mean of a number already converted to radians

    Parameters
    ----------
    x : float
        Angle in degrees

    Returns
    -------
    float
        Circular mean
    """
    dx = np.unique(x)
    dx = dx[1] - dx[0]
    if np.all(np.diff(x) < 2 * dx):
        return np.mean(x)
    return np.atan2(np.mean(np.sin(x)), np.mean(np.cos(x)))

normalize(X)

Normalizes an arbitrary polars DataFrame or numpy Array to a standard normal along one axis. The 0 axis if numpy, the columns if polars. Returns the original minimum and maximum to be able to revert.

Parameters:

Name Type Description Default
X ndarray | DataFrame

Input array

required

Returns:

Name Type Description
X same as input

Input normalised to a standard normal

meanX same as input, with one fewer dimension

Original minimum of the data, used to revert this function

stdX same as input, with one fewer dimension

Original maximum of the data, used to revert this function

Source code in jetutils/definitions.py
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
def normalize(X):
    """
    Normalizes an arbitrary polars DataFrame or numpy Array to a standard normal along one axis. The 0 axis if numpy, the columns if polars. Returns the original minimum and maximum to be able to revert.

    Parameters
    ----------
    X : np.ndarray | pl.DataFrame
        Input array

    Returns
    -------
    X : same as input
        Input normalised to a standard normal

    meanX : same as input, with one fewer dimension
        Original minimum of the data, used to revert this function

    stdX : same as input, with one fewer dimension
        Original maximum of the data, used to revert this function
    """    
    def expr(col):
        return (pl.col(col) - pl.col(col).mean()) / pl.col(col).std()

    if isinstance(X, pl.DataFrame):
        meanX = X.mean()
        stdX = X.std()
        X = X.with_columns(expr(col) for col in X.columns)
        return X, meanX, stdX
    meanX = X.mean(axis=0)
    stdX = X.std(axis=0)
    try:
        X = (X - meanX[None, :]) / stdX[None, :]
    except IndexError:
        X = (X - meanX) / stdX
    return X, meanX, stdX

polars_to_xarray(df, index_columns)

Turns a polars DataFrame into a xarray DataArray if possible, a Dataset otherwise. Which columns of df will be dimensions of the xarray output cannot be inferred from df and have to be passed as index_columns.

Parameters:

Name Type Description Default
df DataFrame

Input array

required
index_columns list[str]

Which columns of df to use as dimensions for the xarray object

required

Returns:

Name Type Description
da DataArray or Dataset

Data transformed in to a xarray object. If df had only index columns and one other column (inferred to be the data), da will be turned into a DataArray. If there are several other columns, then it stays a Dataset.

Source code in jetutils/definitions.py
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def polars_to_xarray(df: pl.DataFrame, index_columns: Sequence[str]):
    """
    Turns a polars DataFrame into a xarray DataArray if possible, a Dataset otherwise. Which columns of `df` will be dimensions of the xarray output cannot be inferred from `df` and have to be passed as `index_columns`.

    Parameters
    ----------
    df : pl.DataFrame
        Input array

    index_columns : list[str]
        Which columns of `df` to use as dimensions for the xarray object

    Returns
    -------
    da : xr.DataArray or xr.Dataset
        Data transformed in to a xarray object. If `df` had only index columns and one other column (inferred to be the data), `da` will be turned into a DataArray. If there are several other columns, then it stays a Dataset.
    """
    ds = xr.Dataset.from_dataframe(df.to_pandas().set_index(index_columns))
    data_vars = list(ds.data_vars)
    if len(data_vars) == 1:
        ds = ds[data_vars[0]]
    return ds

revert_normalize(X, meanX, stdX)

Reverts the function normalize().

Source code in jetutils/definitions.py
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
def revert_normalize(X, meanX, stdX):
    """
    Reverts the function normalize().
    """
    def expr(col):
        return meanX[0, col] + stdX[0, col] * pl.col(col)

    if isinstance(X, pl.DataFrame):
        X = X.with_columns(expr(col).alias(col) for col in X.columns)
        return X
    try:
        X = X * stdX[None, :] + meanX[None, :]
    except IndexError:
        X = X * stdX + meanX
    return X

revert_zero_one(X, Xmin, Xmax)

Reverts the function to_zero_one().

Source code in jetutils/definitions.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def revert_zero_one(X, Xmin, Xmax):
    """
    Reverts the function to_zero_one().
    """
    def expr(col):
        return Xmin[0, col] + (Xmax[0, col] - Xmin[0, col]) * pl.col(col)

    if isinstance(X, pl.DataFrame):
        X = X.with_columns(expr(col).alias(col) for col in X.columns)
        return X
    try:
        X = Xmin[None, :] + (Xmax - Xmin)[None, :] * X
    except IndexError:
        X = Xmin + (Xmax - Xmin) * X
    return X

save_pickle(to_save, filename)

Save a pickleable object to file

Parameters:

Name Type Description Default
to_save Any

Pickleable

required
filename str | Path

path, it's better if it ends in .pkl

required
Source code in jetutils/definitions.py
429
430
431
432
433
434
435
436
437
438
439
440
441
def save_pickle(to_save: Any, filename: str | Path) -> None:
    """
    Save a pickleable object to file

    Parameters
    ----------
    to_save : Any
        Pickleable
    filename : str | Path
        path, it's better if it ends in `.pkl`
    """
    with open(filename, "wb") as handle:
        pkl.dump(to_save, handle)

slice_1d(da, indexers, dim='points')

Gets a (N - n + 1) dimensional slice from a N dimensional Xarray object using Xarray's advanced indexing, by passing n indexers in a dict.

Parameters:

Name Type Description Default
da DataArray | Dataset

Xarray object

required
indexers dict

Dictionnary whose keys must be dimensions of da and values are arrays of values along this dimension, and of the correct dtype. Each array must be of the same length. Indexers is to be interpreted as the coordinates of points onto which we wish to interpolate da.

required
dim str

Name of the newly created dimension in the output, that will be of the same length as all of the (equally sized) arrays in indexers. By default "points".

'points'

Returns:

Name Type Description
da_slice same as `da`

Input DataArray interpolated on the points specified by indexers. It retains all the dimension that are in da but not as keys of indexers. It has lost all the dimensions named in indexers and gained a new dimension named dim and of the same length as all the arrays in indexers.

References

https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing

Source code in jetutils/definitions.py
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
def slice_1d(da: xr.DataArray | xr.Dataset, indexers: dict, dim: str = "points"):
    """
    Gets a *(N - n + 1)* dimensional slice from a *N* dimensional Xarray object using Xarray's advanced indexing, by passing *n* indexers in a dict.

    Parameters
    ----------
    da : xr.DataArray | xr.Dataset
        Xarray object
    indexers : dict
        Dictionnary whose keys must be dimensions of `da` and values are arrays of values along this dimension, and of the correct `dtype`. Each array must be of the same length. Indexers is to be interpreted as the coordinates of points onto which we wish to interpolate `da`. 
    dim : str, optional
        Name of the newly created dimension in the output, that will be of the same length as all of the (equally sized) arrays in `indexers`. By default "points".

    Returns
    -------
    da_slice : same as `da`
        Input DataArray interpolated on the points specified by `indexers`. It retains all the dimension that are in `da` but not as keys of `indexers`. It has lost all the dimensions named in `indexers` and gained a new dimension named `dim` and of the same length as all the arrays in `indexers`.

    References
    ----------
    https://docs.xarray.dev/en/latest/user-guide/indexing.html#more-advanced-indexing
    """
    return da.interp(
        {key: xr.DataArray(indexer, dims=dim) for key, indexer in indexers.items()},
        method="linear",
        kwargs=dict(fill_value=None),
    )

to_expr(expr)

Make sure it's an Expr.

Parameters:

Name Type Description Default
expr Expr | str

Either already an Expr, or a str to be turned into one.

required

Returns:

Type Description
Expr

Same as input or pl.col(expr)

Source code in jetutils/definitions.py
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
def to_expr(expr: Expr | str) -> Expr:
    """
    Make sure it's an `Expr`.

    Parameters
    ----------
    expr : Expr | str
        Either already an `Expr`, or a `str` to be turned into one.

    Returns
    -------
    Expr
        Same as input or `pl.col(expr)`
    """
    if isinstance(expr, str):
        expr = pl.col(expr)
    return expr

to_zero_one(X)

Normalizes an arbitrary polars DataFrame or numpy Array to the range [0, 1] along one axis. The 0 axis if numpy, the columns if polars. Returns the original minimum and maximum to be able to revert.

Parameters:

Name Type Description Default
X ndarray | DataFrame

Input array

required

Returns:

Name Type Description
X same as input

Input normalised to the range [0, 1]

Xmin same as input, with one fewer dimension

Original minimum of the data, used to revert this function

Xmax same as input, with one fewer dimension

Original maximum of the data, used to revert this function

Source code in jetutils/definitions.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
def to_zero_one(X: np.ndarray | pl.DataFrame):
    """
    Normalizes an arbitrary polars DataFrame or numpy Array to the range [0, 1] along one axis. The 0 axis if numpy, the columns if polars. Returns the original minimum and maximum to be able to revert.

    Parameters
    ----------
    X : np.ndarray | pl.DataFrame
        Input array

    Returns
    -------
    X : same as input
        Input normalised to the range [0, 1]

    Xmin : same as input, with one fewer dimension
        Original minimum of the data, used to revert this function

    Xmax : same as input, with one fewer dimension
        Original maximum of the data, used to revert this function
    """    
    def expr(col):
        return (pl.col(col) - pl.col(col).min()) / (
            pl.col(col).max() - pl.col(col).min()
        )

    if isinstance(X, pl.DataFrame):
        Xmin = X.min()
        Xmax = X.max()
        X = X.with_columns(expr(col) for col in X.columns)
        return X, Xmin, Xmax
    Xmin = np.nanmin(X, axis=0)
    Xmax = np.nanmax(X, axis=0)
    try:
        X = (X - Xmin[None, :]) / (Xmax - Xmin)[None, :]
    except IndexError:
        X = (X - Xmin) / (Xmax - Xmin)
    return X, Xmin, Xmax