flox: Faster GroupBy reductions with Xarray

Tuesday, July 18th, 2023 (10 months ago)


Significantly faster groupby calculations are now possible through a new-ish package in the Xarray/Dask/Pangeo ecosystem called flox. Practically, this means faster climatologies, faster resampling, faster histogramming, and faster compositing of array datasets. It also means that very very many discussions in the Pangeo community are now closed 🎉 😱 🤯 🥳.

How do I use it?#

Run mamba install flox and xarray>=2022.06.0 will use it by default for .groupby, .groupby_bins, and .resample!

A lot of effort was spent in ensuring backwards compatibility, so your workloads should only work better. Let us know if it does not

The National Water Model Dataset#

To show off, we demonstrate county-wise aggregation of output from the National Water Model (NWM) available on the AWS Public Data Registry.

Quoting the NOAA page for more.

1import flox  # make sure its available
2import fsspec
3import numpy as np
4import rioxarray
5import xarray as xr
7ds = xr.open_zarr(
8    fsspec.get_mapper("s3://noaa-nwm-retrospective-2-1-zarr-pds/rtout.zarr", anon=True),
9    consolidated=True,
Loading data...

Each field in this dataset is big!

Loading data...

This variable zwattablrt represents "Depth to saturated layers (=2m when no saturation; =0 when fully saturated)" (source). So the 2m depth mean an unsaturated soil column and 0m indicates a fully saturated soil column.

We'll subset to a single variable and a single year for demo purposes.

1subset = ds.zwattablrt.sel(time="2001")
Loading data...

Problem description#

We want to calculate county-level means for 3 hourly time series data on the 250m grid. Our desired output looks like this:

This is a GroupBy problem#

GroupBy is a term used for a very common analysis pattern commonly called "split-apply-combine" (Wickham, 2011) wherein an analyst

  • Splits a dataset into groups (e.g. counties),
  • Applies a transformation to each group of data (here a reduction like .mean)
  • Combines the results of apply to form a new dataset

For this problem we will split the dataset into counties, apply the mean, and then combine the results back.

With Xarray, this would look like


However Xarray's default algorithm is to split the dataset in to groups by indexing, and then applying the reduction as a simple for-loop over groups. This approach doesn't work very well for large distributed problems.

Xarray default groupby

Enter flox.#

flox solves a long-standing problem in the Pangeo array computing ecosytem of computing GroupBy reductions. It implements a parallel groupby algorithm (using a tree reduction) to substantially improve performance of groupby reductions with dask.

flox default map-reduce groupby

  • Specifically, flox speeds up reduction methods like groupby(...).mean(), groupby(...).max(), etc, but not groupby.map.
  • flox also significantly speeds up groupby reductions with pure numpy arrays using optimized implementations in the numpy-groupies package.
  • flox allows more complicated groupby operations such as lazy grouping by a dask array, and grouping by multiple variables. Use flox.xarray.xarray_reduce for these operations. Xarray currently only supports grouping by a single numpy variable.

See here for short examples.


Load county raster for grouping#

A raster TIFF file identifying counties by a unique integer was created separately and saved.

We load that using rioxarray

1import rioxarray
3counties = rioxarray.open_rasterio(
4    "s3://nwm-250m-us-counties/Counties_on_250m_grid.tif", chunks="auto"
7# remove any small floating point error in coordinate locations
8_, counties_aligned = xr.align(ds, counties, join="override")
Loading data...

We'll need the unique county IDs later, calculate that now.

1county_id = np.unique(counties_aligned.data).compute()
2# 0 is used as NULL
3county_id = county_id[county_id != 0]
4print(f"There are {len(county_id)} counties!")

There are 3108 counties!

GroupBy with Flox#

We could run the computation as


This would use flox in the background.

However it would also load counties_aligned in to memory (an unfortunate Xarray implementation detail) which is not so bad (only a gig). To avoid egress charges, we'll instead go through flox.xarray which allows you to lazily groupby a dask array (here counties_aligned) as long as you pass in the expected group labels in expected_groups.

See here for more.

1import flox.xarray
3county_mean = flox.xarray.xarray_reduce(
4    subset,
5    counties_aligned.rename("county"),
6    func="mean",
7    expected_groups=(county_id,),
Loading data...

The computation proceeds very nicely, in particular thanks to recent improvements in dask/distributed (1, 2). We don't anticipate trouble scaling this computation up to the full dataset.


flox makes many large Groupby problems tractable! Use it.

flox also makes many small but more complicated (e.g. multiple variables) Groupby problems tractable! Use it.

We anticipate upgrading Xarray's interface to enable more complicated GroupBy computations. In the mean time, use flox!

Run mamba install flox and xarray>=2022.06.0 will use it by default for .groupby, .groupby_bins, and .resample!

See here for short examples on the many ways to use flox!


Thanks to Matt Rocklin (coiled.io) for facilitating easy computation with Dask in the cloud for the demo calculation.

Thanks to Kevin Sampson, Katelyn Fitzgerald, and James McCreight for feedback.

Deepak Cherian's time was was funded in part by

  1. NASA-ACCESS 80NSSC18M0156 "Community tools for analysis of NASA Earth Observing System Data in the Cloud"
  2. NASA-OSTFL 80NSSC22K0345 "Enhancing analysis of NASA data with the open-source Python Xarray Library", and
  3. NCAR's Earth System Data Science Initiative.
Back to Blog

xarray logo

© 2024, Xarray core developers. Apache 2.0 Licensed.


TwitterGitHubYouTubeBlog RSS Feed
Powered by Vercel