Xarray-compatible "functional" index (demo)¶

Notes:

This currently works with https://github.com/pydata/xarray/pull/6971 only!

WIP documentation on implementing custom indexes: https://github.com/pydata/xarray/pull/6975

1
2
3
4
5
import numpy as np
import pandas as pd
import xarray as xr

from xarray.core.indexes import IndexSelResult
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class FunctionalIndex(xr.Index):
    """Basic 1-dimensional index with linear function."""
    
    def __init__(self, pixel_data, slope, intercept, dim):
        self.pixel_data = pixel_data
        self.slope = slope
        self.intercept = intercept
        
        wmin = pixel_data[0] * slope + intercept
        wmax = pixel_data[-1] * slope + intercept
        self.extent = (wmin, wmax)
        
        self.dim = dim
    
    @classmethod
    def from_variables(cls, variables, options):
        # We should check the validity of the given variables!

        var = next(iter(variables.values()))
        return cls(
            var.values,
            var.attrs["slope"],
            var.attrs["intercept"],
            var.dims[0]
        )
    
    def sel(self, labels):
        # This implementation only works with slices!

        label = next(iter(labels.values()))
        
        if not isinstance(label, slice):
            raise TypeError("Selection using this index only works with slices")
        if label.start < self.extent[0] or label.stop > self.extent[1]:
            raise ValueError("Out of bounds selection")
            
        def convert(val):
            return int(self.slope * val + self.intercept)
        
        int_slice = slice(convert(label.start), convert(label.stop))
        
        return IndexSelResult({self.dim: int_slice})
1
2
3
4
5
6
7
wcs_attrs = {"slope": 0.5, "intercept": 0.0}

da = xr.DataArray(
    np.random.uniform(size=50),
    coords={"x_wcs": ("x", range(50), wcs_attrs)},
    dims="x"
)
1
da
<xarray.DataArray (x: 50)>
array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,
       0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,
       0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,
       0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,
       0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,
       0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,
       0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,
       0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,
       0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,
       0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])
Coordinates:
    x_wcs    (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: x
xarray.DataArray
  • x: 50
  • 0.5529 0.561 0.1661 0.06476 0.7125 ... 0.4738 0.1241 0.9304 0.03163
    array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,
           0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,
           0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,
           0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,
           0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,
           0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,
           0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,
           0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,
           0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,
           0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])
    • x_wcs
      (x)
      int64
      0 1 2 3 4 5 6 ... 44 45 46 47 48 49
      slope :
      0.5
      intercept :
      0.0
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
             17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
             34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
1
da.xindexes
Indexes:
    *empty*
1
2
3
da_indexed = da.set_xindex("x_wcs", FunctionalIndex)

da_indexed
<xarray.DataArray (x: 50)>
array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,
       0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,
       0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,
       0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,
       0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,
       0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,
       0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,
       0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,
       0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,
       0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])
Coordinates:
  * x_wcs    (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
Dimensions without coordinates: x
xarray.DataArray
  • x: 50
  • 0.5529 0.561 0.1661 0.06476 0.7125 ... 0.4738 0.1241 0.9304 0.03163
    array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,
           0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,
           0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,
           0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,
           0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,
           0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,
           0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,
           0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,
           0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,
           0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])
    • x_wcs
      (x)
      int64
      0 1 2 3 4 5 6 ... 44 45 46 47 48 49
      slope :
      0.5
      intercept :
      0.0
      array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
             17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
             34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
1
da_indexed.xindexes
Indexes:
x_wcs: <__main__.FunctionalIndex object at 0x164a66280>
1
da_indexed.xindexes["x_wcs"]
<__main__.FunctionalIndex at 0x164a66280>
1
2
3
4
# selection with "world" coordinate labels works!

da_selected = da_indexed.sel(x_wcs=slice(10, 20))
da_selected
<xarray.DataArray (x: 5)>
array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])
Coordinates:
    x_wcs    (x) int64 5 6 7 8 9
Dimensions without coordinates: x
xarray.DataArray
  • x: 5
  • 0.4955 0.5148 0.8615 0.4603 0.08337
    array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])
    • x_wcs
      (x)
      int64
      5 6 7 8 9
      slope :
      0.5
      intercept :
      0.0
      array([5, 6, 7, 8, 9])
1
2
3
# problem: index not propagated

da_selected.xindexes
Indexes:
    *empty*
1
 
1
 
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# implement `create_variables` and `isel`
# in order to propagate the index in the selected dataset!


class FunctionalIndex(xr.Index):
    """Basic 1-dimensional index with linear function."""
    
    def __init__(self, pixel_data, slope, intercept, dim):
        self.pixel_data = pixel_data
        self.slope = slope
        self.intercept = intercept
        
        wmin = pixel_data[0] * slope + intercept
        wmax = pixel_data[-1] * slope + intercept
        self.extent = (wmin, wmax)
        
        self.dim = dim

    @classmethod
    def from_variables(cls, variables, options):
        var = next(iter(variables.values()))
        return cls(
            var.values,
            var.attrs["slope"],
            var.attrs["intercept"],
            var.dims[0]
        )
    
    def create_variables(self, variables):
        name, var = next(iter(variables.items()))
        
        attrs = {"slope": self.slope, "intercept": self.intercept}
        
        new_var = xr.IndexVariable(self.dim, self.pixel_data, attrs=attrs)
        return {name: new_var}
    
    def isel(self, indexers):
        indxr = indexers[self.dim]
        
        new_pixel_data = self.pixel_data[indxr]
        
        if isinstance(indxr, slice):
            return type(self)(new_pixel_data, self.slope, self.intercept, self.dim)
        elif len(indxr) > 1:
            return type(self)(new_pixel_data, self.slope, self.intercept, self.dim)
        else:
            return None
    
    def sel(self, labels):
        label = next(iter(labels.values()))
        
        if not isinstance(label, slice):
            raise TypeError("Selection using this index only works with slices")
        if label.start < self.extent[0] or label.stop > self.extent[1]:
            raise ValueError("Out of bounds selection")
            
        def convert(val):
            return int(self.slope * val + self.intercept)
        
        int_slice = slice(convert(label.start), convert(label.stop))
        
        return IndexSelResult({self.dim: int_slice})
1
da_indexed2 = da.set_xindex("x_wcs", FunctionalIndex)
1
2
da_selected2 = da_indexed2.sel(x_wcs=slice(10, 20))
da_selected2
<xarray.DataArray (x: 5)>
array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])
Coordinates:
  * x_wcs    (x) int64 5 6 7 8 9
Dimensions without coordinates: x
xarray.DataArray
  • x: 5
  • 0.4955 0.5148 0.8615 0.4603 0.08337
    array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])
    • x_wcs
      (x)
      int64
      5 6 7 8 9
      slope :
      0.5
      intercept :
      0.0
      array([5, 6, 7, 8, 9])
1
da_selected2.xindexes
Indexes:
x_wcs: <__main__.FunctionalIndex object at 0x164a552e0>
1
 
1
 
1