Astropy / Xarray WCS coordinates and index¶
Playing around with Astropy to create fully lazy world coordinates (dask arrays) and index.
Note: this only works with Xarray version 2022.11.0 and above.
Based on https://nbviewer.org/gist/ianthomas23/bbf85d1a38f8f161a2c1f641d9465096 (author: Ian Thomas)
(author: Benoît Bovy https://github.com/benbovy)
1 2 3 4 5 6 7 8 9 10 11 | from astropy.wcs import WCS import colorcet as cc import pandas as pd import dask import dask.array as daskarray import datashader import hvplot.xarray import intake import numpy as np import warnings import xarray as xr |
1 2 3 | #from distributed import Client #client = Client() |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | # workaround for unsupported dask map_blocks with multi-output function # see https://github.com/dask/dask/issues/9510 # the trick here is to define a function wrapper with numpy array record dtype. pixel_to_world_result_dtype = [('x', np.double), ('y', np.double)] def pixel_to_world_wrapper(wcs_method, x_pixel, y_pixel): """Convert pixel to world coordinate values. Pack results as a unique numpy (record) array, so that it can be used with dask array's map_blocks(). """ x_world, y_world = wcs_method(x_pixel, y_pixel) result = np.empty(x_world.size, dtype=pixel_to_world_result_dtype) result['x'] = x_world result['y'] = y_world return result |
Dataset accessor to add pixel and world coordinates.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | @xr.register_dataset_accessor("wcs") class WCSAccessor: def __init__(self, xarray_obj): self._obj = xarray_obj self._wcs = WCS(self._get_wcs_attrs()) def _get_wcs_attrs(self): attrs = self._obj.attrs.copy() attrs.pop("COMMENT", None) # Error about non-ASCII characters if don't do this. return attrs def add_pixel_coords(self, x_dim, y_dim): coords = {} for dim in (x_dim, y_dim): if dim not in self._obj.dims: raise KeyError(f"Dataset has no dimension {dim!r}") coords[dim] = np.arange(self._obj.dims[dim]) return self._obj.assign_coords(coords) def add_world_coords(self, x_dim, y_dim): # Return new xr.Dataset with world coordinates (two x,y coordinates and one # scalar coordinate that hold WCS parameters). # full lazy version nx = self._obj.dims[x_dim] ny = self._obj.dims[y_dim] x_pixel, y_pixel = daskarray.meshgrid(np.arange(nx), np.arange(ny)) result = daskarray.map_blocks( pixel_to_world_wrapper, self._wcs.pixel_to_world_values, x_pixel.ravel(), y_pixel.ravel(), dtype=pixel_to_world_result_dtype, ) x_wcs = result["x"].reshape((ny, nx)) x_wcs = np.where(x_wcs > 180, x_wcs-360, x_wcs) y_wcs = result["y"].reshape((ny, nx)) dims_wcs = ("y", "x") return self._obj.assign_coords( y_wcs=(dims_wcs, y_wcs), x_wcs=(dims_wcs, x_wcs), ref_wcs=((), 0, self._get_wcs_attrs()) ) |
A basic Xarray-compatible index for data selection based on world coordinates.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 | from xarray.indexes import Index from xarray.core.indexes import IndexSelResult class WCSIndex(Index): def __init__(self, xy_names, xy_dims, wcs): self._xy_names = xy_names self._xy_dims = xy_dims if isinstance(wcs, WCS): self._wcs = wcs else: self._wcs = WCS(wcs) @property def wcs(self): return self._wcs @classmethod def from_variables(cls, variables, *, options): if len(variables) != 3: raise ValueError("WCSIndex needs 3 coordinates (x, y, wcs_ref)") # assume first two variables are WCS coords # and last variable holds WCS parameters vnames = list(variables) xy_vnames = vnames[0:2] wcs_vname = vnames[-1] dims = variables[xy_vnames[0]].dims if variables[xy_vnames[1]].dims != dims: raise ValueError("x and y coordinates must have the same dimensions") return cls(xy_vnames, dims, variables[wcs_vname].attrs) def create_variables(self, variables): # explicitly return empty dict so that index coordinates # are treated as regular coordinates return {} def to_pandas_index(self): # hack to make it work with hvplot # (note: hvplot should use the `.xindexes` property instead # of `.indexes` for checking dimensions, etc.) return pd.Index([]) def isel(self, indexers): # just copy the index for now # TODO: check if x and/or y dimension is reduced to a scalar return type(self)(self._xy_names, self._xy_dims, self._wcs) def sel(self, labels, method=None, tolerance=None): # Very very basic!! if method != "nearest": raise ValueError("WCSIndex only supports selection with method='nearest'") # need input labels for both x and y world coordinates xname, yname = self._xy_names if not (xname in labels and yname in labels): raise ValueError("Selection using a WCSIndex requires both x and y labels") # Only supports Xarray advanced (point-wise) indexing # with xarray objects (not sure how to support # orthogonal indexing) is_xr_obj = [ isinstance(label, (xr.DataArray, xr.Variable)) for label in labels.values() ] if not all(is_xr_obj): raise TypeError("WCSIndex only supports advanced (point-wise) indexing") x_idx, y_idx = self.wcs.world_to_array_index_values( labels[xname].values, labels[yname].values ) x_dim, y_dim = self._xy_dims results = {} seq = [(x_dim, labels[xname], x_idx), (y_dim, labels[yname], y_idx)] for dim, label, idx in seq: if isinstance(label, xr.Variable): idx = xr.Variable(label.dims, idx) else: # dataarray idx = xr.DataArray(idx, dims=label.dims) results[dim] = idx return IndexSelResult(results) |
Load example dataset (lazy).
1 2 3 4 | cat = intake.open_catalog("https://github.com/fsspec/kerchunk/raw/main/examples/intake_catalog.yml") with warnings.catch_warnings(): warnings.simplefilter("ignore") dataset = cat.SDO.to_dask() |
Add pixel and world coordinates. Set WCSIndex
for world coordinates.
1 2 3 4 5 6 7 8 | dataset = ( dataset .wcs.add_pixel_coords("x", "y") .wcs.add_world_coords("x", "y") .set_xindex(["x_wcs", "y_wcs", "ref_wcs"], WCSIndex) ) dataset |
<xarray.Dataset> Dimensions: (DATE-OBS: 1800, y: 4096, x: 4096) Coordinates: * DATE-OBS (DATE-OBS) datetime64[us] 2012-09-23T00:00:01 ... 2012-09-23T05... * x (x) int64 0 1 2 3 4 5 6 7 ... 4089 4090 4091 4092 4093 4094 4095 * y (y) int64 0 1 2 3 4 5 6 7 ... 4089 4090 4091 4092 4093 4094 4095 * x_wcs (y, x) float64 dask.array<chunksize=(4096, 4096), meta=np.ndarray> * y_wcs (y, x) float64 dask.array<chunksize=(4096, 4096), meta=np.ndarray> * ref_wcs int64 0 Data variables: 094 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 131 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 171 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 193 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 211 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 304 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> 334 (DATE-OBS, y, x) int16 dask.array<chunksize=(1, 4096, 4096), meta=np.ndarray> Indexes: x_wcs WCSIndex y_wcs WCSIndex ref_wcs WCSIndex Attributes: (12/185) ACS_CGT: GT3 ACS_ECLP: NO ACS_MODE: SCIENCE ACS_SAFE: NO ACS_SUNP: YES AECDELAY: 1535 ... ... T_OBS: 2012-09-23T00:00:02.57Z T_REC: 2012-09-23T00:00:03Z WAVEUNIT: angstrom WAVE_STR: 94_THIN X0_MP: 2059.610107 Y0_MP: 2039.26001
1 | dataset.xindexes |
Indexes: DATE-OBS PandasIndex x PandasIndex y PandasIndex x_wcs WCSIndex y_wcs WCSIndex ref_wcs WCSIndex
1 | dataset.xindexes["ref_wcs"].wcs |
WCS Keywords Number of WCS axes: 2 CTYPE : 'HPLN-TAN' 'HPLT-TAN' CRVAL : 0.0 0.0 CRPIX : 2060.610107 2040.26001 NAXIS : 4096 4096
Select single image.
1 | da = dataset["094"][0] |
1 | #da = da.persist() # Force download of single image array to see how long it takes.
|
Selection using pixel coordinates.
1 2 3 | da_pixel_selected = da.sel(x=slice(0, 100), y=slice(0, 100)) da_pixel_selected |
<xarray.DataArray '094' (y: 101, x: 101)> dask.array<getitem, shape=(101, 101), dtype=int16, chunksize=(101, 101), chunktype=numpy.ndarray> Coordinates: DATE-OBS datetime64[ns] 2012-09-23T00:00:01 * x (x) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99 100 * y (y) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99 100 * x_wcs (y, x) float64 dask.array<chunksize=(101, 101), meta=np.ndarray> * y_wcs (y, x) float64 dask.array<chunksize=(101, 101), meta=np.ndarray> * ref_wcs int64 0 Indexes: x_wcs WCSIndex y_wcs WCSIndex ref_wcs WCSIndex
Selection using world coordinates (this triggers the computation of the world coordinates x_wcs
and y_wcs
values, not sure Xarray supports lazy computation for this case yet).
1 2 3 4 5 6 7 | xw = xr.Variable("points", [-0.1, 0.1]) yw = xr.Variable("points", [0.1, 0.2]) da_world_selected = da.sel(x_wcs=xw, y_wcs=yw, method="nearest") # Load world coordinates and compare their values with the selection da_world_selected.load() |
<xarray.DataArray '094' (points: 2)> array([ 0, -1], dtype=int16) Coordinates: DATE-OBS datetime64[ns] 2012-09-23T00:00:01 x (points) int64 1458 2657 y (points) int64 2638 3240 * x_wcs (points) float64 -0.1 0.1001 * y_wcs (points) float64 0.1 0.1999 * ref_wcs int64 0 Dimensions without coordinates: points Indexes: x_wcs WCSIndex y_wcs WCSIndex ref_wcs WCSIndex
Plot image.
1 2 3 4 5 6 | da.hvplot.quadmesh( x='x_wcs', y='y_wcs', rasterize=True, aspect=1, frame_height=500, cmap=cc.fire, cnorm="linear", clim=(0, 40), # Somewhat arbitrary color limits. #cnorm="eq_hist", ) |
1 |