# Copyright 2024-, European Centre for Medium Range Weather Forecasts.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Iterable, Sequence
from pathlib import Path
from typing import Any
import yaml
from earthkit.plots import styles
from earthkit.plots._plugins import PLUGINS
from earthkit.plots.metadata.units import are_equal
from earthkit.plots.schemas import schema
METADATA = dict[str, Any | Sequence[Any]]
# Colormaps cycled through when no auto-style is found for a variable.
# Variables are assigned a cmap in the order they are first encountered.
_FALLBACK_CMAPS = ["plasma", "viridis", "pink", "copper"]
_fallback_cmap_assignments: dict[str, str] = {}
# One Style instance per variable name so that the same variable always maps to
# the same object (legend deduplication uses ==), while two different variables
# that happen to share a cmap are still distinct objects.
_fallback_style_cache: dict[str, "styles.Style"] = {}
_VariableFallbackStyle = None # built lazily after styles.Style is available
def clear_fallback_cache():
"""Clear the per-variable fallback style cache.
Call this after adding new identity/style YAML files so that variables
previously assigned a fallback colormap are re-evaluated against the
updated style library.
"""
_fallback_cmap_assignments.clear()
_fallback_style_cache.clear()
_cache.invalidate()
def _get_variable_fallback_style_class():
"""Return _VariableFallbackStyle, constructing it on first call.
Deferred to avoid a circular import: auto.py is imported by
styles/__init__.py before styles.Style is defined.
"""
global _VariableFallbackStyle
if _VariableFallbackStyle is None:
class _VFS(styles.Style):
"""Fallback Style that compares by identity so two variables sharing
the same cmap are never merged into one legend entry.
"""
def __eq__(self, other):
return self is other
def __hash__(self):
return id(self)
_VariableFallbackStyle = _VFS
return _VariableFallbackStyle
_METADATA_KEY_ALIASES = {
"long_name": ["long_name", "parameter.long_name"],
"short_name": ["short_name", "parameter.variable"],
"shortName": ["shortName", "parameter.variable"],
"standard_name": ["standard_name", "parameter.standard_name"],
"paramId": ["paramId", "parameter.id"],
}
def criteria_matches(data, criteria: METADATA) -> bool:
"""Test if the metadata matches the criteria."""
for key, value in criteria.items():
aliases = _METADATA_KEY_ALIASES.get(key, [key])
metadata_value = None
for alias in aliases:
metadata_value = data.metadata(alias, None)
if metadata_value is not None:
break
if metadata_value is None:
break
if all(map(lambda x: isinstance(x, Iterable), (metadata_value, value))):
if set(value) != set(metadata_value):
break
elif value != metadata_value:
break
else:
return True
return False
def _fallback_style(data):
"""
Return a :class:`~earthkit.plots.styles.Style` with a per-variable
fallback colormap. Variables are assigned a cmap from ``_FALLBACK_CMAPS``
in the order they are first seen; once the list is exhausted, cmaps cycle.
Variables that don't expose a recognisable name all share the last fallback
slot (same behaviour as the old ``DEFAULT_STYLE``).
Each variable gets its own cached :class:`_VariableFallbackStyle` instance.
Because that subclass compares by identity (not value), two variables that
happen to share a cmap are never incorrectly merged into the same legend.
"""
var_name = None
for attr in ("name", "short_name", "param"):
try:
val = data.metadata(attr, None)
if val:
var_name = str(val)
break
except Exception:
pass
if var_name is None:
return styles.DEFAULT_STYLE
if var_name not in _fallback_style_cache:
if var_name not in _fallback_cmap_assignments:
idx = len(_fallback_cmap_assignments) % len(_FALLBACK_CMAPS)
_fallback_cmap_assignments[var_name] = _FALLBACK_CMAPS[idx]
cls = _get_variable_fallback_style_class()
_fallback_style_cache[var_name] = cls(colors=_fallback_cmap_assignments[var_name])
return _fallback_style_cache[var_name]
# ---------------------------------------------------------------------------
# Style library cache
# ---------------------------------------------------------------------------
class _StyleLibraryCache:
"""
Lazy, invalidatable cache for the YAML-based style library.
Motivation
----------
``guess_style()`` and ``load_style()`` previously scanned all identity and
auto-style YAML files from disk on **every call**. With ~100 files in each
directory this adds measurable latency in notebooks where the same variable
is plotted many times.
This cache loads each plugin's YAML files exactly once per Python session
(or after an explicit :meth:`invalidate` call) and exposes fast in-memory
lookup methods.
Thread safety
-------------
The cache is populated in a single-threaded context (notebook / script) and
is never written to after ``_load()`` completes, so no locking is required.
"""
def __init__(self):
# Keyed by plugin name; each value is the paths dict from PLUGINS.
self._loaded_plugin: str | None = None
# List of (criteria_list, identity_id) pairs — order preserved so that
# the first match wins, exactly as the original glob loop did.
self._identities: list[tuple[list[dict], str]] = []
# identity_id → full style_config dict (contains "styles", "optimal", …)
self._style_configs: dict[str, dict] = {}
# name → style_dict for named styles (across ALL plugins, dedup by path).
# Loaded once independently of the active plugin (all plugins contribute).
self._named_styles: dict[str, dict] = {}
self._named_styles_loaded: bool = False
# ------------------------------------------------------------------
# Public interface
# ------------------------------------------------------------------
def invalidate(self):
"""Discard all cached data so the next access reloads from disk."""
self._loaded_plugin = None
self._identities.clear()
self._style_configs.clear()
self._named_styles.clear()
self._named_styles_loaded = False
def find_identity(self, data) -> str | None:
"""Return the identity id whose criteria first match *data*, or ``None``."""
self._ensure_loaded()
for criteria_list, identity_id in self._identities:
if any(criteria_matches(data, c) for c in criteria_list):
return identity_id
return None
def get_style_config(self, identity_id: str) -> dict | None:
"""Return the full style config dict for *identity_id*, or ``None``."""
self._ensure_loaded()
return self._style_configs.get(identity_id)
def get_named_style(self, name: str) -> dict | None:
"""Return the raw style dict for the given *name*, or ``None``."""
self._ensure_loaded_named_styles()
return self._named_styles.get(name)
def list_named_styles(self) -> list[str]:
"""Return a sorted list of all known named-style names."""
self._ensure_loaded_named_styles()
return sorted(self._named_styles)
# ------------------------------------------------------------------
# Internal loading
# ------------------------------------------------------------------
def _current_plugin_key(self) -> str:
"""Derive a cache key from the active style_library setting."""
return str(schema.style_library)
def _resolve_plugin_paths(self) -> tuple[Path, Path]:
"""Return (identities_path, styles_path) for the active plugin."""
if schema.style_library not in PLUGINS:
path = Path(schema.style_library).expanduser()
return path / "identities", path / "auto-styles"
plugin = PLUGINS[schema.style_library]
return plugin["identities"], plugin["styles"]
def _ensure_loaded(self):
"""Load identity + style-config data if the active plugin has changed."""
key = self._current_plugin_key()
if self._loaded_plugin == key:
return
self._identities.clear()
self._style_configs.clear()
identities_path, styles_path = self._resolve_plugin_paths()
self._load_identities(identities_path)
self._load_style_configs(styles_path)
self._loaded_plugin = key
def _ensure_loaded_named_styles(self):
"""Load named-style index across ALL plugins (done once per session)."""
if self._named_styles_loaded:
return
self._load_named_styles()
self._named_styles_loaded = True
def _load_identities(self, identities_path: Path):
if identities_path is None or not identities_path.is_dir():
return
for fpath in sorted(identities_path.iterdir()):
if not fpath.is_file():
continue
with fpath.open() as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
self._identities.append((config["criteria"], config["id"]))
def _load_style_configs(self, styles_path: Path):
if styles_path is None or not styles_path.is_dir():
return
for fpath in styles_path.iterdir():
if not fpath.is_file():
continue
with fpath.open() as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
self._style_configs[config["id"]] = config
def _load_named_styles(self):
"""Index named-style variants from every registered plugin directory."""
seen_paths: set[str] = set()
for plugin_paths in PLUGINS.values():
styles_path = plugin_paths["styles"]
if styles_path is None or not styles_path.is_dir():
continue
for fpath in sorted(styles_path.iterdir()):
fpath_str = str(fpath)
if not fpath.is_file() or fpath_str in seen_paths:
continue
seen_paths.add(fpath_str)
with fpath.open() as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
for style_dict in config.get("styles", {}).values():
name = style_dict.get("name")
if name and name not in self._named_styles:
self._named_styles[name] = style_dict
def _load_named_styles_from(self, styles_path: Path, seen_paths: set[str] | None = None):
"""
Index named-style variants from a single *styles_path* directory.
This is a test helper — production code goes through :meth:`_load_named_styles`
which handles all plugins. Tests call this directly to load from an
isolated ``tmp_path`` directory without touching the real PLUGINS registry.
Parameters
----------
styles_path:
Directory containing auto-style YAML files.
seen_paths:
Optional deduplication set shared across multiple calls (e.g. when
a test simulates multiple plugin directories).
"""
if styles_path is None or not styles_path.is_dir():
return
for fpath in sorted(styles_path.iterdir()):
fpath_str = str(fpath)
if not fpath.is_file():
continue
if seen_paths is not None:
if fpath_str in seen_paths:
continue
seen_paths.add(fpath_str)
with fpath.open() as f:
config = yaml.load(f, Loader=yaml.SafeLoader)
for style_dict in config.get("styles", {}).values():
name = style_dict.get("name")
if name and name not in self._named_styles:
self._named_styles[name] = style_dict
# Module-level singleton — shared across all callers in the same process.
_cache = _StyleLibraryCache()
def _select_style_variant(style_variants: dict, target_units: str | None, source_units: str | None) -> dict | None:
"""
Pick the best matching style variant dict for the given units.
Selection priority:
1. Exact match on *target_units*
2. Exact match on *source_units*
3. Any variant that carries no ``units`` key (unit-agnostic)
4. ``None`` — caller should fall back to the default style
Parameters
----------
style_variants:
Mapping of variant key → style dict, as stored in the YAML
``styles:`` block.
target_units:
The units the caller wants to plot in (may be ``None``).
source_units:
The native units of the data (may be ``None``).
Returns
-------
dict or None
The chosen style variant dict, or ``None`` if no match was found.
"""
candidates = list(style_variants.values())
for style in candidates:
if are_equal(style.get("units"), target_units):
return style
for style in candidates:
if are_equal(style.get("units"), source_units):
return style
for style in candidates:
if "units" not in style:
return style
return None
def guess_style(data, units=None, **kwargs):
"""
Guess the style to be applied to the data based on its metadata.
The style is guessed by comparing the metadata of the data to the identities
and styles in the style library. The first identity that matches the metadata
is used to select the style. If the style library is not set or no identity
matches the metadata, the default style is returned.
Parameters
----------
data : earthkit.plots.sources.Source
The data object containing the metadata.
units : str, optional
The target units of the plot. If these do not match the units of the
data, the data will be converted to the target units and the style
will be adjusted accordingly.
"""
# Use the source's native units (before any conversion) to pick the style
# variant. The caller-supplied `units` is the *target* units and is used
# below to select the matching style variant and to label the colorbar.
source_units = data.source_units
if units is None:
units = source_units
if not schema.automatic_styles or schema.style_library is None:
return styles.DEFAULT_STYLE
identity = _cache.find_identity(data)
if identity is None:
return _fallback_style(data)
style_config = _cache.get_style_config(identity)
if style_config is None:
return _fallback_style(data)
style_variants = style_config["styles"]
if schema.use_preferred_units:
style = style_variants[style_config["optimal"]]
else:
style = _select_style_variant(style_variants, units, source_units)
if style is None:
return _fallback_style(data)
# If the caller requested specific target units that differ from the
# style variant's own units, override so the colorbar label reflects
# the actual plotted units (unit conversion is handled by Source).
if units is not None and not are_equal(units, style.get("units")):
kwargs.setdefault("units", units)
return styles.Style.from_dict({**style, **kwargs})
[docs]
def load_style(name, **kwargs):
"""
Load a named style by its user-facing name.
Style names are defined in the ``name`` field of each style variant in the
auto-styles YAML files (e.g. ``temperature-2m-turbo-celsius``). The full
list of available names can be retrieved with :func:`list_styles`.
Parameters
----------
name : str
The name of the style to load, as shown in the styles gallery.
**kwargs
Additional keyword arguments passed to the ``Style`` constructor,
allowing individual parameters to be overridden.
Returns
-------
earthkit.plots.styles.Style
The instantiated style object.
Raises
------
KeyError
If no style with the given name is found in any registered style
library.
Examples
--------
>>> import earthkit.plots
>>> style = earthkit.plots.styles.load_style("temperature-2m-turbo-celsius")
>>> chart.contourf(data, style=style)
"""
style_dict = _cache.get_named_style(name)
if style_dict is not None:
return styles.Style.from_dict({**style_dict, **kwargs})
raise KeyError(f"No style named {name!r}. Available styles: {list_styles()}")
[docs]
def list_styles() -> list[str]:
"""
Return a sorted list of all available named style names.
These names can be passed to :func:`load_style` or used directly as the
``style`` parameter in any plotting method.
Returns
-------
list of str
Examples
--------
>>> import earthkit.plots
>>> earthkit.plots.list_styles()
['mslp-contour-hpa', 'mslp-contour-pa', 'precipitation-turbo-kg-m2', ...]
"""
return _cache.list_named_styles()