{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ArrayIndex\n", "\n", "An Xarray-compatible, cheap index for basic indexing and alignment operations, to use when a `pandas.Index` is not really needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence\n", "\n", "import numpy as np\n", "import xarray as xr\n", "\n", "from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar\n", "from xarray.core.indexing import IndexSelResult\n", "from xarray.core import nputils\n", "from xarray.core.variable import Variable, IndexVariable\n", "\n", "#if TYPE_CHECKING:\n", "from xarray.core.types import JoinOptions, T_Index\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class ArrayIndex(Index):\n", " \"\"\"Numpy-like array index.\n", " \n", " Lightweight, inefficient index as a basic wrapper around\n", " its coordinate array data.\n", " \n", " This index is suited for cases where index build overhead\n", " is an issue and where only basic indexing operations are\n", " needed (i.e., strict alignment, data selection in rare occasions).\n", " \n", " \"\"\"\n", " array: np.ndarray\n", " dim: Hashable\n", " name: Hashable\n", " \n", " # cause AttributeError with `da + da` example below?????\n", " #__slots__ = (\"array\", \"dim\", \"name\")\n", " \n", " def __init__(self, array, dim, name):\n", " if array.ndim > 1:\n", " raise ValueError(\"ArrayIndex only accepts 1-dimensional arrays\")\n", "\n", " self.array = array\n", " self.dim = dim\n", " self.name = name\n", " \n", " @classmethod\n", " def from_variables(\n", " cls: type[T_Index], variables: Mapping[Any, Variable], options\n", " ):\n", " if len(variables) != 1:\n", " raise ValueError(\n", " f\"PandasIndex only accepts one variable, found {len(variables)} variables\"\n", " )\n", "\n", " name, var = next(iter(variables.items()))\n", " \n", " # TODO: use `var.data` instead? (allow lazy/duck arrays)\n", " return cls(var.values, var.dims[0], name)\n", " \n", " @classmethod\n", " def concat(\n", " cls: type[T_Index],\n", " indexes: Sequence[T_Index],\n", " dim: Hashable,\n", " positions: Iterable[Iterable[int]] = None,\n", " ) -> T_Index:\n", " if not indexes:\n", " return cls(np.array([]), dim, dim)\n", " \n", " if not all(idx.dim == dim for idx in indexes):\n", " dims = \",\".join({f\"{idx.dim!r}\" for idx in indexes})\n", " raise ValueError(\n", " f\"Cannot concatenate along dimension {dim!r} indexes with \"\n", " f\"dimensions: {dims}\"\n", " )\n", " \n", " arrays = [idx.array for idx in indexes]\n", " new_array = np.concatenate(arrays)\n", " \n", " if positions is not None:\n", " indices = nputils.inverse_permutation(np.concatenate(positions))\n", " new_array = new_array.take(indices)\n", "\n", " return cls(new_array, dim, indexes[0].name)\n", " \n", " def create_variables(\n", " self, variables: Mapping[Any, Variable] | None = None\n", " ) -> IndexVars:\n", " \n", " #\n", " # TODO: implementation is needed so that the corresponding\n", " # coordinate is indexed properly with Dataset.isel.\n", " # Ideally this shouldn't be needed, though.\n", " #\n", " \n", " if variables is not None and self.name in variables:\n", " var = variables[self.name]\n", " attrs = var.attrs\n", " encoding = var.encoding\n", " else:\n", " attrs = None\n", " encoding = None\n", "\n", " var = Variable(self.dim, self.array, attrs=attrs, encoding=encoding)\n", " return {self.name: var}\n", " \n", " def isel(\n", " self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]\n", " ) -> T_Index | PandasIndex | None:\n", " indxr = indexers[self.dim]\n", "\n", " if isinstance(indxr, Variable):\n", " if indxr.dims != (self.dim,):\n", " # can't preserve a index if result has new dimensions\n", " return None\n", " else:\n", " indxr = indxr.data\n", " if not isinstance(indxr, slice) and is_scalar(indxr):\n", " # scalar indexer: drop index\n", " return None\n", " \n", " return type(self)(self.array[indxr], self.dim, self.name)\n", "\n", " def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:\n", " assert len(labels) == 1\n", " _, label = next(iter(labels.items()))\n", " \n", " if isinstance(label, slice):\n", " # TODO: what exactly do we want to do here?\n", " start = np.argmax(self.array == label.start)\n", " stop = np.argmax(self.array == label.stop)\n", " indexer = slice(start, stop)\n", " elif is_scalar(label):\n", " indexer = np.argmax(self.array == label)\n", " else:\n", " # TODO: other label types we want to support (n-d array-like, etc.)\n", " raise ValueError(f\"label {label} not supported by ArrayIndex\")\n", " \n", " return IndexSelResult({self.dim: indexer})\n", " \n", " def equals(self: T_Index, other: T_Index) -> bool:\n", " return np.array_equal(self.array, other.array)\n", " \n", " def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index:\n", " shift = shifts[self.dim]\n", " \n", " return type(self)(np.roll(self.array, shift), self.dim, self.name)\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "da = xr.DataArray(\n", " np.random.uniform(size=4),\n", " coords={\"x\": [2, 3, 4, 5]},\n", " dims=\"x\",\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
<xarray.DataArray (x: 4)>\n",
"array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])\n",
"Coordinates:\n",
" * x (x) int64 2 3 4 5<xarray.DataArray ()>\n",
"array(0.5972036)\n",
"Coordinates:\n",
" x int64 2<xarray.DataArray (x: 2)>\n",
"array([0.5972036 , 0.36977134])\n",
"Coordinates:\n",
" * x (x) int64 2 3<xarray.DataArray (x: 8)>\n",
"array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,\n",
" 0.36977134, 0.23513491, 0.61414618])\n",
"Coordinates:\n",
" * x (x) int64 2 3 4 5 2 3 4 5<xarray.DataArray (x: 4)>\n",
"array([0.23513491, 0.61414618, 0.5972036 , 0.36977134])\n",
"Coordinates:\n",
" * x (x) int64 4 5 2 3<xarray.DataArray (x: 4)>\n",
"array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])\n",
"Coordinates:\n",
" * x (x) int64 2 3 4 5