MultiPandasIndex¶

In many cases an Xarray custom index may be built on top of one or more PandasIndex instances. This notebook provides a helper class MultiPandasIndex with all the boilerplate, i.e., for each method the input arguments are deferred / dispatched to the encapsulated PandasIndex instances.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from __future__ import annotations

from typing import Any, TYPE_CHECKING, Mapping, Hashable, Iterable, Sequence

import numpy as np
import pandas as pd
import xarray as xr

from xarray.core.indexes import Index, PandasIndex, IndexVars, is_scalar
from xarray.core.indexing import IndexSelResult, merge_sel_results
from xarray.core.utils import Frozen
from xarray.core.variable import Variable

#if TYPE_CHECKING:
from xarray.core.types import ErrorOptions, JoinOptions, T_Index
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
class MultiPandasIndex(Index):
    """Helper class to implement meta-indexes encapsulating
    one or more (single) pandas indexes.
    
    Each pandas index must relate to a separate dimension.
    
    This class shoudn't be instantiated directly.

    """
    indexes: Frozen[Hashable, PandasIndex]
    dims: Frozen[Hashable, int]
        
    __slots__ = ("indexes", "dims")
    
    def __init__(self, indexes: Mapping[Hashable, PandasIndex]):
        dims = {idx.dim: idx.index.size for idx in indexes.values()}
        
        seen = set()
        dup_dims = [d for d in dims if d in seen or seen.add(d)]
        if dup_dims:
            raise ValueError(
                f"cannot create a {self.__class__.__name__} from coordinates "
                f"sharing common dimension(s): {dup_dims}"
            )
        
        self.indexes = Frozen(indexes)
        self.dims = Frozen(dims)
    
    @classmethod
    def from_variables(
        cls: type[T_Index], variables: Mapping[Any, Variable], options
    ):
        indexes = {
            k: PandasIndex.from_variables({k: v}, options={})
            for k, v in variables.items()
        }

        return cls(indexes)
    
    @classmethod
    def concat(
        cls: type[T_Index],
        indexes: Sequence[T_Index],
        dim: Hashable,
        positions: Iterable[Iterable[int]] = None,
    ) -> T_Index:
        new_indexes = {}
        
        for k, idx in self.indexes.items():
            if idx.dim == dim:
                new_indexes[k] = PandasIndex.concat(indexes, dim, positions)
            else:
                new_indexes[k] = idx
        
        return cls(new_indexes)
    
    def create_variables(
        self, variables: Mapping[Any, Variable] | None = None
    ) -> IndexVars:

        idx_variables = {}

        for idx in self.indexes.values():
            idx_variables.update(idx.create_variables(variables))

        return idx_variables
    
    def isel(
        self: T_Index, indexers: Mapping[Any, int | slice | np.ndarray | Variable]
    ) -> T_Index | PandasIndex | None:
        new_indexes = {}
        
        for k, idx in self.indexes.items():
            if k in indexers:
                new_idx = idx.isel({k: indexers[k]})
                if new_idx is not None:
                    new_indexes[k] = new_idx
            else:
                new_indexes[k] = idx
                
        #
        # How should we deal with dropped index(es) (scalar selection)?
        # - drop the whole index?
        # - always return a MultiPandasIndex with remaining index(es)?
        # - return either a MultiPandasIndex or a PandasIndex?
        #
                
        if not len(new_indexes):
            return None
        elif len(new_indexes) == 1:
            return next(iter(new_indexes.values()))
        else:
            return type(self)(new_indexes)

    def sel(self, labels: dict[Any, Any], **kwargs) -> IndexSelResult:
        results: list[IndexSelResult] = []

        for k, idx in self.indexes.items():
            if k in labels:
                results.append(idx.sel({k: labels[k]}, **kwargs))
                
        return merge_sel_results(results)
    
    def _get_unmatched_names(self: T_Index, other: T_Index) -> set:
        return set(self.indexes).symmetric_difference(other.indexes)
    
    def equals(self: T_Index, other: T_Index) -> bool:
        # We probably don't need to check for matching coordinate names
        # as this is already done during alignment when finding matching indexes.
        # This may change in the future, though.
        # see https://github.com/pydata/xarray/issues/7002
        if self._get_unmatched_names(other):
            return False
        else:
            return all(
                [idx.equals(other.indexes[k]) for k, idx in self.indexes.items()]
            )
        
    def join(self: T_Index, other: T_Index, how: JoinOptions = "inner") -> T_Index:
        new_indexes = {}

        for k, idx in self.indexes.items():
            new_indexes[k] = idx.join(other.indexes[k], how=how)
        
        return type(self)(new_indexes)
    
    def reindex_like(self: T_Index, other: T_Index) -> dict[Hashable, Any]:
        dim_indexers = {}
        
        for k, idx in self.indexes.items():
            dim_indexers.update(idx.reindex_like(other.indexes[k]))
        
        return dim_indexers
    
    def roll(self: T_Index, shifts: Mapping[Any, int]) -> T_Index:
        new_indexes = {}
        
        for k, idx in self.indexes.items():
            if k in shifts:
                new_indexes[k] = idx.roll({k: shifts[k]})
            else:
                new_indexes[k] = idx

        return type(self)(new_indexes)
    
    def rename(
        self: T_Index,
        name_dict: Mapping[Any, Hashable],
        dims_dict: Mapping[Any, Hashable],
    ) -> T_Index:
        new_indexes = {}
        
        for k, idx in self.indexes.items():
            new_indexes[k] = idx.rename(name_dict, dims_dict)
        
        return type(self)(new_indexes)
        
    def copy(self: T_Index, deep: bool = True) -> T_Index:
        new_indexes = {}
        
        for k, idx in self.indexes.items():
            new_indexes[k] = idx.copy(deep=deep)
        
        return type(self)(new_indexes)

Issues:

  • How to allow custom __init__ options in subclasses be passed to all the type(self)(new_indexes) calls inside the MultiPandasIndex "base" class? This could be done via **kwargs passed through... However, mypy will certainly complain (Liskov Substitution Principle).
1
 

Example¶

Just to see if it works well. MultiPandasIndex shouldn't be used directly in a DataArray or Dataset.

1
2
3
4
5
da = xr.DataArray(
    np.random.uniform(size=(4, 5)),
    coords={"x": range(5), "y": range(4)},
    dims=("y", "x"),
)
1
2
3
4
5
6
7
da = (
    da
    .drop_indexes(["x", "y"])
    .set_xindex(["x", "y"], MultiPandasIndex)
)

da
<xarray.DataArray (y: 4, x: 5)>
array([[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],
       [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],
       [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],
       [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]])
Coordinates:
  * x        (x) int64 0 1 2 3 4
  * y        (y) int64 0 1 2 3
xarray.DataArray
  • y: 4
  • x: 5
  • 0.4395 0.879 0.7642 0.9921 0.8362 ... 0.3158 0.04704 0.8169 0.5648
    array([[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],
           [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],
           [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],
           [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]])
    • x
      (x)
      int64
      0 1 2 3 4
      array([0, 1, 2, 3, 4])
    • y
      (y)
      int64
      0 1 2 3
      array([0, 1, 2, 3])
1
da.xindexes
Indexes:
x: <__main__.MultiPandasIndex object at 0x16b5cb580>
y: <__main__.MultiPandasIndex object at 0x16b5cb580>

sel / isel¶

1
2
3
da_sel = da.sel(x=[0, 0], y=[0, 2])

da_sel
<xarray.DataArray (y: 2, x: 2)>
array([[0.43947939, 0.43947939],
       [0.91366429, 0.91366429]])
Coordinates:
  * x        (x) int64 0 0
  * y        (y) int64 0 2
xarray.DataArray
  • y: 2
  • x: 2
  • 0.4395 0.4395 0.9137 0.9137
    array([[0.43947939, 0.43947939],
           [0.91366429, 0.91366429]])
    • x
      (x)
      int64
      0 0
      array([0, 0])
    • y
      (y)
      int64
      0 2
      array([0, 2])
1
da_sel.xindexes
Indexes:
x: <__main__.MultiPandasIndex object at 0x16b096880>
y: <__main__.MultiPandasIndex object at 0x16b096880>
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# This is not right!
# "x" coordinate should have no index but it has the same PandasIndex than "y"
# it doesn't make any sense
# 
# To fix this, `Index.isel` should probably return a `dict[Hashable, Index]` instead of `Index | None`
# Or alternatively, we drop the whole `PandasMultiIndex` when this happens

da_sel = da.sel(x=0)

da_sel
<xarray.DataArray (y: 4)>
array([0.43947939, 0.75214201, 0.91366429, 0.95073961])
Coordinates:
  * x        int64 0
  * y        (y) int64 0 1 2 3
xarray.DataArray
  • y: 4
  • 0.4395 0.7521 0.9137 0.9507
    array([0.43947939, 0.75214201, 0.91366429, 0.95073961])
    • x
      ()
      int64
      0
      array(0)
    • y
      (y)
      int64
      0 1 2 3
      array([0, 1, 2, 3])
1
da_sel.xindexes
Indexes:
x: <xarray.core.indexes.PandasIndex object at 0x16bf51130>
y: <xarray.core.indexes.PandasIndex object at 0x16bf51130>

align¶

1
xr.align(da, da.isel(x=[0, 1], y=[2, 3]), join="outer")
(<xarray.DataArray (y: 4, x: 5)>
 array([[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],
        [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],
        [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],
        [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]])
 Coordinates:
   * x        (x) int64 0 1 2 3 4
   * y        (y) int64 0 1 2 3,
 <xarray.DataArray (y: 4, x: 5)>
 array([[       nan,        nan,        nan,        nan,        nan],
        [       nan,        nan,        nan,        nan,        nan],
        [0.91366429, 0.25963693,        nan,        nan,        nan],
        [0.95073961, 0.31579758,        nan,        nan,        nan]])
 Coordinates:
   * x        (x) int64 0 1 2 3 4
   * y        (y) int64 0 1 2 3)

roll¶

1
2
3
da_roll = da.roll({"x": 2, "y": -1}, roll_coords=True)

da_roll
<xarray.DataArray (y: 4, x: 5)>
array([[0.74263207, 0.60629903, 0.75214201, 0.22178014, 0.0969697 ],
       [0.50972423, 0.3037911 , 0.91366429, 0.25963693, 0.20251133],
       [0.81686866, 0.56483109, 0.95073961, 0.31579758, 0.04704333],
       [0.99212782, 0.83624422, 0.43947939, 0.87899004, 0.76420298]])
Coordinates:
  * x        (x) int64 3 4 0 1 2
  * y        (y) int64 1 2 3 0
xarray.DataArray
  • y: 4
  • x: 5
  • 0.7426 0.6063 0.7521 0.2218 0.09697 ... 0.8362 0.4395 0.879 0.7642
    array([[0.74263207, 0.60629903, 0.75214201, 0.22178014, 0.0969697 ],
           [0.50972423, 0.3037911 , 0.91366429, 0.25963693, 0.20251133],
           [0.81686866, 0.56483109, 0.95073961, 0.31579758, 0.04704333],
           [0.99212782, 0.83624422, 0.43947939, 0.87899004, 0.76420298]])
    • x
      (x)
      int64
      3 4 0 1 2
      array([3, 4, 0, 1, 2])
    • y
      (y)
      int64
      1 2 3 0
      array([1, 2, 3, 0])
1
da_roll.xindexes
Indexes:
x: <__main__.MultiPandasIndex object at 0x16bf6bb80>
y: <__main__.MultiPandasIndex object at 0x16bf6bb80>

rename¶

1
2
3
da_renamed = da.to_dataset(name="foo").rename({"x": "z"}).to_array()

da_renamed
<xarray.DataArray (variable: 1, y: 4, z: 5)>
array([[[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],
        [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],
        [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],
        [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]]])
Coordinates:
  * z         (z) int64 0 1 2 3 4
  * y         (y) int64 0 1 2 3
  * variable  (variable) object 'foo'
xarray.DataArray
  • variable: 1
  • y: 4
  • z: 5
  • 0.4395 0.879 0.7642 0.9921 0.8362 ... 0.3158 0.04704 0.8169 0.5648
    array([[[0.43947939, 0.87899004, 0.76420298, 0.99212782, 0.83624422],
            [0.75214201, 0.22178014, 0.0969697 , 0.74263207, 0.60629903],
            [0.91366429, 0.25963693, 0.20251133, 0.50972423, 0.3037911 ],
            [0.95073961, 0.31579758, 0.04704333, 0.81686866, 0.56483109]]])
    • z
      (z)
      int64
      0 1 2 3 4
      array([0, 1, 2, 3, 4])
    • y
      (y)
      int64
      0 1 2 3
      array([0, 1, 2, 3])
    • variable
      (variable)
      object
      'foo'
      array(['foo'], dtype=object)
1
da_renamed.xindexes
Indexes:
z: <__main__.MultiPandasIndex object at 0x10e512780>
y: <__main__.MultiPandasIndex object at 0x10e512780>
variable: <xarray.core.indexes.PandasIndex object at 0x16bf3be50>

concat¶

1
2
3
4
5
# Does Xarray support concat along a dimension part of a multi-dimension index? Looks like it doesn't.
# It might require to review some alignment rules...
# I've no idea how easy/hard would it be to support that.

xr.concat([da.isel(x=[0]), da.isel(x=[1])], dim="x")
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/xd/3ls911kd6_n2wphwwd74b1dc0000gn/T/ipykernel_14023/2421416148.py in <module>
      3 # I've no idea how easy/hard would it be to support that.
      4 
----> 5 xr.concat([da.isel(x=[0]), da.isel(x=[1])], dim="x")

~/Git/github/benbovy/xarray/xarray/core/concat.py in concat(objs, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs)
    229 
    230     if isinstance(first_obj, DataArray):
--> 231         return _dataarray_concat(
    232             objs,
    233             dim=dim,

~/Git/github/benbovy/xarray/xarray/core/concat.py in _dataarray_concat(arrays, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs)
    655         datasets.append(arr._to_temp_dataset())
    656 
--> 657     ds = _dataset_concat(
    658         datasets,
    659         dim,

~/Git/github/benbovy/xarray/xarray/core/concat.py in _dataset_concat(datasets, dim, data_vars, coords, compat, positions, fill_value, join, combine_attrs)
    464     datasets = [ds.copy() for ds in datasets]
    465     datasets = list(
--> 466         align(*datasets, join=join, copy=False, exclude=[dim], fill_value=fill_value)
    467     )
    468 

~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(join, copy, indexes, exclude, fill_value, *objects)
    762         fill_value=fill_value,
    763     )
--> 764     aligner.align()
    765     return aligner.results
    766 

~/Git/github/benbovy/xarray/xarray/core/alignment.py in align(self)
    546             self.results = (obj.copy(deep=self.copy),)
    547 
--> 548         self.find_matching_indexes()
    549         self.find_matching_unindexed_dims()
    550         self.assert_no_index_conflict()

~/Git/github/benbovy/xarray/xarray/core/alignment.py in find_matching_indexes(self)
    253 
    254         for obj in self.objects:
--> 255             obj_indexes, obj_index_vars = self._normalize_indexes(obj.xindexes)
    256             objects_matching_indexes.append(obj_indexes)
    257             for key, idx in obj_indexes.items():

~/Git/github/benbovy/xarray/xarray/core/alignment.py in _normalize_indexes(self, indexes)
    229                 excl_dims_str = ", ".join(str(d) for d in exclude_dims)
    230                 incl_dims_str = ", ".join(str(d) for d in all_dims - exclude_dims)
--> 231                 raise ValueError(
    232                     f"cannot exclude dimension(s) {excl_dims_str} from alignment because "
    233                     "these are used by an index together with non-excluded dimensions "

ValueError: cannot exclude dimension(s) x from alignment because these are used by an index together with non-excluded dimensions y
1