{ "cells": [ { "cell_type": "markdown", "id": "4f4d7519-6942-4ae4-a2f8-27aa167b6050", "metadata": {}, "source": [ "## Xarray: propagate bounds coordinate with an IntervalIndex in DataArray\n", "\n", "It should also work with any other coordinate associated with an Xarray index that shares at least one dimension with dataarray's dimensions." ] }, { "cell_type": "code", "execution_count": 1, "id": "779c11f2-2ec2-4c16-98e6-4b601bdf6c61", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "xr.set_options(display_expand_indexes=True);" ] }, { "cell_type": "markdown", "id": "6e5dafab-e774-4380-9f2c-c37a9f3dd262", "metadata": {}, "source": [ "Example of an Xarray IntervalIndex that can be associated to a dimension coordinate and its CF bounds coordinate companion." ] }, { "cell_type": "code", "execution_count": 2, "id": "d5c647f4-b12f-4b22-9351-c201191a452c", "metadata": {}, "outputs": [], "source": [ "from collections.abc import Hashable\n", "\n", "from xarray.core.indexes import Index, PandasIndex\n", "from xarray import Variable\n", "\n", "\n", "class IntervalIndex(Index):\n", " # adapted from https://github.com/dcherian/xindexes/blob/main/interval-array.ipynb\n", "\n", " def __init__(self, index: PandasIndex, bounds_name: Hashable, bounds_dim: str):\n", " assert isinstance(index.index, pd.IntervalIndex)\n", " self._index = index\n", " self._bounds_name = bounds_name\n", " self._bounds_dim = bounds_dim\n", "\n", " @classmethod\n", " def from_variables(cls, variables, options):\n", " assert len(variables) == 2\n", "\n", " for k, v in variables.items():\n", " if v.ndim == 2:\n", " bounds_name, bounds = k, v\n", " elif v.ndim == 1:\n", " dim, _ = k, v\n", "\n", " bounds = bounds.transpose(..., dim)\n", " left, right = bounds.data.tolist()\n", " index = PandasIndex(pd.IntervalIndex.from_arrays(left, right), dim)\n", " bounds_dim = (set(bounds.dims) - set(dim)).pop()\n", " \n", " return cls(index, bounds_name, bounds_dim)\n", "\n", " @classmethod\n", " def concat(self, indexes, dim, positions=None):\n", " new_index = self._index.concat([idx._index for idx in indexes], dim, positions=positions)\n", "\n", " if indexes:\n", " bounds_name0 = indexes[0]._bounds_name\n", " bounds_dim0 = indexes[0]._bounds_dim\n", " if any(idx._bounds_name != bounds_name0 or idx._bounds_dim != bounds_dim0 for idx in indexes):\n", " raise ValueError(\n", " f\"Cannot concatenate along dimension {dim!r} indexes with different \"\n", " \"boundary coordinate or dimension names\"\n", " )\n", "\n", " return type(self)(new_index, self._bounds_name, self._bounds_dim)\n", " \n", " def create_variables(self, variables=None):\n", " empty_var = Variable((), 0)\n", " bounds_attrs = variables.get(self._bounds_name, empty_var).attrs\n", " mid_attrs = variables.get(self._index.dim, empty_var).attrs\n", "\n", " bounds_var = Variable(\n", " dims=(self._bounds_dim, self._index.dim),\n", " data=np.stack([self._index.index.left, self._index.index.right], axis=0),\n", " attrs=bounds_attrs,\n", " )\n", " mid_var = Variable(\n", " dims=(self._index.dim,),\n", " data=self._index.index.mid,\n", " attrs=mid_attrs,\n", " )\n", "\n", " return {self._index.dim: mid_var, self._bounds_name: bounds_var}\n", "\n", " def validate_dataarray_coord(self, name, var, dims):\n", " # check the \"mid\" coordinate is enough here\n", " if var.ndim == 1 and var.dims[0] not in dims:\n", " raise xr.CoordinateValidationError(\n", " f\"interval coordinate {name!r} has dimensions {var.dims}, but these \"\n", " \"are not a subset of the DataArray \"\n", " f\"dimensions {tuple(dims)}\"\n", " )\n", " \n", " def sel(self, labels, **kwargs):\n", " return self._index.sel(labels, **kwargs)\n", "\n", " def isel(self, indexers):\n", " new_index = self._index.isel(indexers)\n", " if new_index is not None:\n", " return type(self)(new_index, self._bounds_name, self._bounds_dim)\n", " else:\n", " return None\n", "\n", " def roll(self, shifts):\n", " new_index = self._index.roll(shifts)\n", " return type(self)(new_index, self._bounds_name, self._bounds_dim)\n", "\n", " def rename(self, name_dict, dims_dict):\n", " new_index = self._index.rename(name_dict, dims_dict)\n", "\n", " bounds_name = name_dict.get(self._bounds_name, self._bounds_name)\n", " bounds_dim = dims_dict.get(self._bounds_dim, self._bounds_dim)\n", " \n", " return type(self)(new_index, bounds_name, bounds_dim)\n", "\n", " def __repr__(self):\n", " string = f\"{self._index!r}\"\n", " return string\n" ] }, { "cell_type": "markdown", "id": "a4ff39b9-b36d-444f-add9-cfd77cea8b78", "metadata": {}, "source": [ "Create an example dataset with an IntervalIndex" ] }, { "cell_type": "code", "execution_count": 3, "id": "1c0ec9e9-234d-4a19-9a9c-23932f14d8a8", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
<xarray.Dataset> Size: 128B\n", "Dimensions: (x: 4, bnds: 2)\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Dimensions without coordinates: bnds\n", "Data variables:\n", " foo (x) int64 32B 1 2 3 4\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (x: 4)> Size: 32B\n", "array([1, 2, 3, 4])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'x' (x: 4)> Size: 32B\n", "array([1., 2., 3., 4.])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (x: 4)> Size: 32B\n", "array([1, 2, 3, 4])\n", "Coordinates:\n", " x (x) float64 32B 1.0 2.0 3.0 4.0
<xarray.DataArray 'foo' (x: 4)> Size: 32B\n", "array([1, 2, 3, 4])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray (x: 4)> Size: 32B\n", "array([0, 1, 2, 3])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'x_bounds' (bnds: 2, x: 4)> Size: 64B\n", "array([[0.5, 1.5, 2.5, 3.5],\n", " [1.5, 2.5, 3.5, 4.5]])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Dimensions without coordinates: bnds\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.Dataset> Size: 128B\n", "Dimensions: (x: 4, bnds: 2)\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Dimensions without coordinates: bnds\n", "Data variables:\n", " foo (x) int64 32B 1 2 3 4\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (x: 2)> Size: 16B\n", "array([1, 3])\n", "Coordinates:\n", " * x (x) float64 16B 1.0 3.0\n", " * x_bounds (bnds, x) float64 32B 0.5 2.5 1.5 3.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (x: 4)> Size: 32B\n", "array([2, 4, 6, 8])\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (x: 4)> Size: 32B\n", "array([4, 1, 2, 3])\n", "Coordinates:\n", " * x (x) float64 32B 4.0 1.0 2.0 3.0\n", " * x_bounds (bnds, x) float64 64B 3.5 0.5 1.5 2.5 4.5 1.5 2.5 3.5\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds
<xarray.DataArray 'foo' (y: 4)> Size: 32B\n", "array([1, 2, 3, 4])\n", "Coordinates:\n", " * y (y) float64 32B 1.0 2.0 3.0 4.0\n", " * y_bounds (bnds, y) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Indexes:\n", " ┌ y IntervalIndex\n", " └ y_bounds
<xarray.Dataset> Size: 128B\n", "Dimensions: (x: 4, bnds: 2)\n", "Coordinates:\n", " * x (x) float64 32B 1.0 2.0 3.0 4.0\n", " * x_bounds (bnds, x) float64 64B 0.5 1.5 2.5 3.5 1.5 2.5 3.5 4.5\n", "Dimensions without coordinates: bnds\n", "Data variables:\n", " foo (x) int64 32B 1 2 3 4\n", "Indexes:\n", " ┌ x IntervalIndex\n", " └ x_bounds