DaskIndex (out-of-core array index)¶

Implementation is taken from this NumpyIndex prototype, with only minor adaptation.

We could merge both NumpyIndex and DaskIndex into a generic (and basic) ArrayIndex, which would work with any duck array (lazy or not).

1
2
3
4
5
from typing import Any, Hashable, Iterable, Mapping, Self, Sequence

import dask.array as da
import numpy as np
import xarray as xr

Implementation¶

1
2
3
4
from xarray import Variable
from xarray.indexes import Index
from xarray.core.indexing import IndexSelResult
from xarray.core.utils import is_scalar
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
class DaskIndex(Index):
    """Out-of-core (dask) array index.
    
    Lightweight, inefficient index as a basic wrapper around
    its coordinate array data.
    
    This index is suited for cases where index build overhead
    is an issue and where only basic indexing operations are
    needed (i.e., strict alignment, data selection in rare occasions).
    
    """
    array: da.Array
    dim: Hashable
    name: Hashable
    
    def __init__(self, array: da.Array, dim: Hashable, name: Hashable):
        if array.ndim > 1:
            raise ValueError("ArrayIndex only accepts 1-dimensional arrays")

        self.array = array
        self.dim = dim
        self.name = name
        
    @classmethod
    def from_variables(
        cls: type[Self], variables: Mapping[Any, Variable], options
    ) -> Self:
        if len(variables) != 1:
            raise ValueError(
                f"DaskIndex only accepts one variable, found {len(variables)} variables"
            )

        name, var = next(iter(variables.items()))
        
        return cls(var.data, var.dims[0], name)
    
    @classmethod
    def concat(
        cls: type[Self],
        indexes: Sequence[Self],
        dim: Hashable,
        positions: Iterable[Iterable[int]] = None,
    ) -> Self:
        raise NotImplementedError
        if not indexes:
            return cls(da.Array([]), dim, dim)
        
        if not all(idx.dim == dim for idx in indexes):
            dims = ",".join({f"{idx.dim!r}" for idx in indexes})
            raise ValueError(
                f"Cannot concatenate along dimension {dim!r} indexes with "
                f"dimensions: {dims}"
            )
        
        arrays = [idx.array for idx in indexes]
        new_array = da.concatenate(arrays)
        
        if positions is not None:
            indices = nputils.inverse_permutation(da.concatenate(positions))
            new_array = new_array.take(indices)

        return cls(new_array, dim, indexes[0].name)
    
    def create_variables(
        self, variables: Mapping[Any, Variable] | None = None
    ) -> dict[Hashable, Variable]:
        
        #
        # TODO: implementating this method is needed so that
        # the corresponding coordinate is indexed properly with Dataset.isel.
        # Ideally this shouldn't be needed, though (we only extract and
        # shallow copy the coordinate variable here, but really not even
        # a copy is needed).
        #
        
        if variables is not None and self.name in variables:
            var = variables[self.name]
            attrs = var.attrs
            encoding = var.encoding
        else:
            attrs = None
            encoding = None

        var = Variable(self.dim, self.array, attrs=attrs, encoding=encoding)
        return {self.name: var}
        
    def isel(
        self: Self, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
    ) -> Self | None:
        indxr = indexers[self.dim]

        if isinstance(indxr, Variable):
            if indxr.dims != (self.dim,):
                # can't preserve a index if result has new dimensions
                return None
            else:
                indxr = indxr.data
        if not isinstance(indxr, slice) and is_scalar(indxr):
            # scalar indexer: drop index
            return None
        
        return type(self)(self.array[indxr], self.dim, self.name)

    def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:
        assert len(labels) == 1
        _, label = next(iter(labels.items()))
        
        if isinstance(label, slice):
            # TODO: what exactly do we want to do here?
            start = da.argmax(self.array == label.start)
            stop = da.argmax(self.array == label.stop)
            indexer = slice(start, stop)
        elif is_scalar(label):
            indexer = da.argmax(self.array == label)
        else:
            # TODO: other label types we want to support (n-d array-like, etc.)
            raise ValueError(f"label {label} not (yet) supported by DaskIndex")
        
        return IndexSelResult({self.dim: indexer})
    
    def equals(self: Self, other: Self) -> bool:
        if self.array.size != other.array.size:
            return False
        else:
            return da.all(self.array == other.array)
    
    def roll(self: Self, shifts: Mapping[Any, int]) -> Self:
        shift = shifts[self.dim]
        
        return type(self)(da.roll(self.array, shift), self.dim, self.name)

Example¶

Construction¶

Create coordinates "x" and "y", with no index for "x" and a default (pandas) index for "y".

Only works with https://github.com/pydata/xarray/pull/8094!

1
2
coords = xr.Coordinates({"x": ("x", da.arange(100_000_000))}, indexes={})
coords["y"] = np.arange(100)

Create a dataset using the coordinates above, and set a DaskIndex for the "x" coordinate. Coordinate data remains lazy.

1
2
3
4
5
6
7
8
ds = xr.Dataset(
    data_vars={"foo": (("y", "x"), da.random.random((100, 100_000_000)))},
    coords=coords,
)

ds = ds.set_xindex("x", DaskIndex)

ds
<xarray.Dataset>
Dimensions:  (y: 100, x: 100000000)
Coordinates:
  * x        (x) int64 dask.array<chunksize=(16777216,), meta=np.ndarray>
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
Data variables:
    foo      (y, x) float64 dask.array<chunksize=(100, 167772), meta=np.ndarray>
Indexes:
    x        DaskIndex
xarray.Dataset
    • y: 100
    • x: 100000000
    • x
      (x)
      int64
      dask.array<chunksize=(16777216,), meta=np.ndarray>
      Array Chunk
      Bytes 762.94 MiB 128.00 MiB
      Shape (100000000,) (16777216,)
      Dask graph 6 chunks in 1 graph layer
      Data type int64 numpy.ndarray
      100000000 1
    • y
      (y)
      int64
      0 1 2 3 4 5 6 ... 94 95 96 97 98 99
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
             36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
             54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
             72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
             90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
    • foo
      (y, x)
      float64
      dask.array<chunksize=(100, 167772), meta=np.ndarray>
      Array Chunk
      Bytes 74.51 GiB 128.00 MiB
      Shape (100, 100000000) (100, 167772)
      Dask graph 597 chunks in 1 graph layer
      Data type float64 numpy.ndarray
      100000000 100
    • y
      PandasIndex
      PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
                  17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
                  34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
                  51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
                  68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
                  85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
                 dtype='int64', name='y'))
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x109492a50>

Label-based selection¶

Select data by coordinate labels.

1
2
ds_subset = ds.sel(y=[10, 12], x=slice(10_000, 20_000))
ds_subset
<xarray.Dataset>
Dimensions:  (y: 2, x: 10000)
Coordinates:
  * x        (x) int64 dask.array<chunksize=(10000,), meta=np.ndarray>
  * y        (y) int64 10 12
Data variables:
    foo      (y, x) float64 dask.array<chunksize=(2, 10000), meta=np.ndarray>
Indexes:
    x        DaskIndex
xarray.Dataset
    • y: 2
    • x: 10000
    • x
      (x)
      int64
      dask.array<chunksize=(10000,), meta=np.ndarray>
      Array Chunk
      Bytes 78.12 kiB 78.12 kiB
      Shape (10000,) (10000,)
      Dask graph 1 chunks in 2 graph layers
      Data type int64 numpy.ndarray
      10000 1
    • y
      (y)
      int64
      10 12
      array([10, 12])
    • foo
      (y, x)
      float64
      dask.array<chunksize=(2, 10000), meta=np.ndarray>
      Array Chunk
      Bytes 156.25 kiB 156.25 kiB
      Shape (2, 10000) (2, 10000)
      Dask graph 1 chunks in 2 graph layers
      Data type float64 numpy.ndarray
      10000 2
    • y
      PandasIndex
      PandasIndex(Int64Index([10, 12], dtype='int64', name='y'))
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x114f517d0>

The operation above is fully lazy:

1
%timeit ds.sel(y=[10, 12], x=slice(10_000, 20_000))
655 ms ± 13.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

The coordinates and data variables of the selection remain lazy

1
ds_subset.x
<xarray.DataArray 'x' (x: 10000)>
dask.array<getitem, shape=(10000,), dtype=int64, chunksize=(10000,), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 dask.array<chunksize=(10000,), meta=np.ndarray>
Indexes:
    x        DaskIndex
xarray.DataArray
'x'
  • x: 10000
  • dask.array<chunksize=(10000,), meta=np.ndarray>
    Array Chunk
    Bytes 78.12 kiB 78.12 kiB
    Shape (10000,) (10000,)
    Dask graph 1 chunks in 2 graph layers
    Data type int64 numpy.ndarray
    10000 1
    • x
      (x)
      int64
      dask.array<chunksize=(10000,), meta=np.ndarray>
      Array Chunk
      Bytes 78.12 kiB 78.12 kiB
      Shape (10000,) (10000,)
      Dask graph 1 chunks in 2 graph layers
      Data type int64 numpy.ndarray
      10000 1
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x114f517d0>
1
ds_subset.foo
<xarray.DataArray 'foo' (y: 2, x: 10000)>
dask.array<getitem, shape=(2, 10000), dtype=float64, chunksize=(2, 10000), chunktype=numpy.ndarray>
Coordinates:
  * x        (x) int64 dask.array<chunksize=(10000,), meta=np.ndarray>
  * y        (y) int64 10 12
Indexes:
    x        DaskIndex
xarray.DataArray
'foo'
  • y: 2
  • x: 10000
  • dask.array<chunksize=(2, 10000), meta=np.ndarray>
    Array Chunk
    Bytes 156.25 kiB 156.25 kiB
    Shape (2, 10000) (2, 10000)
    Dask graph 1 chunks in 2 graph layers
    Data type float64 numpy.ndarray
    10000 2
    • x
      (x)
      int64
      dask.array<chunksize=(10000,), meta=np.ndarray>
      Array Chunk
      Bytes 78.12 kiB 78.12 kiB
      Shape (10000,) (10000,)
      Dask graph 1 chunks in 2 graph layers
      Data type int64 numpy.ndarray
      10000 1
    • y
      (y)
      int64
      10 12
      array([10, 12])
    • y
      PandasIndex
      PandasIndex(Int64Index([10, 12], dtype='int64', name='y'))
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x114f517d0>

The selection is small so computing the Dataset is fast

1
ds_subset.compute()
<xarray.Dataset>
Dimensions:  (y: 2, x: 10000)
Coordinates:
  * x        (x) int64 10000 10001 10002 10003 10004 ... 19996 19997 19998 19999
  * y        (y) int64 10 12
Data variables:
    foo      (y, x) float64 0.8685 0.06375 0.2268 ... 0.7615 0.8553 0.8133
Indexes:
    x        DaskIndex
xarray.Dataset
    • y: 2
    • x: 10000
    • x
      (x)
      int64
      10000 10001 10002 ... 19998 19999
      array([10000, 10001, 10002, ..., 19997, 19998, 19999])
    • y
      (y)
      int64
      10 12
      array([10, 12])
    • foo
      (y, x)
      float64
      0.8685 0.06375 ... 0.8553 0.8133
      array([[0.86852983, 0.0637496 , 0.22682869, ..., 0.02478069, 0.59563265,
              0.88391897],
             [0.74251054, 0.78986414, 0.91074379, ..., 0.76147836, 0.85527309,
              0.81327248]])
    • y
      PandasIndex
      PandasIndex(Int64Index([10, 12], dtype='int64', name='y'))
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x115002890>

Roll¶

This is lazy too

1
ds.roll(x=5)
<xarray.Dataset>
Dimensions:  (y: 100, x: 100000000)
Coordinates:
  * x        (x) int64 dask.array<chunksize=(16777216,), meta=np.ndarray>
  * y        (y) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
Data variables:
    foo      (y, x) float64 dask.array<chunksize=(100, 167772), meta=np.ndarray>
Indexes:
    x        DaskIndex
xarray.Dataset
    • y: 100
    • x: 100000000
    • x
      (x)
      int64
      dask.array<chunksize=(16777216,), meta=np.ndarray>
      Array Chunk
      Bytes 762.94 MiB 128.00 MiB
      Shape (100000000,) (16777216,)
      Dask graph 6 chunks in 1 graph layer
      Data type int64 numpy.ndarray
      100000000 1
    • y
      (y)
      int64
      0 1 2 3 4 5 6 ... 94 95 96 97 98 99
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
             18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
             36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
             54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
             72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
             90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
    • foo
      (y, x)
      float64
      dask.array<chunksize=(100, 167772), meta=np.ndarray>
      Array Chunk
      Bytes 74.51 GiB 128.00 MiB
      Shape (100, 100000000) (100, 167772)
      Dask graph 597 chunks in 5 graph layers
      Data type float64 numpy.ndarray
      100000000 100
    • y
      PandasIndex
      PandasIndex(Int64Index([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
                  17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
                  34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
                  51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
                  68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
                  85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99],
                 dtype='int64', name='y'))
    • x
      DaskIndex
      <__main__.DaskIndex object at 0x109492a50>
1
ds.roll(x=5).foo.data.dask

HighLevelGraph

HighLevelGraph with 5 layers and 3582 keys from all layers.

Layer1: random_sample

random_sample-acc21eb09941670cf01836dec0ac388a

layer_type MaterializedLayer
is_materialized True
number of outputs 597
shape (100, 100000000)
dtype float64
chunksize (100, 167772)
type dask.array.core.Array
chunk_type numpy.ndarray
100000000 100

Layer2: getitem

getitem-4ac0bb0d0de92654303bf4a075878043

layer_type MaterializedLayer
is_materialized True
number of outputs 597
shape (100, 99999995)
dtype float64
chunksize (100, 167772)
type dask.array.core.Array
chunk_type numpy.ndarray
depends on random_sample-acc21eb09941670cf01836dec0ac388a
99999995 100

Layer3: getitem

getitem-98d9e40a49433bb8174db9d6f8751789

layer_type MaterializedLayer
is_materialized True
number of outputs 1
shape (100, 5)
dtype float64
chunksize (100, 5)
type dask.array.core.Array
chunk_type numpy.ndarray
depends on random_sample-acc21eb09941670cf01836dec0ac388a
5 100

Layer4: concatenate

concatenate-c0d3fe82da806cf22304a8920339aa2c

layer_type MaterializedLayer
is_materialized True
number of outputs 598
shape (100, 100000000)
dtype float64
chunksize (100, 167772)
type dask.array.core.Array
chunk_type numpy.ndarray
depends on getitem-98d9e40a49433bb8174db9d6f8751789
getitem-4ac0bb0d0de92654303bf4a075878043
100000000 100

Layer5: rechunk-merge

rechunk-merge-8bdcaead68d616d69dbf2aefcac6a961

layer_type MaterializedLayer
is_materialized True
number of outputs 1789
shape (100, 100000000)
dtype float64
chunksize (100, 167772)
type dask.array.core.Array
chunk_type numpy.ndarray
depends on concatenate-c0d3fe82da806cf22304a8920339aa2c
100000000 100

Alignment¶

Alignment, re-indexing, concatenate, etc. are not (well) supported. It may either fail (good) or try to compute and/or load all data (bad).

1
2
3
4
ds2 = xr.Dataset(coords={"x": [-2, -1]})
ds2 = ds2.drop_indexes("x").set_xindex("x", DaskIndex)

xr.align(ds, ds2)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[13], line 4
      1 ds2 = xr.Dataset(coords={"x": [-2, -1]})
      2 ds2 = ds2.drop_indexes("x").set_xindex("x", DaskIndex)
----> 4 xr.align(ds, ds2)

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:783, in align(join, copy, indexes, exclude, fill_value, *objects)
    587 """
    588 Given any number of Dataset and/or DataArray objects, returns new
    589 objects with aligned indexes and dimension sizes.
   (...)
    773 
    774 """
    775 aligner = Aligner(
    776     objects,
    777     join=join,
   (...)
    781     fill_value=fill_value,
    782 )
--> 783 aligner.align()
    784 return aligner.results

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:568, in Aligner.align(self)
    566 self.find_matching_unindexed_dims()
    567 self.assert_no_index_conflict()
--> 568 self.align_indexes()
    569 self.assert_unindexed_dim_sizes_equal()
    571 if self.join == "override":

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:422, in Aligner.align_indexes(self)
    415     raise ValueError(
    416         "cannot align objects with join='exact' where "
    417         "index/labels/sizes are not equal along "
    418         "these coordinates (dimensions): "
    419         + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0])
    420     )
    421 joiner = self._get_index_joiner(index_cls)
--> 422 joined_index = joiner(matching_indexes)
    423 if self.join == "left":
    424     joined_index_vars = matching_index_vars[0]

File ~/Git/github/benbovy/xarray/xarray/core/indexes.py:285, in Index.join(self, other, how)
    267 def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
    268     """Return a new index from the combination of this index with another
    269     index of the same type.
    270 
   (...)
    283         A new Index object.
    284     """
--> 285     raise NotImplementedError(
    286         f"{self!r} doesn't support alignment with inner/outer join method"
    287     )

NotImplementedError: <__main__.DaskIndex object at 0x109492a50> doesn't support alignment with inner/outer join method
1
xr.align(ds, ds2, join="inner")
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[14], line 1
----> 1 xr.align(ds, ds2, join="inner")

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:783, in align(join, copy, indexes, exclude, fill_value, *objects)
    587 """
    588 Given any number of Dataset and/or DataArray objects, returns new
    589 objects with aligned indexes and dimension sizes.
   (...)
    773 
    774 """
    775 aligner = Aligner(
    776     objects,
    777     join=join,
   (...)
    781     fill_value=fill_value,
    782 )
--> 783 aligner.align()
    784 return aligner.results

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:568, in Aligner.align(self)
    566 self.find_matching_unindexed_dims()
    567 self.assert_no_index_conflict()
--> 568 self.align_indexes()
    569 self.assert_unindexed_dim_sizes_equal()
    571 if self.join == "override":

File ~/Git/github/benbovy/xarray/xarray/core/alignment.py:422, in Aligner.align_indexes(self)
    415     raise ValueError(
    416         "cannot align objects with join='exact' where "
    417         "index/labels/sizes are not equal along "
    418         "these coordinates (dimensions): "
    419         + ", ".join(f"{name!r} {dims!r}" for name, dims in key[0])
    420     )
    421 joiner = self._get_index_joiner(index_cls)
--> 422 joined_index = joiner(matching_indexes)
    423 if self.join == "left":
    424     joined_index_vars = matching_index_vars[0]

File ~/Git/github/benbovy/xarray/xarray/core/indexes.py:285, in Index.join(self, other, how)
    267 def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
    268     """Return a new index from the combination of this index with another
    269     index of the same type.
    270 
   (...)
    283         A new Index object.
    284     """
--> 285     raise NotImplementedError(
    286         f"{self!r} doesn't support alignment with inner/outer join method"
    287     )

NotImplementedError: <__main__.DaskIndex object at 0x109492a50> doesn't support alignment with inner/outer join method
1