ArrayIndex¶
An Xarray-compatible, cheap index for basic indexing and alignment operations, to use when a pandas.Index is not really needed.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | from __future__ import annotations from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence import numpy as np import xarray as xr from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar from xarray.core.indexing import IndexSelResult from xarray.core import nputils from xarray.core.variable import Variable, IndexVariable #if TYPE_CHECKING: from xarray.core.types import JoinOptions, T_Index |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | class ArrayIndex(Index): """Numpy-like array index. Lightweight, inefficient index as a basic wrapper around its coordinate array data. This index is suited for cases where index build overhead is an issue and where only basic indexing operations are needed (i.e., strict alignment, data selection in rare occasions). """ array: np.ndarray dim: Hashable name: Hashable # cause AttributeError with `da + da` example below????? #__slots__ = ("array", "dim", "name") def __init__(self, array, dim, name): if array.ndim > 1: raise ValueError("ArrayIndex only accepts 1-dimensional arrays") self.array = array self.dim = dim self.name = name @classmethod def from_variables( cls: type[T_Index], variables: Mapping[Any, Variable], options ): if len(variables) != 1: raise ValueError( f"PandasIndex only accepts one variable, found {len(variables)} variables" ) name, var = next(iter(variables.items())) # TODO: use `var.data` instead? (allow lazy/duck arrays) return cls(var.values, var.dims[0], name) @classmethod def concat( cls: type[T_Index], indexes: Sequence[T_Index], dim: Hashable, positions: Iterable[Iterable[int]] = None, ) -> T_Index: if not indexes: return cls(np.array([]), dim, dim) if not all(idx.dim == dim for idx in indexes): dims = ",".join({f"{idx.dim!r}" for idx in indexes}) raise ValueError( f"Cannot concatenate along dimension {dim!r} indexes with " f"dimensions: {dims}" ) arrays = [idx.array for idx in indexes] new_array = np.concatenate(arrays) if positions is not None: indices = nputils.inverse_permutation(np.concatenate(positions)) new_array = new_array.take(indices) return cls(new_array, dim, indexes[0].name) def create_variables( self, variables: Mapping[Any, Variable] | None = None ) -> IndexVars: # # TODO: implementation is needed so that the corresponding # coordinate is indexed properly with Dataset.isel. # Ideally this shouldn't be needed, though. # if variables is not None and self.name in variables: var = variables[self.name] attrs = var.attrs encoding = var.encoding else: attrs = None encoding = None var = Variable(self.dim, self.array, attrs=attrs, encoding=encoding) return {self.name: var} def isel( self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable] ) -> T_Index | PandasIndex | None: indxr = indexers[self.dim] if isinstance(indxr, Variable): if indxr.dims != (self.dim,): # can't preserve a index if result has new dimensions return None else: indxr = indxr.data if not isinstance(indxr, slice) and is_scalar(indxr): # scalar indexer: drop index return None return type(self)(self.array[indxr], self.dim, self.name) def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult: assert len(labels) == 1 _, label = next(iter(labels.items())) if isinstance(label, slice): # TODO: what exactly do we want to do here? start = np.argmax(self.array == label.start) stop = np.argmax(self.array == label.stop) indexer = slice(start, stop) elif is_scalar(label): indexer = np.argmax(self.array == label) else: # TODO: other label types we want to support (n-d array-like, etc.) raise ValueError(f"label {label} not supported by ArrayIndex") return IndexSelResult({self.dim: indexer}) def equals(self: T_Index, other: T_Index) -> bool: return np.array_equal(self.array, other.array) def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index: shift = shifts[self.dim] return type(self)(np.roll(self.array, shift), self.dim, self.name) |
1 |
Example¶
1 2 3 4 5 | da = xr.DataArray( np.random.uniform(size=4), coords={"x": [2, 3, 4, 5]}, dims="x", ) |
1 2 3 | da = da.drop_indexes("x").set_xindex("x", ArrayIndex) da |
<xarray.DataArray (x: 4)> array([0.5972036 , 0.36977134, 0.23513491, 0.61414618]) Coordinates: * x (x) int64 2 3 4 5
1 | da.xindexes |
Indexes: x: <__main__.ArrayIndex object at 0x16c28ee80>
sel / isel¶
1 2 3 | da_sel = da.sel(x=2) da_sel |
<xarray.DataArray ()>
array(0.5972036)
Coordinates:
x int64 21 2 3 | da_sel = da.sel(x=slice(2, 4)) da_sel |
<xarray.DataArray (x: 2)> array([0.5972036 , 0.36977134]) Coordinates: * x (x) int64 2 3
1 | da_sel.xindexes |
Indexes: x: <__main__.ArrayIndex object at 0x16c294040>
concat¶
1 2 3 | da_concat = xr.concat([da, da], "x") da_concat |
<xarray.DataArray (x: 8)>
array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,
0.36977134, 0.23513491, 0.61414618])
Coordinates:
* x (x) int64 2 3 4 5 2 3 4 51 | da_concat.xindexes |
Indexes: x: <__main__.ArrayIndex object at 0x16c294af0>
roll¶
1 | da.roll({"x": 2}, roll_coords=True) |
<xarray.DataArray (x: 4)> array([0.23513491, 0.61414618, 0.5972036 , 0.36977134]) Coordinates: * x (x) int64 4 5 2 3
align¶
Only exact alignment is supported.
1 | xr.align(da, da * 2, join="exact") |
(<xarray.DataArray (x: 4)> array([0.5972036 , 0.36977134, 0.23513491, 0.61414618]) Coordinates: * x (x) int64 2 3 4 5, <xarray.DataArray (x: 4)> array([1.1944072 , 0.73954268, 0.47026983, 1.22829237]) Coordinates: * x (x) int64 2 3 4 5)
1 | da + da |
<xarray.DataArray (x: 4)> array([1.1944072 , 0.73954268, 0.47026983, 1.22829237]) Coordinates: * x (x) int64 2 3 4 5
1 | xr.align(da, xr.concat([da, da], "x"), join="inner") |
--------------------------------------------------------------------------- NotImplementedError Traceback (most recent call last) /var/folders/xd/3ls911kd6_n2wphwwd74b1dc0000gn/T/ipykernel_23878/2620929422.py in <module> ----> 1 xr.align(da, xr.concat([da, da], "x"), join="inner") ~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(join, copy, indexes, exclude, fill_value, *objects) 762 fill_value=fill_value, 763 ) --> 764 aligner.align() 765 return aligner.results 766 ~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(self) 549 self.find_matching_unindexed_dims() 550 self.assert_no_index_conflict() --> 551 self.align_indexes() 552 self.assert_unindexed_dim_sizes_equal() 553 ~/Git/github/benbovy/xarray/xarray/core/alignment.py in align_indexes(self) 402 ) 403 joiner = self._get_index_joiner(index_cls) --> 404 joined_index = joiner(matching_indexes) 405 if self.join == "left": 406 joined_index_vars = matching_index_vars[0] ~/Git/github/benbovy/xarray/xarray/core/indexes.py in join(self, other, how) 93 94 def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index: ---> 95 raise NotImplementedError( 96 f"{self!r} doesn't support alignment with inner/outer join method" 97 ) NotImplementedError: <__main__.ArrayIndex object at 0x16c28ee80> doesn't support alignment with inner/outer join method
1 |
Compare Index build overhead PandasIndex vs. ArrayIndex¶
1 2 | data = np.random.permutation(np.arange(10_000_000)) ds = xr.Dataset(coords={"x": data}).drop_indexes("x") |
1 2 | %timeit PandasIndex(data, "x") %timeit ArrayIndex(data, "x", "x") |
23.5 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 308 ns ± 21.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
1 2 3 4 5 6 | # why do we get a ratio here that is not as good as for the index constructors? # ArrayIndex.create_variables() likely adds overhead. # TODO: profile this! %timeit ds.set_xindex("x", PandasIndex) %timeit ds.set_xindex("x", ArrayIndex) |
42.5 µs ± 884 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each) 18.5 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1 |