{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Xarray-compatible \"functional\" index (demo)\n", "\n", "Notes:\n", "\n", "This currently works with https://github.com/pydata/xarray/pull/6971 only!\n", "\n", "WIP documentation on implementing custom indexes: https://github.com/pydata/xarray/pull/6975\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "from xarray.core.indexes import IndexSelResult" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class FunctionalIndex(xr.Index):\n", " \"\"\"Basic 1-dimensional index with linear function.\"\"\"\n", " \n", " def __init__(self, pixel_data, slope, intercept, dim):\n", " self.pixel_data = pixel_data\n", " self.slope = slope\n", " self.intercept = intercept\n", " \n", " wmin = pixel_data[0] * slope + intercept\n", " wmax = pixel_data[-1] * slope + intercept\n", " self.extent = (wmin, wmax)\n", " \n", " self.dim = dim\n", " \n", " @classmethod\n", " def from_variables(cls, variables, options):\n", " # We should check the validity of the given variables!\n", "\n", " var = next(iter(variables.values()))\n", " return cls(\n", " var.values,\n", " var.attrs[\"slope\"],\n", " var.attrs[\"intercept\"],\n", " var.dims[0]\n", " )\n", " \n", " def sel(self, labels):\n", " # This implementation only works with slices!\n", "\n", " label = next(iter(labels.values()))\n", " \n", " if not isinstance(label, slice):\n", " raise TypeError(\"Selection using this index only works with slices\")\n", " if label.start < self.extent[0] or label.stop > self.extent[1]:\n", " raise ValueError(\"Out of bounds selection\")\n", " \n", " def convert(val):\n", " return int(self.slope * val + self.intercept)\n", " \n", " int_slice = slice(convert(label.start), convert(label.stop))\n", " \n", " return IndexSelResult({self.dim: int_slice})\n", "\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "wcs_attrs = {\"slope\": 0.5, \"intercept\": 0.0}\n", "\n", "da = xr.DataArray(\n", " np.random.uniform(size=50),\n", " coords={\"x_wcs\": (\"x\", range(50), wcs_attrs)},\n", " dims=\"x\"\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 50)>\n",
       "array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,\n",
       "       0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,\n",
       "       0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,\n",
       "       0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,\n",
       "       0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,\n",
       "       0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,\n",
       "       0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,\n",
       "       0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,\n",
       "       0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,\n",
       "       0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])\n",
       "Coordinates:\n",
       "    x_wcs    (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49\n",
       "Dimensions without coordinates: x
" ], "text/plain": [ "\n", "array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,\n", " 0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,\n", " 0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,\n", " 0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,\n", " 0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,\n", " 0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,\n", " 0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,\n", " 0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,\n", " 0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,\n", " 0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])\n", "Coordinates:\n", " x_wcs (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49\n", "Dimensions without coordinates: x" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", " *empty*" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da.xindexes" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 50)>\n",
       "array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,\n",
       "       0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,\n",
       "       0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,\n",
       "       0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,\n",
       "       0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,\n",
       "       0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,\n",
       "       0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,\n",
       "       0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,\n",
       "       0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,\n",
       "       0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])\n",
       "Coordinates:\n",
       "  * x_wcs    (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49\n",
       "Dimensions without coordinates: x
" ], "text/plain": [ "\n", "array([0.55286327, 0.56095464, 0.16611634, 0.06475896, 0.7124942 ,\n", " 0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805,\n", " 0.48883895, 0.27718464, 0.12142079, 0.98683591, 0.87797165,\n", " 0.70290819, 0.00964021, 0.83901826, 0.95960231, 0.23635901,\n", " 0.67194208, 0.7097981 , 0.1988284 , 0.29785769, 0.6247646 ,\n", " 0.32563944, 0.20781743, 0.51757803, 0.04724169, 0.95930763,\n", " 0.55790878, 0.00819702, 0.5159818 , 0.8344841 , 0.66735477,\n", " 0.36696268, 0.87265844, 0.43291816, 0.65711125, 0.36713875,\n", " 0.23912852, 0.32327912, 0.83723518, 0.51004308, 0.0257541 ,\n", " 0.18629893, 0.47381572, 0.12414299, 0.9304311 , 0.03163012])\n", "Coordinates:\n", " * x_wcs (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49\n", "Dimensions without coordinates: x" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_indexed = da.set_xindex(\"x_wcs\", FunctionalIndex)\n", "\n", "da_indexed" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", "x_wcs: <__main__.FunctionalIndex object at 0x164a66280>" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_indexed.xindexes" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<__main__.FunctionalIndex at 0x164a66280>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_indexed.xindexes[\"x_wcs\"]" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 5)>\n",
       "array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])\n",
       "Coordinates:\n",
       "    x_wcs    (x) int64 5 6 7 8 9\n",
       "Dimensions without coordinates: x
" ], "text/plain": [ "\n", "array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])\n", "Coordinates:\n", " x_wcs (x) int64 5 6 7 8 9\n", "Dimensions without coordinates: x" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# selection with \"world\" coordinate labels works!\n", "\n", "da_selected = da_indexed.sel(x_wcs=slice(10, 20))\n", "da_selected" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", " *empty*" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# problem: index not propagated\n", "\n", "da_selected.xindexes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# implement `create_variables` and `isel`\n", "# in order to propagate the index in the selected dataset!\n", "\n", "\n", "class FunctionalIndex(xr.Index):\n", " \"\"\"Basic 1-dimensional index with linear function.\"\"\"\n", " \n", " def __init__(self, pixel_data, slope, intercept, dim):\n", " self.pixel_data = pixel_data\n", " self.slope = slope\n", " self.intercept = intercept\n", " \n", " wmin = pixel_data[0] * slope + intercept\n", " wmax = pixel_data[-1] * slope + intercept\n", " self.extent = (wmin, wmax)\n", " \n", " self.dim = dim\n", "\n", " @classmethod\n", " def from_variables(cls, variables, options):\n", " var = next(iter(variables.values()))\n", " return cls(\n", " var.values,\n", " var.attrs[\"slope\"],\n", " var.attrs[\"intercept\"],\n", " var.dims[0]\n", " )\n", " \n", " def create_variables(self, variables):\n", " name, var = next(iter(variables.items()))\n", " \n", " attrs = {\"slope\": self.slope, \"intercept\": self.intercept}\n", " \n", " new_var = xr.IndexVariable(self.dim, self.pixel_data, attrs=attrs)\n", " return {name: new_var}\n", " \n", " def isel(self, indexers):\n", " indxr = indexers[self.dim]\n", " \n", " new_pixel_data = self.pixel_data[indxr]\n", " \n", " if isinstance(indxr, slice):\n", " return type(self)(new_pixel_data, self.slope, self.intercept, self.dim)\n", " elif len(indxr) > 1:\n", " return type(self)(new_pixel_data, self.slope, self.intercept, self.dim)\n", " else:\n", " return None\n", " \n", " def sel(self, labels):\n", " label = next(iter(labels.values()))\n", " \n", " if not isinstance(label, slice):\n", " raise TypeError(\"Selection using this index only works with slices\")\n", " if label.start < self.extent[0] or label.stop > self.extent[1]:\n", " raise ValueError(\"Out of bounds selection\")\n", " \n", " def convert(val):\n", " return int(self.slope * val + self.intercept)\n", " \n", " int_slice = slice(convert(label.start), convert(label.stop))\n", " \n", " return IndexSelResult({self.dim: int_slice})\n", "\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "da_indexed2 = da.set_xindex(\"x_wcs\", FunctionalIndex)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
<xarray.DataArray (x: 5)>\n",
       "array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])\n",
       "Coordinates:\n",
       "  * x_wcs    (x) int64 5 6 7 8 9\n",
       "Dimensions without coordinates: x
" ], "text/plain": [ "\n", "array([0.49553549, 0.51480291, 0.86151124, 0.46034642, 0.08336805])\n", "Coordinates:\n", " * x_wcs (x) int64 5 6 7 8 9\n", "Dimensions without coordinates: x" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_selected2 = da_indexed2.sel(x_wcs=slice(10, 20))\n", "da_selected2" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Indexes:\n", "x_wcs: <__main__.FunctionalIndex object at 0x164a552e0>" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "da_selected2.xindexes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:xarray_dev]", "language": "python", "name": "conda-env-xarray_dev-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }