Multidimensional lazy-evaluated WCS coordinates in xarray¶
Data used here are available from http://docs.virtualsolar.org/. The two files to download and untar are from:
- http://netdrms01.nispdc.nso.edu/cgi-bin/netdrms/drms_export.cgi?series=aia__lev1;record=94_1283731250-1283731790
- http://netdrms01.nispdc.nso.edu/cgi-bin/netdrms/drms_export.cgi?series=aia__lev1;record=304_1283731244-1283731796
Each is about 300 MB.
Author https://github.com/ianthomas23 with thanks to https://github.com/benbovy and https://github.com/Cadair. Based on prior work https://nbviewer.org/gist/ianthomas23/bbf85d1a38f8f161a2c1f641d9465096 and https://notebooksharing.space/view/2e33c4554e5dfe754306515dbb5f223615ca4f0bbbf54bfbf1494b9417e33d14
1 2 3 4 5 6 7 8 | from astropy.io import fits from astropy.wcs import WCS import dask.array as daskarray from dateutil.parser import isoparse from glob import glob import numpy as np import warnings import xarray as xr |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 | filename_pattern = "/Users/iant/data/panhelio/aia.lev1.{wavelength}A_2017-09-06T0*" wavelengths = [94, 304] max_files_per_wavelength = 10 # Minimal set of attributes needed to transform to WCS is in the following two lists wcs_constant_names = [ # These are the same for each input file 'CTYPE1', 'CTYPE2', # string 'CUNIT1', 'CUNIT2', # string ] wcs_variable_names = [ # These are different for each input file 'CDELT1', 'CDELT2', # float 'CRPIX1', 'CRPIX2', # float 'CRLN_OBS', 'CRLT_OBS', # float 'CROTA2', # float ] wcs_attr_prefix = 'WCS_' # Prefix to identify WCS attributes in Dataset date_attr = 'DATE-OBS' |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 | def get_filenames(): # Return 2D array of string filenames of shape (nwavelength, ntime). nwavelength = len(wavelengths) ntime = max_files_per_wavelength filenames = np.empty((nwavelength, ntime), dtype=object) for i, wavelength in enumerate(wavelengths): filenames[i] = sorted(glob(filename_pattern.format(wavelength=wavelength)))[:ntime] return filenames def fits_to_xarray(filenames): nwavelength, ntime = filenames.shape ny, nx = None, None image, time, wcs_constants, wcs_variables = None, None, None, None for w, wavelength in enumerate(wavelengths): for t in range(ntime): filename = filenames[w, t] with fits.open(filename) as hdul: with warnings.catch_warnings(): warnings.simplefilter('ignore') hdul.verify('fix') header = hdul[1].header data = hdul[1].data if image is None: # Delay creating arrays until know size and dtype of each files' data ny, nx = data.shape image = daskarray.empty((nwavelength, ntime, ny, nx), dtype=data.dtype, chunks=(1, 1, ny, nx)) time2d = np.empty((nwavelength, ntime), dtype='datetime64[s]') wcs_constants = {f'{wcs_attr_prefix}{name}': header[name] for name in wcs_constant_names} wcs_variables = {name: np.empty((nwavelength, ntime), dtype=np.float64) for name in wcs_variable_names} image[w, t] = data time2d[w, t] = isoparse(header[date_attr]).replace(microsecond=0) for name in wcs_variable_names: wcs_variables[name][w, t] = header[name] return image, time2d, wcs_constants, wcs_variables def create_dataset(image, time2d, wcs_constants, wcs_variables): nwavelength, ntime, ny, nx = image.shape x, y, time = np.arange(nx), np.arange(ny), np.arange(ntime) dims_image = dict(wavelength=nwavelength, time=ntime, y=ny, x=nx) dims_2d = dict(wavelength=nwavelength, time=ntime) wcs_variables = {f'{wcs_attr_prefix}{name}': (dims_2d, v) for name, v in wcs_variables.items()} ds = xr.Dataset( coords=dict( wavelength=(['wavelength'], wavelengths), time=(['time'], time), y=(['y'], y), x=(['x'], x), ), data_vars=dict( image=xr.DataArray(data=image, dims=dims_image), time2d=xr.DataArray(data=time2d, dims=dims_2d), **wcs_variables, ), attrs=wcs_constants, ) return ds pixel_to_world_result_dtype = [('x', np.double), ('y', np.double)] def map_blocks_func(x_pixel, y_pixel, wcs_callback, block_info=None): if block_info is not None: chunk_location = block_info[None]['chunk-location'] wavelength_index = chunk_location[0] time_index = chunk_location[1] wcs = wcs_callback(wavelength_index, time_index) x, y = wcs.pixel_to_world_values(x_pixel, y_pixel) x = np.where(x > 180, x-360, x) result = np.empty(x_pixel.shape, dtype=pixel_to_world_result_dtype) result['x'][:] = x result['y'][:] = y return result @xr.register_dataset_accessor('wcs') class WCSAccessor2D: def __init__(self, xarray_obj): self._obj = xarray_obj def _wcs_from_attr(self, wavelength_index, time_index): if wavelength_index < 0 or wavelength_index >= self._obj.dims['wavelength']: raise ValueError(f'Invalid wavelength index {wavelength_index}') if time_index < 0 or time_index >= self._obj.dims['time']: raise ValueError(f'Invalid time index {time_index}') ds = self._obj # Variable attributes. This works if the attribute values are a dask array or not dict_ = ds.data_vars attr = {name[4:]:dict_[name][wavelength_index, time_index].values.item() for name in dict_ if name.startswith(wcs_attr_prefix)} # Constant attributes dict_ = ds.attrs attr |= {name[4:]:dict_[name] for name in dict_ if name.startswith(wcs_attr_prefix)} return WCS(attr) def add_world_coordinates(self, data_array_name): da = self._obj[data_array_name] nwavelength, ntime, ny, nx = da.shape # This sets up the lazy evaluation of WCS coordinates for a chunk that is a single wavelength and time index. x_pixel = daskarray.empty((nwavelength, ntime, ny, nx), dtype=np.float64, chunks=(1, 1, ny, nx)) y_pixel = daskarray.empty((nwavelength, ntime, ny, nx), dtype=np.float64, chunks=(1, 1, ny, nx)) for i in range(nwavelength): # This loop isn't ideal. x_pixel[i, :], y_pixel[i, :] = daskarray.meshgrid(np.arange(nx), np.arange(ny)) res = daskarray.map_blocks( map_blocks_func, x_pixel, y_pixel, self._wcs_from_attr, dtype=pixel_to_world_result_dtype, ) dims_wcs = ('wavelength', 'time', 'y', 'x') return self._obj.assign_coords( y_wcs=(dims_wcs, res['y']), x_wcs=(dims_wcs, res['x']), ) |
Load FITS data into xarray Dataset, without WCS coordinates¶
1 2 3 4 | filenames = get_filenames() image, time2d, wcs_constants, wcs_variables = fits_to_xarray(filenames) ds = create_dataset(image, time2d, wcs_constants, wcs_variables) ds |
<xarray.Dataset> Dimensions: (wavelength: 2, time: 10, y: 4096, x: 4096) Coordinates: * wavelength (wavelength) int64 94 304 * time (time) int64 0 1 2 3 4 5 6 7 8 9 * y (y) int64 0 1 2 3 4 5 6 ... 4089 4090 4091 4092 4093 4094 4095 * x (x) int64 0 1 2 3 4 5 6 ... 4089 4090 4091 4092 4093 4094 4095 Data variables: image (wavelength, time, y, x) int16 dask.array<chunksize=(1, 1, 4096, 4096), meta=np.ndarray> time2d (wavelength, time) datetime64[ns] 2017-09-06T00:00:11 ... 2... WCS_CDELT1 (wavelength, time) float64 0.6001 0.6001 ... 0.6002 0.6002 WCS_CDELT2 (wavelength, time) float64 0.6001 0.6001 ... 0.6002 0.6002 WCS_CRPIX1 (wavelength, time) float64 2.07e+03 2.07e+03 ... 2.07e+03 WCS_CRPIX2 (wavelength, time) float64 2.012e+03 2.012e+03 ... 2.011e+03 WCS_CRLN_OBS (wavelength, time) float64 88.88 88.88 88.87 ... 88.86 88.86 WCS_CRLT_OBS (wavelength, time) float64 7.256 7.256 7.256 ... 7.256 7.256 WCS_CROTA2 (wavelength, time) float64 -0.1377 -0.1376 ... -0.1316 -0.1316 Attributes: WCS_CTYPE1: HPLN-TAN WCS_CTYPE2: HPLT-TAN WCS_CUNIT1: arcsec WCS_CUNIT2: arcsec
Add lazy-evaluation WCS coordinates¶
1 2 | ds = ds.wcs.add_world_coordinates('image') ds |
<xarray.Dataset> Dimensions: (wavelength: 2, time: 10, y: 4096, x: 4096) Coordinates: * wavelength (wavelength) int64 94 304 * time (time) int64 0 1 2 3 4 5 6 7 8 9 * y (y) int64 0 1 2 3 4 5 6 ... 4089 4090 4091 4092 4093 4094 4095 * x (x) int64 0 1 2 3 4 5 6 ... 4089 4090 4091 4092 4093 4094 4095 y_wcs (wavelength, time, y, x) float64 dask.array<chunksize=(1, 1, 4096, 4096), meta=np.ndarray> x_wcs (wavelength, time, y, x) float64 dask.array<chunksize=(1, 1, 4096, 4096), meta=np.ndarray> Data variables: image (wavelength, time, y, x) int16 dask.array<chunksize=(1, 1, 4096, 4096), meta=np.ndarray> time2d (wavelength, time) datetime64[ns] 2017-09-06T00:00:11 ... 2... WCS_CDELT1 (wavelength, time) float64 0.6001 0.6001 ... 0.6002 0.6002 WCS_CDELT2 (wavelength, time) float64 0.6001 0.6001 ... 0.6002 0.6002 WCS_CRPIX1 (wavelength, time) float64 2.07e+03 2.07e+03 ... 2.07e+03 WCS_CRPIX2 (wavelength, time) float64 2.012e+03 2.012e+03 ... 2.011e+03 WCS_CRLN_OBS (wavelength, time) float64 88.88 88.88 88.87 ... 88.86 88.86 WCS_CRLT_OBS (wavelength, time) float64 7.256 7.256 7.256 ... 7.256 7.256 WCS_CROTA2 (wavelength, time) float64 -0.1377 -0.1376 ... -0.1316 -0.1316 Attributes: WCS_CTYPE1: HPLN-TAN WCS_CTYPE2: HPLT-TAN WCS_CUNIT1: arcsec WCS_CUNIT2: arcsec
Note new 4D coordinates x_wcs
and y_wcs
that are dask arrays with one chunk per image.
Try out lazy-evaluating some of the WCS coords:
1 2 | for w, t in [(0, 0), (0, 3), (1, 9)]: print(w, t, ds.x_wcs[w, t].min().compute().item()) |
0 0 -0.34572133103091574 0 3 -0.3457204906230231 1 9 -0.3456356691837641
Visualise using hvplot, panel, datashader¶
1 2 3 | import colorcet as cc import hvplot.xarray import panel as pn |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | cmaps = [cc.fire, cc.bkr] # Separate colormap per wavelength for clarity wavelength_widget = pn.widgets.RadioBoxGroup(options={f"{v} Å": i for i, v in enumerate(ds.wavelength.values)}, inline=True) time_widget = pn.widgets.DiscreteSlider(options={k: v for v, k in enumerate(ds.time2d[0].values)}) setattr(time_widget, "wavelength", 0) def callback(wavelength_index, time_index): # Need to identify what has changed. All calls need to update the image, but only if wavelength has changed does the time need to be updated. if getattr(time_widget, "wavelength") != wavelength_index: time_widget.options={k: v for v, k in enumerate(ds.time2d[wavelength_index].values)} setattr(time_widget, "wavelength", wavelength_index) image = ds.image.isel(wavelength=wavelength_index, time=time_index).compute() title = f"{wavelength_widget.labels[wavelength_index]}: {str(ds.time2d[wavelength_index, time_index].values)[:19]}" return image.hvplot.quadmesh(x="x_wcs", y="y_wcs", aspect=1, frame_height=600, cnorm='eq_hist', cmap=cmaps[wavelength_index], hover=False, rasterize=True, xlim=(-0.34, 0.34), ylim=(-0.34, 0.34), title=title, xlabel="Solar longitude (arcsec)", ylabel="Solar latitude (arcsec)") interactive = pn.bind(callback, wavelength_index=wavelength_widget, time_index=time_widget) pn.Column(pn.Row(time_widget, pn.Column(pn.widgets.StaticText(value='Wavelength'), wavelength_widget)), interactive) |
1 |