Grouping by multiple dimensions

I built a manual solution. To make it efficient, I discard all of xarray and rebuild indices and values by hand. Any change to use more xarray (e.g. using sel, re-packaging cells into a DataArray; also see https://github.com/pydata/xarray/issues/2452) led to serious losses in speed.

import itertools
from collections import defaultdict

import numpy as np
import xarray as xr
from xarray import DataArray

class DataAssembly(DataArray):
    def multi_dim_groupby(self, groups, apply):
        # align
        groups = sorted(groups, key=lambda group: self.dims.index(self[group].dims[0]))
        # build indices
        groups = {group: np.unique(self[group]) for group in groups}
        group_dims = {self[group].dims: group for group in groups}
        indices = defaultdict(lambda: defaultdict(list))
        result_indices = defaultdict(dict)
        for group in groups:
            for index, value in enumerate(self[group].values):
                indices[group][value].append(index)
                if value not in result_indices[group]:  # if captured once, it will be "grouped away"
                    index = max(result_indices[group].values()) + 1 if len(result_indices[group]) > 0 else 0
                    result_indices[group][value] = index

        coords = {coord: (dims, value) for coord, dims, value in walk_coords(self)}

        def simplify(value):
            return value.item() if value.size == 1 else value

        def indexify(dict_indices):
            return [(i,) if isinstance(i, int) else tuple(i) for i in dict_indices.values()]

        # group and apply
        # making this a DataArray right away and then inserting through .loc would slow things down
        result = np.zeros([len(indices) for indices in result_indices.values()])
        result_coords = {coord: (dims, [None] * len(result_indices[group_dims[dims]]))
                         for coord, (dims, value) in coords.items()}
        for values in itertools.product(*groups.values()):
            group_values = dict(zip(groups.keys(), values))
            self_indices = {group: indices[group][value] for group, value in group_values.items()}
            values_indices = indexify(self_indices)
            cells = self.values[values_indices]  # using DataArray would slow things down. thus we pass coords as kwargs
            cells = simplify(cells)
            cell_coords = {coord: (dims, value[self_indices[group_dims[dims]]])
                           for coord, (dims, value) in coords.items()}
            cell_coords = {coord: (dims, simplify(np.unique(value))) for coord, (dims, value) in cell_coords.items()}

            # ignore dims when passing to function
            passed_coords = {coord: value for coord, (dims, value) in cell_coords.items()}
            merge = apply(cells, **passed_coords)
            result_idx = {group: result_indices[group][value] for group, value in group_values.items()}
            result[indexify(result_idx)] = merge
            for coord, (dims, value) in cell_coords.items():
                if isinstance(value, np.ndarray):  # multiple values for coord -> ignore
                    if coord in result_coords:  # delete from result coords if not yet deleted
                        del result_coords[coord]
                    continue
                assert dims == result_coords[coord][0]
                coord_index = result_idx[group_dims[dims]]
                result_coords[coord][1][coord_index] = value

        # re-package
        result = type(self)(result, coords=result_coords, dims=list(itertools.chain(*group_dims.keys())))
        return result

def walk_coords(assembly):
    """
    walks through coords and all levels, just like the `__repr__` function, yielding `(name, dims, values)`.
    """
    coords = {}

    for name, values in assembly.coords.items():
        # partly borrowed from xarray.core.formatting#summarize_coord
        is_index = name in assembly.dims
        if is_index and values.variable.level_names:
            for level in values.variable.level_names:
                level_values = assembly.coords[level]
                yield level, level_values.dims, level_values.values
        else:
            yield name, values.dims, values.values
    return coords

The method multi_dim_groupby performs grouping and apply in one step. The passed apply method can accept group coords via parameters named after the coords (or ignore the coords by putting **_ in the function header).

It's not particularly pretty and does not cover all possible cases but at least covers the following test cases:

import DataAssembly

class TestMultiDimGroupby:
    def test_unique_values(self):
        d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                         coords={'a': ['a', 'b', 'c', 'd'],
                                 'b': ['x', 'y', 'z']},
                         dims=['a', 'b'])
        g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x)
        assert g.equals(d)

    def test_nonunique_singledim(self):
        d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                         coords={'a': ['a', 'a', 'b', 'b'],
                                 'b': ['x', 'y', 'z']},
                         dims=['a', 'b'])
        g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
        assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                     coords={'a': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                     dims=['a', 'b']))

    def test_nonunique_adjacentcoord(self):
        d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                         coords={'a': ('adim', ['a', 'a', 'b', 'b']),
                                 'aa': ('adim', ['a', 'b', 'a', 'b']),
                                 'b': ['x', 'y', 'z']},
                         dims=['adim', 'b'])
        g = d.multi_dim_groupby(['a', 'b'], lambda x, **_: x.mean())
        assert g.equals(DataAssembly([[2.5, 3.5, 4.5], [8.5, 9.5, 10.5]],
                                     coords={'adim': ['a', 'b'], 'b': ['x', 'y', 'z']},
                                     dims=['adim', 'b'])), \
            "adjacent coord aa should be discarded due to non-mappability"

    def test_unique_values_swappeddims(self):
        d = DataAssembly([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]],
                         coords={'a': ['a', 'b', 'c', 'd'],
                                 'b': ['x', 'y', 'z']},
                         dims=['a', 'b'])
        g = d.multi_dim_groupby(['b', 'a'], lambda x, **_: x)
        assert g.equals(d)