Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

groupby on dask objects doesn't handle chunks well #1832

Closed
rabernat opened this issue Jan 16, 2018 · 22 comments
Closed

groupby on dask objects doesn't handle chunks well #1832

rabernat opened this issue Jan 16, 2018 · 22 comments

Comments

@rabernat
Copy link
Contributor

rabernat commented Jan 16, 2018

80% of climate data analysis begins with calculating the monthly-mean climatology and subtracting it from the dataset to get an anomaly. Unfortunately this is a fail case for xarray / dask with out-of-core datasets. This is becoming a serious problem for me.

Code Sample

# Your code here
import xarray as xr
import dask.array as da
import pandas as pd
# construct an example datatset chunked in time
nt, ny, nx = 366, 180, 360
time = pd.date_range(start='1950-01-01', periods=nt, freq='10D')
ds = xr.DataArray(da.random.random((nt, ny, nx), chunks=(1, ny, nx)),
                   dims=('time', 'lat', 'lon'),
                   coords={'time': time}).to_dataset(name='field')
# monthly climatology
ds_mm = ds.groupby('time.month').mean(dim='time')
# anomaly
ds_anom = ds.groupby('time.month')- ds_mm
print(ds_anom)
<xarray.Dataset>
Dimensions:  (lat: 180, lon: 360, time: 366)
Coordinates:
  * time     (time) datetime64[ns] 1950-01-01 1950-01-11 1950-01-21 ...
    month    (time) int64 1 1 1 1 2 2 3 3 3 4 4 4 5 5 5 5 6 6 6 7 7 7 8 8 8 ...
Dimensions without coordinates: lat, lon
Data variables:
    field    (time, lat, lon) float64 dask.array<shape=(366, 180, 360), chunksize=(366, 180, 360)>

Problem description

As we can see in the example above, the chunking has been lost. The dataset contains just one single huge chunk. This happens with any non-reducing operation on the groupby, even

ds.groupby('time.month').apply(lambda x: x)

Say we wanted to compute some statistics of the anomaly, like the variance:

(ds_anom.field**2).mean(dim='time').load()

This triggers the whole big chunk (with the whole timeseries) to be loaded into memory somewhere. For out-of-core datasets, this will crash our system.

Expected Output

It seems like we should be able to do this lazily, maintaining a chunk size of (1, 180, 360) for ds_anom.

Output of xr.show_versions()

INSTALLED VERSIONS ------------------ commit: None python: 3.6.2.final.0 python-bits: 64 OS: Darwin OS-release: 16.7.0 machine: x86_64 processor: i386 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8

xarray: 0.10.0+dev27.g049cbdd
pandas: 0.20.3
numpy: 1.13.1
scipy: 0.19.1
netCDF4: 1.3.1
h5netcdf: 0.4.1
Nio: None
zarr: 2.2.0a2.dev91
bottleneck: 1.2.1
cyordereddict: None
dask: 0.16.0
distributed: 1.20.1
matplotlib: 2.1.0
cartopy: 0.15.1
seaborn: 0.8.1
setuptools: 36.3.0
pip: 9.0.1
conda: None
pytest: 3.2.1
IPython: 6.1.0
sphinx: 1.6.5

Possibly related to #392.

cc @mrocklin

@shoyer
Copy link
Member

shoyer commented Jan 16, 2018

See also dask/dask#874

@mrocklin
Copy link
Contributor

# monthly climatology
ds_mm = ds.groupby('time.month').mean(dim='time')
# anomaly
ds_anom = ds.groupby('time.month')- ds_mm

I would actually hope that this would be a little bit nicer than the case in the dask issue, especially if you are chunked by some dimension other than time. In the case that @shoyer points to we're creating a global aggregation value and then applying that to all input data. In @rabernat's case we have at least twelve aggregation points and possibly more if there are other chunked dimensions like ensemble (or lat/lon if you choose to chunk those).

@rabernat
Copy link
Contributor Author

The operation

ds_anom = ds - ds.mean(dim='time')

is also extremely common. Both should work well by default.

@mrocklin
Copy link
Contributor

Teaching the scheduler to delete-and-recompute is possible but also expensive to implement. I would not expect it near term from me.

@rabernat
Copy link
Contributor Author

Below is how I work around the issue in practice: writing a loop over each item in the groupby, and then looping over each variable, loading, and writing to disk.

gb = ds.groupby('time.month')
for month, dsm in gb:
    dsm_anom2 = ((dsm - ds_mm.sel(month=month))**2).mean(dim='time')
    dsm_anom2 = dsm_anom2.rename({f: f + '2' for f in fields})
    dsm_anom2.coords['month'] = month
    for var in dsm_anom2.data_vars:
        filename = save_dir + '%02d.%s_%s.nc' % (month, prefix, var)
        print(filename)
        ds_out = dsm_anom2[[var]].load()
        ds_out.to_netcdf(filename)

Needless to say, this feels more like my pre-xarray/dask workflow.

Since @mrocklin has made it pretty clear that dask will not automatically solve this for us any time soon, we need to brainstorm some creative ways to make this extremely common use case more friendly with out-of-core data.

@mrocklin
Copy link
Contributor

Since @mrocklin has made it pretty clear that dask will not automatically solve this for us any time soon, we need to brainstorm some creative ways to make this extremely common use case more friendly with out-of-core data.

That's not entirely true. I've said that delete-and-recompute is unlikely to be resolved in the near future. This is the solution proposed by @shoyer but only one possible solution. The fact that your for loop solution works well is evidence that delete-and-recompute is not necessary to solve this problem in your case. I'm actively working on this at dask/dask#3066 (fortunately paid for by other groups).

@mrocklin
Copy link
Contributor

(not to sound too rosy though, these problems have had me stumped for a couple days)

@mrocklin
Copy link
Contributor

mrocklin commented Jan 16, 2018

This example is an interesting one that was adapted from something that @rabernat produced

import dask
import xarray as xr
import dask.array as da
import pandas as pd
from tornado import gen

from dask.distributed import Client
client = Client(processes=False)
# below I create a random dataset that is typical of high-res climate models
# size of example can be adjusted up and down by changing shape
dims = ('time', 'depth', 'lat', 'lon')
time = pd.date_range('1980-01-01', '1980-12-01', freq='1d')
shape = (len(time), 5, 1800, 360)
# what I consider to be a reasonable chunk size
chunks = (1, 1, 1800, 360)
ds = xr.Dataset({k: (dims, da.random.random(shape, chunks=chunks))
                 for k in ['u', 'v', 'w']},
                coords={'time': time})

# create seasonal climatology
ds_clim = ds.groupby('time.week').mean(dim='time')

# construct seasonal anomaly
ds_anom = ds.groupby('time.week') - ds_clim
# compute variance of seasonal anomaly
ds_anom_var = (ds_anom**2).mean(dim='time')
ds_anom_var.compute()

It works fine locally with processes=False and poorly with processes=True. If anyone has time to help on this issue I recommend investigating what is different in these two cases. If I had time I would start here by trying to improve our understanding with better visual diagnostics, although just poring over logs might also provide some insight.

@rabernat
Copy link
Contributor Author

I am developing a use case for this scenario using real data. I will put the data in cloud storage as soon as #1800 is merged. That should make it easier to debug.

@rabernat
Copy link
Contributor Author

Or maybe real data just gets in the way of the core dask issue?

@mrocklin
Copy link
Contributor

mrocklin commented Jan 16, 2018

Looking at the worker diagnostic page during execution is informative. It has a ton of work that it can do and a ton of communication that it can do (to share results with other workers to compute the reductions). In this example it's able to start new work much faster than it is able to communicate results to peers, leading to significant buildup. These two processes happen asynchronously without any back-pressure between them, leading to most of the input being produced before it can be reduced and processed.

That's my current guess anyway. I could imagine pausing worker threads if there is a heavy communication buildup. I'm not sure how generally valuable this is though.

@mrocklin
Copy link
Contributor

I encourage you to look at the diagnostic page for one of your workers if you get a chance. This is typically served on port 8789 if that port is open.

@mrocklin
Copy link
Contributor

@rabernat you might also consider turning off spill-to-disk. I suspect that by prioritizing the other mechanisms to slow processing that you'll have a better experience

worker-memory-target: False # target fraction to stay below
worker-memory-spill: False # fraction at which we spill to disk
worker-memory-pause: 0.80  # fraction at which we pause worker threads
worker-memory-terminate: 0.95  # fraction at which we terminate the worker

@mrocklin
Copy link
Contributor

Or, this might work in conjunction with dask/dask#3066

diff --git a/distributed/worker.py b/distributed/worker.py
index a1b9f32..62b5f07 100644
--- a/distributed/worker.py
+++ b/distributed/worker.py
@@ -1227,8 +1227,8 @@ class Worker(WorkerBase):
     def add_task(self, key, function=None, args=None, kwargs=None, task=None,
                  who_has=None, nbytes=None, priority=None, duration=None,
                  resource_restrictions=None, **kwargs2):
-        if isinstance(priority, list):
-            priority.insert(1, self.priority_counter)
+        # if isinstance(priority, list):
+        #     priority.insert(1, self.priority_counter)
         try:
             if key in self.tasks:
                 state = self.task_state[key]

@mrocklin
Copy link
Contributor

mrocklin commented Feb 1, 2018

@rabernat I recommend trying with a combination of these two PRs. These do well for me on the problem listed above.

There is still some memory requriement, but it seems to be under better control

@rabernat
Copy link
Contributor Author

rabernat commented Feb 1, 2018

@mrocklin thanks for the updates. I should have some time on Friday morning to give it a try on Cheyenne.

@mrocklin
Copy link
Contributor

mrocklin commented Feb 1, 2018

The relevant PRs have been merged into master on both repositories.

@mrocklin
Copy link
Contributor

mrocklin commented Feb 3, 2018

@rabernat you shouldn't need the spill to disk comment above, just things on master branches. Ideally you would try your clmatology computation again and see if memory use continues to exceed expectations.

@mrocklin
Copy link
Contributor

mrocklin commented Feb 6, 2018

Checking in here. Any luck? I noticed your comment in dask/distributed#1736 but that seems to be a separate issue about file-based locks rather than about task scheduling priorities. Is the file-based locking stuff getting in the way of you checking for low-memory use?

@rabernat
Copy link
Contributor Author

rabernat commented Feb 6, 2018

Short answer...no luck. With the latest masters (but without the suggested dask config), I am still getting the same basic performance limitations.

I can update you more when we talk in person later today.

@rabernat
Copy link
Contributor Author

rabernat commented Jun 6, 2019

In recent versions of xarray (0.12.1) and dask (0.12.1), this issue has been ameliorated significantly. I believe this issue should now be closed.

@rabernat
Copy link
Contributor Author

I am trying a new approach to this problem using xarray's new map_blocks function. See this example: https://nbviewer.jupyter.org/gist/rabernat/30e7b747f0e3583b5b776e4093266114

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants