{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MultiPandasIndex\n", "\n", "In many cases an Xarray custom index may be built on top of one or more `PandasIndex` instances. This notebook provides a helper class `MultiPandasIndex` with all the boilerplate, i.e., for each method the input arguments are deferred / dispatched to the encapsulated `PandasIndex` instances." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar\n", "from xarray.core.indexing import IndexSelResult, merge_sel_results\n", "from xarray.core.utils import Frozen\n", "from xarray.core.variable import Variable\n", "\n", "#if TYPE_CHECKING:\n", "from xarray.core.types import ErrorOptions, JoinOptions, T_Index\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class MultiPandasIndex(Index):\n", " \"\"\"Helper class to implement meta-indexes encapsulating\n", " one or more (single) pandas indexes.\n", " \n", " Each pandas index must relate to a separate dimension.\n", " \n", " This class shoudn't be instantiated directly.\n", "\n", " \"\"\"\n", " indexes: Frozen[Hashable, PandasIndex]\n", " dims: Frozen[Hashable, int]\n", " \n", " __slots__ = (\"indexes\", \"dims\")\n", " \n", " def __init__(self, indexes: Mapping[Hashable, PandasIndex]):\n", " dims = {idx.dim: idx.index.size for idx in indexes.values()}\n", " \n", " seen = set()\n", " dup_dims = [d for d in dims if d in seen or seen.add(d)]\n", " if dup_dims:\n", " raise ValueError(\n", " f\"cannot create a {self.__class__.__name__} from coordinates \"\n", " f\"sharing common dimension(s): {dup_dims}\"\n", " )\n", " \n", " self.indexes = Frozen(indexes)\n", " self.dims = Frozen(dims)\n", " \n", " @classmethod\n", " def from_variables(\n", " cls: type[T_Index], variables: Mapping[Any, Variable], options\n", " ):\n", " indexes = {\n", " k: PandasIndex.from_variables({k: v}, options={})\n", " for k, v in variables.items()\n", " }\n", "\n", " return cls(indexes)\n", " \n", " @classmethod\n", " def concat(\n", " cls: type[T_Index],\n", " indexes: Sequence[T_Index],\n", " dim: Hashable,\n", " positions: Iterable[Iterable[int]] = None,\n", " ) -> T_Index:\n", " new_indexes = {}\n", " \n", " for k, idx in self.indexes.items():\n", " if idx.dim == dim:\n", " new_indexes[k] = PandasIndex.concat(indexes, dim, positions)\n", " else:\n", " new_indexes[k] = idx\n", " \n", " return cls(new_indexes)\n", " \n", " def create_variables(\n", " self, variables: Mapping[Any, Variable] | None = None\n", " ) -> IndexVars:\n", "\n", " idx_variables = {}\n", "\n", " for idx in self.indexes.values():\n", " idx_variables.update(idx.create_variables(variables))\n", "\n", " return idx_variables\n", " \n", " def isel(\n", " self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]\n", " ) -> T_Index | PandasIndex | None:\n", " new_indexes = {}\n", " \n", " for k, idx in self.indexes.items():\n", " if k in indexers:\n", " new_idx = idx.isel({k: indexers[k]})\n", " if new_idx is not None:\n", " new_indexes[k] = new_idx\n", " else:\n", " new_indexes[k] = idx\n", " \n", " #\n", " # How should we deal with dropped index(es) (scalar selection)?\n", " # - drop the whole index?\n", " # - always return a MultiPandasIndex with remaining index(es)?\n", " # - return either a MultiPandasIndex or a PandasIndex?\n", " #\n", " \n", " if not len(new_indexes):\n", " return None\n", " elif len(new_indexes) == 1:\n", " return next(iter(new_indexes.values()))\n", " else:\n", " return type(self)(new_indexes)\n", "\n", " def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:\n", " results: list[IndexSelResult] = []\n", "\n", " for k, idx in self.indexes.items():\n", " if k in labels:\n", " results.append(idx.sel({k: labels[k]}, **kwargs))\n", " \n", " return merge_sel_results(results)\n", " \n", " def _get_unmatched_names(self: T_Index, other: T_Index) -> set:\n", " return set(self.indexes).symmetric_difference(other.indexes)\n", " \n", " def equals(self: T_Index, other: T_Index) -> bool:\n", " # We probably don't need to check for matching coordinate names\n", " # as this is already done during alignment when finding matching indexes.\n", " # This may change in the future, though.\n", " # see https://github.com/pydata/xarray/issues/7002\n", " if self._get_unmatched_names(other):\n", " return False\n", " else:\n", " return all(\n", " [idx.equals(other.indexes[k]) for k, idx in self.indexes.items()]\n", " )\n", " \n", " def join(self: T_Index, other: T_Index, how: JoinOptions = \"inner\") -> T_Index:\n", " new_indexes = {}\n", "\n", " for k, idx in self.indexes.items():\n", " new_indexes[k] = idx.join(other.indexes[k], how=how)\n", " \n", " return type(self)(new_indexes)\n", " \n", " def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:\n", " dim_indexers = {}\n", " \n", " for k, idx in self.indexes.items():\n", " dim_indexers.update(idx.reindex_like(other.indexes[k]))\n", " \n", " return dim_indexers\n", " \n", " def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index:\n", " new_indexes = {}\n", " \n", " for k, idx in self.indexes.items():\n", " if k in shifts:\n", " new_indexes[k] = idx.roll({k: shifts[k]})\n", " else:\n", " new_indexes[k] = idx\n", "\n", " return type(self)(new_indexes)\n", " \n", " def rename(\n", " self: T_Index,\n", " name_dict: Mapping[Any, Hashable],\n", " dims_dict: Mapping[Any, Hashable],\n", " ) -> T_Index:\n", " new_indexes = {}\n", " \n", " for k, idx in self.indexes.items():\n", " new_indexes[k] = idx.rename(name_dict, dims_dict)\n", " \n", " return type(self)(new_indexes)\n", " \n", " def copy(self: T_Index, deep: bool = True) -> T_Index:\n", " new_indexes = {}\n", " \n", " for k, idx in self.indexes.items():\n", " new_indexes[k] = idx.copy(deep=deep)\n", " \n", " return type(self)(new_indexes)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Issues:\n", " \n", "- How to allow custom `__init__` options in subclasses be passed to all the `type(self)(new_indexes)` calls inside the `MultiPandasIndex` \"base\" class? This could be done via `**kwargs` passed through... However, mypy will certainly complain (Liskov Substitution Principle)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example\n", "\n", "Just to see if it works well. `MultiPandasIndex` shouldn't be used directly in a DataArray or Dataset." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "da = xr.DataArray(\n", " np.random.uniform(size=(4, 5)),\n", " coords={\"x\": range(5), \"y\": range(4)},\n", " dims=(\"y\", \"x\"),\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
<xarray.DataArray (y: 4, x: 5)>\n",
"array([[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],\n",
" [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],\n",
" [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],\n",
" [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]])\n",
"Coordinates:\n",
" * x (x) int64 0 1 2 3 4\n",
" * y (y) int64 0 1 2 3<xarray.DataArray (y: 2, x: 2)>\n",
"array([[0.43947939, 0.43947939],\n",
" [0.91366429, 0.91366429]])\n",
"Coordinates:\n",
" * x (x) int64 0 0\n",
" * y (y) int64 0 2<xarray.DataArray (y: 4)>\n",
"array([0.43947939, 0.75214201, 0.91366429, 0.95073961])\n",
"Coordinates:\n",
" * x int64 0\n",
" * y (y) int64 0 1 2 3<xarray.DataArray (y: 4, x: 5)>\n",
"array([[0.74263207, 0.60629903, 0.75214201, 0.22178014, 0.0969697 ],\n",
" [0.50972423, 0.3037911 , 0.91366429, 0.25963693, 0.20251133],\n",
" [0.81686866, 0.56483109, 0.95073961, 0.31579758, 0.04704333],\n",
" [0.99212782, 0.83624422, 0.43947939, 0.87899004, 0.76420298]])\n",
"Coordinates:\n",
" * x (x) int64 3 4 0 1 2\n",
" * y (y) int64 1 2 3 0<xarray.DataArray (variable: 1, y: 4, z: 5)>\n",
"array([[[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],\n",
" [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],\n",
" [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],\n",
" [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]]])\n",
"Coordinates:\n",
" * z (z) int64 0 1 2 3 4\n",
" * y (y) int64 0 1 2 3\n",
" * variable (variable) object 'foo'