Source code for earthkit.plots.components.figures

# Copyright 2024-, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import re

import matplotlib.pyplot as plt

from earthkit.plots.ancillary import find_logo
from earthkit.plots.components.layers import LayerGroup
from earthkit.plots.components.layouts import rows_cols
from earthkit.plots.components.subplots import Subplot
from earthkit.plots.metadata import formatters
from earthkit.plots.schemas import schema
from earthkit.plots.utils import string_utils


[docs] class Figure: """ The overall canvas onto which subplots are drawn. A Figure is a container for one or more Subplots, each of which can contain one or more Layers. The Figure is responsible for managing the layout of Subplots and Layers, as well as providing methods for adding common elements like legends and titles. Parameters ---------- rows : int, optional The number of rows in the figure. columns : int, optional The number of columns in the figure. size : list, optional The size of the figure in inches. This can be a list or tuple of two floats representing the width and height of the figure. domain : earthkit.geo.Domain, optional The domain of the data being plotted. This is used to set the extent and projection of the map. crs : cartopy.crs.CRS, optional The CRS of the map. If not provided, it will be inferred from the domain. See https://cartopy.readthedocs.io/stable/reference/projections.html#cartopy-projections for a list of available CRSs. kwargs : dict, optional Additional keyword arguments to pass to :class:`matplotlib.gridspec.GridSpec`. """ def __init__( self, rows=None, columns=None, figsize=None, domain=None, crs=None, size=None, gridspec=None, chainable=False, **kwargs, ): self._chainable = chainable if size is not None: import warnings warnings.warn( "The 'size' argument is deprecated and will be removed in a future release. Use 'figsize' instead.", DeprecationWarning, stacklevel=2, ) if figsize is None: figsize = size self._external_gridspec = gridspec if gridspec is not None: nrows, ncols = gridspec.get_geometry() if rows is not None and rows != nrows: raise ValueError(f"rows={rows} conflicts with the provided GridSpec ({nrows} rows).") if columns is not None and columns != ncols: raise ValueError(f"columns={columns} conflicts with the provided GridSpec ({ncols} columns).") rows = nrows columns = ncols self.rows = rows self.columns = columns self.fig = None self.gridspec = None self._style_context = None self._row = 0 self._col = 0 self._figsize = self._parse_size(figsize) self._gridspec_kwargs = kwargs self._domain = domain self._crs = crs self.subplots = [] self._last_subplot_location = None self._isubplot = 0 self._queue = [] self._subplot_queue = [] self._released = False self.attributions = [] self.logos = [] self._ancillary_cache = {} self._style_context = None self._jupyter_display_hook = None if None not in (self.rows, self.columns): self._setup() def _setup(self): """Set up the figure the first time it is needed.""" self._style_context = schema.style_context() self._style_context.__enter__() self.fig = plt.figure(figsize=self._figsize, constrained_layout=True) if self._external_gridspec is not None: self._external_gridspec.figure = self.fig self.gridspec = self._external_gridspec else: self.gridspec = self.fig.add_gridspec(self.rows, self.columns, **self._gridspec_kwargs) self._register_jupyter_display() def _register_jupyter_display(self): """Register a post_execute hook so the figure auto-displays in Jupyter.""" try: ip = get_ipython() # noqa: F821 except NameError: return if ip is None: return def _display_once(): self._jupyter_display_hook = None ip.events.unregister("post_execute", _display_once) self._prepare_for_display() try: plt.show() finally: self._exit_style_context() self._jupyter_display_hook = (_display_once, ip) ip.events.register("post_execute", _display_once) def _cancel_jupyter_display(self): """Unregister the auto-display hook (called when show/save is explicit).""" if self._jupyter_display_hook is not None: hook_fn, ip = self._jupyter_display_hook self._jupyter_display_hook = None try: ip.events.unregister("post_execute", hook_fn) except ValueError: pass def _exit_style_context(self): """Exit the style context, restoring matplotlib's global rcParams.""" if self._style_context is not None: self._style_context.__exit__(None, None, None) self._style_context = None def _defer_until_setup(method): """Decorator to defer calling a method until the figure is setup.""" @functools.wraps(method) def wrapper(self, *args, **kwargs): if not self.subplots: self._queue.append((method, args, kwargs)) else: return method(self, *args, **kwargs) return wrapper def _defer_subplot(method): """Decorator to defer calling a method until the subplots are setup.""" @functools.wraps(method) def wrapper(self, *args, **kwargs): if self.rows is None or self.columns is None: self._subplot_queue.append((method, args, kwargs)) else: return method(self, *args, **kwargs) return wrapper def __len__(self): return len(self.subplots) def __getitem__(self, i): return self.subplots[i] def _parse_size(self, size): """Parse the size of the figure.""" if size is not None: figsize = [] for length in size: if isinstance(length, str): if length.isnumeric(): length = float(length) else: match = re.match(r"([0-9]+)([a-z]+)", length, re.I) value, units = match.groups() value = float(value) if units == "px": from matplotlib import rcParams as _rc length = value / _rc["figure.dpi"] elif units == "cm": length = value * 2.54 figsize.append(length) else: figsize = size return figsize
[docs] def apply_to_subplots(method): """Decorator to apply a method to all subplots in the figure.""" @functools.wraps(method) def wrapper(self, *args, **kwargs): success = False for subplot in self.subplots: # try: getattr(subplot, method.__name__)(*args, **kwargs) success = True # except (NotImplementedError, AttributeError): # continue if not success: raise NotImplementedError(f"No subplots have method '{method.__name__}'") return self if self._chainable else None return wrapper
[docs] def iterate_subplots(method): """Decorator to iterate simultaneously over data and subplots.""" @functools.wraps(method) def wrapper(self, data, *args, **kwargs): import xarray as xr groupby = kwargs.pop("groupby", None) if groupby is not None: from earthkit.plots.quickplot import _coerce_to_fieldlist, _group_data fields = _coerce_to_fieldlist(data) grouped = _group_data(fields, groupby) data_items = list(grouped.values()) elif isinstance(data, xr.Dataset): # Yield one DataArray per variable so subplots pair correctly. data_items = [data[v] for v in data.data_vars] else: if not hasattr(data, "__len__"): data = [data] data_items = list(data) if not self.subplots: self.rows, self.columns = rows_cols(len(data_items), rows=self.rows, columns=self.columns) self._setup() for _ in range(len(data_items)): self.add_map() for datum, subplot in zip(data_items, self.subplots): getattr(subplot, method.__name__)(datum, *args, **kwargs) return wrapper
def _determine_row_column(self, row, column): """Determine the row and column of the next subplot.""" if row is not None and column is not None: pass else: if self._last_subplot_location is None: row, column = (0, -1) if row is None: row = self._last_subplot_location[0] if column is None: column = self._last_subplot_location[1] if column < self.columns - 1: column = column + 1 else: column = 0 row = row + 1 self._last_subplot_location = row, column return row, column
[docs] @apply_to_subplots def xticks(self, *args, **kwargs): """ Set x-axis tick locations and labels on every subplot. Forwards all arguments to each subplot's :meth:`Subplot.xticks` method. See :meth:`~earthkit.plots.components.subplots.Subplot.xticks` for the full parameter list. """
[docs] @apply_to_subplots def yticks(self, *args, **kwargs): """ Set y-axis tick locations and labels on every subplot. Forwards all arguments to each subplot's :meth:`Subplot.yticks` method. See :meth:`~earthkit.plots.components.subplots.Subplot.yticks` for the full parameter list. """
[docs] @apply_to_subplots def xlabel(self, *args, **kwargs): """ Set the x-axis label on every subplot. Forwards all arguments to each subplot's :meth:`Subplot.xlabel` method, which ultimately calls :meth:`matplotlib.axes.Axes.set_xlabel`. """
[docs] @apply_to_subplots def ylabel(self, *args, **kwargs): """ Set the y-axis label on every subplot. Forwards all arguments to each subplot's :meth:`Subplot.ylabel` method, which ultimately calls :meth:`matplotlib.axes.Axes.set_ylabel`. """
[docs] def add_subplot(self, row=None, column=None, **kwargs): """ Add a subplot to the figure. Parameters ---------- row : int, optional The row in which to place the subplot. column : int, optional The column in which to place the subplot. kwargs : dict, optional Additional keyword arguments to pass to the :class:`Subplot` constructor. """ row, column = self._determine_row_column(row, column) subplot = Subplot(row=row, column=column, figure=self, chainable=self._chainable, **kwargs) self.subplots.append(subplot) return subplot
[docs] @_defer_subplot def add_map(self, row=None, column=None, domain=None, crs=None, **kwargs): """ Add a map to the figure. Parameters ---------- row : int, optional The row in which to place the subplot. column : int, optional The column in which to place the subplot. domain : earthkit.geo.Domain, optional The domain of the data being plotted. This is used to set the extent and projection of the map. crs : cartopy.crs.CRS, optional The CRS of the map. If not provided, it will be inferred from the domain or set to PlateCarree (regular lat-lon). kwargs : dict, optional Additional keyword arguments to pass to the :class:`Map` constructor. """ if domain is None: domain = self._domain if crs is None: crs = self._crs from earthkit.plots.components.maps import Map row, column = self._determine_row_column(row, column) subplot = Map( row=row, column=column, domain=domain, crs=crs, figure=self, chainable=self._chainable, **kwargs, ) self.subplots.append(subplot) return subplot
[docs] @_defer_subplot def add_timeseries(self, row=None, column=None, **kwargs): """ Add a :class:`~earthkit.plots.temporal.timeseries.TimeSeries` subplot to the figure. Returns a :class:`TimeSeries` instance pre-configured for time series visualisation (sensible default size, automatic time-axis margin removal on show/save). Parameters ---------- row : int, optional The row in which to place the subplot. column : int, optional The column in which to place the subplot. kwargs : dict, optional Additional keyword arguments passed to the :class:`~earthkit.plots.temporal.timeseries.TimeSeries` constructor. Returns ------- TimeSeries Examples -------- >>> fig = ekp.Figure(rows=2, columns=1) >>> ts1 = fig.add_timeseries() >>> ts1.line(t2m_da, x="valid_time", units="celsius") >>> ts2 = fig.add_timeseries() >>> ts2.band(mean_da, std_da, x="valid_time", units="celsius") >>> fig.show() """ from earthkit.plots.temporal.timeseries import TimeSeries row, column = self._determine_row_column(row, column) subplot = TimeSeries( row=row, column=column, size=None, figure=self, chainable=self._chainable, **kwargs, ) self.subplots.append(subplot) return subplot
[docs] def add_hovmoller(self, row=None, column=None, **kwargs): """ Add a :class:`~earthkit.plots.temporal.hovmoller.Hovmoller` subplot to the figure. Returns a :class:`Hovmoller` instance pre-configured for Hovmƶller diagrams (time on one axis, pressure/height on the other, with automatic axis inversion for pressure coordinates). Parameters ---------- row : int, optional The row in which to place the subplot. column : int, optional The column in which to place the subplot. **kwargs : Additional keyword arguments passed to the :class:`~earthkit.plots.temporal.hovmoller.Hovmoller` constructor. Key options include ``time_axis`` (``"x"`` or ``"y"``) and ``invert_vertical`` (``True``, ``False``, or ``"auto"``). Returns ------- Hovmoller Examples -------- >>> fig = ekp.Figure() >>> hov = fig.add_hovmoller() >>> hov.contourf(da, style="auto") >>> fig.show() """ from earthkit.plots.temporal.hovmoller import Hovmoller row, column = self._determine_row_column(row, column) subplot = Hovmoller( row=row, column=column, size=None, figure=self, chainable=self._chainable, **kwargs, ) self.subplots.append(subplot) return subplot
[docs] def add_climatology(self, row=None, column=None, **kwargs): """ Add a :class:`~earthkit.plots.temporal.climatology.Climatology` subplot to the figure. Returns a :class:`Climatology` instance whose :meth:`line` method automatically splits multi-year data by year and remaps each year onto a common Jan-to-Dec x-axis. Parameters ---------- row : int, optional The row in which to place the subplot. column : int, optional The column in which to place the subplot. **kwargs : Additional keyword arguments passed to the :class:`~earthkit.plots.temporal.climatology.Climatology` constructor. Returns ------- Climatology Examples -------- >>> fig = ekp.Figure(rows=1, columns=1) >>> ax = fig.add_climatology() >>> ax.line(da) >>> fig.show() """ from earthkit.plots.temporal.climatology import Climatology row, column = self._determine_row_column(row, column) subplot = Climatology( row=row, column=column, size=None, figure=self, chainable=self._chainable, **kwargs, ) self.subplots.append(subplot) return subplot
[docs] def subplot_titles(self, *args, **kwargs): """ Set the titles of all subplots. Parameters ---------- label : str, optional The text to use in the title. This text can include format keys surrounded by `{}` curly brackets, which will extract metadata from your plotted data layers. unique : bool, optional If True, format keys which are uniform across subplots/layers will produce a single result. For example, if all data layers have the same `variable_name`, only one variable name will appear in the title. If False, each format key will evaluate to a list of values found across subplots/layers. kwargs : dict, optional Additional keyword arguments to pass to :func:`matplotlib.pyplot.title`. """ return [subplot.title(*args, **kwargs) for subplot in self.subplots]
[docs] def distinct_legend_layers(self, subplots=None): """ Get a list of layers with distinct styles. Parameters ---------- subplots : list, optional If provided, only these subplots will be considered when identifying unique styles. """ if subplots is None: subplots = self.subplots subplot_layers = [subplot.distinct_legend_layers for subplot in subplots] subplot_layers = [item for sublist in subplot_layers for item in sublist] groups = [] for layer in subplot_layers: for i in range(len(groups)): if groups[i][0].style == layer.style: groups[i].append(layer) break else: groups.append([layer]) groups = [LayerGroup(layers) for layers in list(groups)] return groups
[docs] @_defer_until_setup @schema.legend.apply() def legend(self, *args, subplots=None, location=None, **kwargs): """ Add legends to the figure. Parameters ---------- subplots : list, optional If provided, only these subplots will have legends. location : str or list, optional The location of the legend. If a list, each item is the location for the corresponding subplot. kwargs : dict, optional Additional keyword arguments to pass to the Subplot legend method. """ import matplotlib.lines as mlines legends = [] anchor = None non_cbar_layers = [] for i, layer in enumerate(self.distinct_legend_layers(subplots)): if isinstance(location, (list, tuple)): loc = location[i] else: loc = location if layer.style is not None: legend = layer.style.legend( layer, *args, location=loc, **kwargs, ) if legend.__class__.__name__ != "Colorbar": non_cbar_layers.append(layer) else: anchor = layer.axes[0].get_anchor() legends.append(legend) if anchor is not None: for layer in non_cbar_layers: for ax in layer.axes: ax.set_anchor(anchor) # Collect proxy-label layers (e.g. from spaghetti or labelled contours) # and render them as a line legend on each subplot that has them. _subplots = subplots if subplots is not None else self.subplots for subplot in _subplots: proxy_handles = [] for layer in subplot.layers: proxy_label = getattr(layer, "proxy_label", None) if proxy_label is not None: color = getattr(layer, "_proxy_color", None) lw = getattr(layer, "_proxy_linewidth", 1.0) if color is None: try: color = layer.mappable.collections[0].get_edgecolor()[0] except (AttributeError, IndexError): color = "black" proxy_handles.append(mlines.Line2D([], [], color=color, linewidth=lw, label=proxy_label)) if proxy_handles: subplot.ax.legend(handles=proxy_handles) return self if self._chainable else legends
[docs] @_defer_until_setup @apply_to_subplots def cities(self, *args, **kwargs): """ Add cities to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.cities`. """
[docs] @_defer_until_setup @apply_to_subplots def coastlines(self, *args, **kwargs): """ Add coastlines to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.coastlines`. """
[docs] @_defer_until_setup @apply_to_subplots def countries(self, *args, **kwargs): """ Add country boundaries to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.countries`. """
[docs] @_defer_until_setup @apply_to_subplots def urban_areas(self, *args, **kwargs): """ Add urban areas to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.urban_areas`. """
[docs] @_defer_until_setup @apply_to_subplots def land(self, *args, **kwargs): """ Add land polygons to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.land`. """
[docs] @_defer_until_setup @apply_to_subplots def borders(self, *args, **kwargs): """ Add country borders to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.borders`. """
[docs] @_defer_until_setup @apply_to_subplots def standard_layers(self, *args, **kwargs): """ Add standard geographic layers to every `Map` subplot in the figure. Parameters ---------- Accepts the same arguments as :meth:`Map.standard_layers`. """
[docs] @_defer_until_setup @apply_to_subplots def administrative_areas(self, *args, **kwargs): """ Add administrative areas to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.administrative_areas`. """
[docs] @_defer_until_setup @apply_to_subplots def stock_img(self, *args, **kwargs): """ Add a stock background image to every :class:`~earthkit.plots.components.maps.Map` subplot. Accepts the same arguments as :meth:`~earthkit.plots.components.maps.Map.stock_img`. """
[docs] @iterate_subplots def block(self, *args, **kwargs): """ Plot a pcolormesh on every subplot in the figure. Deprecated: Use :meth:`pcolormesh` instead. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. style : earthkit.plots.styles.Style, optional The Style to use. If None, a Style is automatically generated from the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.pcolormesh`. """
[docs] @iterate_subplots def gridpoints(self, *args, **kwargs): """ Plot grid point centroids on every subplot in the figure. Parameters ---------- data : xarray.DataArray or earthkit.data.core.Base, optional The data source for which to plot grid_points. x : str, optional The name of the x-coordinate variable in the data source. y : str, optional The name of the y-coordinate variable in the data source. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.scatter`. """
[docs] @iterate_subplots def quickplot(self, *args, **kwargs): """ Auto-detect the best plot type and render data on every subplot. Iterates over data items and subplots simultaneously, calling :meth:`~earthkit.plots.components.subplots.Subplot.quickplot` on each. Parameters ---------- data : xarray.DataArray, xarray.Dataset, or earthkit.data.core.Base The data to plot. **kwargs Additional keyword arguments forwarded to each subplot's :meth:`~earthkit.plots.components.subplots.Subplot.quickplot`. """
[docs] @iterate_subplots def pcolormesh(self, *args, **kwargs): """ Plot a pseudocolor mesh on every subplot in the figure. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. x : str, list, numpy.ndarray, or xarray.DataArray, optional The x values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. y : str, list, numpy.ndarray, or xarray.DataArray, optional The y values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. z : str, list, numpy.ndarray, or xarray.DataArray, optional The z values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. style : earthkit.plots.styles.Style, optional The Style to use. If None, a Style is automatically generated from the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.pcolormesh`. See the `matplotlib pcolormesh documentation <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.pcolormesh.html>`_ for the full list of accepted arguments. """
[docs] @iterate_subplots def grid_cells(self, *args, **kwargs): """ Plot data as grid cells on every subplot in the figure. For HEALPix and octahedral reduced Gaussian grids the fast pixel- sampling ``nnshow`` backends are used automatically. For other grid types, plain pcolormesh rendering is used. Parameters ---------- data : xarray.DataArray or earthkit.data.core.Base, optional The data to plot. x, y, z : str, array-like, or None, optional Explicit coordinates / values. style : earthkit.plots.styles.Style, optional The Style to apply. If None, a Style is automatically generated from the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. grid : str or GridSpec, optional Grid specification to use for rendering. Pass ``"auto"`` (the default) to detect the grid type from the data metadata. **kwargs Additional keyword arguments forwarded to the underlying plot method. """
[docs] @iterate_subplots def point_cloud(self, *args, **kwargs): """ Plot data values as a coloured point cloud on every subplot in the figure. Each data point is rendered as a scatter point coloured by its value. Suitable for sparse or unstructured observation data. Parameters ---------- data : xarray.DataArray or earthkit.data.core.Base, optional The data to plot. x : str, optional The name of the x-coordinate variable in the data source. y : str, optional The name of the y-coordinate variable in the data source. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.scatter`. """
[docs] @iterate_subplots def imshow(self, *args, **kwargs): """ Plot an image on every subplot in the figure. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. x : str, list, numpy.ndarray, or xarray.DataArray, optional The x values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. y : str, list, numpy.ndarray, or xarray.DataArray, optional The y values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. z : str, list, numpy.ndarray, or xarray.DataArray, optional The z values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. style : earthkit.plots.styles.Style, optional The Style to use. If None, a Style is automatically generated from the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.imshow`. """
[docs] @iterate_subplots def plot(self, *args, **kwargs): """ Plot filled contours on every subplot in the figure. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. x : str, list, numpy.ndarray, or xarray.DataArray, optional The x values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. y : str, list, numpy.ndarray, or xarray.DataArray, optional The y values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. z : str, list, numpy.ndarray, or xarray.DataArray, optional The z values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. style : earthkit.plots.styles.Style, optional The Style to use for the filled contour plot. If None, a Style is automatically generated based on the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. """
[docs] @iterate_subplots def contourf(self, *args, **kwargs): """ Plot filled contours on every subplot in the figure. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. x : str, list, numpy.ndarray, or xarray.DataArray, optional The x values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. y : str, list, numpy.ndarray, or xarray.DataArray, optional The y values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. z : str, list, numpy.ndarray, or xarray.DataArray, optional The z values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. style : earthkit.plots.styles.Style, optional The Style to use for the filled contour plot. If None, a Style is automatically generated based on the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.contourf`. See the `matplotlib contourf documentation <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.contourf.html>`_ for the full list of accepted arguments. """
[docs] @iterate_subplots def contour(self, *args, **kwargs): """ Plot contour lines on every subplot in the figure. Parameters ---------- data : list, numpy.ndarray, xarray.DataArray, or earthkit.data.core.Base, optional The data to plot. If None, x, y, and z must be provided. x : str, list, numpy.ndarray, or xarray.DataArray, optional The x values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. y : str, list, numpy.ndarray, or xarray.DataArray, optional The y values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. z : str, list, numpy.ndarray, or xarray.DataArray, optional The z values to plot. If data is provided, this is assumed to be the name of a coordinate in the data. If None, data must be provided. style : earthkit.plots.styles.Style, optional The Style to use for the contour lines. If None, a Style is automatically generated based on the data. units : str, optional Target units for value conversion (e.g. ``"celsius"``). See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. **kwargs Additional keyword arguments passed to :meth:`matplotlib.axes.Axes.contour`. See the `matplotlib contour documentation <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.contour.html>`_ for the full list of accepted arguments. """
[docs] @iterate_subplots def line(self, *args, **kwargs): """ Plot a line on every subplot in the figure. Parameters ---------- data : xarray.DataArray or array-like The data to plot. **kwargs Additional keyword arguments forwarded to each subplot's :meth:`~earthkit.plots.components.subplots.Subplot.line`. """
[docs] @iterate_subplots def multiboxplot(self, *args, **kwargs): """ Plot a multiboxplot on every subplot in the figure. Parameters ---------- data : xarray.DataArray The data to plot. **kwargs Additional keyword arguments forwarded to each subplot's :meth:`~earthkit.plots.components.subplots.Subplot.multiboxplot`. """
def _plot( self, method, data, *args, row=None, col=None, subplot_class=None, subplot_titles=None, rows=None, columns=None, figsize=None, size=None, **kwargs, ): """ Apply a plotting method across panels of an xarray Dataset. This is the generic FacetGrid-style engine for ``Figure``. It splits *data* into a grid of subplots according to *row* and *col*, creates one subplot per panel using *subplot_class*, and calls *method* on each panel's data slice. Parameters ---------- method : str Name of the subplot method to call on each panel (e.g. ``"line"``, ``"bar"``, ``"contourf"``). data : xarray.Dataset or xarray.DataArray The data to distribute across panels. When *data* is a Dataset, ``"variable"`` is a special token for *row* / *col* that means "split by data variable". Any other string is treated as a coordinate name along which to select unique values. *args : Positional arguments forwarded to the subplot method. row : str or None, optional Dimension to vary along rows. Use ``"variable"`` to put each Dataset variable in its own row, or pass a coordinate name (e.g. ``"step"``). Default is ``None`` (single row). col : str or None, optional Dimension to vary along columns. Same tokens as *row*. Default is ``None`` (single column). subplot_class : type, optional Subplot class to instantiate for each panel. Defaults to :class:`~earthkit.plots.components.subplots.Subplot`. subplot_titles : str or None, optional Format string for per-panel titles. Supports metadata placeholders such as ``"{variable_name}"``. Set to ``None`` to suppress titles. rows : int, optional Override the total number of rows in the Figure grid. columns : int, optional Override the total number of columns in the Figure grid. size : tuple, optional Figure size ``(width, height)`` in inches. Defaults to ``(8 * n_cols, 4 * n_rows)``. **kwargs : Additional keyword arguments forwarded to the subplot method. Returns ------- self Returns the Figure so calls can be chained. Examples -------- Two-variable Dataset, one row per variable: >>> fig = ekp.Figure() >>> fig.plot("line", ds, row="variable") >>> fig.show() Variable Ɨ step grid: >>> fig = ekp.Figure() >>> fig.plot("line", ds, row="variable", col="step") >>> fig.show() Single DataArray across ensemble members: >>> fig = ekp.Figure() >>> fig.plot("line", ds["t2m"], col="number") >>> fig.show() """ import xarray as xr if subplot_class is None: subplot_class = Subplot # --- Resolve row/col dimensions into (row_vals, col_vals) lists ------ def _dim_vals(data, dim): """Return the unique values for a panel dimension token.""" if dim is None: return [None] if dim == "variable": if isinstance(data, xr.Dataset): return list(data.data_vars) return [None] # Treat as a coordinate name if isinstance(data, xr.Dataset): coord = data[list(data.data_vars)[0]][dim] else: coord = data[dim] return list(dict.fromkeys(coord.values.tolist())) row_vals = _dim_vals(data, row) col_vals = _dim_vals(data, col) n_rows = rows if rows is not None else len(row_vals) n_cols = columns if columns is not None else len(col_vals) # --- Set up the Figure grid if not already done ---------------------- if size is not None: import warnings warnings.warn( "The 'size' argument is deprecated and will be removed in a future release. Use 'figsize' instead.", DeprecationWarning, stacklevel=2, ) if figsize is None: figsize = size if self.rows is None or self.columns is None: self.rows = n_rows self.columns = n_cols if self.fig is None: if figsize is None: figsize = (8 * n_cols, 4 * n_rows) self._figsize = self._parse_size(figsize) self._setup() # --- Build panels ----------------------------------------------------- def _slice(data, row_dim, row_val, col_dim, col_val): """Extract the DataArray/Dataset slice for one panel.""" result = data for dim, val in ((row_dim, row_val), (col_dim, col_val)): if dim is None or val is None: continue if dim == "variable": result = result[val] if isinstance(result, xr.Dataset) else result else: result = result.sel({dim: val}) return result for r_i, r_val in enumerate(row_vals): for c_i, c_val in enumerate(col_vals): panel_data = _slice(data, row, r_val, col, c_val) sp = subplot_class(row=r_i, column=c_i, figure=self) self.subplots.append(sp) getattr(sp, method)(panel_data, *args, **kwargs) if subplot_titles is not None: try: sp.title(subplot_titles) except Exception: pass return self if self._chainable else None
[docs] def timeseries( self, data, *args, row=None, col=None, plot="line", subplot_titles="{variable_name}", rows=None, columns=None, figsize=None, size=None, xticks=None, yticks=None, xlabel=None, ylabel=None, **kwargs, ): """ Plot time series data across a grid of panels. A convenience wrapper around :meth:`plot` that uses :class:`~earthkit.plots.temporal.timeseries.TimeSeries` subplots and applies time-axis formatting. When *data* is an xarray Dataset with more than one variable, ``row`` defaults to ``"variable"`` so each variable appears in its own row. Parameters ---------- data : xarray.Dataset or xarray.DataArray The time series data to distribute across panels. *args : Positional arguments forwarded to the subplot plot method. row : str or None, optional Dimension to vary along rows. Defaults to ``"variable"`` when *data* is a multi-variable Dataset. col : str or None, optional Dimension to vary along columns (e.g. a coordinate name like ``"step"`` or ``"number"``). Default is ``None``. plot : str, optional Subplot method to call on each panel. Default is ``"line"``. subplot_titles : str or None, optional Per-panel title format string. Default is ``"{variable_name}"``. rows : int, optional Override the total number of rows. columns : int, optional Override the total number of columns. size : tuple, optional Figure size ``(width, height)`` in inches. xticks : str or dict, optional Tick configuration for the x-axis of every panel. yticks : str or dict, optional Tick configuration for the y-axis of every panel. xlabel : str, optional x-axis label applied to every panel. ylabel : str, optional y-axis label applied to every panel. **kwargs : Additional keyword arguments forwarded to the subplot method. Returns ------- self Examples -------- Multi-variable Dataset – one row per variable: >>> fig = ekp.Figure() >>> fig.timeseries(ds) >>> fig.show() Variable Ɨ step grid: >>> fig = ekp.Figure() >>> fig.timeseries(ds, row="variable", col="step") >>> fig.show() """ import xarray as xr from earthkit.plots.temporal.timeseries import TimeSeries # Default row to "variable" for multi-variable Datasets if row is None and col is None and isinstance(data, xr.Dataset) and len(data.data_vars) > 1: row = "variable" if size is not None: import warnings warnings.warn( "The 'size' argument is deprecated and will be removed in a future release. Use 'figsize' instead.", DeprecationWarning, stacklevel=2, ) if figsize is None: figsize = size self._plot( plot, data, *args, row=row, col=col, subplot_class=TimeSeries, subplot_titles=subplot_titles, rows=rows, columns=columns, figsize=figsize, **kwargs, ) # Apply time-axis formatting to every TimeSeries subplot for sp in self.subplots: if isinstance(sp, TimeSeries): if xlabel is not None: sp.xlabel(xlabel) if ylabel is not None: sp.ylabel(ylabel) if xticks is not None: if isinstance(xticks, str): sp.xticks(frequency=xticks) else: sp.xticks(**xticks) if yticks is not None: if isinstance(yticks, str): sp.yticks(frequency=yticks) else: sp.yticks(**yticks) return self if self._chainable else None
[docs] @_defer_until_setup def gridlines(self, *args, sharex=True, sharey=True, **kwargs): """ Add gridlines to every :class:`Map` subplot in the figure. Parameters ---------- sharex : bool, optional If True, only the bottom row of subplots will have x-axis gridlines. sharey : bool, optional If True, only the leftmost column of subplots will have y-axis gridlines. kwargs : dict, optional Additional keyword arguments to pass to the :meth:`Map.gridlines` method. """ draw_labels = kwargs.pop("draw_labels", ["left", "bottom"]) if draw_labels is True: draw_labels = ["left", "right", "bottom", "top"] for subplot in self.subplots: if draw_labels: subplot_draw_labels = [item for item in draw_labels] if sharex and all( sp.domain == subplot.domain for sp in [s for s in self.subplots if s.column == subplot.column] ): if "top" in draw_labels and subplot.row != 0: subplot_draw_labels = [loc for loc in subplot_draw_labels if loc != "top"] if "bottom" in draw_labels and subplot.row != max(sp.row for sp in self.subplots): subplot_draw_labels = [loc for loc in subplot_draw_labels if loc != "bottom"] if sharey and all( sp.domain == subplot.domain for sp in [s for s in self.subplots if s.row == subplot.row] ): if "left" in draw_labels and subplot.column != 0: subplot_draw_labels = [loc for loc in subplot_draw_labels if loc != "left"] if "right" in draw_labels and subplot.column != max(sp.column for sp in self.subplots): subplot_draw_labels = [loc for loc in subplot_draw_labels if loc != "right"] else: subplot_draw_labels = False subplot.gridlines(*args, draw_labels=subplot_draw_labels, **kwargs) return self if self._chainable else None
[docs] @schema.suptitle.apply() def title(self, label=None, unique=True, grouped=True, y=None, **kwargs): """ Add a top-level title to the chart. Parameters ---------- label : str, optional The text to use in the title. This text can include format keys surrounded by `{}` curly brackets, which will extract metadata from your plotted data layers. unique : bool, optional If True, format keys which are uniform across subplots/layers will produce a single result. For example, if all data layers have the same `variable_name`, only one variable name will appear in the title. If False, each format key will evaluate to a list of values found across subplots/layers. grouped : bool, optional If True, a single title will be generated to represent all data layers, with each format key evaluating to a list where layers differ - e.g. `"{variable} at {time}"` might be evaluated to `"temperature and wind at 2023-01-01 00:00". If False, the title will be duplicated by the number of subplots/ layers - e.g. `"{variable} at {time}"` might be evaluated to `"temperature at 2023-01-01 00:00 and wind at 2023-01-01 00:00". kwargs : dict, optional Additional keyword arguments to pass to :func:`matplotlib.pyplot.suptitle`. """ if label is None: label = self._default_title_template label = self.format_string(label, unique, grouped) if y is None: y = self._get_suptitle_y() result = self.fig.suptitle(label, y=y, **kwargs) return self if self._chainable else result
[docs] def set_title(self, label=None, **kwargs): """ Set the top-level title of the figure. Alias for :meth:`title` that matches the matplotlib ``set_title`` convention. Accepts the same arguments. Parameters ---------- label : str, optional The title text. Can contain metadata keys in curly braces, e.g. ``"{variable_name}"``. **kwargs Additional keyword arguments forwarded to :meth:`title`. """ return self.title(label, **kwargs)
[docs] def draw(self): """ Draw the figure and all its subplots. This calls :meth:`matplotlib.backend_bases.FigureCanvasBase.draw` to render the figure and then resets face colors for all layers. """ self.fig.canvas.draw() for subplot in self.subplots: for layer in subplot.layers: layer.reset_facecolors()
def _get_suptitle_y(self): """ Calculate suptitle y position by using the axis positions and estimated title heights. This method uses :meth:`matplotlib.axes.Axes.get_position` to determine the position of axes and calculates an appropriate y position for the suptitle based on the highest subplot and estimated title height. Returns ------- float The y position for the suptitle in figure coordinates. """ if not self.subplots: return 0.95 # Default fallback # Find the highest subplot position max_ax_top = max(ax.get_position().y1 for ax in self.fig.axes) fig_height = self.fig.get_size_inches()[1] # Get the default title font size (or use a reasonable default) title_fontsize = 12 try: # Try to get font size from the first subplot's title first_ax = self.fig.axes[0] if first_ax.get_title(): title_fontsize = first_ax.title.get_fontsize() except IndexError: # IndexError: self.fig.axes is empty pass # Convert font size to figure-relative units using actual figure DPI fig_dpi = self.fig.get_dpi() title_height_fig = (title_fontsize / fig_dpi) / fig_height # Add some padding above the title title_padding = 0.15 # 15% of figure height suptitle_y = max_ax_top + title_height_fig + title_padding return suptitle_y
[docs] def format_string(self, string, unique=True, grouped=True): """ Format a string with the subplot titles. Parameters ---------- string : str The string to format. unique : bool, optional If True, format keys which are uniform across subplots/layers will produce a single result. For example, if all data layers have the same `variable_name`, only one variable name will appear in the title. grouped : bool, optional If True, a single title will be generated to represent all data layers, with each format key evaluating to a list where layers differ - e.g. `"{variable} at {time}"` might be evaluated to `"temperature and wind at 2023-01-01 00:00". If False, the title will be duplicated by the number of subplots/ layers - e.g. `"{variable} at {time}"` might be evaluated to `"temperature at 2023-01-01 00:00 and wind at 2023-01-01 00:00". """ if not grouped: results = [subplot.format_string(string, unique, grouped) for subplot in self.subplots] result = string_utils.list_to_human(results) else: result = formatters.FigureFormatter(self.subplots, unique=unique).format(string) return result
@property def _default_title_template(self): return self.subplots[0]._default_title_template def _release_queue(self): if self._released: return self self._released = True if self._subplot_queue: self.rows, self.columns = rows_cols(len(self._subplot_queue), rows=self.rows, columns=self.columns) self._setup() for item in self._subplot_queue: method, args, kwargs = item method(self, *args, **kwargs) self._subplot_queue.clear() for queued_method, queued_args, queued_kwargs in self._queue: queued_method(self, *queued_args, **queued_kwargs) self._queue.clear() if self.attributions: _location_coords = { "upper left": (0.0, 1.0, "left", "bottom"), "upper center": (0.5, 1.0, "center", "bottom"), "upper right": (1.0, 1.0, "right", "bottom"), "center left": (0.0, 0.5, "left", "center"), "center": (0.5, 0.5, "center", "center"), "center right": (1.0, 0.5, "right", "center"), "lower left": (0.0, -0.02, "left", "top"), "lower center": (0.5, -0.02, "center", "top"), "lower right": (1.0, -0.02, "right", "top"), } # Group attributions by location from collections import defaultdict groups = defaultdict(list) group_kwargs = {} for text, loc, kw in self.attributions: text = self.format_string(text) groups[loc].append(text) if loc not in group_kwargs: group_kwargs[loc] = kw for loc, texts in groups.items(): combined = "; ".join(texts) x, y, ha, va = _location_coords.get(loc, (0.5, -0.02, "center", "top")) text_kwargs = dict( ha=ha, va=va, fontsize=9, color="gray", wrap=True, ) text_kwargs.update(group_kwargs[loc]) self.fig.text(x, y, combined, **text_kwargs) if self.logos: # Place each logo horizontally, bottom-right, with some spacing logo_width = 0.12 # fraction of figure width logo_height = 0.05 # fraction of figure height spacing = 0.01 # horizontal spacing # Start from right, go left for i, image_file in enumerate(reversed(self.logos)): if not os.path.exists(image_file): image_file = find_logo(image_file) import matplotlib.image as mpimg logo = mpimg.imread(image_file) left = 1.0 - (i + 1) * logo_width - i * spacing - 0.05 bottom = -0.05 # 0.01 margin from bottom ax_logo = self.fig.add_axes([left, bottom, logo_width, logo_height], zorder=100) ax_logo.imshow(logo) ax_logo.axis("off") return self def _exit_style_context(self): """Exit the style context if one is active, restoring global rcParams.""" if self._style_context is not None: self._style_context.__exit__(None, None, None) self._style_context = None def _apply_subplot_pre_render(self): """Apply any pre-render hooks on subplots (e.g. tight time axis).""" from earthkit.plots.temporal.timeseries import TimeSeries for subplot in self.subplots: if isinstance(subplot, TimeSeries): subplot._apply_tight_time_axis()
[docs] def show(self, *args, **kwargs): """ Display the figure. This calls :func:`matplotlib.pyplot.show` to display the figure. """ self._cancel_jupyter_display() self._prepare_for_display() try: plt.show(*args, **kwargs) finally: self._exit_style_context() return self if self._chainable else None
[docs] def save(self, *args, bbox_inches="tight", **kwargs): """ Save the figure to a file. Parameters ---------- fname : str or file-like object The file to which to save the figure. bbox_inches : str, optional The bounding box in inches to use when saving the figure. kwargs : dict, optional Additional keyword arguments to pass to :func:`matplotlib.pyplot.savefig`. """ self._cancel_jupyter_display() self._prepare_for_display() try: from matplotlib import rcParams as _rc plt.savefig( *args, bbox_inches=bbox_inches, dpi=kwargs.pop("dpi", _rc["figure.dpi"]), **kwargs, ) finally: self._exit_style_context() return self if self._chainable else None
def _prepare_for_display(self): """Flush the queue and apply pre-render hooks. Safe to call multiple times.""" self._apply_subplot_pre_render() self._release_queue() def _resize(self): """Resize the figure to fit its axes.""" self._release_queue() return resize_figure_to_fit_axes(self.fig)
[docs] def attribution(self, attribution, location="lower center", **kwargs): """ Add an attribution to the figure. Parameters ---------- attribution : str The attribution text to add to the figure. location : str, optional The location of the attribution text. Accepts the same values as matplotlib legend locations: 'upper left', 'upper right', 'lower left', 'lower right', 'upper center', 'lower center', 'center left', 'center right', 'center'. Default is 'lower center'. **kwargs Additional keyword arguments passed to ``matplotlib.figure.Figure.text``. """ entry = (attribution, location, kwargs) if entry not in self.attributions: self.attributions.append(entry) return self if self._chainable else None
def resize_figure_to_fit_axes(fig): """ Adjust the size of a Matplotlib figure so that it fits its axes perfectly. This function calculates the bounding box of all axes in the figure and resizes the figure to fit them exactly, removing any extra whitespace. Parameters ---------- fig : :class:`matplotlib.figure.Figure` A Matplotlib Figure object to resize. Returns ------- :class:`matplotlib.figure.Figure` The resized figure object. """ # Get the current size of the figure and its DPI current_size = fig.get_size_inches() # Initialize variables to find the min/max extents of all axes min_left = 1.0 max_right = 0.0 min_bottom = 1.0 max_top = 0.0 # Loop through all axes to find the outer bounds for ax in fig.axes: bbox = ax.get_position() min_left = min(min_left, bbox.x0) max_right = max(max_right, bbox.x1) min_bottom = min(min_bottom, bbox.y0) max_top = max(max_top, bbox.y1) # Calculate new figure size new_width = (max_right - min_left) * current_size[0] new_height = (max_top - min_bottom) * current_size[1] # Resize figure fig.set_size_inches(new_width, new_height) fig.subplots_adjust(left=0, right=1, top=1, bottom=0) return fig