scikit-learn: A walk through of GroupKFold.split()

Yao Yao on January 25, 2018

Suppose $X[“groups”] = \begin{bmatrix} a \newline b \newline b \newline c \newline c \newline c \end{bmatrix}$ and n_splits=3.

Then GroupKFold.split(X, y, X["groups"]) will run into the _iter_test_indices method which simply yields the indices of the test folds.

# Parameter groups == X["groups"]
unique_groups, groups = np.unique(groups, return_inverse=True)

So this groups is an interesting index: if X["groups"] has $n$ unique values, groups could assign $n$ markers to the original X["groups"]. E.g.

markers = np.array(['△', '○', '□'])
markers[[0, 1, 1, 2, 2, 2]] == array(['△', '○', '○', '□', '□', '□'], dtype='<U1')

And especially, unique_groups[groups] == X["groups"].

n_groups = len(unique_groups)  # 3
 
# Weight groups by their number of occurrences
n_samples_per_group = np.bincount(groups)
# Distribute the most frequent groups first
indices = np.argsort(n_samples_per_group)[::-1]
n_samples_per_group = n_samples_per_group[indices]
# Total weight of each fold
n_samples_per_fold = np.zeros(self.n_splits)  # [0, 0, 0]

# Mapping from group index to fold index
group_to_fold = np.zeros(len(unique_groups))  # [0, 0, 0]

# Distribute samples by adding the largest weight to the lightest fold
for group_index, weight in enumerate(n_samples_per_group):
    lightest_fold = np.argmin(n_samples_per_fold)
    n_samples_per_fold[lightest_fold] += weight
    group_to_fold[indices[group_index]] = lightest_fold
  • group_index = 0weight = 3
    • lightest_fold = 0
    • n_samples_per_fold[0] = 3
    • group_to_fold[2] = 0
  • group_index = 1; weight = 2
    • lightest_fold = 1
    • n_samples_per_fold[1] = 2
    • group_to_fold[1] = 1
  • group_index = 2; weight = 1
    • lightest_fold = 2
    • n_samples_per_fold[2] = 1
    • group_to_fold[0] = 2
indices = group_to_fold[groups]

Key step! group_to_fold is actually a marker triple here.

for f in range(self.n_splits):
    yield np.where(indices == f)[0]  # note that `np.where` here return a one-elemented tuple
  • The 1st split: f = 0, yield np.array([3, 4, 5])
  • The 2nd split: f = 1, yield np.array([1, 2])
  • The 3rd split: f = 2, yield np.array([0])
# This is an abstract class, `_iter_test_indices` being the abstract method
class BaseCrossValidator(with_metaclass(ABCMeta)):
    def split(self, X, y=None, groups=None):
        X, y, groups = indexable(X, y, groups)
        indices = np.arange(_num_samples(X))  # array([0, 1, 2, 3, 4, 5]) here
        for test_index in self._iter_test_masks(X, y, groups):
            train_index = indices[np.logical_not(test_index)]
            test_index = indices[test_index]
            yield train_index, test_index

    def _iter_test_masks(self, X=None, y=None, groups=None):
        """Generates boolean masks corresponding to test sets.
        By default, delegates to _iter_test_indices(X, y, groups)
        """
        for test_index in self._iter_test_indices(X, y, groups):
            test_mask = np.zeros(_num_samples(X), dtype=np.bool)
            test_mask[test_index] = True
            yield test_mask

    def _iter_test_indices(self, X=None, y=None, groups=None):
        """Generates integer indices corresponding to test sets."""
        raise NotImplementedError
  • The 1st split:
    • test_mask == np.array([False, False, False, True, True, True])
    • train_index == np.array([0, 1, 2])
    • test_index == np.array([3, 4, 5])
  • The 2nd split:
    • test_mask == np.array([False, True, True, False, False, False])
    • train_index == np.array([0, 3, 4, 5])
    • test_index == np.array([1, 2])
  • The 3rd split:
    • test_mask == np.array([True, False, False, False, False, False])
    • train_index == np.array([1, 2, 3, 4, 5])
    • test_index == np.array([0])

P.S. Note that, given its input, GroupKFold’s output is fixed. No random seed is needed.



blog comments powered by Disqus