Conservative Region Aggregation with Xarray, Geopandas and Sparse¶
Goal: Regrid a global precipitation dataset into countries conservatively, i.e. by exactly partitioning each grid cell into the precise region boundaries.
Meta Goal: Demonstrate that we don't necessarily need a package for this workflow and showcase some of the new capabilities of GeoPandas along the way.
Approach: We take a three step approach:
- Represent both the original grid and target grid as GeoSeries with Polygon geometry
- Compute their area overlay and turn it into a sparse matrix
- Perform matrix multiplication on the full Xarray dataset (with a time dimension)
It is quite fast and transparent.
1 2 3 4 5 | import xarray as xr import geopandas as gp import pandas as pd import sparse %xmode minimal |
Exception reporting mode: Minimal
Load Region Data¶
To make this realistic, we will start from an actual shapefile. To download the data, run the following cell uncommented.
1 2 | # ! wget https://www.naturalearthdata.com/http//www.naturalearthdata.com/download/50m/cultural/ne_50m_admin_0_countries.zip # ! unzip ne_50m_admin_0_countries.zip |
Load with geopandas:
1 2 | regions_df = gp.read_file("ne_50m_admin_0_countries.shp") regions_df |
featurecla | scalerank | LABELRANK | SOVEREIGNT | SOV_A3 | ADM0_DIF | LEVEL | TYPE | TLC | ADMIN | ... | FCLASS_TR | FCLASS_ID | FCLASS_PL | FCLASS_GR | FCLASS_IT | FCLASS_NL | FCLASS_SE | FCLASS_BD | FCLASS_UA | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Admin-0 country | 1 | 3 | Zimbabwe | ZWE | 0 | 2 | Sovereign country | 1 | Zimbabwe | ... | None | None | None | None | None | None | None | None | None | POLYGON ((31.28789 -22.40205, 31.19727 -22.344... |
1 | Admin-0 country | 1 | 3 | Zambia | ZMB | 0 | 2 | Sovereign country | 1 | Zambia | ... | None | None | None | None | None | None | None | None | None | POLYGON ((30.39609 -15.64307, 30.25068 -15.643... |
2 | Admin-0 country | 1 | 3 | Yemen | YEM | 0 | 2 | Sovereign country | 1 | Yemen | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((53.08564 16.64839, 52.58145 16... |
3 | Admin-0 country | 3 | 2 | Vietnam | VNM | 0 | 2 | Sovereign country | 1 | Vietnam | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((104.06396 10.39082, 104.08301 ... |
4 | Admin-0 country | 5 | 3 | Venezuela | VEN | 0 | 2 | Sovereign country | 1 | Venezuela | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((-60.82119 9.13838, -60.94141 9... |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
237 | Admin-0 country | 1 | 3 | Afghanistan | AFG | 0 | 2 | Sovereign country | 1 | Afghanistan | ... | None | None | None | None | None | None | None | None | None | POLYGON ((66.52227 37.34849, 66.82773 37.37129... |
238 | Admin-0 country | 1 | 5 | Kashmir | KAS | 0 | 2 | Indeterminate | None | Siachen Glacier | ... | Unrecognized | Unrecognized | Unrecognized | Unrecognized | Unrecognized | Unrecognized | Unrecognized | Unrecognized | Unrecognized | POLYGON ((77.04863 35.10991, 77.00449 35.19634... |
239 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | 1 | Antarctica | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((-45.71777 -60.52090, -45.49971... |
240 | Admin-0 country | 3 | 6 | Netherlands | NL1 | 1 | 2 | Country | 1 | Sint Maarten | ... | None | None | None | None | None | None | None | None | None | POLYGON ((-63.12305 18.06895, -63.01118 18.068... |
241 | Admin-0 country | 5 | 6 | Tuvalu | TUV | 0 | 2 | Sovereign country | 1 | Tuvalu | ... | None | None | None | None | None | None | None | None | None | POLYGON ((179.21367 -8.52422, 179.20059 -8.534... |
242 rows × 169 columns
All geodataframes should have a coordinate reference system. This is important (and sometimes unfamiliar to users coming from the global climate world).
1 | regions_df.crs |
<Geographic 2D CRS: EPSG:4326> Name: WGS 84 Axis Info [ellipsoidal]: - Lat[north]: Geodetic latitude (degree) - Lon[east]: Geodetic longitude (degree) Area of Use: - name: World. - bounds: (-180.0, -90.0, 180.0, 90.0) Datum: World Geodetic System 1984 ensemble - Ellipsoid: WGS 84 - Prime Meridian: Greenwich
1 | crs_orig = "EPSG:4326" |
We will now transform to an area preserving projection. This is imporant because we want to do area-weighted regridding.
1 2 3 | # use an area preserving projections crs = "ESRI:53034" regions_df = regions_df.to_crs(crs) |
Explore the dataset a bit.
1 | regions_df.geometry.plot() |
<AxesSubplot:>
1 | regions_df.geometry.iloc[0] |
This is the area of all the land on earth:
1 | regions_df.area.sum() |
146631193254209.66
Load Precipitation Data¶
This is the NASA / NOAA Global Precipitation Climatology Project processed and stored by Pangeo Forge: https://pangeo-forge.org/dashboard/feedstock/42. We want to aggregate the precipitation into countries defined in the shapefile.
Note that we have the entire dataset in one object:
1 2 3 | store = 'https://ncsa.osn.xsede.org/Pangeo/pangeo-forge/gpcp-feedstock/gpcp.zarr' ds = xr.open_dataset(store, engine='zarr', chunks={}) ds |
<xarray.Dataset> Dimensions: (latitude: 180, nv: 2, longitude: 360, time: 9226) Coordinates: lat_bounds (latitude, nv) float32 dask.array<chunksize=(180, 2), meta=np.ndarray> * latitude (latitude) float32 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 lon_bounds (longitude, nv) float32 dask.array<chunksize=(360, 2), meta=np.ndarray> * longitude (longitude) float32 0.0 1.0 2.0 3.0 ... 356.0 357.0 358.0 359.0 * time (time) datetime64[ns] 1996-10-01 1996-10-02 ... 2021-12-31 time_bounds (time, nv) datetime64[ns] dask.array<chunksize=(200, 2), meta=np.ndarray> Dimensions without coordinates: nv Data variables: precip (time, latitude, longitude) float32 dask.array<chunksize=(200, 180, 360), meta=np.ndarray> Attributes: (12/45) Conventions: CF-1.6, ACDD 1.3 Metadata_Conventions: CF-1.6, Unidata Dataset Discovery v1.0, NOAA ... acknowledgment: This project was supported in part by a grant... cdm_data_type: Grid cdr_program: NOAA Climate Data Record Program for satellit... cdr_variable: precipitation ... ... standard_name_vocabulary: CF Standard Name Table (v41, 22 February 2017) summary: Global Precipitation Climatology Project (GPC... time_coverage_duration: P1D time_coverage_end: 1996-10-01T23:59:59Z time_coverage_start: 1996-10-01T00:00:00Z title: Global Precipitation Climatatology Project (G...
Now we extract just the horizontal grid information. The dataset has information about the lat and lon bounds of each cell, which we need to create the polygons.
1 2 | grid = ds.drop(['time', 'time_bounds', 'precip']).reset_coords().load() grid |
<xarray.Dataset> Dimensions: (latitude: 180, nv: 2, longitude: 360) Coordinates: * latitude (latitude) float32 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 * longitude (longitude) float32 0.0 1.0 2.0 3.0 ... 356.0 357.0 358.0 359.0 Dimensions without coordinates: nv Data variables: lat_bounds (latitude, nv) float32 -90.0 -89.0 -89.0 ... 89.0 89.0 90.0 lon_bounds (longitude, nv) float32 0.0 1.0 1.0 2.0 ... 359.0 359.0 360.0 Attributes: (12/45) Conventions: CF-1.6, ACDD 1.3 Metadata_Conventions: CF-1.6, Unidata Dataset Discovery v1.0, NOAA ... acknowledgment: This project was supported in part by a grant... cdm_data_type: Grid cdr_program: NOAA Climate Data Record Program for satellit... cdr_variable: precipitation ... ... standard_name_vocabulary: CF Standard Name Table (v41, 22 February 2017) summary: Global Precipitation Climatology Project (GPC... time_coverage_duration: P1D time_coverage_end: 1996-10-01T23:59:59Z time_coverage_start: 1996-10-01T00:00:00Z title: Global Precipitation Climatatology Project (G...
Now we "stack" the data into a single 1D array. This is the first step towards transitioning to pandas.
1 2 | points = grid.stack(point=("latitude", "longitude")) points |
<xarray.Dataset> Dimensions: (nv: 2, point: 64800) Coordinates: * point (point) MultiIndex - latitude (point) float64 -90.0 -90.0 -90.0 -90.0 ... 89.0 89.0 89.0 89.0 - longitude (point) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0 Dimensions without coordinates: nv Data variables: lat_bounds (nv, point) float32 -90.0 -90.0 -90.0 -90.0 ... 90.0 90.0 90.0 lon_bounds (nv, point) float32 0.0 1.0 2.0 3.0 ... 357.0 358.0 359.0 360.0 Attributes: (12/45) Conventions: CF-1.6, ACDD 1.3 Metadata_Conventions: CF-1.6, Unidata Dataset Discovery v1.0, NOAA ... acknowledgment: This project was supported in part by a grant... cdm_data_type: Grid cdr_program: NOAA Climate Data Record Program for satellit... cdr_variable: precipitation ... ... standard_name_vocabulary: CF Standard Name Table (v41, 22 February 2017) summary: Global Precipitation Climatology Project (GPC... time_coverage_duration: P1D time_coverage_end: 1996-10-01T23:59:59Z time_coverage_start: 1996-10-01T00:00:00Z title: Global Precipitation Climatatology Project (G...
This function creates geometries for a single pair of bounds. It is not fast, but it is fast enough here. Perhaps could be vectorized using pygeos...
1 2 3 4 5 6 7 8 9 10 11 12 | from shapely.geometry import Polygon def bounds_to_poly(lon_bounds, lat_bounds): if lon_bounds[0] >= 180: # geopandas needs this lon_bounds = lon_bounds - 360 return Polygon([ (lon_bounds[0], lat_bounds[0]), (lon_bounds[0], lat_bounds[1]), (lon_bounds[1], lat_bounds[1]), (lon_bounds[1], lat_bounds[0]) ]) |
We apply this function to each grid cell.
1 2 3 4 5 6 7 8 9 10 11 | %%time import numpy as np boxes = xr.apply_ufunc( bounds_to_poly, points.lon_bounds, points.lat_bounds, input_core_dims=[("nv",), ("nv",)], output_dtypes=[np.dtype('O')], vectorize=True ) boxes |
CPU times: user 1.14 s, sys: 43.4 ms, total: 1.18 s Wall time: 1.16 s
<xarray.DataArray (point: 64800)> array([<shapely.geometry.polygon.Polygon object at 0x7fb5449d8c70>, <shapely.geometry.polygon.Polygon object at 0x7fb5449d8fd0>, <shapely.geometry.polygon.Polygon object at 0x7fb5449d8550>, ..., <shapely.geometry.polygon.Polygon object at 0x7fb535d69ac0>, <shapely.geometry.polygon.Polygon object at 0x7fb535d69af0>, <shapely.geometry.polygon.Polygon object at 0x7fb535d69b20>], dtype=object) Coordinates: * point (point) MultiIndex - latitude (point) float64 -90.0 -90.0 -90.0 -90.0 ... 89.0 89.0 89.0 89.0 - longitude (point) float64 0.0 1.0 2.0 3.0 4.0 ... 356.0 357.0 358.0 359.0
Finally, we convert to a GeoDataframe.
We specify the CRS as EPSG:4326
because the geometry is in lat/lon coordinates.
1 2 3 4 5 6 | grid_df= gp.GeoDataFrame( data={"geometry": boxes.values, "latitude": boxes.latitude, "longitude": boxes.longitude}, index=boxes.indexes["point"], crs=crs_orig ) grid_df |
geometry | latitude | longitude | ||
---|---|---|---|---|
latitude | longitude | |||
-90.0 | 0.0 | POLYGON ((0.00000 -90.00000, 0.00000 -89.00000... | -90.0 | 0.0 |
1.0 | POLYGON ((1.00000 -90.00000, 1.00000 -89.00000... | -90.0 | 1.0 | |
2.0 | POLYGON ((2.00000 -90.00000, 2.00000 -89.00000... | -90.0 | 2.0 | |
3.0 | POLYGON ((3.00000 -90.00000, 3.00000 -89.00000... | -90.0 | 3.0 | |
4.0 | POLYGON ((4.00000 -90.00000, 4.00000 -89.00000... | -90.0 | 4.0 | |
... | ... | ... | ... | ... |
89.0 | 355.0 | POLYGON ((-5.00000 89.00000, -5.00000 90.00000... | 89.0 | 355.0 |
356.0 | POLYGON ((-4.00000 89.00000, -4.00000 90.00000... | 89.0 | 356.0 | |
357.0 | POLYGON ((-3.00000 89.00000, -3.00000 90.00000... | 89.0 | 357.0 | |
358.0 | POLYGON ((-2.00000 89.00000, -2.00000 90.00000... | 89.0 | 358.0 | |
359.0 | POLYGON ((-1.00000 89.00000, -1.00000 90.00000... | 89.0 | 359.0 |
64800 rows × 3 columns
1 | grid_df.crs |
<Geographic 2D CRS: EPSG:4326> Name: WGS 84 Axis Info [ellipsoidal]: - Lat[north]: Geodetic latitude (degree) - Lon[east]: Geodetic longitude (degree) Area of Use: - name: World. - bounds: (-180.0, -90.0, 180.0, 90.0) Datum: World Geodetic System 1984 ensemble - Ellipsoid: WGS 84 - Prime Meridian: Greenwich
Now we transform to the area-preserving CRS.
1 2 | grid_df = grid_df.to_crs(crs) grid_df.crs |
<Derived Projected CRS: ESRI:53034> Name: Sphere_Cylindrical_Equal_Area Axis Info [cartesian]: - E[east]: Easting (metre) - N[north]: Northing (metre) Area of Use: - name: World. - bounds: (-180.0, -90.0, 180.0, 90.0) Coordinate Operation: - name: Sphere_Cylindrical_Equal_Area - method: Lambert Cylindrical Equal Area (Spherical) Datum: Not specified (based on Authalic Sphere) - Ellipsoid: Sphere - Prime Meridian: Greenwich
Plotting just shows the glob completely covered by grid boxes.
1 | grid_df.geometry.plot() |
<AxesSubplot:>
The total area matches the expected area of the globe.
1 | grid_df.geometry.area.sum() |
510064471909788.25
Key Step: Overlay the two geometries¶
This is the magic of geopandas; it can calculate the overlap between the original grid and the regions. It is expensive because it has to compare 64800 grid boxes with 242 regions.
In this dataframe, the latitude
and longitude
values are from the grid, while all the other columns are from the regions.
1 2 | %time overlay = grid_df.overlay(regions_df) overlay |
CPU times: user 13.4 s, sys: 72.7 ms, total: 13.5 s Wall time: 13.5 s
/srv/conda/envs/notebook/lib/python3.9/site-packages/geopandas/geodataframe.py:2196: UserWarning: `keep_geom_type=True` in overlay resulted in 1 dropped geometries of different geometry types than df1 has. Set `keep_geom_type=False` to retain all geometries return geopandas.overlay(
latitude | longitude | featurecla | scalerank | LABELRANK | SOVEREIGNT | SOV_A3 | ADM0_DIF | LEVEL | TYPE | ... | FCLASS_TR | FCLASS_ID | FCLASS_PL | FCLASS_GR | FCLASS_IT | FCLASS_NL | FCLASS_SE | FCLASS_BD | FCLASS_UA | geometry | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -90.0 | 0.0 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | ... | None | None | None | None | None | None | None | None | None | POLYGON ((0.000 -6370029.666, 111194.927 -6370... |
1 | -90.0 | 1.0 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | ... | None | None | None | None | None | None | None | None | None | POLYGON ((111194.927 -6370029.666, 222389.853 ... |
2 | -90.0 | 2.0 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | ... | None | None | None | None | None | None | None | None | None | POLYGON ((222389.853 -6370029.666, 333584.780 ... |
3 | -90.0 | 3.0 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | ... | None | None | None | None | None | None | None | None | None | POLYGON ((333584.780 -6370029.666, 444779.707 ... |
4 | -90.0 | 4.0 | Admin-0 country | 3 | 4 | Antarctica | ATA | 0 | 2 | Indeterminate | ... | None | None | None | None | None | None | None | None | None | POLYGON ((444779.707 -6370029.666, 555974.633 ... |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
27634 | 66.0 | 341.0 | Admin-0 country | 1 | 3 | Iceland | ISL | 0 | 2 | Sovereign country | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((-2001508.680 5820198.111, -202... |
27635 | 66.0 | 342.0 | Admin-0 country | 1 | 3 | Iceland | ISL | 0 | 2 | Sovereign country | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((-1890313.753 5820198.111, -194... |
27636 | 66.0 | 343.0 | Admin-0 country | 1 | 3 | Iceland | ISL | 0 | 2 | Sovereign country | ... | None | None | None | None | None | None | None | None | None | POLYGON ((-1890313.753 5828182.772, -1886925.7... |
27637 | 66.0 | 344.0 | Admin-0 country | 1 | 3 | Iceland | ISL | 0 | 2 | Sovereign country | ... | None | None | None | None | None | None | None | None | None | POLYGON ((-1779118.826 5843385.413, -1777495.4... |
27638 | 66.0 | 345.0 | Admin-0 country | 1 | 3 | Iceland | ISL | 0 | 2 | Sovereign country | ... | None | None | None | None | None | None | None | None | None | MULTIPOLYGON (((-1667923.900 5835808.462, -166... |
27639 rows × 171 columns
This is essentially already a sparse matrix mapping one grid space to the other. How sparse?
1 2 | sparsity = len(overlay) / (len(grid_df) * len(regions_df)) sparsity |
0.0017625114784205692
Let's explore these overlays a little bit
1 | overlay[overlay.SOVEREIGNT == "Italy"].geometry.plot(edgecolor='k') |
<AxesSubplot:>
As we can see, filtering by country shows each of the grid boxes, partioned exactly on top of the country geometry. Plotting them all gives us back all the land.
1 | overlay.plot() |
<AxesSubplot:>
We can verify that each country's area is preserved in the overlay operation.
1 | overlay.geometry.area.groupby(overlay.SOVEREIGNT).sum().nlargest(10) |
SOVEREIGNT Russia 1.687963e+13 Antarctica 1.225664e+13 Canada 9.872058e+12 United States of America 9.449645e+12 China 9.372241e+12 Brazil 8.499337e+12 Australia 7.706575e+12 India 3.155821e+12 Argentina 2.782670e+12 Kazakhstan 2.712770e+12 dtype: float64
1 | regions_df.geometry.area.groupby(regions_df.SOVEREIGNT).sum().nlargest(10) |
SOVEREIGNT Russia 1.687963e+13 Antarctica 1.225664e+13 Canada 9.872058e+12 United States of America 9.449645e+12 China 9.372241e+12 Brazil 8.499337e+12 Australia 7.706575e+12 India 3.155821e+12 Argentina 2.782670e+12 Kazakhstan 2.712770e+12 dtype: float64
Calculate the area fraction for each region¶
This is another key step. This transform tells us how much of a country's total area comes from each of the grid cells. This is accurate because we used an area-preserving CRS.
1 2 | grid_cell_fraction = overlay.geometry.area.groupby(overlay.SOVEREIGNT).transform(lambda x: x / x.sum()) grid_cell_fraction |
0 0.000009 1 0.000009 2 0.000009 3 0.000009 4 0.000009 ... 27634 0.005459 27635 0.004817 27636 0.016131 27637 0.015022 27638 0.001937 Length: 27639, dtype: float64
We can verify that these all sum up to one.
1 | grid_cell_fraction.groupby(overlay.SOVEREIGNT).sum() |
SOVEREIGNT Afghanistan 1.0 Albania 1.0 Algeria 1.0 Andorra 1.0 Angola 1.0 ... Western Sahara 1.0 Yemen 1.0 Zambia 1.0 Zimbabwe 1.0 eSwatini 1.0 Length: 201, dtype: float64
Turn this into a sparse Xarray DataArray¶
The first step is making a MultIndex
1 2 3 | multi_index = overlay.set_index(["latitude", "longitude", "SOVEREIGNT"]).index df_weights = pd.DataFrame({"weights": grid_cell_fraction.values}, index=multi_index) df_weights |
weights | |||
---|---|---|---|
latitude | longitude | SOVEREIGNT | |
-90.0 | 0.0 | Antarctica | 0.000009 |
1.0 | Antarctica | 0.000009 | |
2.0 | Antarctica | 0.000009 | |
3.0 | Antarctica | 0.000009 | |
4.0 | Antarctica | 0.000009 | |
... | ... | ... | ... |
66.0 | 341.0 | Iceland | 0.005459 |
342.0 | Iceland | 0.004817 | |
343.0 | Iceland | 0.016131 | |
344.0 | Iceland | 0.015022 | |
345.0 | Iceland | 0.001937 |
27639 rows × 1 columns
We can bring this directly into xarray as a 1D Dataset.
1 2 3 | import xarray as xr ds_weights = xr.Dataset(df_weights) ds_weights |
<xarray.Dataset> Dimensions: (dim_0: 27639) Coordinates: * dim_0 (dim_0) MultiIndex - latitude (dim_0) float64 -90.0 -90.0 -90.0 -90.0 ... 66.0 66.0 66.0 66.0 - longitude (dim_0) float64 0.0 1.0 2.0 3.0 4.0 ... 342.0 343.0 344.0 345.0 - SOVEREIGNT (dim_0) object 'Antarctica' 'Antarctica' ... 'Iceland' 'Iceland' Data variables: weights (dim_0) float64 8.803e-06 8.803e-06 ... 0.01502 0.001937
Now we unstack it into a sparse array.
1 2 | weights_sparse = ds_weights.unstack(sparse=True, fill_value=0.).weights weights_sparse |
<xarray.DataArray 'weights' (latitude: 171, longitude: 360, SOVEREIGNT: 201)> <COO: shape=(171, 360, 201), dtype=float64, nnz=27627, fill_value=0.0> Coordinates: * latitude (latitude) float64 -90.0 -89.0 -88.0 -87.0 ... 81.0 82.0 83.0 * longitude (longitude) float64 0.0 1.0 2.0 3.0 ... 356.0 357.0 358.0 359.0 * SOVEREIGNT (SOVEREIGNT) object 'Afghanistan' 'Albania' ... 'eSwatini'
Here we can clearly see that this is a sparse matrix, mapping the input space (lat, lon) to the output space (SOVEREIGNT).
Perform Matrix Multiplication¶
To regrid the data, we just have to multiply the original precip dataset by this matrix. There are many different ways to do this. The simplest one would just be:
1 2 | regridded = xr.dot(ds.precip, weights_sparse) regridded |
/srv/conda/envs/notebook/lib/python3.9/site-packages/xarray/core/indexing.py:1228: PerformanceWarning: Slicing is producing a large chunk. To accept the large chunk and silence this warning, set the option >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}): ... array[indexer] To avoid creating the large chunks, set the option >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}): ... array[indexer] return self.array[key]
<xarray.DataArray (time: 9226, SOVEREIGNT: 201)> dask.array<sum-aggregate, shape=(9226, 201), dtype=float64, chunksize=(200, 201), chunktype=numpy.ndarray> Coordinates: * time (time) datetime64[ns] 1996-10-01 1996-10-02 ... 2021-12-31 * SOVEREIGNT (SOVEREIGNT) object 'Afghanistan' 'Albania' ... 'eSwatini'
Unfortunately, that doesn't work out of the box, because sparse doesn't implement einsum (see https://github.com/pydata/sparse/issues/31).
1 | regridded[0].compute() |
TypeError: no implementation found for 'numpy.einsum' on types that implement __array_function__: [<class 'numpy.ndarray'>, <class 'sparse._coo.core.COO'>]
There are many ways to work around this, starting from just using a dense matrix instead. Below is the fastest thing I came up with after some extensive experimentation.
Sparse does implement matmul, so we can use that. But we have to do some reshaping to make it work with our data.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | def apply_weights_matmul_sparse(weights, data): assert isinstance(weights, sparse.SparseArray) assert isinstance(data, np.ndarray) data = sparse.COO.from_numpy(data) data_shape = data.shape # k = nlat * nlon n, k = data_shape[0], data_shape[1] * data_shape[2] data = data.reshape((n, k)) weights_shape = weights.shape k_, m = weights_shape[0] * weights_shape[1], weights_shape[2] assert k == k_ weights_data = weights.reshape((k, m)) regridded = sparse.matmul(data, weights_data) assert regridded.shape == (n, m) return regridded.todense() |
Before applying this to the data, let's load it into memory and then chunk it again. This is not necessary (we could just stream the data from the cloud), but it is a cleaner benchmark. Chunking again allows us to leverage dask parallelism. We also eliminate some know corrupted values via a mask.
1 2 3 4 | mask = (ds.precip >= 0) & (ds.precip < 3000) precip = ds.precip.where(mask) precip_in_mem = precip.compute().chunk({"time": "10MB"}) precip_in_mem |
<xarray.DataArray 'precip' (time: 9226, latitude: 180, longitude: 360)> dask.array<xarray-<this-array>, shape=(9226, 180, 360), dtype=float32, chunksize=(38, 180, 360), chunktype=numpy.ndarray> Coordinates: * latitude (latitude) float32 -90.0 -89.0 -88.0 -87.0 ... 87.0 88.0 89.0 * longitude (longitude) float32 0.0 1.0 2.0 3.0 ... 356.0 357.0 358.0 359.0 * time (time) datetime64[ns] 1996-10-01 1996-10-02 ... 2021-12-31 Attributes: cell_methods: area: mean time: mean long_name: NOAA Climate Data Record (CDR) of Daily GPCP Satellite-Ga... standard_name: lwe_precipitation_rate units: mm/day valid_range: [0.0, 100.0]
1 2 3 4 5 6 7 8 9 10 11 | precip_regridded = xr.apply_ufunc( apply_weights_matmul_sparse, weights_sparse, precip_in_mem, join="left", input_core_dims=[["latitude", "longitude", "SOVEREIGNT"], ["latitude", "longitude"]], output_core_dims=[["SOVEREIGNT"]], dask="parallelized", meta=[np.ndarray((0,))] ) precip_regridded |
/tmp/ipykernel_1777/4033782278.py:1: FutureWarning: ``meta`` should be given in the ``dask_gufunc_kwargs`` parameter. It will be removed as direct parameter in a future version. precip_regridded = xr.apply_ufunc( /srv/conda/envs/notebook/lib/python3.9/site-packages/xarray/core/indexing.py:1228: PerformanceWarning: Slicing is producing a large chunk. To accept the large chunk and silence this warning, set the option >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}): ... array[indexer] To avoid creating the large chunks, set the option >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}): ... array[indexer] return self.array[key]
<xarray.DataArray (time: 9226, SOVEREIGNT: 201)> dask.array<transpose, shape=(9226, 201), dtype=float64, chunksize=(38, 201), chunktype=numpy.ndarray> Coordinates: * SOVEREIGNT (SOVEREIGNT) object 'Afghanistan' 'Albania' ... 'eSwatini' * time (time) datetime64[ns] 1996-10-01 1996-10-02 ... 2021-12-31
Finally, we compute it!
1 2 3 4 | from dask.diagnostics import ProgressBar with ProgressBar(): precip_regridded.load() |
[ ] | 2% Completed | 1.4s
/srv/conda/envs/notebook/lib/python3.9/site-packages/sparse/_common.py:232: RuntimeWarning: Nan will not be propagated in matrix multiplication warnings.warn(
[### ] | 8% Completed | 3.1s
/srv/conda/envs/notebook/lib/python3.9/site-packages/sparse/_common.py:232: RuntimeWarning: Nan will not be propagated in matrix multiplication warnings.warn(
[########################################] | 100% Completed | 8.5s
With this approach, it look us 6s to regrid the entire dataset (9226) timesteps!
We can now explore the data by region.
1 | precip_regridded.sel(SOVEREIGNT="Italy").resample(time="MS").mean().plot() |
[<matplotlib.lines.Line2D at 0x7fb5348c9790>]
1 |