Source code for earthkit.plots.styles

# Copyright 2024-, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from earthkit.plots import metadata, plottypes, styles
from earthkit.plots.schemas import schema
from earthkit.plots.styles import auto, colors, legends, levels
from earthkit.plots.styles.colors import magics_colors_to_rgb

__all__ = [
    "colors",
    "legends",
    "levels",
    "auto",
    "Style",
    "DEFAULT_STYLE",
    "_STYLE_KWARGS",
    "_OVERRIDE_KWARGS",
    "load_style",
    "list_styles",
]


def linspace_datetime64(start_date, end_date, n):
    """
    Generate a linearly spaced array of datetime64 objects.

    Parameters
    ----------
    start_date : numpy.datetime64
        The starting date.
    end_date : numpy.datetime64
        The ending date.
    n : int
        The number of dates to generate.
    """
    return np.linspace(0, 1, n) * (end_date - start_date) + start_date


def spline_interpolate(x, *ys, n=None):
    """
    Fit a cubic spline to one or more y arrays against x and return a dense
    smooth (x_smooth, y_smooth, ...) tuple.

    Handles datetime64 x values and deduplicates x before fitting (necessary
    when wrapped climatology data contains repeated timestamps).

    Parameters
    ----------
    x : array-like
        Independent variable. May be datetime64 or numeric.
    *ys : array-like
        One or more dependent variable arrays, each the same length as x.
    n : int, optional
        Number of points in the smooth output. Defaults to max(300, len(x)*5).
    """
    from scipy.interpolate import make_interp_spline

    x = np.asarray(x)
    n = n or max(300, len(x) * 5)

    if np.issubdtype(x.dtype, np.datetime64):
        x_smooth = linspace_datetime64(x.min(), x.max(), n)
    else:
        x_smooth = np.linspace(x.min(), x.max(), n)

    # Deduplicate — wrapped climatology data can have repeated timestamps
    _, unique_idx = np.unique(x.astype(np.float64), return_index=True)
    x_fit = x[unique_idx]

    k = min(3, len(x_fit) - 1)
    if k < 1:
        # Fewer than 2 unique points — can't interpolate, return as-is
        result = [x_fit]
        for y in ys:
            result.append(np.asarray(y)[unique_idx])
        return tuple(result)

    result = [x_smooth]
    for y in ys:
        y_fit = np.asarray(y)[unique_idx]
        result.append(make_interp_spline(x_fit, y_fit, k=k)(x_smooth))
    return tuple(result)


def _validate_projection_for_tricontour(ccrs) -> bool:
    """
    Validate that the projection is suitable for tricontour plotting.

    Banned list established from iterative search.
    """
    if ccrs is None:
        return True

    bad_list = [
        "TransverseMercator",
        "Sinusoidal",
        "Robinson",
        "Hammer",
        "EqualEarth",
        "LambertAzimuthalEqualArea",
    ]
    if any(proj in str(ccrs.__class__.__name__) for proj in bad_list):
        raise ValueError(
            f"Projection {ccrs} is not suitable for tricontour plotting. Please use a different projection."
        )
    return True


[docs] class Style: """ A style for plotting data. Parameters ---------- colors : str or list or matplotlib.colors.Colormap, optional The colors to be used in this `Style`. This can be a `named matplotlib colormap <https://matplotlib.org/stable/gallery/color/colormap_reference.html>`__, a list of colors (as named CSS4 colors, hexadecimal colors or three (four)-element lists of RGB(A) values), or a pre-defined matplotlib colormap object. If not provided, the default colormap of the active `schema` will be used. levels : list or earthkit.maps.styles.levels.Levels, optional The levels to use in this `Style`. This can be a list of specific levels, or an earthkit `Levels` object. If not provided, some suitable levels will be generated automatically (experimental!). gradients : list, optional The number of colors to insert between each level in `levels`. If None, one color level will be inserted between each level. normalize : bool, optional If `True` (default), then the colors will be normalized over the level range. units : str, optional The units in which the levels are defined. If this `Style` is used with data not using the given units, then a conversion will be attempted; any data incompatible with these units will not be able to use this `Style`. If `units` are not provided, then data plotted using this `Style` will remain in their original units. See :doc:`/examples/examples/introduction/08-unit-conversion` for examples. units_label : str, optional The label to use in titles and legends to represent the units of the data. legend_style : str, optional The style of legend to use by default with this style. Must be one of `colorbar` (default), `disjoint`, `histogram`, or `None` (no legend). bin_labels : list, optional A list of categorical labels for each bin in the legend. """
[docs] @classmethod def from_dict(cls, kwargs): """Create a `Style` from a dictionary.""" style_type = kwargs.pop("type") name = kwargs.pop("name", None) if "levels" in kwargs: kwargs["levels"] = levels.Levels.from_config(kwargs["levels"]) instance = getattr(styles, style_type)(**kwargs) instance._name = name return instance
def __init__( self, colors=schema.default_cmap, levels=None, gradients=None, normalize=True, units=None, scale_factor=None, units_label=None, legend_style="colorbar", legend_kwargs=None, categories=None, ticks=None, preferred_method="contourf", resample=None, vmin=None, vmax=None, **kwargs, ): # Handle cmap as an alias for colors if "cmap" in kwargs: if colors != schema.default_cmap: raise ValueError("Cannot specify both 'colors' and 'cmap'. They are aliases for the same parameter.") colors = kwargs.pop("cmap") if categories is not None and levels is None: levels = range(len(categories) + 1) self._colors = colors if isinstance(self._colors, (list, tuple)) and schema.color_mode == "magics": self._colors = magics_colors_to_rgb(self._colors) if isinstance(levels, dict): levels = styles.levels.Levels(**levels) self._levels = ( levels if isinstance(levels, styles.levels.Levels) else styles.levels.Levels( levels, categorical=categories is not None or self.__class__.__name__ == "Categorical", ) ) self.normalize = normalize self.gradients = gradients self.resample = resample self._units = units self._units_label = units_label self.scale_factor = scale_factor self._legend_style = legend_style if self._legend_style == "None": self._legend_style = None self._bin_labels = categories self._legend_kwargs = legend_kwargs or dict() if ticks is not None: self._legend_kwargs["ticks"] = ticks # Strip internal earthkit-plots keys that are not valid matplotlib kwargs. kwargs.pop("regrid", None) self._kwargs = kwargs self._preferred_method = preferred_method self._vmin = vmin self._vmax = vmax self._name = None # TODO # def to_yaml(self): # pass # TODO # def to_magics_style(self): # pass def __eq__(self, other): keys = ["_name", "_levels", "_colors"] return compare_attributes(self, other, keys) def _get_config(self): """Return a dict of constructor kwargs representing the current state.""" import copy levels_config = None if hasattr(self._levels, "_levels"): if hasattr(self._levels, "_step") and self._levels._step is not None: levels_config = {"step": self._levels._step} if hasattr(self._levels, "_reference") and self._levels._reference is not None: levels_config["reference"] = self._levels._reference if hasattr(self._levels, "_divergence_point") and self._levels._divergence_point is not None: levels_config["divergence_point"] = self._levels._divergence_point else: levels_config = self._levels._levels config = { "colors": self._colors, "levels": levels_config, "gradients": self.gradients, "normalize": self.normalize, "units": self._units, "scale_factor": self.scale_factor, "units_label": self._units_label, "legend_style": self._legend_style, "legend_kwargs": copy.deepcopy(self._legend_kwargs) if self._legend_kwargs else None, "categories": self._bin_labels, "preferred_method": self._preferred_method, "resample": self.resample, "vmin": self._vmin, "vmax": self._vmax, } config.update(copy.deepcopy(self._kwargs)) return config
[docs] def with_overrides(self, **overrides): """ Create a copy of this style with some parameters overridden. This method creates a new Style instance with the same configuration as the current one, but with specific parameters overridden. The original style is not modified. Parameters ---------- **overrides : dict Keyword arguments to override in the new style. Common parameters include: - colors (or cmap): color scheme or colormap - levels: contour levels - gradients: gradient steps between levels - normalize: whether to normalize colors - units: data units - legend_style: type of legend - categories: categorical labels - ticks: tick locations Returns ------- Style A new Style instance with overridden parameters. Examples -------- >>> style = Style(colors="viridis", levels=[0, 10, 20, 30]) >>> new_style = style.with_overrides(levels=[0, 5, 10, 15, 20]) >>> # original style is unchanged """ if "cmap" in overrides and "colors" in overrides: raise ValueError("Cannot specify both 'cmap' and 'colors'. They are aliases for the same parameter.") if "cmap" in overrides: overrides["colors"] = overrides.pop("cmap") config = self._get_config() config.update(overrides) return self.__class__(**config)
[docs] def levels(self, data=None): """ Generate levels specific to some data. Parameters ---------- data : numpy.ndarray or xarray.DataArray or earthkit.data.core.Base The data for which to generate a list of levels. Returns ------- list """ if data is None: if self._levels._levels is not None: return self._levels._levels else: raise ValueError("this style uses dynamic levels; include the `data` argument to generate levels") return self._levels.apply(data)
@property def extend(self): """Convenience access to 'extend' kwarg.""" return self._kwargs.get("extend") @property def units(self): """Formatted units for use in figure text.""" if self._units_label is not None: return self._units_label elif self._units is not None: return self._units
[docs] def apply_scale_factor(self, values): """Apply the scale factor to some values.""" if self.scale_factor is not None: values *= self.scale_factor return values
[docs] def convert_units(self, values, source_units, short_name=""): """ Convert some values from their source units to this `Style`'s units. Parameters ---------- values : numpy.ndarray The values to convert from their source units to this `Style`'s units. source_units : str The source units of the given values. short_name : str, optional The short name of the variable, which is used to make extra assumptions about the data's unit covnersion (for example, temperature anomalies need special consideration when converting between Celsius and Kelvin). """ if self._units is None or source_units is None: return values # For temperature anomalies we do not want to convert values, just # change the units string if "anomaly" in short_name.lower() and metadata.units.anomaly_equivalence(source_units): return values return metadata.units.convert(values, source_units, self._units)
[docs] def to_matplotlib_kwargs(self, data, extend_levels=True): """ Generate matplotlib arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ # vmin/vmax path: produce a continuous Normalize instead of BoundaryNorm. # Only used when vmin or vmax is explicitly set AND no levels are configured # (explicit levels always take precedence over vmin/vmax). has_levels = self._levels._levels is not None or self._levels._step is not None if (self._vmin is not None or self._vmax is not None) and not has_levels: vmin = self._vmin if self._vmin is not None else (float(np.nanmin(data)) if data is not None else None) vmax = self._vmax if self._vmax is not None else (float(np.nanmax(data)) if data is not None else None) # Resolve colormap directly — avoid cmap_and_norm which requires # discrete levels to build its ListedColormap. colors_spec = self._colors if isinstance(colors_spec, mpl.colors.Colormap): cmap = colors_spec elif isinstance(colors_spec, str): cmap = mpl.colormaps[colors_spec] else: # List of colours: build a continuous LinearSegmentedColormap. cmap = mpl.colors.LinearSegmentedColormap.from_list("", colors_spec) cmap = cmap.with_extremes(bad=(0, 0, 0, 0)) norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) return {**{"cmap": cmap, "norm": norm}, **self._kwargs} levels = self.levels(data) if self.gradients is not None: self._legend_kwargs.setdefault("ticks", None) # Let matplotlib auto-generate ticks return colors.gradients( levels, self._colors, self.gradients, self.normalize, **self._kwargs, ) cmap, norm = styles.colors.cmap_and_norm( self._colors, levels, self.normalize, self.extend, extend_levels=extend_levels, ) cmap = cmap.with_extremes(bad=(0, 0, 0, 0)) return { **{"cmap": cmap, "norm": norm, "levels": levels}, **self._kwargs, }
@staticmethod def _xy_for_contour(x, y): if x.ndim == 1 and y.ndim == 1: x, y = np.meshgrid(x, y) return x, y @staticmethod def _xy_for_scatter(x, y): if x.ndim == 2 and y.ndim == 2 and x.shape == y.shape: x = x.flatten() y = y.flatten() return x, y
[docs] def to_contourf_kwargs(self, data): """ Generate `contourf` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ kwargs = self.to_matplotlib_kwargs(data) # hatches is only valid for contourf when used by Hatched subclass; # popped here so it doesn't leak into plain contourf calls from base Style. kwargs.pop("hatches", None) return kwargs
[docs] def to_contour_kwargs(self, data): """ Generate `contour` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ return self.to_matplotlib_kwargs(data)
[docs] def to_pcolormesh_kwargs(self, data): """ Generate `pcolormesh` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ kwargs = self.to_matplotlib_kwargs(data, extend_levels=False) kwargs.pop("levels", None) kwargs.pop("transform_first", None) kwargs.pop("extend", None) kwargs.pop("labels", None) kwargs.pop("hatches", None) return kwargs
[docs] def to_add_geometries_kwargs(self, data): """ Generate `add_geometries` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ kwargs = self.to_matplotlib_kwargs(data, extend_levels=False) kwargs.pop("levels", None) kwargs.pop("extend", None) kwargs.pop("labels", None) kwargs.pop("linecolors", None) kwargs.pop("hatches", None) masked = np.ma.masked_invalid(data) kwargs["facecolor"] = kwargs.get("cmap")(kwargs.get("norm")(masked)) return kwargs
[docs] def to_scatter_kwargs(self, data): """ Generate `scatter` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ kwargs = self.to_matplotlib_kwargs(data, extend_levels=False) kwargs.pop("levels", None) return kwargs
[docs] def to_quiver_kwargs(self, data): """ Generate `quiver` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ kwargs = self.to_matplotlib_kwargs(data) kwargs.pop("levels", None) return kwargs
[docs] def plot(self, *args, **kwargs): """Plot the data using the `Style`'s defaults.""" return self.pcolormesh(*args, **kwargs)
[docs] def contourf(self, ax, x, y, values, *args, **kwargs): """ Plot shaded contours using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.contourf`. """ if kwargs.get("interpolate") is None and values.ndim == 1: return self.tricontourf(ax, x, y, values, *args, **kwargs) kwargs = {**self.to_contourf_kwargs(values), **kwargs} x, y = self._xy_for_contour(x, y) return ax.contourf(x, y, values, *args, **kwargs)
[docs] def quiver(self, ax, x, y, u, v, *args, **kwargs): """ Plot quiver arrows using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. u : numpy.ndarray The u-component of the data to be plotted. v : numpy.ndarray The v-component of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.quiver`. """ if self.extend is not None or kwargs.get("extend") is not None: raise ValueError("'extend' is not supported for quiver styles. Remove 'extend' from your Style definition.") cmap = None norm = None magnitude = None kwargs = {**self._kwargs, **kwargs} if self._colors: magnitude = np.sqrt(u**2 + v**2) kwargs = {**self.to_quiver_kwargs(magnitude), **kwargs} cmap = kwargs.pop("cmap", None) norm = kwargs.pop("norm", None) if cmap and norm and magnitude is not None: mappable = ax.quiver(x, y, u, v, magnitude.ravel(), *args, cmap=cmap, norm=norm, **kwargs) else: mappable = ax.quiver(x, y, u, v, *args, **kwargs) # Mark as non-coloured so colorbar is skipped mappable.cmap = None mappable.norm = None return mappable
def barbs(self, ax, x, y, u, v, *args, **kwargs): if self.extend is not None or kwargs.get("extend") is not None: raise ValueError("'extend' is not supported for barbs styles. Remove 'extend' from your Style definition.") cmap = None norm = None magnitude = None kwargs = {**self._kwargs, **kwargs} if self._colors: magnitude = np.sqrt(u**2 + v**2) kwargs = {**self.to_quiver_kwargs(magnitude), **kwargs} cmap = kwargs.pop("cmap", None) norm = kwargs.pop("norm", None) if cmap and norm and magnitude is not None: mappable = ax.barbs(x, y, u, v, magnitude.ravel(), *args, cmap=cmap, norm=norm, **kwargs) else: mappable = ax.barbs(x, y, u, v, *args, **kwargs) mappable.cmap = None mappable.norm = None return mappable
[docs] def streamplot(self, ax, x, y, u, v, *args, **kwargs): """ Plot streamlines using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. u : numpy.ndarray The u-component of the data to be plotted. v : numpy.ndarray The v-component of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.streamplot`. """ return ax.streamplot(x, y, u, v, *args, **kwargs)
[docs] def tricontour(self, ax, x, y, values, *args, **kwargs): """ Plot triangulated contour lines using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.tricontour`. """ kwargs = {**self.to_contour_kwargs(values), **kwargs} kwargs.pop("labels", None) _validate_projection_for_tricontour(kwargs.get("transform", None)) return ax.tricontour(x, y, values, *args, **kwargs)
[docs] def tricontourf(self, ax, x, y, values, *args, **kwargs): """ Plot triangulated shaded contours using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.tricontourf`. """ kwargs = {**self.to_contourf_kwargs(values), **kwargs} _validate_projection_for_tricontour(kwargs.get("transform", None)) return ax.tricontourf(x, y, values, *args, **kwargs)
[docs] def tripcolor(self, ax, x, y, values, *args, **kwargs): """ Plot triangulated shaded contours using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.tricontourf`. """ kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs} return ax.tripcolor(x, y, values, *args, **kwargs)
[docs] def contour(self, ax, x, y, values, *args, **kwargs): """ Plot line contours using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.contour`. """ if kwargs.get("interpolate") is None and values.ndim == 1: return self.tricontour(ax, x, y, values, *args, **kwargs) kwargs = {**self.to_contour_kwargs(values), **kwargs} kwargs.pop("labels", None) x, y = self._xy_for_contour(x, y) return ax.contour(x, y, values, *args, **kwargs)
[docs] def pcolormesh(self, ax, x, y, values, *args, **kwargs): """ Plot a pcolormesh using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.pcolormesh`. """ kwargs.pop("transform_first", None) kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs} result = ax.pcolormesh(x, y, values, *args, **kwargs) return result
[docs] def imshow(self, ax, _x, _y, values, *args, **kwargs): """ Plot an image using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates (unused directly; pass ``extent=`` if needed). y : numpy.ndarray The y coordinates (unused directly; pass ``extent=`` if needed). values : numpy.ndarray The 2-D image array to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.imshow`. """ kwargs.pop("transform_first", None) kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs} result = ax.imshow(values, *args, **kwargs) return result
[docs] def scatter(self, ax, x, y, values, s=3, *args, **kwargs): """ Plot a scatter plot using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.scatter`. """ kwargs.pop("transform_first", None) missing_values = kwargs.pop("missing_values", None) original_kwargs = kwargs.copy() if values is not None: kwargs = {**self.to_scatter_kwargs(values), **kwargs} kwargs.pop("extend", None) if values is not None and missing_values is not None and np.isnan(values).any(): nan_mask = np.isnan(values.ravel()) missing_x = x.ravel()[nan_mask] missing_y = y.ravel()[nan_mask] missing_s = s.ravel()[nan_mask] if not np.isscalar(s) else s x = x.ravel()[~nan_mask] y = y.ravel()[~nan_mask] values = values.ravel()[~nan_mask] s = s.ravel()[~nan_mask] if not np.isscalar(s) else s if missing_values: ax.scatter( missing_x, missing_y, s=missing_s, *args, **{**original_kwargs, **missing_values}, ) if values is not None: kwargs["c"] = kwargs.pop("c", values) if isinstance(kwargs.get("c"), str): kwargs.pop("cmap", None) kwargs.pop("norm", None) x, y = self._xy_for_scatter(x, y) return ax.scatter(x, y, s=s, *args, **kwargs)
[docs] def quantiles( self, ax, x, y, values, *args, type="band", quantiles=[0, 0.25, 0.5, 0.75, 1], **kwargs, ): """ Compute and plot quantiles using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The coordinates of the data to be plotted. values : numpy.ndarray The data of which to compute the statistics. The computation is applied along axis 0, so the size along axis 1 must match the size of the coordinates. type : "box" | "band" The type of plot used to represent the quantile ranges. quantiles : array_like Probabilities of the quantiles to compute. Values must be between 0 and 1 inclusive. """ quantiles = np.sort(quantiles) stats = np.quantile(values, quantiles, axis=0) if type == "box": mappable = self.boxplot(ax, x, stats, *args, **kwargs) elif type == "band": mappable = self.bandplot(ax, x, stats, *args, **kwargs) else: raise NotImplementedError(f"Plot of type {type} not yet implemented.") return mappable
def to_quantiles_kwargs(self, n, c=None): if c is None: c = self._colors if isinstance(c, str): # Generate symmetric colors if c in mpl.colormaps: c = mpl.colormaps[c](np.abs(np.linspace(-1.0, 1.0, n))) else: c = colors.symmetric_from_color(c, n) return {"colors": c, **self._kwargs} def bandplot(self, ax, x, values, colors=None, *args, **kwargs): num_bands = len(values) - 1 kwargs = {**self.to_quantiles_kwargs(num_bands, c=colors), **kwargs} return plottypes.bandplot(ax, x, values, *args, **kwargs) def boxplot(self, ax, x, values, colors=None, *args, **kwargs): num_bands = len(values) - 1 kwargs = {**self.to_quantiles_kwargs(num_bands, c=colors), **kwargs} return plottypes.boxplot(ax, x, values, *args, **kwargs)
[docs] def line(self, ax, x, y, values, *args, **kwargs): """ Plot a scatter plot using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.scatter`. """ kwargs.pop("transform_first", None) if values is not None: kwargs = {**self.to_scatter_kwargs(values), **kwargs} kwargs.pop("extend", None) kwargs["c"] = kwargs.pop("c", values) mode = kwargs.pop("drawstyle", "linear") if mode not in ("spline", "smooth"): # Real matplotlib drawstyle — put it back for ax.plot if mode != "linear": kwargs["drawstyle"] = mode mode = "linear" if mode == "spline": x_smooth, y_smooth = spline_interpolate(x, y) marker = kwargs.pop("marker", None) mappable = ax.plot(x_smooth, y_smooth, *args, **kwargs) if marker is not None: kwargs.pop("linewidth", None) color = mappable[0].get_color() self.line( ax, x, y, values, *args, marker=marker, color=color, linewidth=0, **kwargs, ) elif mode == "smooth": if np.issubdtype(x.dtype, np.datetime64): x_smooth = linspace_datetime64(x.min(), x.max(), max(300, len(x) * 5)) else: x_smooth = np.linspace(x.min(), x.max(), max(300, len(x) * 5)) from scipy.interpolate import interp1d func = interp1d( x, y, axis=0, # interpolate along columns bounds_error=False, kind="linear", fill_value=(y[0], y[-1]), ) y_smooth = func(x_smooth) marker = kwargs.pop("marker", None) mappable = ax.plot(x_smooth, y_smooth, *args, **kwargs) if marker is not None: kwargs.pop("linewidth", None) color = mappable[0].get_color() self.line( ax, x, y, values, *args, marker=marker, color=color, linewidth=0, **kwargs, ) else: mappable = ax.plot(x, y, *args, **kwargs) return mappable
[docs] def bar(self, ax, x, y, values, *args, **kwargs): """ Plot a scatter plot using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.scatter`. """ kwargs.pop("transform_first", None) if values is not None: kwargs = {**self.to_scatter_kwargs(values), **kwargs} kwargs.pop("extend", None) kwargs["c"] = kwargs.pop("c", values) mappable = ax.bar(x, y, *args, **kwargs) return mappable
[docs] def stripes(self, ax, x, y, values, *args, **kwargs): """ Plot climate stripes: one vertical colored bar per data point. Each bar spans between the midpoints of adjacent x positions, and is colored according to the y value mapped through this Style's colormap and normalisation. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot. x : numpy.ndarray The x coordinates (e.g. time values as matplotlib date floats or numeric values). y : numpy.ndarray The data values used to determine each bar's color. values : numpy.ndarray Unused (kept for API consistency with other Style methods). **kwargs Forwarded to :func:`matplotlib.axes.Axes.broken_barh`. Common useful kwargs: ``ymin``, ``ymax``. """ import matplotlib.dates as mdates from matplotlib.transforms import blended_transform_factory kwargs.pop("transform_first", None) # Convert datetime64 x-values to matplotlib float dates for arithmetic x = np.asarray(x) if np.issubdtype(x.dtype, np.datetime64): x_num = mdates.date2num(x.astype("datetime64[ms]").astype("object")) else: x_num = x.astype(float) y = np.asarray(y, dtype=float) # Build bar edge positions at midpoints between adjacent x values if len(x_num) == 1: edges = np.array([x_num[0] - 0.5, x_num[0] + 0.5]) else: mids = (x_num[:-1] + x_num[1:]) / 2.0 edges = np.concatenate([ [x_num[0] - (mids[0] - x_num[0])], mids, [x_num[-1] + (x_num[-1] - mids[-1])], ]) # Resolve colormap + norm from the style. When neither explicit vmin/vmax # nor discrete levels have been configured, default to a symmetric # zero-centred range based on the absolute maximum of the data. has_explicit_range = ( self._vmin is not None or self._vmax is not None or self._levels._levels is not None or self._levels._step is not None ) if not has_explicit_range: absmax = float(np.nanmax(np.abs(y))) colors_spec = self._colors if isinstance(colors_spec, mpl.colors.Colormap): cmap = colors_spec elif isinstance(colors_spec, str): cmap = mpl.colormaps[colors_spec] else: cmap = mpl.colors.LinearSegmentedColormap.from_list("", colors_spec) cmap = cmap.with_extremes(bad=(0, 0, 0, 0)) norm = mpl.colors.Normalize(vmin=-absmax, vmax=absmax) else: mpl_kwargs = self.to_scatter_kwargs(y) cmap = mpl_kwargs["cmap"] norm = mpl_kwargs["norm"] ymin = kwargs.pop("ymin", 0) ymax = kwargs.pop("ymax", 1) kwargs.pop("transform", None) # Blended transform: x in data coordinates, y in axes fraction (0=bottom, 1=top) transform = blended_transform_factory(ax.transData, ax.transAxes) # Use ax.axvspan-style rendering: draw all bars via a PolyCollection for # efficiency, one Rectangle patch per stripe. from matplotlib.patches import Rectangle for i, val in enumerate(y): x0 = edges[i] width = edges[i + 1] - x0 rect = Rectangle( (x0, ymin), width, ymax - ymin, facecolor=cmap(norm(val)), linewidth=0, transform=transform, ) ax.add_patch(rect) # Set x limits in data coordinates to cover the full stripe range ax.set_xlim(edges[0], edges[-1]) # Return a ScalarMappable so colorbar() works sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array(y) return sm
[docs] def values_to_colors(self, values, data=None): """ Convert a value or list of values to colors based on this `Style`. Parameters ---------- values : float or list of floats The values to convert to colors on this `Style`'s color scale. """ mpl_kwargs = self.to_matplotlib_kwargs(data=data) cmap = mpl_kwargs["cmap"] norm = mpl_kwargs["norm"] return cmap(norm(values))
[docs] def legend(self, *args, **kwargs): """ Create the default legend for this `Style`. Parameters ---------- *args : list Arguments to be passed to the legend method. **kwargs : dict Keyword arguments to be passed to the legend method. """ if self._legend_style is None: return try: method = getattr(self, self._legend_style) except AttributeError: raise AttributeError(f"invalid legend type '{self._legend_style}'") return method(*args, **kwargs)
[docs] def colorbar(self, *args, **kwargs): """Create a colorbar legend for this `Style`.""" # A colorbar requires a ScalarMappable. If the layer's mappable is a # list (e.g. Line2D objects from a line plot) skip silently. layer = args[0] if args else None if layer is not None: m = layer.mappable if not hasattr(m, "cmap"): return None # Uncoloured quiver/barbs have no array data assigned — norm.vmin stays None if getattr(getattr(m, "norm", None), "vmin", None) is None: return None ticks = self._legend_kwargs.get("ticks") if ticks is None and self._levels._levels is not None: if len(np.unique(np.ediff1d(self._levels._levels))) != 1: self._legend_kwargs["ticks"] = self._levels._levels if self.extend is not None: self._legend_kwargs.setdefault("extend", self.extend) return styles.legends.colorbar(*args, **kwargs)
[docs] def disjoint(self, *args, **kwargs): """Create a disjoint legend for this `Style`.""" return styles.legends.disjoint(*args, **kwargs)
[docs] def vector(self, *args, **kwargs): """Create a vector legend for this `Style`.""" return styles.legends.vector(*args, **kwargs)
[docs] def quiverkey(self, *args, **kwargs): """Create a quiverkey legend for this `Style`.""" return styles.legends.quiverkey(*args, **kwargs)
[docs] def save_legend(self, data=None, label=None, filename="legend.png", transparent=True, **kwargs): """ Save a standalone image of the legend associated with this `Style`. Parameters ---------- data : earthkit.data.core.Base, optional It can sometimes be useful to pass some data in order to automatically generate legend labels or color ranges, depending on the `Style`. label : str, optional The label to use for the legend. If not provided, the label will be generated automatically based on the `Style`'s units. filename : str The name of the file to save the legend to. The file format will be determined by the file extension (e.g., `.png`, `.pdf`, etc.). By default, the legend will be saved as a file named `legend.png`. transparent : bool, optional If `True`, the saved legend will have a transparent background. Otherwise, it will have a white background. Default is `True`. **kwargs Additional keyword arguments to be passed to the legend method. """ from earthkit.plots import Subplot if label is None and data is None: label = "{units}" if self.units is not None else "" plot_data = [[1, 2], [3, 4]] chart = Subplot() chart.contourf(plot_data, style=self) legend = chart.legend(label=label, **kwargs)[0] chart.fig.canvas.draw() bbox = legend.ax.get_window_extent().transformed(chart.fig.dpi_scale_trans.inverted()) title_bbox = legend.ax.xaxis.label.get_window_extent().transformed(chart.fig.dpi_scale_trans.inverted()) x, y = chart.fig.get_size_inches() xmod, ymod = (0.05, 0.01) if legend.orientation == "horizontal" else (0.01, 0.05) bbox.x0 = min(bbox.x0, title_bbox.x0) - x * xmod bbox.x1 = max(bbox.x1, title_bbox.x1) + x * xmod bbox.y0 = min(bbox.y0, title_bbox.y0) - y * ymod bbox.y1 = max(bbox.y1, title_bbox.y1) + y * ymod chart.ax.set_xlim(bbox.x0, bbox.x1) chart.ax.set_ylim(bbox.y0, bbox.y1) plt.savefig(filename, dpi="figure", bbox_inches=bbox, transparent=transparent)
class Categorical(Style): """A style for plotting categorical data.""" def __init__(self, *args, **kwargs): kwargs["legend_style"] = "disjoint" if isinstance(kwargs.get("levels"), dict): kwargs["levels"], kwargs["categories"] = zip(*kwargs["levels"].items()) if "categories" not in kwargs: kwargs["categories"] = kwargs.get("levels") super().__init__(*args, **kwargs) class Vector(Style): """A style for plotting vector data. Parameters ---------- colors : str or list or matplotlib.colors.Colormap, optional The colors to be used in this `Style`. This can be a named matplotlib colormap, a list of colors (as named CSS4 colors, hexadecimal colors or three (four)-element lists of RGB(A) values), or a pre-defined matplotlib colormap object. If not provided, the default colormap of the active `schema` will be used. **kwargs Additional keyword arguments to be passed to the vector methods. """ def __init__(self, *args, colors=None, **kwargs): kwargs.setdefault("legend_style", "vector") super().__init__(*args, colors=colors, **kwargs) class Quiver(Vector): def __init__(self, *args, colors=None, preferred_method="quiver", **kwargs): kwargs.setdefault("legend_style", "vector") super().__init__(*args, colors=colors, preferred_method=preferred_method, **kwargs) class Contour(Style): """ A style for plotting contour data. Parameters ---------- colors : str or list or matplotlib.colors.Colormap, optional The colors to be used in this `Style`. This can be a named matplotlib colormap, a list of colors (as named CSS4 colors, hexadecimal colors or three (four)-element lists of RGB(A) values), or a pre-defined matplotlib colormap object. If not provided, the default colormap of the active `schema` will be used. colors : str or list or matplotlib.colors.Colormap, optional The colors to be used for contour lines. This can be a named matplotlib colormap, a list of colors (as named CSS4 colors, hexadecimal colors or three (four)-element lists of RGB(A) values), or a pre-defined matplotlib colormap object. If not provided, the default colormap of the active ``schema`` will be used. labels : bool, optional If `True`, then contour labels will be displayed. label_kwargs : dict, optional Additional keyword arguments to be passed to the `clabel` method. interpolate : bool, optional If `True`, then the data will be interpolated before plotting. preferred_method : str, optional The preferred method for plotting the data. Must be one of `contour`, `contourf`, or `pcolormesh`. **kwargs Additional keyword arguments to be passed to the `contour` or `contourf` method. """ def __init__( self, colors=schema.default_cmap, labels=False, label_kwargs=None, interpolate=True, preferred_method="contour", **kwargs, ): super().__init__(colors=colors, preferred_method=preferred_method, **kwargs) self.labels = labels self._label_kwargs = label_kwargs or dict() self._interpolate = interpolate def _get_config(self): import copy config = super()._get_config() config["labels"] = self.labels config["label_kwargs"] = copy.deepcopy(self._label_kwargs) if self._label_kwargs else None config["interpolate"] = self._interpolate return config def plot(self, *args, **kwargs): """Plot the data using the `Style`'s defaults.""" if self._colors is not None: if self._interpolate: return self.contourf(*args, **kwargs) else: return self.pcolormesh(*args, **kwargs) else: return self.contour(*args, **kwargs) def to_contour_kwargs(self, data): """ Generate `contour` arguments required for plotting data in this `Style`. Parameters ---------- data : numpy.ndarray The data to be plotted using this `Style`. """ levels = self.levels(data) # Use cmap+norm path only when colors is a named colormap or Colormap object. # Plain colour strings (e.g. "black", "#ff0000") and lists of colours are # passed directly as colors= to ax.contour, matching matplotlib's own API. colors_is_cmap = isinstance(self._colors, mpl.colors.Colormap) or ( isinstance(self._colors, str) and self._colors in mpl.colormaps ) if self._colors is not None and colors_is_cmap and len(levels) > 1: cmap, norm = styles.colors.cmap_and_norm( self._colors, levels, self.normalize, self.extend, ) return { **{"cmap": cmap, "norm": norm, "levels": levels}, **self._kwargs, } return {"levels": levels, "colors": self._colors, **self._kwargs} def contourf(self, ax, x, y, values, *args, **kwargs): """ Plot shaded contours using this `Style`. Parameters ---------- ax : matplotlib.axes.Axes The axes on which to plot the data. x : numpy.ndarray The x coordinates of the data to be plotted. y : numpy.ndarray The y coordinates of the data to be plotted. values : numpy.ndarray The values of the data to be plotted. **kwargs Any additional arguments accepted by `matplotlib.axes.Axes.contourf`. """ mappable = super().contourf(ax, x, y, values, *args, **kwargs) return mappable def contour(self, *args, **kwargs): """ Plot line contours using this `Style`. Parameters ---------- *args The positional arguments to pass to the `contour` method. **kwargs The keyword arguments to pass to the `contour` method. """ mappable = super().contour(*args, **kwargs) if self.labels: self.contour_labels(mappable, **self._label_kwargs) return mappable def contour_labels( self, mappable, label_fontsize=7, label_colors=None, label_frequency=1, label_background=None, label_fmt=None, ): """ Add labels to a contour plot. Parameters ---------- mappable : matplotlib.contour.ContourSet The contour plot to which to add labels. label_fontsize : int, optional The fontsize of the labels. label_colors : str or list, optional The colors of the labels. label_frequency : int, optional The frequency of contour levels at which to add labels. label_background : str, optional The background color of the labels. label_fmt : str, optional The string format of the labels. """ clabels = mappable.axes.clabel( mappable, mappable.levels[0::label_frequency], inline=True, fontsize=label_fontsize, colors=label_colors, fmt=label_fmt, inline_spacing=2, ) if label_background is not None: for label in clabels: label.set_backgroundcolor(label_background) return clabels class Hatched(Contour): """ A style for plotting hatched contours. Parameters ---------- *args The positional arguments to pass to the `Contour` constructor. hatches : str, optional The pattern of hatching to use. background_colors : list, optional The colors to use for the background of the hatched contours. **kwargs The keyword arguments to pass to the `Contour` constructor. """ def __init__(self, *args, hatches=".", background_colors=None, **kwargs): super().__init__(*args, **kwargs) self.hatches = hatches self._foreground_colors = self._colors self._colors = background_colors or [(0, 0, 0, 0)] def __eq__(self, other): keys = ["_levels", "_colors", "_foreground_colors", "hatches"] return all([getattr(self, key, None) == getattr(other, key, None) for key in keys]) def contourf(self, *args, **kwargs): """ Plot hatched shaded contours using this `Style`. Parameters ---------- *args The positional arguments to pass to the `contourf` method. **kwargs The keyword arguments to pass to the `contourf` method. """ mappable = super().contourf(*args, hatches=self.hatches, **kwargs) linecolors = colors.expand(self._foreground_colors, mappable.levels) mappable.set_edgecolors(linecolors) mappable.set_facecolors([(0, 0, 0, 0)] * len(mappable.levels)) mappable.set_linewidth(0) return mappable def colorbar(self, *args, **kwargs): """ Create a colorbar legend for this `Style`. Parameters ---------- *args The positional arguments to pass to the `colorbar` method. **kwargs The keyword arguments to pass to the `colorbar` method. """ colorbar = super().colorbar(*args, **kwargs) levels = colorbar.mappable.levels linecolors = colors.expand(self._foreground_colors, levels) for i, artist in enumerate(colorbar.solids_patches): artist.set_edgecolor(linecolors[i]) return colorbar def disjoint(self, layer, *args, **kwargs): """ Create a disjoint legend for this `Style`. Parameters ---------- layer : earthkit.maps.charts.layers.Layer The layer for which to create a legend. *args The positional arguments to pass to the `dis **kwargs The keyword arguments to pass to the `disjoint` method. """ legend = super().disjoint(layer, *args, **kwargs) linecolors = colors.expand(self._foreground_colors, layer.mappable.levels) for color, artist in zip(linecolors, legend.get_patches()): artist.set_edgecolor(color) artist.set_linewidth(0.0) return legend DEFAULT_STYLE = Style() DEFAULT_VECTOR_STYLE = Vector() _STYLE_KWARGS = list(set(inspect.getfullargspec(Style)[0] + inspect.getfullargspec(Contour)[0])) _OVERRIDE_KWARGS = ["labels"] def compare_attributes(self, other, keys): def is_equal(x, y): is_x_arr = isinstance(x, np.ndarray) is_y_arr = isinstance(y, np.ndarray) # Check if both are numpy arrays if is_x_arr and is_y_arr: return np.array_equal(x, y) # If one is an array and the other isn't, they are not equal if is_x_arr != is_y_arr: return False # Default to standard equality check for non-array types return x == y # Use the is_equal function for each key in your check try: return all(is_equal(getattr(self, key, None), getattr(other, key, None)) for key in keys) except ValueError: return False # Imported here (after DEFAULT_STYLE) to avoid circular imports, since auto.py # imports from this module. from earthkit.plots.styles.auto import list_styles, load_style # noqa: E402, F401 def get_style_class(method: str) -> type[Style]: """ Get the `Style` class for a given plotting method. Parameters ---------- method : str The plotting method for which to get the `Style` class for. Returns ------- type[Style] The default `Style` class for the given plotting method. """ if method in ["quiver"]: return Quiver elif method in ["barbs"]: return Vector elif method in ["contour", "contourf"]: return Contour return Style