Source code for earthkit.plots.quickplot

# Copyright 2024-, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

from earthkit.plots.components import layouts
from earthkit.plots.components.figures import Figure
from earthkit.plots.metadata.units import are_equal
from earthkit.plots.schemas import schema


def _coerce_to_fieldlist(*args):
    """Convert positional data arguments to a FieldList."""
    import earthkit.data
    from earthkit.data import FieldList
    from earthkit.data.core import Base

    field_list = []
    for arg in args:
        if isinstance(arg, FieldList):
            field_list.extend(list(arg))
        else:
            if not isinstance(arg, Base):
                arg = earthkit.data.from_object(arg)
            field_list.append(arg)
    return FieldList.from_fields(field_list)


_DEFAULT_SINGLE_SIZE = (7, 8)
_MULTI_PANEL_WIDTH = 5.0
_MULTI_PANEL_HEIGHT = 4.0
_MAX_FIGURE_SIZE = (40.0, 40.0)


def _auto_figure_size(rows, columns):
    """
    Return a ``(width, height)`` tuple scaled to the panel grid.

    Single panels use ``_DEFAULT_SINGLE_SIZE``.  Multi-panel layouts use
    ``_MULTI_PANEL_WIDTH × _MULTI_PANEL_HEIGHT`` per panel, capped at
    ``_MAX_FIGURE_SIZE``.
    """
    if rows == 1 and columns == 1:
        return _DEFAULT_SINGLE_SIZE
    width = min(_MULTI_PANEL_WIDTH * columns, _MAX_FIGURE_SIZE[0])
    height = min(_MULTI_PANEL_HEIGHT * rows, _MAX_FIGURE_SIZE[1])
    return (width, height)


def _iter_plot_groups(args, groupby, mode, combine_vectors=False):
    """
    Dispatch to the correct source-type group iterator.

    Yields ``(key, [data_item, ...])`` tuples consumed by :func:`plot`.

    Parameters
    ----------
    args : tuple
        Positional arguments passed to :func:`plot`.
    groupby : str or None
        Metadata key / coordinate name to split on.
    mode : str
        ``"auto"``, ``"overlay"``, or ``"split"``.
    combine_vectors : bool
        Passed through to the xarray group iterator.
    """
    import xarray as xr

    from earthkit.plots.sources import _is_xarray_backed_earthkit

    if len(args) == 1 and isinstance(args[0], (xr.DataArray, xr.Dataset)):
        from earthkit.plots.sources.extractors.xarray import iter_plot_groups

        yield from iter_plot_groups(args[0], groupby, mode, combine_vectors=combine_vectors)
    elif len(args) == 1 and _is_xarray_backed_earthkit(args[0]):
        from earthkit.plots.sources.extractors.xarray import iter_plot_groups

        yield from iter_plot_groups(args[0].to_xarray(), groupby, mode, combine_vectors=combine_vectors)
    elif all(isinstance(a, xr.DataArray) for a in args):
        # Multiple DataArrays — merge into a Dataset so each variable gets its
        # own panel with the correct auto-style.
        from earthkit.plots.sources.extractors.xarray import iter_plot_groups

        ds = xr.merge(args, compat="override")
        yield from iter_plot_groups(ds, groupby, mode, combine_vectors=combine_vectors)
    else:
        try:
            from earthkit.data import FieldList

            is_fieldlist = all(isinstance(a, FieldList) for a in args)
        except ImportError:
            is_fieldlist = False

        if is_fieldlist:
            from earthkit.plots.sources.extractors.earthkit import iter_plot_groups

            yield from iter_plot_groups(_coerce_to_fieldlist(*args), groupby, mode, combine_vectors=combine_vectors)
        else:
            # Numpy arrays or other raw data — yield directly, one panel.
            yield None, list(args)


def _iter_plot_groups_2d(args, row_dim, col_dim, groupby, mode):
    """
    Dispatch to the 2-D group iterator for structured row/column layout.

    Yields ``(row_key, col_key, [data_item, ...])`` tuples.
    """
    import xarray as xr

    if len(args) == 1 and isinstance(args[0], (xr.DataArray, xr.Dataset)):
        from earthkit.plots.sources.extractors.xarray import iter_plot_groups_2d

        yield from iter_plot_groups_2d(args[0], row_dim, col_dim, groupby, mode)
    else:
        raise NotImplementedError("row/column dimension layout is only supported for xarray data.")


def plot(
    *args,
    domain=None,
    crs=None,
    groupby=None,
    row=None,
    column=None,
    rows=None,
    columns=None,
    figsize=None,
    size=None,
    units=None,
    style="auto",
    subplot_titles=None,
    method="quickplot",
    mode="auto",
    combine_vectors=False,
    title=True,
    legend=True,
    coastlines=True,
    **kwargs,
):
    """
    Plot geospatial data as one or more map panels.

    This is the primary high-level function in earthkit-plots.  Pass a single
    data object to get a single map; pass ``groupby`` to get a grid of panels,
    one per unique value of that metadata key (e.g. forecast step, ensemble
    member, pressure level).

    Parameters
    ----------
    *args :
        The data to plot.  Accepts any format supported by earthkit-data
        (GRIB FieldList, xarray DataArray/Dataset, numpy array, …).
    domain : str or list, optional
        Named domain (e.g. ``"Europe"``) or bounding box
        ``[lon_min, lon_max, lat_min, lat_max]``.  If omitted the extent is
        inferred from the data.
    crs : cartopy.crs.CRS, optional
        Map projection.  If omitted an appropriate projection is chosen
        automatically for the domain.
    groupby : str, optional
        Metadata key along which to split the data into separate panels
        (e.g. ``"step"``, ``"number"``, ``"pressure_level"``).
    row : str, optional
        Dimension name (or ``"variable"`` for Dataset variables) to lay out
        along the row axis of the panel grid.  Determines the number of rows
        automatically from unique values.  Use together with *column* to create
        a structured 2-D grid, e.g. ``row="valid_time", column="variable"``.
    column : str, optional
        Dimension name (or ``"variable"``) to lay out along the column axis.
    rows : int, optional
        Number of rows in the panel grid.  If only one of *rows* / *columns*
        is given the other is calculated automatically.  Ignored when *row* is
        a dimension name.
    columns : int, optional
        Number of columns in the panel grid.  Ignored when *column* is a
        dimension name.
    size : tuple of float, optional
        Explicit ``(width, height)`` in inches for the whole figure.  When not
        provided the size is chosen automatically based on the panel grid
        (approximately 5 × 4 inches per panel, capped at 40 inches).
    units : str or list of str, optional
        Units to convert the data to at plot time (e.g. ``"celsius"``). See
        :doc:`/examples/examples/introduction/08-unit-conversion` for
        examples.
    style : str or Style, optional
        Named style or :class:`~earthkit.plots.styles.Style` object.
        Defaults to ``"auto"`` which selects a style based on the variable.
    subplot_titles : str, optional
        Format string for per-panel titles.  Metadata placeholders like
        ``{step}`` or ``{valid_time}`` are resolved from each panel's data.
        When *groupby* is set this defaults to ``"{<groupby key>}"``.
    method : str, optional
        The plotting method to call on each subplot (default ``"quickplot"``).
    title : bool or str, optional
        ``True`` (default) adds an automatic title from the data metadata.
        Pass a string to use a custom title.  ``False`` suppresses the title.
        For full control (font size, position, etc.) pass ``False`` and call
        ``.title()`` on the returned object.
    legend : bool, optional
        ``True`` (default) adds a legend/colorbar.  ``False`` suppresses it.
        For full control pass ``False`` and call ``.legend()`` on the returned
        object.
    coastlines : bool, optional
        ``True`` (default) overlays coastlines on the map.  ``False``
        suppresses them.  For full control (resolution, styling, etc.) pass
        ``False`` and call ``.coastlines()`` on the returned object.
    **kwargs :
        Additional keyword arguments forwarded to the plotting method.

    Returns
    -------
    Map or Figure
        A :class:`~earthkit.plots.components.maps.Map` when a single panel is
        produced, or a :class:`~earthkit.plots.components.figures.Figure` for
        multi-panel layouts.  Both support ``.show()`` and ``.save()``.

    Examples
    --------
    Single panel:

    >>> import earthkit.plots as ekp
    >>> ekp.plot(data, domain="Europe", units="celsius").show()

    Grid of panels, one per forecast step:

    >>> ekp.plot(data, groupby="step", domain="Europe", columns=4).show()

    Structured 2-D grid — variables in columns, time steps in rows:

    >>> ekp.plot(ds, row="valid_time", column="variable").show()

    Override the plot method:

    >>> ekp.plot(data, method="contourf", domain="Europe").show()
    """
    if size is not None:
        warnings.warn(
            "The 'size' argument is deprecated and will be removed in a future release. Use 'figsize' instead.",
            DeprecationWarning,
            stacklevel=2,
        )
        if figsize is None:
            figsize = size

    # --- 2-D structured layout (row/column as dimension names) ---
    if row is not None or column is not None:
        groups_2d = list(_iter_plot_groups_2d(args, row, column, groupby, mode))
        # Determine grid dimensions from unique row/col keys
        row_keys = list(dict.fromkeys(rk for rk, _, _ in groups_2d))
        col_keys = list(dict.fromkeys(ck for _, ck, _ in groups_2d))
        n_rows = len(row_keys) if row_keys != [None] else (rows or 1)
        n_cols = len(col_keys) if col_keys != [None] else (columns or 1)
        figure = Figure(
            rows=n_rows,
            columns=n_cols,
            figsize=figsize or _auto_figure_size(n_rows, n_cols),
            chainable=True,
        )

        if not isinstance(units, (list, tuple)):
            units_flat = [units] * len(groups_2d)
        else:
            units_flat = list(units)

        for i, (row_key, col_key, targets) in enumerate(groups_2d):
            r = row_keys.index(row_key) if row_keys != [None] else 0
            c = col_keys.index(col_key) if col_keys != [None] else 0
            subplot = figure.add_map(row=r, column=c, domain=domain, crs=crs)
            unit = units_flat[i] if i < len(units_flat) else None
            for target in targets:
                try:
                    getattr(subplot, method)(target, units=unit, style=style, **kwargs)
                except Exception as err:
                    warnings.warn(
                        f"ekp.plot: failed to call '{method}' on panel ({r},{c}):\n"
                        f"{err}\n\n"
                        "Consider building the plot manually using ekp.Figure and ekp.Map."
                    )
                    raise
        _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
        return _unwrap_if_single(figure)

    # --- Flat layout (original behaviour) ---
    if subplot_titles is None and groupby:
        subplot_titles = f"{{{groupby}}}"

    groups = list(_iter_plot_groups(args, groupby, mode, combine_vectors=combine_vectors))
    n_plots = len(groups)

    rows, columns = layouts.rows_cols(n_plots, rows, columns)
    figure = Figure(
        rows=rows,
        columns=columns,
        figsize=figsize or _auto_figure_size(rows, columns),
        chainable=True,
    )

    if not isinstance(units, (list, tuple)):
        units = [units] * n_plots

    for i, (key, targets) in enumerate(groups):
        subplot = figure.add_map(domain=domain, crs=crs)
        unit = units[i] if i < len(units) else None
        is_vector = isinstance(key, tuple) and key[0] == "__vector__"
        try:
            if is_vector:
                import xarray as xr

                target = targets[0]
                if isinstance(target, xr.Dataset):
                    # xarray UV pair sub-Dataset
                    subplot.quiver(target, units=unit, **kwargs)
                else:
                    # earthkit: targets is [u_field, v_field]
                    from earthkit.data import FieldList

                    fl = FieldList.from_fields(targets)
                    subplot.quiver(fl, units=unit, **kwargs)
            else:
                for target in targets:
                    getattr(subplot, method)(target, units=unit, style=style, **kwargs)
        except Exception as err:
            warnings.warn(
                f"ekp.plot: failed to call '{method}' on panel {i} with:\n"
                f"{err}\n\n"
                "Consider building the plot manually using ekp.Figure and ekp.Map."
            )
            raise
    _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
    return _unwrap_if_single(figure)


def contourf(*args, style=None, **kwargs):
    """
    Plot filled contours on a map.

    A shortcut for ``ekp.plot(*args, method="contourf", **kwargs)``.
    Accepts all the same arguments as :func:`plot`.

    The filled contours are rendered via
    :meth:`matplotlib.axes.Axes.contourf` — see the `matplotlib contourf
    documentation
    <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.contourf.html>`_
    for the full list of accepted ``**kwargs``.
    """
    return plot(*args, method="contourf", style=style, **kwargs)


def contour(*args, style=None, **kwargs):
    """
    Plot contour lines on a map.

    A shortcut for ``ekp.plot(*args, method="contour", **kwargs)``.
    Accepts all the same arguments as :func:`plot`.

    The contour lines are rendered via
    :meth:`matplotlib.axes.Axes.contour` — see the `matplotlib contour
    documentation
    <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.contour.html>`_
    for the full list of accepted ``**kwargs``.
    """
    return plot(*args, method="contour", style=style, **kwargs)


def pcolormesh(*args, style=None, **kwargs):
    """
    Plot a pseudocolor mesh on a map.

    A shortcut for ``ekp.plot(*args, method="pcolormesh", **kwargs)``.
    Accepts all the same arguments as :func:`plot`.

    The mesh is rendered via
    :meth:`matplotlib.axes.Axes.pcolormesh` — see the `matplotlib pcolormesh
    documentation
    <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.pcolormesh.html>`_
    for the full list of accepted ``**kwargs``.
    """
    return plot(*args, method="pcolormesh", style=style, **kwargs)


def _unwrap_if_single(figure):
    """Return the sole subplot when the figure has exactly one panel."""
    if len(figure.subplots) == 1:
        return figure.subplots[0]
    return figure


def _apply_map_decoration(figure, title=True, legend=True, coastlines=True):
    """
    Apply standard post-plot decoration to a geo Figure.

    Each step is attempted independently; failures are silently skipped so
    that a missing metadata field (e.g. no variable name for the title) does
    not prevent the other steps from running.

    coastlines and legend are figure-level (applied across all subplots).
    title is subplot-level so each panel gets its own title from its data.
    """
    if coastlines:
        try:
            figure.coastlines()
        except Exception:
            pass
    if legend:
        try:
            figure.legend()
        except Exception:
            pass
    if title is not False:
        label = title if isinstance(title, str) else None
        for subplot in figure.subplots:
            try:
                subplot.title(label)
            except Exception:
                pass


def _single_map_function(method_name, data_args, domain, crs, kwargs):
    """Shared helper: create a single-panel Map and call *method_name* on it."""
    import xarray as xr

    title = kwargs.pop("title", True)
    legend = kwargs.pop("legend", True)
    coastlines = kwargs.pop("coastlines", True)
    figure = Figure(rows=1, columns=1, chainable=True)
    subplot = figure.add_map(domain=domain, crs=crs)
    if not data_args:
        getattr(subplot, method_name)(**kwargs)
    elif all(isinstance(a, (xr.DataArray, xr.Dataset)) for a in data_args):
        getattr(subplot, method_name)(*data_args, **kwargs)
    else:
        fields = _coerce_to_fieldlist(*data_args)
        getattr(subplot, method_name)(fields, **kwargs)
    _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
    return _unwrap_if_single(figure)


def grid_cells(
    *args,
    domain=None,
    crs=None,
    **kwargs,
):
    """
    Plot data as grid cells on a map.

    Uses specialised nnshow backends for HEALPix and octahedral reduced
    Gaussian grids; falls back to pcolormesh for other grid types.  This is
    the fastest way to visualise the native grid structure of your data.

    Parameters
    ----------
    *args :
        The data to plot.
    domain : str or list, optional
        Named domain or bounding box ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    **kwargs :
        Additional keyword arguments forwarded to :meth:`Map.grid_cells`.

    Returns
    -------
    Map
    """
    return _single_map_function("grid_cells", args, domain, crs, kwargs)


def grid_points(
    *args,
    domain=None,
    crs=None,
    **kwargs,
):
    """
    Plot grid point centroids as scatter points on a map.

    Useful for inspecting the spatial coverage and density of a dataset's
    native grid.

    Parameters
    ----------
    *args :
        The data whose grid points to plot.
    domain : str or list, optional
        Named domain or bounding box ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    **kwargs :
        Additional keyword arguments forwarded to :meth:`Map.grid_points`
        (and ultimately to :func:`matplotlib.pyplot.scatter`).

    Returns
    -------
    Map
    """
    return _single_map_function("grid_points", args, domain, crs, kwargs)


def point_cloud(
    *args,
    domain=None,
    crs=None,
    **kwargs,
):
    """
    Plot data values as a coloured point cloud on a map.

    Each data point is rendered as a scatter point coloured by its value.
    Suitable for sparse or unstructured observation data.

    Parameters
    ----------
    *args :
        The data to plot.
    domain : str or list, optional
        Named domain or bounding box ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    **kwargs :
        Additional keyword arguments forwarded to :meth:`Map.point_cloud`
        (and ultimately to :func:`matplotlib.pyplot.scatter`).

    Returns
    -------
    Map
    """
    return _single_map_function("point_cloud", args, domain, crs, kwargs)


def rgb_composite(
    *args,
    domain=None,
    crs=None,
    title=True,
    legend=True,
    coastlines=True,
    **kwargs,
):
    """
    Plot an RGB composite image on a map.

    Combines three data fields (red, green, blue channels) into a single
    colour image. Each channel is normalised to [0, 1] before compositing.

    Parameters
    ----------
    *args :
        Either three separate data objects (red, green, blue) or a single
        iterable of three data objects.
    domain : str or list, optional
        Named domain or bounding box ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    **kwargs :
        Additional keyword arguments forwarded to :meth:`Map.rgb_composite`.

    Returns
    -------
    Map
    """
    figure = Figure(rows=1, columns=1, chainable=True)
    subplot = figure.add_map(domain=domain, crs=crs)
    subplot.rgb_composite(*args, **kwargs)
    _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
    return _unwrap_if_single(figure)


def choropleth(
    data,
    domain=None,
    crs=None,
    title=True,
    legend=True,
    coastlines=True,
    **kwargs,
):
    """
    Create a choropleth map from a GeoDataFrame.

    Parameters
    ----------
    data : geopandas.GeoDataFrame or earthkit-data object
        The data to plot. GeoDataFrame objects are used directly; earthkit-data
        objects are converted via ``to_geopandas()`` first.
    domain : str or list, optional
        Named domain or bounding box ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    **kwargs :
        Additional keyword arguments forwarded to :meth:`Map.choropleth`.

    Returns
    -------
    Map
    """
    figure = Figure(rows=1, columns=1, chainable=True)
    subplot = figure.add_map(domain=domain, crs=crs)
    subplot.choropleth(data, **kwargs)
    _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
    return _unwrap_if_single(figure)


def spaghetti(
    *args,
    domain=None,
    crs=None,
    levels=None,
    color="#0673e0",
    label=None,
    highlight=None,
    highlight_kwargs=None,
    highlight_label=None,
    title=True,
    legend=True,
    coastlines=True,
    **kwargs,
):
    """
    Plot spaghetti contours for ensemble data on a single map.

    Each ensemble member is drawn as a thin contour line.  An optional
    ``highlight`` selector can draw specific members (e.g. the control
    forecast) with a different style.

    Parameters
    ----------
    *args :
        The ensemble data to plot.  Accepts any format supported by
        earthkit-data (GRIB FieldList, xarray DataArray, …).  All fields
        are plotted on the same map as individual contour lines.
    domain : str or list, optional
        Named domain (e.g. ``"Europe"``) or bounding box
        ``[lon_min, lon_max, lat_min, lat_max]``.
    crs : cartopy.crs.CRS, optional
        Map projection.
    levels : float or list of float, optional
        Contour level(s) to draw for each member.  Accepts a single value
        (e.g. ``5400``) or multiple values (e.g. ``[5400, 5700, 5900]``).
        If omitted, contour levels are chosen automatically.
    color : str, default ``"#0673e0"``
        Line colour for normal ensemble members.
    label : str, optional
        Legend label for the ensemble members.  When set, a legend is
        automatically added to the plot.
    highlight : dict, optional
        Metadata criteria used to select members for highlighted rendering,
        e.g. ``{"dataType": "cf"}`` to pick out the control forecast.
    highlight_kwargs : dict, optional
        Keyword arguments passed to the contour method for highlighted
        members.  Defaults to ``{"color": "red", "linewidths": 1.5}``.
    highlight_label : str, optional
        Legend label for the highlighted members.  Defaults to
        ``"Control"`` when ``label`` is set and ``highlight`` is used.
    **kwargs :
        Additional keyword arguments forwarded to the underlying
        ``contour`` call (e.g. ``linewidths``, ``alpha``).

    Returns
    -------
    Map
        An earthkit-plots :class:`~earthkit.plots.components.maps.Map`.

    Examples
    --------
    Single contour level across all members:

    >>> import earthkit.plots as ekp
    >>> ekp.spaghetti(data, levels=5400, domain="Europe").show()

    Multiple levels, with control forecast highlighted and a legend:

    >>> ekp.spaghetti(
    ...     data,
    ...     levels=[5400, 5700],
    ...     domain="Europe",
    ...     label="Ensemble",
    ...     highlight={"dataType": "cf"},
    ...     highlight_label="Control",
    ... ).show()
    """
    fields = _coerce_to_fieldlist(*args)

    figure = Figure(rows=1, columns=1, chainable=True)
    subplot = figure.add_map(domain=domain, crs=crs)
    subplot.spaghetti(
        fields,
        levels=levels,
        color=color,
        label=label,
        highlight=highlight,
        highlight_kwargs=highlight_kwargs,
        highlight_label=highlight_label,
        **kwargs,
    )

    _apply_map_decoration(figure, title=title, legend=legend, coastlines=coastlines)
    return _unwrap_if_single(figure)


def quiver(*args, domain=None, crs=None, **kwargs):
    """
    Plot wind / vector data as arrows on a map.

    Equivalent to creating a single-panel Map and calling ``quiver`` on it.
    Accepts all the same arguments as :meth:`~earthkit.plots.components.maps.Map.quiver`.
    """
    return _single_map_function("quiver", args, domain, crs, kwargs)


def streamplot(*args, domain=None, crs=None, **kwargs):
    """
    Plot wind / vector data as streamlines on a map.

    Equivalent to creating a single-panel Map and calling ``streamplot`` on it.
    Accepts all the same arguments as
    :meth:`~earthkit.plots.components.subplots.Subplot.streamplot`.
    """
    return _single_map_function("streamplot", args, domain, crs, kwargs)


def barbs(*args, domain=None, crs=None, **kwargs):
    """
    Plot wind barbs on a map.

    Equivalent to creating a single-panel Map and calling ``barbs`` on it.
    Accepts all the same arguments as
    :meth:`~earthkit.plots.components.subplots.Subplot.barbs`.
    """
    return _single_map_function("barbs", args, domain, crs, kwargs)


[docs] def quickplot(*args, **kwargs): """Alias for :func:`plot`. Use ``ekp.plot()`` instead.""" return plot(*args, **kwargs)
def climatology( data, *args, plot="line", title=None, xticks=None, yticks=None, xlabel=None, ylabel=None, **kwargs, ): """ Create a climatology (annual-cycle) plot. Splits multi-year timeseries data by calendar year and plots each year on a common Jan-to-Dec x-axis. Leap years are mapped onto the reference year 2000; non-leap years onto 2001, so Feb 29 is naturally absent for non-leap years. Parameters ---------- data : xarray.DataArray Multi-year timeseries data with a time coordinate. plot : str, optional The plotting method to call on the Climatology subplot. One of ``"line"``, ``"scatter"``, ``"bar"``, ``"fill_between"``. Default ``"line"``. title : str, optional Plot title. xticks : str or dict, optional X-axis tick configuration. If a string, treated as a frequency (e.g. ``"M"``, ``"M3"``). If a dict, passed as kwargs to ``xticks()``. yticks : str or dict, optional Y-axis tick configuration. Same format as *xticks*. xlabel : str, optional Label for the x-axis. ylabel : str, optional Label for the y-axis. **kwargs : Additional keyword arguments forwarded to the plotting method (e.g. ``color``, ``linewidth``). Returns ------- Climatology An earthkit-plots :class:`~earthkit.plots.temporal.climatology.Climatology` subplot that can be further customised and displayed with ``.show()`` or saved with ``.save()``. Examples -------- >>> import earthkit.plots as ekp >>> ekp.climatology.line(da).show() >>> ekp.climatology.bar(da, ylabel="Temperature (°C)", title="Annual cycle").show() """ from earthkit.plots.temporal.climatology import Climatology class_kwargs = {k: kwargs.pop(k) for k in _TIMESERIES_CLASS_KWARGS if k in kwargs} ts = Climatology(chainable=True, **class_kwargs) getattr(ts, plot)(data, *args, **kwargs) if xlabel is not None: ts.xlabel(xlabel) if ylabel is not None: ts.ylabel(ylabel) _apply_ticks(ts, xticks, yticks) if title: ts.title(title) ts.figure.legend() return ts _TIMESERIES_CLASS_KWARGS = {"size"} def _apply_ticks(subplot, xticks, yticks): """Apply xticks/yticks configuration to a subplot.""" if xticks is not None: if isinstance(xticks, str): subplot.xticks(frequency=xticks) else: subplot.xticks(**xticks) if yticks is not None: if isinstance(yticks, str): subplot.yticks(frequency=yticks) else: subplot.yticks(**yticks) def _run_timeseries_subplot_workflow(ts, xlabel=None, ylabel=None, xticks=None, yticks=None): """Apply schema.timeseries.subplot.decorate steps to a TimeSeries subplot.""" for m in schema.timeseries.subplot.decorate: try: if m == "xlabel": ts.xlabel(xlabel) if xlabel is not None else ts.xlabel() elif m == "ylabel": ts.ylabel(ylabel) if ylabel is not None else ts.ylabel() else: getattr(ts, m)() except Exception as err: warnings.warn(f"timeseries subplot workflow step '{m}' failed with:\n{err}") _apply_ticks(ts, xticks, yticks) def _run_timeseries_fig_workflow(fig, subplot_titles=None): """Apply schema.timeseries.fig.decorate steps to a Figure.""" for m in schema.timeseries.fig.decorate: try: if m == "subplot_titles" and subplot_titles: fig.subplot_titles(subplot_titles) else: getattr(fig, m)() except Exception as err: warnings.warn(f"timeseries fig workflow step '{m}' failed with:\n{err}") def timeseries( data, *args, overlay=False, groupby=None, rows=None, columns=None, title=None, subplot_titles="{variable_name}", xticks=None, yticks=None, xlabel=None, ylabel=None, plot="line", **kwargs, ): """ Create a time series plot with automatic configuration. This is a convenience function that creates a TimeSeries subplot (or a Figure of multiple TimeSeries subplots) with sensible defaults for time series visualization. When *data* is an xarray Dataset with more than one data variable, a separate subplot is created for each variable by default (one per row). Pass ``overlay=True`` to plot all variables on a single subplot instead. Use ``groupby`` to split the data along a coordinate dimension (produces one panel per unique value). Parameters ---------- data : array-like, xarray DataArray/Dataset, or earthkit data source The time series data to plot. *args : tuple Additional positional arguments passed to the plotting method. overlay : bool, optional When *data* is a multi-variable Dataset, plot all variables on a single subplot instead of creating one subplot per variable. Default is False. groupby : str, optional Coordinate name along which to split the data into separate panels. rows : int, optional Override the number of rows in the Figure layout. columns : int, optional Override the number of columns in the Figure layout. title : str, optional Figure-level title. Supports format strings like ``{variable_name}``. subplot_titles : str, optional Per-panel title format string. Default is ``"{variable_name}"``. xticks : str or dict, optional Configuration for x-axis ticks. If a string, treated as frequency (e.g. ``"Y"``, ``"M6"``, ``"D7"``, ``"h"``). If a dict, passed as kwargs to ``xticks()``. yticks : str or dict, optional Configuration for y-axis ticks. Same format as *xticks*. xlabel : str, optional Label for the x-axis. ylabel : str, optional Label for the y-axis. plot : str, optional Plotting method to call on each TimeSeries subplot. Default is ``"line"``. **kwargs : Additional keyword arguments passed to the plotting method. Returns ------- TimeSeries or Figure """ import xarray as xr from earthkit.plots.metadata.formatters import LayerFormatter from earthkit.plots.temporal.timeseries import TimeSeries # ------------------------------------------------------------------ # Multi-variable Dataset, not overlaid → one row per variable # ------------------------------------------------------------------ if not overlay and isinstance(data, xr.Dataset) and len(data.data_vars) > 1: size = kwargs.pop("size", None) col = groupby if groupby is not None else None fig = Figure(chainable=True) fig.timeseries( data, *args, row="variable", col=col, plot=plot, subplot_titles=subplot_titles, rows=rows, columns=columns, size=size, xticks=xticks, yticks=yticks, xlabel=xlabel, ylabel=ylabel, **kwargs, ) if title: try: fig.title(title) except Exception: pass # subplot_titles already handled inside fig.timeseries(); pass None to # avoid a second overwriting call from the workflow. _run_timeseries_fig_workflow(fig, subplot_titles=None) return fig # ------------------------------------------------------------------ # Multi-variable Dataset, overlaid → all vars on one subplot # ------------------------------------------------------------------ if overlay and isinstance(data, xr.Dataset) and len(data.data_vars) > 1: class_kwargs = {k: kwargs.pop(k) for k in _TIMESERIES_CLASS_KWARGS if k in kwargs} ts = TimeSeries(chainable=True, **class_kwargs) axes_by_units = {} layers_by_ax = {} primary_ax = None for var_name in data.data_vars: da = data[var_name] var_units = da.attrs.get("units", None) if not axes_by_units: layers_before = len(ts.layers) getattr(ts, plot)(da, *args, **kwargs) primary_ax = ts._ax axes_by_units[var_units] = primary_ax layers_by_ax[primary_ax] = ts.layers[layers_before:] else: target_ax = next( (ax for u, ax in axes_by_units.items() if are_equal(var_units, u)), None, ) if target_ax is None: target_ax = primary_ax.twinx() axes_by_units[var_units] = target_ax layers_by_ax[target_ax] = [] original_ax = ts._ax ts._ax = target_ax layers_before = len(ts.layers) getattr(ts, plot)(da, *args, **kwargs) ts._ax = original_ax layers_by_ax[target_ax].extend(ts.layers[layers_before:]) _run_timeseries_subplot_workflow(ts, xlabel=xlabel, ylabel=None, xticks=xticks, yticks=yticks) for ax, ax_layers in layers_by_ax.items(): if ylabel is not None: ax.set_ylabel(ylabel) elif ax_layers: try: src = ax_layers[0].sources[0] units = src.y.metadata("units") lbl = "{variable_name} ({units})" if units else "{variable_name}" ax.set_ylabel(LayerFormatter(ax_layers[0]).format(lbl)) except Exception: pass if title: ts.title(title) _run_timeseries_fig_workflow(ts.figure, subplot_titles=subplot_titles) return ts # ------------------------------------------------------------------ # groupby → one panel per unique value (DataArray or single-var Dataset) # ------------------------------------------------------------------ # For single-variable paths, "{variable_name}" as the default subplot title # is redundant (all panels show the same variable); suppress it unless the # user set it explicitly or we're splitting by groupby (where we auto-set it). if groupby is None and subplot_titles == "{variable_name}": subplot_titles = None if groupby is not None: if subplot_titles is None: subplot_titles = f"{{{groupby}}}" groups = list(_iter_plot_groups((data,), groupby, mode="split")) n_plots = len(groups) _rows, _cols = layouts.rows_cols(n_plots, rows, columns) kwargs.pop("size", None) fig = Figure(rows=_rows, columns=_cols, chainable=True) for _, targets in groups: ts = fig.add_timeseries() for target in targets: getattr(ts, plot)(target, *args, **kwargs) _run_timeseries_subplot_workflow(ts, xlabel=xlabel, ylabel=ylabel, xticks=xticks, yticks=yticks) if title: try: fig.title(title) except Exception: pass _run_timeseries_fig_workflow(fig, subplot_titles=subplot_titles) return fig # ------------------------------------------------------------------ # Single panel # ------------------------------------------------------------------ class_kwargs = {k: kwargs.pop(k) for k in _TIMESERIES_CLASS_KWARGS if k in kwargs} ts = TimeSeries(chainable=True, **class_kwargs) getattr(ts, plot)(data, *args, **kwargs) _run_timeseries_subplot_workflow(ts, xlabel=xlabel, ylabel=ylabel, xticks=xticks, yticks=yticks) if title: ts.title(title) _run_timeseries_fig_workflow(ts.figure, subplot_titles=subplot_titles) return ts