ArrayIndex¶

An Xarray-compatible, cheap index for basic indexing and alignment operations, to use when a pandas.Index is not really needed.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
from __future__ import annotations

from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence

import numpy as np
import xarray as xr

from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar
from xarray.core.indexing import IndexSelResult
from xarray.core import nputils
from xarray.core.variable import Variable, IndexVariable

#if TYPE_CHECKING:
from xarray.core.types import JoinOptions, T_Index
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class ArrayIndex(Index):
    """Numpy-like array index.
    
    Lightweight, inefficient index as a basic wrapper around
    its coordinate array data.
    
    This index is suited for cases where index build overhead
    is an issue and where only basic indexing operations are
    needed (i.e., strict alignment, data selection in rare occasions).
    
    """
    array: np.ndarray
    dim: Hashable
    name: Hashable
      
    # cause AttributeError with `da + da` example below?????
    #__slots__ = ("array", "dim", "name")
    
    def __init__(self, array, dim, name):
        if array.ndim > 1:
            raise ValueError("ArrayIndex only accepts 1-dimensional arrays")

        self.array = array
        self.dim = dim
        self.name = name
        
    @classmethod
    def from_variables(
        cls: type[T_Index], variables: Mapping[Any, Variable], options
    ):
        if len(variables) != 1:
            raise ValueError(
                f"PandasIndex only accepts one variable, found {len(variables)} variables"
            )

        name, var = next(iter(variables.items()))
        
        # TODO: use `var.data` instead? (allow lazy/duck arrays)
        return cls(var.values, var.dims[0], name)
    
    @classmethod
    def concat(
        cls: type[T_Index],
        indexes: Sequence[T_Index],
        dim: Hashable,
        positions: Iterable[Iterable[int]] = None,
    ) -> T_Index:
        if not indexes:
            return cls(np.array([]), dim, dim)
        
        if not all(idx.dim == dim for idx in indexes):
            dims = ",".join({f"{idx.dim!r}" for idx in indexes})
            raise ValueError(
                f"Cannot concatenate along dimension {dim!r} indexes with "
                f"dimensions: {dims}"
            )
        
        arrays = [idx.array for idx in indexes]
        new_array = np.concatenate(arrays)
        
        if positions is not None:
            indices = nputils.inverse_permutation(np.concatenate(positions))
            new_array = new_array.take(indices)

        return cls(new_array, dim, indexes[0].name)
    
    def create_variables(
        self, variables: Mapping[Any, Variable] | None = None
    ) -> IndexVars:
        
        #
        # TODO: implementation is needed so that the corresponding
        # coordinate is indexed properly with Dataset.isel.
        # Ideally this shouldn't be needed, though.
        #
        
        if variables is not None and self.name in variables:
            var = variables[self.name]
            attrs = var.attrs
            encoding = var.encoding
        else:
            attrs = None
            encoding = None

        var = Variable(self.dim, self.array, attrs=attrs, encoding=encoding)
        return {self.name: var}
        
    def isel(
        self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
    ) -> T_Index | PandasIndex | None:
        indxr = indexers[self.dim]

        if isinstance(indxr, Variable):
            if indxr.dims != (self.dim,):
                # can't preserve a index if result has new dimensions
                return None
            else:
                indxr = indxr.data
        if not isinstance(indxr, slice) and is_scalar(indxr):
            # scalar indexer: drop index
            return None
        
        return type(self)(self.array[indxr], self.dim, self.name)

    def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:
        assert len(labels) == 1
        _, label = next(iter(labels.items()))
        
        if isinstance(label, slice):
            # TODO: what exactly do we want to do here?
            start = np.argmax(self.array == label.start)
            stop = np.argmax(self.array == label.stop)
            indexer = slice(start, stop)
        elif is_scalar(label):
            indexer = np.argmax(self.array == label)
        else:
            # TODO: other label types we want to support (n-d array-like, etc.)
            raise ValueError(f"label {label} not supported by ArrayIndex")
        
        return IndexSelResult({self.dim: indexer})
    
    def equals(self: T_Index, other: T_Index) -> bool:
        return np.array_equal(self.array, other.array)
    
    def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index:
        shift = shifts[self.dim]
        
        return type(self)(np.roll(self.array, shift), self.dim, self.name)

    
1
 

Example¶

1
2
3
4
5
da = xr.DataArray(
    np.random.uniform(size=4),
    coords={"x": [2, 3, 4, 5]},
    dims="x",
)
1
2
3
da = da.drop_indexes("x").set_xindex("x", ArrayIndex)

da
<xarray.DataArray (x: 4)>
array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])
Coordinates:
  * x        (x) int64 2 3 4 5
xarray.DataArray
  • x: 4
  • 0.5972 0.3698 0.2351 0.6141
    array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])
    • x
      (x)
      int64
      2 3 4 5
      array([2, 3, 4, 5])
1
da.xindexes
Indexes:
x: <__main__.ArrayIndex object at 0x16c28ee80>

sel / isel¶

1
2
3
da_sel = da.sel(x=2)

da_sel
<xarray.DataArray ()>
array(0.5972036)
Coordinates:
    x        int64 2
xarray.DataArray
  • 0.5972
    array(0.5972036)
    • x
      ()
      int64
      2
      array(2)
1
2
3
da_sel = da.sel(x=slice(2, 4))

da_sel
<xarray.DataArray (x: 2)>
array([0.5972036 , 0.36977134])
Coordinates:
  * x        (x) int64 2 3
xarray.DataArray
  • x: 2
  • 0.5972 0.3698
    array([0.5972036 , 0.36977134])
    • x
      (x)
      int64
      2 3
      array([2, 3])
1
da_sel.xindexes
Indexes:
x: <__main__.ArrayIndex object at 0x16c294040>

concat¶

1
2
3
da_concat = xr.concat([da, da], "x")

da_concat
<xarray.DataArray (x: 8)>
array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,
       0.36977134, 0.23513491, 0.61414618])
Coordinates:
  * x        (x) int64 2 3 4 5 2 3 4 5
xarray.DataArray
  • x: 8
  • 0.5972 0.3698 0.2351 0.6141 0.5972 0.3698 0.2351 0.6141
    array([0.5972036 , 0.36977134, 0.23513491, 0.61414618, 0.5972036 ,
           0.36977134, 0.23513491, 0.61414618])
    • x
      (x)
      int64
      2 3 4 5 2 3 4 5
      array([2, 3, 4, 5, 2, 3, 4, 5])
1
da_concat.xindexes
Indexes:
x: <__main__.ArrayIndex object at 0x16c294af0>

roll¶

1
da.roll({"x": 2}, roll_coords=True)
<xarray.DataArray (x: 4)>
array([0.23513491, 0.61414618, 0.5972036 , 0.36977134])
Coordinates:
  * x        (x) int64 4 5 2 3
xarray.DataArray
  • x: 4
  • 0.2351 0.6141 0.5972 0.3698
    array([0.23513491, 0.61414618, 0.5972036 , 0.36977134])
    • x
      (x)
      int64
      4 5 2 3
      array([4, 5, 2, 3])

align¶

Only exact alignment is supported.

1
xr.align(da, da * 2, join="exact")
(<xarray.DataArray (x: 4)>
 array([0.5972036 , 0.36977134, 0.23513491, 0.61414618])
 Coordinates:
   * x        (x) int64 2 3 4 5,
 <xarray.DataArray (x: 4)>
 array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])
 Coordinates:
   * x        (x) int64 2 3 4 5)
1
da + da
<xarray.DataArray (x: 4)>
array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])
Coordinates:
  * x        (x) int64 2 3 4 5
xarray.DataArray
  • x: 4
  • 1.194 0.7395 0.4703 1.228
    array([1.1944072 , 0.73954268, 0.47026983, 1.22829237])
    • x
      (x)
      int64
      2 3 4 5
      array([2, 3, 4, 5])
1
xr.align(da, xr.concat([da, da], "x"), join="inner")
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
/var/folders/xd/3ls911kd6_n2wphwwd74b1dc0000gn/T/ipykernel_23878/2620929422.py in <module>
----> 1 xr.align(da, xr.concat([da, da], "x"), join="inner")

~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(join, copy, indexes, exclude, fill_value, *objects)
    762         fill_value=fill_value,
    763     )
--> 764     aligner.align()
    765     return aligner.results
    766 

~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(self)
    549         self.find_matching_unindexed_dims()
    550         self.assert_no_index_conflict()
--> 551         self.align_indexes()
    552         self.assert_unindexed_dim_sizes_equal()
    553 

~/Git/github/benbovy/xarray/xarray/core/alignment.py in align_indexes(self)
    402                         )
    403                     joiner = self._get_index_joiner(index_cls)
--> 404                     joined_index = joiner(matching_indexes)
    405                     if self.join == "left":
    406                         joined_index_vars = matching_index_vars[0]

~/Git/github/benbovy/xarray/xarray/core/indexes.py in join(self, other, how)
     93 
     94     def join(self: T_Index, other: T_Index, how: str = "inner") -> T_Index:
---> 95         raise NotImplementedError(
     96             f"{self!r} doesn't support alignment with inner/outer join method"
     97         )

NotImplementedError: <__main__.ArrayIndex object at 0x16c28ee80> doesn't support alignment with inner/outer join method
1
 

Compare Index build overhead PandasIndex vs. ArrayIndex¶

1
2
data = np.random.permutation(np.arange(10_000_000))
ds = xr.Dataset(coords={"x": data}).drop_indexes("x")
1
2
%timeit PandasIndex(data, "x")
%timeit ArrayIndex(data, "x", "x")
23.5 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
308 ns ± 21.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
1
2
3
4
5
6
# why do we get a ratio here that is not as good as for the index constructors?
# ArrayIndex.create_variables() likely adds overhead.
# TODO: profile this!

%timeit ds.set_xindex("x", PandasIndex)
%timeit ds.set_xindex("x", ArrayIndex)
42.5 µs ± 884 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
18.5 µs ± 496 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1