# Copyright 2024-, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from earthkit.plots import metadata, plottypes, styles
from earthkit.plots.schemas import schema
from earthkit.plots.styles import auto, colors, legends, levels
from earthkit.plots.styles.colors import magics_colors_to_rgb
__all__ = [
"colors",
"legends",
"levels",
"auto",
"Style",
"DEFAULT_STYLE",
"_STYLE_KWARGS",
"_OVERRIDE_KWARGS",
"load_style",
"list_styles",
]
def linspace_datetime64(start_date, end_date, n):
"""
Generate a linearly spaced array of datetime64 objects.
Parameters
----------
start_date : numpy.datetime64
The starting date.
end_date : numpy.datetime64
The ending date.
n : int
The number of dates to generate.
"""
return np.linspace(0, 1, n) * (end_date - start_date) + start_date
def spline_interpolate(x, *ys, n=None):
"""
Fit a cubic spline to one or more y arrays against x and return a dense
smooth (x_smooth, y_smooth, ...) tuple.
Handles datetime64 x values and deduplicates x before fitting (necessary
when wrapped climatology data contains repeated timestamps).
Parameters
----------
x : array-like
Independent variable. May be datetime64 or numeric.
*ys : array-like
One or more dependent variable arrays, each the same length as x.
n : int, optional
Number of points in the smooth output. Defaults to max(300, len(x)*5).
"""
from scipy.interpolate import make_interp_spline
x = np.asarray(x)
n = n or max(300, len(x) * 5)
if np.issubdtype(x.dtype, np.datetime64):
x_smooth = linspace_datetime64(x.min(), x.max(), n)
else:
x_smooth = np.linspace(x.min(), x.max(), n)
# Deduplicate — wrapped climatology data can have repeated timestamps
_, unique_idx = np.unique(x.astype(np.float64), return_index=True)
x_fit = x[unique_idx]
k = min(3, len(x_fit) - 1)
if k < 1:
# Fewer than 2 unique points — can't interpolate, return as-is
result = [x_fit]
for y in ys:
result.append(np.asarray(y)[unique_idx])
return tuple(result)
result = [x_smooth]
for y in ys:
y_fit = np.asarray(y)[unique_idx]
result.append(make_interp_spline(x_fit, y_fit, k=k)(x_smooth))
return tuple(result)
def _validate_projection_for_tricontour(ccrs) -> bool:
"""
Validate that the projection is suitable for tricontour plotting.
Banned list established from iterative search.
"""
if ccrs is None:
return True
bad_list = [
"TransverseMercator",
"Sinusoidal",
"Robinson",
"Hammer",
"EqualEarth",
"LambertAzimuthalEqualArea",
]
if any(proj in str(ccrs.__class__.__name__) for proj in bad_list):
raise ValueError(
f"Projection {ccrs} is not suitable for tricontour plotting. Please use a different projection."
)
return True
[docs]
class Style:
"""
A style for plotting data.
Parameters
----------
colors : str or list or matplotlib.colors.Colormap, optional
The colors to be used in this `Style`. This can be a
`named matplotlib colormap
<https://matplotlib.org/stable/gallery/color/colormap_reference.html>`__,
a list of colors (as named CSS4 colors, hexadecimal colors or three
(four)-element lists of RGB(A) values), or a pre-defined matplotlib
colormap object. If not provided, the default colormap of the active
`schema` will be used.
levels : list or earthkit.maps.styles.levels.Levels, optional
The levels to use in this `Style`. This can be a list of specific
levels, or an earthkit `Levels` object. If not provided, some suitable
levels will be generated automatically (experimental!).
gradients : list, optional
The number of colors to insert between each level in `levels`. If None,
one color level will be inserted between each level.
normalize : bool, optional
If `True` (default), then the colors will be normalized over the level
range.
units : str, optional
The units in which the levels are defined. If this `Style` is used with
data not using the given units, then a conversion will be attempted;
any data incompatible with these units will not be able to use this
`Style`. If `units` are not provided, then data plotted using this
`Style` will remain in their original units. See
:doc:`/examples/examples/introduction/08-unit-conversion` for
examples.
units_label : str, optional
The label to use in titles and legends to represent the units of the
data.
legend_style : str, optional
The style of legend to use by default with this style. Must be one of
`colorbar` (default), `disjoint`, `histogram`, or `None` (no legend).
bin_labels : list, optional
A list of categorical labels for each bin in the legend.
"""
[docs]
@classmethod
def from_dict(cls, kwargs):
"""Create a `Style` from a dictionary."""
style_type = kwargs.pop("type")
name = kwargs.pop("name", None)
if "levels" in kwargs:
kwargs["levels"] = levels.Levels.from_config(kwargs["levels"])
instance = getattr(styles, style_type)(**kwargs)
instance._name = name
return instance
def __init__(
self,
colors=schema.default_cmap,
levels=None,
gradients=None,
normalize=True,
units=None,
scale_factor=None,
units_label=None,
legend_style="colorbar",
legend_kwargs=None,
categories=None,
ticks=None,
preferred_method="contourf",
resample=None,
vmin=None,
vmax=None,
**kwargs,
):
# Handle cmap as an alias for colors
if "cmap" in kwargs:
if colors != schema.default_cmap:
raise ValueError("Cannot specify both 'colors' and 'cmap'. They are aliases for the same parameter.")
colors = kwargs.pop("cmap")
if categories is not None and levels is None:
levels = range(len(categories) + 1)
self._colors = colors
if isinstance(self._colors, (list, tuple)) and schema.color_mode == "magics":
self._colors = magics_colors_to_rgb(self._colors)
if isinstance(levels, dict):
levels = styles.levels.Levels(**levels)
self._levels = (
levels
if isinstance(levels, styles.levels.Levels)
else styles.levels.Levels(
levels,
categorical=categories is not None or self.__class__.__name__ == "Categorical",
)
)
self.normalize = normalize
self.gradients = gradients
self.resample = resample
self._units = units
self._units_label = units_label
self.scale_factor = scale_factor
self._legend_style = legend_style
if self._legend_style == "None":
self._legend_style = None
self._bin_labels = categories
self._legend_kwargs = legend_kwargs or dict()
if ticks is not None:
self._legend_kwargs["ticks"] = ticks
# Strip internal earthkit-plots keys that are not valid matplotlib kwargs.
kwargs.pop("regrid", None)
self._kwargs = kwargs
self._preferred_method = preferred_method
self._vmin = vmin
self._vmax = vmax
self._name = None
# TODO
# def to_yaml(self):
# pass
# TODO
# def to_magics_style(self):
# pass
def __eq__(self, other):
keys = ["_name", "_levels", "_colors"]
return compare_attributes(self, other, keys)
def _get_config(self):
"""Return a dict of constructor kwargs representing the current state."""
import copy
levels_config = None
if hasattr(self._levels, "_levels"):
if hasattr(self._levels, "_step") and self._levels._step is not None:
levels_config = {"step": self._levels._step}
if hasattr(self._levels, "_reference") and self._levels._reference is not None:
levels_config["reference"] = self._levels._reference
if hasattr(self._levels, "_divergence_point") and self._levels._divergence_point is not None:
levels_config["divergence_point"] = self._levels._divergence_point
else:
levels_config = self._levels._levels
config = {
"colors": self._colors,
"levels": levels_config,
"gradients": self.gradients,
"normalize": self.normalize,
"units": self._units,
"scale_factor": self.scale_factor,
"units_label": self._units_label,
"legend_style": self._legend_style,
"legend_kwargs": copy.deepcopy(self._legend_kwargs) if self._legend_kwargs else None,
"categories": self._bin_labels,
"preferred_method": self._preferred_method,
"resample": self.resample,
"vmin": self._vmin,
"vmax": self._vmax,
}
config.update(copy.deepcopy(self._kwargs))
return config
[docs]
def with_overrides(self, **overrides):
"""
Create a copy of this style with some parameters overridden.
This method creates a new Style instance with the same configuration
as the current one, but with specific parameters overridden. The original
style is not modified.
Parameters
----------
**overrides : dict
Keyword arguments to override in the new style. Common parameters include:
- colors (or cmap): color scheme or colormap
- levels: contour levels
- gradients: gradient steps between levels
- normalize: whether to normalize colors
- units: data units
- legend_style: type of legend
- categories: categorical labels
- ticks: tick locations
Returns
-------
Style
A new Style instance with overridden parameters.
Examples
--------
>>> style = Style(colors="viridis", levels=[0, 10, 20, 30])
>>> new_style = style.with_overrides(levels=[0, 5, 10, 15, 20])
>>> # original style is unchanged
"""
if "cmap" in overrides and "colors" in overrides:
raise ValueError("Cannot specify both 'cmap' and 'colors'. They are aliases for the same parameter.")
if "cmap" in overrides:
overrides["colors"] = overrides.pop("cmap")
config = self._get_config()
config.update(overrides)
return self.__class__(**config)
[docs]
def levels(self, data=None):
"""
Generate levels specific to some data.
Parameters
----------
data : numpy.ndarray or xarray.DataArray or earthkit.data.core.Base
The data for which to generate a list of levels.
Returns
-------
list
"""
if data is None:
if self._levels._levels is not None:
return self._levels._levels
else:
raise ValueError("this style uses dynamic levels; include the `data` argument to generate levels")
return self._levels.apply(data)
@property
def extend(self):
"""Convenience access to 'extend' kwarg."""
return self._kwargs.get("extend")
@property
def units(self):
"""Formatted units for use in figure text."""
if self._units_label is not None:
return self._units_label
elif self._units is not None:
return self._units
[docs]
def apply_scale_factor(self, values):
"""Apply the scale factor to some values."""
if self.scale_factor is not None:
values *= self.scale_factor
return values
[docs]
def convert_units(self, values, source_units, short_name=""):
"""
Convert some values from their source units to this `Style`'s units.
Parameters
----------
values : numpy.ndarray
The values to convert from their source units to this `Style`'s
units.
source_units : str
The source units of the given values.
short_name : str, optional
The short name of the variable, which is used to make extra
assumptions about the data's unit covnersion (for example,
temperature anomalies need special consideration when converting
between Celsius and Kelvin).
"""
if self._units is None or source_units is None:
return values
# For temperature anomalies we do not want to convert values, just
# change the units string
if "anomaly" in short_name.lower() and metadata.units.anomaly_equivalence(source_units):
return values
return metadata.units.convert(values, source_units, self._units)
[docs]
def to_matplotlib_kwargs(self, data, extend_levels=True):
"""
Generate matplotlib arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
# vmin/vmax path: produce a continuous Normalize instead of BoundaryNorm.
# Only used when vmin or vmax is explicitly set AND no levels are configured
# (explicit levels always take precedence over vmin/vmax).
has_levels = self._levels._levels is not None or self._levels._step is not None
if (self._vmin is not None or self._vmax is not None) and not has_levels:
vmin = self._vmin if self._vmin is not None else (float(np.nanmin(data)) if data is not None else None)
vmax = self._vmax if self._vmax is not None else (float(np.nanmax(data)) if data is not None else None)
# Resolve colormap directly — avoid cmap_and_norm which requires
# discrete levels to build its ListedColormap.
colors_spec = self._colors
if isinstance(colors_spec, mpl.colors.Colormap):
cmap = colors_spec
elif isinstance(colors_spec, str):
cmap = mpl.colormaps[colors_spec]
else:
# List of colours: build a continuous LinearSegmentedColormap.
cmap = mpl.colors.LinearSegmentedColormap.from_list("", colors_spec)
cmap = cmap.with_extremes(bad=(0, 0, 0, 0))
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
return {**{"cmap": cmap, "norm": norm}, **self._kwargs}
levels = self.levels(data)
if self.gradients is not None:
self._legend_kwargs.setdefault("ticks", None) # Let matplotlib auto-generate ticks
return colors.gradients(
levels,
self._colors,
self.gradients,
self.normalize,
**self._kwargs,
)
cmap, norm = styles.colors.cmap_and_norm(
self._colors,
levels,
self.normalize,
self.extend,
extend_levels=extend_levels,
)
cmap = cmap.with_extremes(bad=(0, 0, 0, 0))
return {
**{"cmap": cmap, "norm": norm, "levels": levels},
**self._kwargs,
}
@staticmethod
def _xy_for_contour(x, y):
if x.ndim == 1 and y.ndim == 1:
x, y = np.meshgrid(x, y)
return x, y
@staticmethod
def _xy_for_scatter(x, y):
if x.ndim == 2 and y.ndim == 2 and x.shape == y.shape:
x = x.flatten()
y = y.flatten()
return x, y
[docs]
def to_contourf_kwargs(self, data):
"""
Generate `contourf` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
kwargs = self.to_matplotlib_kwargs(data)
# hatches is only valid for contourf when used by Hatched subclass;
# popped here so it doesn't leak into plain contourf calls from base Style.
kwargs.pop("hatches", None)
return kwargs
[docs]
def to_contour_kwargs(self, data):
"""
Generate `contour` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
return self.to_matplotlib_kwargs(data)
[docs]
def to_pcolormesh_kwargs(self, data):
"""
Generate `pcolormesh` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
kwargs = self.to_matplotlib_kwargs(data, extend_levels=False)
kwargs.pop("levels", None)
kwargs.pop("transform_first", None)
kwargs.pop("extend", None)
kwargs.pop("labels", None)
kwargs.pop("hatches", None)
return kwargs
[docs]
def to_add_geometries_kwargs(self, data):
"""
Generate `add_geometries` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
kwargs = self.to_matplotlib_kwargs(data, extend_levels=False)
kwargs.pop("levels", None)
kwargs.pop("extend", None)
kwargs.pop("labels", None)
kwargs.pop("linecolors", None)
kwargs.pop("hatches", None)
masked = np.ma.masked_invalid(data)
kwargs["facecolor"] = kwargs.get("cmap")(kwargs.get("norm")(masked))
return kwargs
[docs]
def to_scatter_kwargs(self, data):
"""
Generate `scatter` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
kwargs = self.to_matplotlib_kwargs(data, extend_levels=False)
kwargs.pop("levels", None)
return kwargs
[docs]
def to_quiver_kwargs(self, data):
"""
Generate `quiver` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
kwargs = self.to_matplotlib_kwargs(data)
kwargs.pop("levels", None)
return kwargs
[docs]
def plot(self, *args, **kwargs):
"""Plot the data using the `Style`'s defaults."""
return self.pcolormesh(*args, **kwargs)
[docs]
def contourf(self, ax, x, y, values, *args, **kwargs):
"""
Plot shaded contours using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.contourf`.
"""
if kwargs.get("interpolate") is None and values.ndim == 1:
return self.tricontourf(ax, x, y, values, *args, **kwargs)
kwargs = {**self.to_contourf_kwargs(values), **kwargs}
x, y = self._xy_for_contour(x, y)
return ax.contourf(x, y, values, *args, **kwargs)
[docs]
def quiver(self, ax, x, y, u, v, *args, **kwargs):
"""
Plot quiver arrows using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
u : numpy.ndarray
The u-component of the data to be plotted.
v : numpy.ndarray
The v-component of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.quiver`.
"""
if self.extend is not None or kwargs.get("extend") is not None:
raise ValueError("'extend' is not supported for quiver styles. Remove 'extend' from your Style definition.")
cmap = None
norm = None
magnitude = None
kwargs = {**self._kwargs, **kwargs}
if self._colors:
magnitude = np.sqrt(u**2 + v**2)
kwargs = {**self.to_quiver_kwargs(magnitude), **kwargs}
cmap = kwargs.pop("cmap", None)
norm = kwargs.pop("norm", None)
if cmap and norm and magnitude is not None:
mappable = ax.quiver(x, y, u, v, magnitude.ravel(), *args, cmap=cmap, norm=norm, **kwargs)
else:
mappable = ax.quiver(x, y, u, v, *args, **kwargs)
# Mark as non-coloured so colorbar is skipped
mappable.cmap = None
mappable.norm = None
return mappable
def barbs(self, ax, x, y, u, v, *args, **kwargs):
if self.extend is not None or kwargs.get("extend") is not None:
raise ValueError("'extend' is not supported for barbs styles. Remove 'extend' from your Style definition.")
cmap = None
norm = None
magnitude = None
kwargs = {**self._kwargs, **kwargs}
if self._colors:
magnitude = np.sqrt(u**2 + v**2)
kwargs = {**self.to_quiver_kwargs(magnitude), **kwargs}
cmap = kwargs.pop("cmap", None)
norm = kwargs.pop("norm", None)
if cmap and norm and magnitude is not None:
mappable = ax.barbs(x, y, u, v, magnitude.ravel(), *args, cmap=cmap, norm=norm, **kwargs)
else:
mappable = ax.barbs(x, y, u, v, *args, **kwargs)
mappable.cmap = None
mappable.norm = None
return mappable
[docs]
def streamplot(self, ax, x, y, u, v, *args, **kwargs):
"""
Plot streamlines using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
u : numpy.ndarray
The u-component of the data to be plotted.
v : numpy.ndarray
The v-component of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.streamplot`.
"""
return ax.streamplot(x, y, u, v, *args, **kwargs)
[docs]
def tricontour(self, ax, x, y, values, *args, **kwargs):
"""
Plot triangulated contour lines using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.tricontour`.
"""
kwargs = {**self.to_contour_kwargs(values), **kwargs}
kwargs.pop("labels", None)
_validate_projection_for_tricontour(kwargs.get("transform", None))
return ax.tricontour(x, y, values, *args, **kwargs)
[docs]
def tricontourf(self, ax, x, y, values, *args, **kwargs):
"""
Plot triangulated shaded contours using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.tricontourf`.
"""
kwargs = {**self.to_contourf_kwargs(values), **kwargs}
_validate_projection_for_tricontour(kwargs.get("transform", None))
return ax.tricontourf(x, y, values, *args, **kwargs)
[docs]
def tripcolor(self, ax, x, y, values, *args, **kwargs):
"""
Plot triangulated shaded contours using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.tricontourf`.
"""
kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs}
return ax.tripcolor(x, y, values, *args, **kwargs)
[docs]
def contour(self, ax, x, y, values, *args, **kwargs):
"""
Plot line contours using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.contour`.
"""
if kwargs.get("interpolate") is None and values.ndim == 1:
return self.tricontour(ax, x, y, values, *args, **kwargs)
kwargs = {**self.to_contour_kwargs(values), **kwargs}
kwargs.pop("labels", None)
x, y = self._xy_for_contour(x, y)
return ax.contour(x, y, values, *args, **kwargs)
[docs]
def pcolormesh(self, ax, x, y, values, *args, **kwargs):
"""
Plot a pcolormesh using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.pcolormesh`.
"""
kwargs.pop("transform_first", None)
kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs}
result = ax.pcolormesh(x, y, values, *args, **kwargs)
return result
[docs]
def imshow(self, ax, _x, _y, values, *args, **kwargs):
"""
Plot an image using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates (unused directly; pass ``extent=`` if needed).
y : numpy.ndarray
The y coordinates (unused directly; pass ``extent=`` if needed).
values : numpy.ndarray
The 2-D image array to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.imshow`.
"""
kwargs.pop("transform_first", None)
kwargs = {**self.to_pcolormesh_kwargs(values), **kwargs}
result = ax.imshow(values, *args, **kwargs)
return result
[docs]
def scatter(self, ax, x, y, values, s=3, *args, **kwargs):
"""
Plot a scatter plot using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.scatter`.
"""
kwargs.pop("transform_first", None)
missing_values = kwargs.pop("missing_values", None)
original_kwargs = kwargs.copy()
if values is not None:
kwargs = {**self.to_scatter_kwargs(values), **kwargs}
kwargs.pop("extend", None)
if values is not None and missing_values is not None and np.isnan(values).any():
nan_mask = np.isnan(values.ravel())
missing_x = x.ravel()[nan_mask]
missing_y = y.ravel()[nan_mask]
missing_s = s.ravel()[nan_mask] if not np.isscalar(s) else s
x = x.ravel()[~nan_mask]
y = y.ravel()[~nan_mask]
values = values.ravel()[~nan_mask]
s = s.ravel()[~nan_mask] if not np.isscalar(s) else s
if missing_values:
ax.scatter(
missing_x,
missing_y,
s=missing_s,
*args,
**{**original_kwargs, **missing_values},
)
if values is not None:
kwargs["c"] = kwargs.pop("c", values)
if isinstance(kwargs.get("c"), str):
kwargs.pop("cmap", None)
kwargs.pop("norm", None)
x, y = self._xy_for_scatter(x, y)
return ax.scatter(x, y, s=s, *args, **kwargs)
[docs]
def quantiles(
self,
ax,
x,
y,
values,
*args,
type="band",
quantiles=[0, 0.25, 0.5, 0.75, 1],
**kwargs,
):
"""
Compute and plot quantiles using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The coordinates of the data to be plotted.
values : numpy.ndarray
The data of which to compute the statistics. The computation
is applied along axis 0, so the size along axis 1 must match
the size of the coordinates.
type : "box" | "band"
The type of plot used to represent the quantile ranges.
quantiles : array_like
Probabilities of the quantiles to compute. Values must be
between 0 and 1 inclusive.
"""
quantiles = np.sort(quantiles)
stats = np.quantile(values, quantiles, axis=0)
if type == "box":
mappable = self.boxplot(ax, x, stats, *args, **kwargs)
elif type == "band":
mappable = self.bandplot(ax, x, stats, *args, **kwargs)
else:
raise NotImplementedError(f"Plot of type {type} not yet implemented.")
return mappable
def to_quantiles_kwargs(self, n, c=None):
if c is None:
c = self._colors
if isinstance(c, str):
# Generate symmetric colors
if c in mpl.colormaps:
c = mpl.colormaps[c](np.abs(np.linspace(-1.0, 1.0, n)))
else:
c = colors.symmetric_from_color(c, n)
return {"colors": c, **self._kwargs}
def bandplot(self, ax, x, values, colors=None, *args, **kwargs):
num_bands = len(values) - 1
kwargs = {**self.to_quantiles_kwargs(num_bands, c=colors), **kwargs}
return plottypes.bandplot(ax, x, values, *args, **kwargs)
def boxplot(self, ax, x, values, colors=None, *args, **kwargs):
num_bands = len(values) - 1
kwargs = {**self.to_quantiles_kwargs(num_bands, c=colors), **kwargs}
return plottypes.boxplot(ax, x, values, *args, **kwargs)
[docs]
def line(self, ax, x, y, values, *args, **kwargs):
"""
Plot a scatter plot using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.scatter`.
"""
kwargs.pop("transform_first", None)
if values is not None:
kwargs = {**self.to_scatter_kwargs(values), **kwargs}
kwargs.pop("extend", None)
kwargs["c"] = kwargs.pop("c", values)
mode = kwargs.pop("drawstyle", "linear")
if mode not in ("spline", "smooth"):
# Real matplotlib drawstyle — put it back for ax.plot
if mode != "linear":
kwargs["drawstyle"] = mode
mode = "linear"
if mode == "spline":
x_smooth, y_smooth = spline_interpolate(x, y)
marker = kwargs.pop("marker", None)
mappable = ax.plot(x_smooth, y_smooth, *args, **kwargs)
if marker is not None:
kwargs.pop("linewidth", None)
color = mappable[0].get_color()
self.line(
ax,
x,
y,
values,
*args,
marker=marker,
color=color,
linewidth=0,
**kwargs,
)
elif mode == "smooth":
if np.issubdtype(x.dtype, np.datetime64):
x_smooth = linspace_datetime64(x.min(), x.max(), max(300, len(x) * 5))
else:
x_smooth = np.linspace(x.min(), x.max(), max(300, len(x) * 5))
from scipy.interpolate import interp1d
func = interp1d(
x,
y,
axis=0, # interpolate along columns
bounds_error=False,
kind="linear",
fill_value=(y[0], y[-1]),
)
y_smooth = func(x_smooth)
marker = kwargs.pop("marker", None)
mappable = ax.plot(x_smooth, y_smooth, *args, **kwargs)
if marker is not None:
kwargs.pop("linewidth", None)
color = mappable[0].get_color()
self.line(
ax,
x,
y,
values,
*args,
marker=marker,
color=color,
linewidth=0,
**kwargs,
)
else:
mappable = ax.plot(x, y, *args, **kwargs)
return mappable
[docs]
def bar(self, ax, x, y, values, *args, **kwargs):
"""
Plot a scatter plot using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.scatter`.
"""
kwargs.pop("transform_first", None)
if values is not None:
kwargs = {**self.to_scatter_kwargs(values), **kwargs}
kwargs.pop("extend", None)
kwargs["c"] = kwargs.pop("c", values)
mappable = ax.bar(x, y, *args, **kwargs)
return mappable
[docs]
def stripes(self, ax, x, y, values, *args, **kwargs):
"""
Plot climate stripes: one vertical colored bar per data point.
Each bar spans between the midpoints of adjacent x positions, and is
colored according to the y value mapped through this Style's colormap
and normalisation.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot.
x : numpy.ndarray
The x coordinates (e.g. time values as matplotlib date floats or
numeric values).
y : numpy.ndarray
The data values used to determine each bar's color.
values : numpy.ndarray
Unused (kept for API consistency with other Style methods).
**kwargs
Forwarded to :func:`matplotlib.axes.Axes.broken_barh`.
Common useful kwargs: ``ymin``, ``ymax``.
"""
import matplotlib.dates as mdates
from matplotlib.transforms import blended_transform_factory
kwargs.pop("transform_first", None)
# Convert datetime64 x-values to matplotlib float dates for arithmetic
x = np.asarray(x)
if np.issubdtype(x.dtype, np.datetime64):
x_num = mdates.date2num(x.astype("datetime64[ms]").astype("object"))
else:
x_num = x.astype(float)
y = np.asarray(y, dtype=float)
# Build bar edge positions at midpoints between adjacent x values
if len(x_num) == 1:
edges = np.array([x_num[0] - 0.5, x_num[0] + 0.5])
else:
mids = (x_num[:-1] + x_num[1:]) / 2.0
edges = np.concatenate([
[x_num[0] - (mids[0] - x_num[0])],
mids,
[x_num[-1] + (x_num[-1] - mids[-1])],
])
# Resolve colormap + norm from the style. When neither explicit vmin/vmax
# nor discrete levels have been configured, default to a symmetric
# zero-centred range based on the absolute maximum of the data.
has_explicit_range = (
self._vmin is not None
or self._vmax is not None
or self._levels._levels is not None
or self._levels._step is not None
)
if not has_explicit_range:
absmax = float(np.nanmax(np.abs(y)))
colors_spec = self._colors
if isinstance(colors_spec, mpl.colors.Colormap):
cmap = colors_spec
elif isinstance(colors_spec, str):
cmap = mpl.colormaps[colors_spec]
else:
cmap = mpl.colors.LinearSegmentedColormap.from_list("", colors_spec)
cmap = cmap.with_extremes(bad=(0, 0, 0, 0))
norm = mpl.colors.Normalize(vmin=-absmax, vmax=absmax)
else:
mpl_kwargs = self.to_scatter_kwargs(y)
cmap = mpl_kwargs["cmap"]
norm = mpl_kwargs["norm"]
ymin = kwargs.pop("ymin", 0)
ymax = kwargs.pop("ymax", 1)
kwargs.pop("transform", None)
# Blended transform: x in data coordinates, y in axes fraction (0=bottom, 1=top)
transform = blended_transform_factory(ax.transData, ax.transAxes)
# Use ax.axvspan-style rendering: draw all bars via a PolyCollection for
# efficiency, one Rectangle patch per stripe.
from matplotlib.patches import Rectangle
for i, val in enumerate(y):
x0 = edges[i]
width = edges[i + 1] - x0
rect = Rectangle(
(x0, ymin),
width,
ymax - ymin,
facecolor=cmap(norm(val)),
linewidth=0,
transform=transform,
)
ax.add_patch(rect)
# Set x limits in data coordinates to cover the full stripe range
ax.set_xlim(edges[0], edges[-1])
# Return a ScalarMappable so colorbar() works
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array(y)
return sm
[docs]
def values_to_colors(self, values, data=None):
"""
Convert a value or list of values to colors based on this `Style`.
Parameters
----------
values : float or list of floats
The values to convert to colors on this `Style`'s color scale.
"""
mpl_kwargs = self.to_matplotlib_kwargs(data=data)
cmap = mpl_kwargs["cmap"]
norm = mpl_kwargs["norm"]
return cmap(norm(values))
[docs]
def legend(self, *args, **kwargs):
"""
Create the default legend for this `Style`.
Parameters
----------
*args : list
Arguments to be passed to the legend method.
**kwargs : dict
Keyword arguments to be passed to the legend method.
"""
if self._legend_style is None:
return
try:
method = getattr(self, self._legend_style)
except AttributeError:
raise AttributeError(f"invalid legend type '{self._legend_style}'")
return method(*args, **kwargs)
[docs]
def colorbar(self, *args, **kwargs):
"""Create a colorbar legend for this `Style`."""
# A colorbar requires a ScalarMappable. If the layer's mappable is a
# list (e.g. Line2D objects from a line plot) skip silently.
layer = args[0] if args else None
if layer is not None:
m = layer.mappable
if not hasattr(m, "cmap"):
return None
# Uncoloured quiver/barbs have no array data assigned — norm.vmin stays None
if getattr(getattr(m, "norm", None), "vmin", None) is None:
return None
ticks = self._legend_kwargs.get("ticks")
if ticks is None and self._levels._levels is not None:
if len(np.unique(np.ediff1d(self._levels._levels))) != 1:
self._legend_kwargs["ticks"] = self._levels._levels
if self.extend is not None:
self._legend_kwargs.setdefault("extend", self.extend)
return styles.legends.colorbar(*args, **kwargs)
[docs]
def disjoint(self, *args, **kwargs):
"""Create a disjoint legend for this `Style`."""
return styles.legends.disjoint(*args, **kwargs)
[docs]
def vector(self, *args, **kwargs):
"""Create a vector legend for this `Style`."""
return styles.legends.vector(*args, **kwargs)
[docs]
def quiverkey(self, *args, **kwargs):
"""Create a quiverkey legend for this `Style`."""
return styles.legends.quiverkey(*args, **kwargs)
[docs]
def save_legend(self, data=None, label=None, filename="legend.png", transparent=True, **kwargs):
"""
Save a standalone image of the legend associated with this `Style`.
Parameters
----------
data : earthkit.data.core.Base, optional
It can sometimes be useful to pass some data in order to
automatically generate legend labels or color ranges, depending on
the `Style`.
label : str, optional
The label to use for the legend. If not provided, the label will be
generated automatically based on the `Style`'s units.
filename : str
The name of the file to save the legend to. The file format will
be determined by the file extension (e.g., `.png`, `.pdf`, etc.).
By default, the legend will be saved as a file named `legend.png`.
transparent : bool, optional
If `True`, the saved legend will have a transparent background.
Otherwise, it will have a white background. Default is `True`.
**kwargs
Additional keyword arguments to be passed to the legend method.
"""
from earthkit.plots import Subplot
if label is None and data is None:
label = "{units}" if self.units is not None else ""
plot_data = [[1, 2], [3, 4]]
chart = Subplot()
chart.contourf(plot_data, style=self)
legend = chart.legend(label=label, **kwargs)[0]
chart.fig.canvas.draw()
bbox = legend.ax.get_window_extent().transformed(chart.fig.dpi_scale_trans.inverted())
title_bbox = legend.ax.xaxis.label.get_window_extent().transformed(chart.fig.dpi_scale_trans.inverted())
x, y = chart.fig.get_size_inches()
xmod, ymod = (0.05, 0.01) if legend.orientation == "horizontal" else (0.01, 0.05)
bbox.x0 = min(bbox.x0, title_bbox.x0) - x * xmod
bbox.x1 = max(bbox.x1, title_bbox.x1) + x * xmod
bbox.y0 = min(bbox.y0, title_bbox.y0) - y * ymod
bbox.y1 = max(bbox.y1, title_bbox.y1) + y * ymod
chart.ax.set_xlim(bbox.x0, bbox.x1)
chart.ax.set_ylim(bbox.y0, bbox.y1)
plt.savefig(filename, dpi="figure", bbox_inches=bbox, transparent=transparent)
class Categorical(Style):
"""A style for plotting categorical data."""
def __init__(self, *args, **kwargs):
kwargs["legend_style"] = "disjoint"
if isinstance(kwargs.get("levels"), dict):
kwargs["levels"], kwargs["categories"] = zip(*kwargs["levels"].items())
if "categories" not in kwargs:
kwargs["categories"] = kwargs.get("levels")
super().__init__(*args, **kwargs)
class Vector(Style):
"""A style for plotting vector data.
Parameters
----------
colors : str or list or matplotlib.colors.Colormap, optional
The colors to be used in this `Style`. This can be a named matplotlib
colormap, a list of colors (as named CSS4 colors, hexadecimal colors or
three (four)-element lists of RGB(A) values), or a pre-defined
matplotlib colormap object. If not provided, the default colormap of the
active `schema` will be used.
**kwargs
Additional keyword arguments to be passed to the vector methods.
"""
def __init__(self, *args, colors=None, **kwargs):
kwargs.setdefault("legend_style", "vector")
super().__init__(*args, colors=colors, **kwargs)
class Quiver(Vector):
def __init__(self, *args, colors=None, preferred_method="quiver", **kwargs):
kwargs.setdefault("legend_style", "vector")
super().__init__(*args, colors=colors, preferred_method=preferred_method, **kwargs)
class Contour(Style):
"""
A style for plotting contour data.
Parameters
----------
colors : str or list or matplotlib.colors.Colormap, optional
The colors to be used in this `Style`. This can be a named matplotlib
colormap, a list of colors (as named CSS4 colors, hexadecimal colors or
three (four)-element lists of RGB(A) values), or a pre-defined
matplotlib colormap object. If not provided, the default colormap of the
active `schema` will be used.
colors : str or list or matplotlib.colors.Colormap, optional
The colors to be used for contour lines. This can be a named matplotlib
colormap, a list of colors (as named CSS4 colors, hexadecimal colors or
three (four)-element lists of RGB(A) values), or a pre-defined
matplotlib colormap object. If not provided, the default colormap of the
active ``schema`` will be used.
labels : bool, optional
If `True`, then contour labels will be displayed.
label_kwargs : dict, optional
Additional keyword arguments to be passed to the `clabel` method.
interpolate : bool, optional
If `True`, then the data will be interpolated before plotting.
preferred_method : str, optional
The preferred method for plotting the data. Must be one of `contour`,
`contourf`, or `pcolormesh`.
**kwargs
Additional keyword arguments to be passed to the `contour` or `contourf`
method.
"""
def __init__(
self,
colors=schema.default_cmap,
labels=False,
label_kwargs=None,
interpolate=True,
preferred_method="contour",
**kwargs,
):
super().__init__(colors=colors, preferred_method=preferred_method, **kwargs)
self.labels = labels
self._label_kwargs = label_kwargs or dict()
self._interpolate = interpolate
def _get_config(self):
import copy
config = super()._get_config()
config["labels"] = self.labels
config["label_kwargs"] = copy.deepcopy(self._label_kwargs) if self._label_kwargs else None
config["interpolate"] = self._interpolate
return config
def plot(self, *args, **kwargs):
"""Plot the data using the `Style`'s defaults."""
if self._colors is not None:
if self._interpolate:
return self.contourf(*args, **kwargs)
else:
return self.pcolormesh(*args, **kwargs)
else:
return self.contour(*args, **kwargs)
def to_contour_kwargs(self, data):
"""
Generate `contour` arguments required for plotting data in this `Style`.
Parameters
----------
data : numpy.ndarray
The data to be plotted using this `Style`.
"""
levels = self.levels(data)
# Use cmap+norm path only when colors is a named colormap or Colormap object.
# Plain colour strings (e.g. "black", "#ff0000") and lists of colours are
# passed directly as colors= to ax.contour, matching matplotlib's own API.
colors_is_cmap = isinstance(self._colors, mpl.colors.Colormap) or (
isinstance(self._colors, str) and self._colors in mpl.colormaps
)
if self._colors is not None and colors_is_cmap and len(levels) > 1:
cmap, norm = styles.colors.cmap_and_norm(
self._colors,
levels,
self.normalize,
self.extend,
)
return {
**{"cmap": cmap, "norm": norm, "levels": levels},
**self._kwargs,
}
return {"levels": levels, "colors": self._colors, **self._kwargs}
def contourf(self, ax, x, y, values, *args, **kwargs):
"""
Plot shaded contours using this `Style`.
Parameters
----------
ax : matplotlib.axes.Axes
The axes on which to plot the data.
x : numpy.ndarray
The x coordinates of the data to be plotted.
y : numpy.ndarray
The y coordinates of the data to be plotted.
values : numpy.ndarray
The values of the data to be plotted.
**kwargs
Any additional arguments accepted by `matplotlib.axes.Axes.contourf`.
"""
mappable = super().contourf(ax, x, y, values, *args, **kwargs)
return mappable
def contour(self, *args, **kwargs):
"""
Plot line contours using this `Style`.
Parameters
----------
*args
The positional arguments to pass to the `contour` method.
**kwargs
The keyword arguments to pass to the `contour` method.
"""
mappable = super().contour(*args, **kwargs)
if self.labels:
self.contour_labels(mappable, **self._label_kwargs)
return mappable
def contour_labels(
self,
mappable,
label_fontsize=7,
label_colors=None,
label_frequency=1,
label_background=None,
label_fmt=None,
):
"""
Add labels to a contour plot.
Parameters
----------
mappable : matplotlib.contour.ContourSet
The contour plot to which to add labels.
label_fontsize : int, optional
The fontsize of the labels.
label_colors : str or list, optional
The colors of the labels.
label_frequency : int, optional
The frequency of contour levels at which to add labels.
label_background : str, optional
The background color of the labels.
label_fmt : str, optional
The string format of the labels.
"""
clabels = mappable.axes.clabel(
mappable,
mappable.levels[0::label_frequency],
inline=True,
fontsize=label_fontsize,
colors=label_colors,
fmt=label_fmt,
inline_spacing=2,
)
if label_background is not None:
for label in clabels:
label.set_backgroundcolor(label_background)
return clabels
class Hatched(Contour):
"""
A style for plotting hatched contours.
Parameters
----------
*args
The positional arguments to pass to the `Contour` constructor.
hatches : str, optional
The pattern of hatching to use.
background_colors : list, optional
The colors to use for the background of the hatched contours.
**kwargs
The keyword arguments to pass to the `Contour` constructor.
"""
def __init__(self, *args, hatches=".", background_colors=None, **kwargs):
super().__init__(*args, **kwargs)
self.hatches = hatches
self._foreground_colors = self._colors
self._colors = background_colors or [(0, 0, 0, 0)]
def __eq__(self, other):
keys = ["_levels", "_colors", "_foreground_colors", "hatches"]
return all([getattr(self, key, None) == getattr(other, key, None) for key in keys])
def contourf(self, *args, **kwargs):
"""
Plot hatched shaded contours using this `Style`.
Parameters
----------
*args
The positional arguments to pass to the `contourf` method.
**kwargs
The keyword arguments to pass to the `contourf` method.
"""
mappable = super().contourf(*args, hatches=self.hatches, **kwargs)
linecolors = colors.expand(self._foreground_colors, mappable.levels)
mappable.set_edgecolors(linecolors)
mappable.set_facecolors([(0, 0, 0, 0)] * len(mappable.levels))
mappable.set_linewidth(0)
return mappable
def colorbar(self, *args, **kwargs):
"""
Create a colorbar legend for this `Style`.
Parameters
----------
*args
The positional arguments to pass to the `colorbar` method.
**kwargs
The keyword arguments to pass to the `colorbar` method.
"""
colorbar = super().colorbar(*args, **kwargs)
levels = colorbar.mappable.levels
linecolors = colors.expand(self._foreground_colors, levels)
for i, artist in enumerate(colorbar.solids_patches):
artist.set_edgecolor(linecolors[i])
return colorbar
def disjoint(self, layer, *args, **kwargs):
"""
Create a disjoint legend for this `Style`.
Parameters
----------
layer : earthkit.maps.charts.layers.Layer
The layer for which to create a legend.
*args
The positional arguments to pass to the `dis
**kwargs
The keyword arguments to pass to the `disjoint` method.
"""
legend = super().disjoint(layer, *args, **kwargs)
linecolors = colors.expand(self._foreground_colors, layer.mappable.levels)
for color, artist in zip(linecolors, legend.get_patches()):
artist.set_edgecolor(color)
artist.set_linewidth(0.0)
return legend
DEFAULT_STYLE = Style()
DEFAULT_VECTOR_STYLE = Vector()
_STYLE_KWARGS = list(set(inspect.getfullargspec(Style)[0] + inspect.getfullargspec(Contour)[0]))
_OVERRIDE_KWARGS = ["labels"]
def compare_attributes(self, other, keys):
def is_equal(x, y):
is_x_arr = isinstance(x, np.ndarray)
is_y_arr = isinstance(y, np.ndarray)
# Check if both are numpy arrays
if is_x_arr and is_y_arr:
return np.array_equal(x, y)
# If one is an array and the other isn't, they are not equal
if is_x_arr != is_y_arr:
return False
# Default to standard equality check for non-array types
return x == y
# Use the is_equal function for each key in your check
try:
return all(is_equal(getattr(self, key, None), getattr(other, key, None)) for key in keys)
except ValueError:
return False
# Imported here (after DEFAULT_STYLE) to avoid circular imports, since auto.py
# imports from this module.
from earthkit.plots.styles.auto import list_styles, load_style # noqa: E402, F401
def get_style_class(method: str) -> type[Style]:
"""
Get the `Style` class for a given plotting method.
Parameters
----------
method : str
The plotting method for which to get the `Style` class for.
Returns
-------
type[Style]
The default `Style` class for the given plotting method.
"""
if method in ["quiver"]:
return Quiver
elif method in ["barbs"]:
return Vector
elif method in ["contour", "contourf"]:
return Contour
return Style