{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ArrayIndex\n", "\n", "An Xarray-compatible, cheap index for basic indexing and alignment operations, to use when a `pandas.Index` is not really needed." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import annotations\n", "\n", "from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence\n", "\n", "import numpy as np\n", "import xarray as xr\n", "\n", "from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar\n", "from xarray.core.indexing import IndexSelResult\n", "from xarray.core import nputils\n", "from xarray.core.variable import Variable, IndexVariable\n", "\n", "#if TYPE_CHECKING:\n", "from xarray.core.types import JoinOptions, T_Index\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class ArrayIndex(Index):\n", " \"\"\"Numpy-like array index.\n", " \n", " Lightweight, inefficient index as a basic wrapper around\n", " its coordinate array data.\n", " \n", " This index is suited for cases where index build overhead\n", " is an issue and where only basic indexing operations are\n", " needed (i.e., strict alignment, data selection in rare occasions).\n", " \n", " \"\"\"\n", " array: np.ndarray\n", " dim: Hashable\n", " name: Hashable\n", " \n", " # cause AttributeError with `da + da` example below?????\n", " #__slots__ = (\"array\", \"dim\", \"name\")\n", " \n", " def __init__(self, array, dim, name):\n", " if array.ndim > 1:\n", " raise ValueError(\"ArrayIndex only accepts 1-dimensional arrays\")\n", "\n", " self.array = array\n", " self.dim = dim\n", " self.name = name\n", " \n", " @classmethod\n", " def from_variables(\n", " cls: type[T_Index], variables: Mapping[Any, Variable], options\n", " ):\n", " if len(variables) != 1:\n", " raise ValueError(\n", " f\"PandasIndex only accepts one variable, found {len(variables)} variables\"\n", " )\n", "\n", " name, var = next(iter(variables.items()))\n", " \n", " # TODO: use `var.data` instead? (allow lazy/duck arrays)\n", " return cls(var.values, var.dims[0], name)\n", " \n", " @classmethod\n", " def concat(\n", " cls: type[T_Index],\n", " indexes: Sequence[T_Index],\n", " dim: Hashable,\n", " positions: Iterable[Iterable[int]] = None,\n", " ) -> T_Index:\n", " if not indexes:\n", " return cls(np.array([]), dim, dim)\n", " \n", " if not all(idx.dim == dim for idx in indexes):\n", " dims = \",\".join({f\"{idx.dim!r}\" for idx in indexes})\n", " raise ValueError(\n", " f\"Cannot concatenate along dimension {dim!r} indexes with \"\n", " f\"dimensions: {dims}\"\n", " )\n", " \n", " arrays = [idx.array for idx in indexes]\n", " new_array = np.concatenate(arrays)\n", " \n", " if positions is not None:\n", " indices = nputils.inverse_permutation(np.concatenate(positions))\n", " new_array = new_array.take(indices)\n", "\n", " return cls(new_array, dim, indexes[0].name)\n", " \n", " def create_variables(\n", " self, variables: Mapping[Any, Variable] | None = None\n", " ) -> IndexVars:\n", " \n", " #\n", " # TODO: implementation is needed so that the corresponding\n", " # coordinate is indexed properly with Dataset.isel.\n", " # Ideally this shouldn't be needed, though.\n", " #\n", " \n", " if variables is not None and self.name in variables:\n", " var = variables[self.name]\n", " attrs = var.attrs\n", " encoding = var.encoding\n", " else:\n", " attrs = None\n", " encoding = None\n", "\n", " var = Variable(self.dim, self.array, attrs=attrs, encoding=encoding)\n", " return {self.name: var}\n", " \n", " def isel(\n", " self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]\n", " ) -> T_Index | PandasIndex | None:\n", " indxr = indexers[self.dim]\n", "\n", " if isinstance(indxr, Variable):\n", " if indxr.dims != (self.dim,):\n", " # can't preserve a index if result has new dimensions\n", " return None\n", " else:\n", " indxr = indxr.data\n", " if not isinstance(indxr, slice) and is_scalar(indxr):\n", " # scalar indexer: drop index\n", " return None\n", " \n", " return type(self)(self.array[indxr], self.dim, self.name)\n", "\n", " def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:\n", " assert len(labels) == 1\n", " _, label = next(iter(labels.items()))\n", " \n", " if isinstance(label, slice):\n", " # TODO: what exactly do we want to do here?\n", " start = np.argmax(self.array == label.start)\n", " stop = np.argmax(self.array == label.stop)\n", " indexer = slice(start, stop)\n", " elif is_scalar(label):\n", " indexer = np.argmax(self.array == label)\n", " else:\n", " # TODO: other label types we want to support (n-d array-like, etc.)\n", " raise ValueError(f\"label {label} not supported by ArrayIndex\")\n", " \n", " return IndexSelResult({self.dim: indexer})\n", " \n", " def equals(self: T_Index, other: T_Index) -> bool:\n", " return np.array_equal(self.array, other.array)\n", " \n", " def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index:\n", " shift = shifts[self.dim]\n", " \n", " return type(self)(np.roll(self.array, shift), self.dim, self.name)\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "da = xr.DataArray(\n", " np.random.uniform(size=4),\n", " coords={\"x\": [2, 3, 4, 5]},\n", " dims=\"x\",\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 4)>\n",
       "array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])\n",
       "Coordinates:\n",
       "  * x        (x) int64 2 3 4 5
" ], "text/plain": [ "\n", "array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])\n", "Coordinates:\n", " * x (x) int64 2 3 4 5" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da = da.drop_indexes(\"x\").set_xindex(\"x\", ArrayIndex)\n", "\n", "da" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", "x: <__main__.ArrayIndex object at 0x16c28ee80>" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da.xindexes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## sel / isel" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray ()>\n",
       "array(0.5972036)\n",
       "Coordinates:\n",
       "    x        int64 2
" ], "text/plain": [ "\n", "array(0.5972036)\n", "Coordinates:\n", " x int64 2" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_sel = da.sel(x=2)\n", "\n", "da_sel" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 2)>\n",
       "array([0.5972036 , 0.36977134])\n",
       "Coordinates:\n",
       "  * x        (x) int64 2 3
" ], "text/plain": [ "\n", "array([0.5972036 , 0.36977134])\n", "Coordinates:\n", " * x (x) int64 2 3" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_sel = da.sel(x=slice(2, 4))\n", "\n", "da_sel" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", "x: <__main__.ArrayIndex object at 0x16c294040>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_sel.xindexes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## concat" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 8)>\n",
       "array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,\n",
       "       0.36977134, 0.23513491, 0.61414618])\n",
       "Coordinates:\n",
       "  * x        (x) int64 2 3 4 5 2 3 4 5
" ], "text/plain": [ "\n", "array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,\n", " 0.36977134, 0.23513491, 0.61414618])\n", "Coordinates:\n", " * x (x) int64 2 3 4 5 2 3 4 5" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_concat = xr.concat([da, da], \"x\")\n", "\n", "da_concat" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", "x: <__main__.ArrayIndex object at 0x16c294af0>" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_concat.xindexes" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## roll" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 4)>\n",
       "array([0.23513491, 0.61414618, 0.5972036 , 0.36977134])\n",
       "Coordinates:\n",
       "  * x        (x) int64 4 5 2 3
" ], "text/plain": [ "\n", "array([0.23513491, 0.61414618, 0.5972036 , 0.36977134])\n", "Coordinates:\n", " * x (x) int64 4 5 2 3" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da.roll({\"x\": 2}, roll_coords=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## align\n", "\n", "Only exact alignment is supported." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(\n", " array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])\n", " Coordinates:\n", " * x (x) int64 2 3 4 5,\n", " \n", " array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])\n", " Coordinates:\n", " * x (x) int64 2 3 4 5)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xr.align(da, da * 2, join=\"exact\")" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 4)>\n",
       "array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])\n",
       "Coordinates:\n",
       "  * x        (x) int64 2 3 4 5
" ], "text/plain": [ "\n", "array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])\n", "Coordinates:\n", " * x (x) int64 2 3 4 5" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da + da" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "ename": "NotImplementedError", "evalue": "<__main__.ArrayIndex object at 0x16c28ee80> doesn't support alignment with inner/outer join method", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNotImplementedError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/var/folders/xd/3ls911kd6_n2wphwwd74b1dc0000gn/T/ipykernel_23878/2620929422.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mda\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mda\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"x\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"inner\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;32m~/Git/github/benbovy/xarray/xarray/core/alignment.py\u001b[0m in \u001b[0;36malign\u001b[0;34m(join, copy, indexes, exclude, fill_value, *objects)\u001b[0m\n\u001b[1;32m 762\u001b[0m \u001b[0mfill_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfill_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 763\u001b[0m )\n\u001b[0;32m--> 764\u001b[0;31m \u001b[0maligner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malign\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 765\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0maligner\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Git/github/benbovy/xarray/xarray/core/alignment.py\u001b[0m in \u001b[0;36malign\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfind_matching_unindexed_dims\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_no_index_conflict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malign_indexes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0massert_unindexed_dim_sizes_equal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Git/github/benbovy/xarray/xarray/core/alignment.py\u001b[0m in \u001b[0;36malign_indexes\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 402\u001b[0m )\n\u001b[1;32m 403\u001b[0m \u001b[0mjoiner\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_index_joiner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mindex_cls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 404\u001b[0;31m \u001b[0mjoined_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjoiner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmatching_indexes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 405\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"left\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 406\u001b[0m \u001b[0mjoined_index_vars\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmatching_index_vars\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/Git/github/benbovy/xarray/xarray/core/indexes.py\u001b[0m in \u001b[0;36mjoin\u001b[0;34m(self, other, how)\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 94\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mT_Index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mT_Index\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhow\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"inner\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mT_Index\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 95\u001b[0;31m raise NotImplementedError(\n\u001b[0m\u001b[1;32m 96\u001b[0m \u001b[0;34mf\"{self!r} doesn't support alignment with inner/outer join method\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 97\u001b[0m )\n", "\u001b[0;31mNotImplementedError\u001b[0m: <__main__.ArrayIndex object at 0x16c28ee80> doesn't support alignment with inner/outer join method" ] } ], "source": [ "xr.align(da, xr.concat([da, da], \"x\"), join=\"inner\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare Index build overhead PandasIndex vs. ArrayIndex" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "data = np.random.permutation(np.arange(10_000_000))\n", "ds = xr.Dataset(coords={\"x\": data}).drop_indexes(\"x\")" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "23.5 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "308 ns ± 21.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)\n" ] } ], "source": [ "%timeit PandasIndex(data, \"x\")\n", "%timeit ArrayIndex(data, \"x\", \"x\")" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "42.5 µs ± 884 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n", "18.5 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n" ] } ], "source": [ "# why do we get a ratio here that is not as good as for the index constructors?\n", "# ArrayIndex.create_variables() likely adds overhead.\n", "# TODO: profile this!\n", "\n", "%timeit ds.set_xindex(\"x\", PandasIndex)\n", "%timeit ds.set_xindex(\"x\", ArrayIndex)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:xarray_dev]", "language": "python", "name": "conda-env-xarray_dev-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" } }, "nbformat": 4, "nbformat_minor": 4 }