Complex dataset split - StratifiedGroupShuffleSplit

this got more than a year, but i found my self in a similare situation where i have labels and a groups, and due to the nature of the groups one group of data points can be either in test only or in train only, i've wrote this a small algo using pandas and sklearn i hope this would help

from sklearn.model_selection import GroupShuffleSplit
groups = df.groupby('label')
all_train = []
all_test = []
for group_id, group in groups:
    # if a group is already taken in test or train it must stay there
    group = group[~group['groups'].isin(all_train+all_test)]
    # if group is empty 
    if group.shape[0] == 0:
        continue
    train_inds, test_inds = next(GroupShuffleSplit(
        test_size=valid_size, n_splits=2, random_state=7).split(group, groups=group['groups']))

    all_train += group.iloc[train_inds]['groups'].tolist()
    all_test += group.iloc[test_inds]['groups'].tolist()



train= df[df['groups'].isin(all_train)]
test= df[df['groups'].isin(all_test)]

form_train = set(train['groups'].tolist())
form_test = set(test['groups'].tolist())
inter = form_train.intersection(form_test)

print(df.groupby('label').count())
print(train.groupby('label').count())
print(test.groupby('label').count())
print(inter) # this should be empty

Essentially I need StratifiedGroupShuffleSplit which does not exist (Github issue). This is because the behaviour of such a function is unclear and accomplishing this to yield a dataset which is both grouped and stratified is not always possible (also discussed here) - especially with a heavily imbalanced dataset such as mine. In my case, I want grouping to be done strictly to ensure there is no overlap of groups whatsoever whilst stratification and the dataset ratio split of 60:20:20 to be done approximately i.e. as well as is possible.

As Ghanem mentions, I have no choice but to build a function to split the dataset myself, which I have done below:

def StratifiedGroupShuffleSplit(df_main):

    df_main = df_main.reindex(np.random.permutation(df_main.index)) # shuffle dataset

    # create empty train, val and test datasets
    df_train = pd.DataFrame()
    df_val = pd.DataFrame()
    df_test = pd.DataFrame()

    hparam_mse_wgt = 0.1 # must be between 0 and 1
    assert(0 <= hparam_mse_wgt <= 1)
    train_proportion = 0.6 # must be between 0 and 1
    assert(0 <= train_proportion <= 1)
    val_test_proportion = (1-train_proportion)/2

    subject_grouped_df_main = df_main.groupby(['subject_id'], sort=False, as_index=False)
    category_grouped_df_main = df_main.groupby('category').count()[['subject_id']]/len(df_main)*100

    def calc_mse_loss(df):
        grouped_df = df.groupby('category').count()[['subject_id']]/len(df)*100
        df_temp = category_grouped_df_main.join(grouped_df, on = 'category', how = 'left', lsuffix = '_main')
        df_temp.fillna(0, inplace=True)
        df_temp['diff'] = (df_temp['subject_id_main'] - df_temp['subject_id'])**2
        mse_loss = np.mean(df_temp['diff'])
        return mse_loss

    i = 0
    for _, group in subject_grouped_df_main:

        if (i < 3):
            if (i == 0):
                df_train = df_train.append(pd.DataFrame(group), ignore_index=True)
                i += 1
                continue
            elif (i == 1):
                df_val = df_val.append(pd.DataFrame(group), ignore_index=True)
                i += 1
                continue
            else:
                df_test = df_test.append(pd.DataFrame(group), ignore_index=True)
                i += 1
                continue

        mse_loss_diff_train = calc_mse_loss(df_train) - calc_mse_loss(df_train.append(pd.DataFrame(group), ignore_index=True))
        mse_loss_diff_val = calc_mse_loss(df_val) - calc_mse_loss(df_val.append(pd.DataFrame(group), ignore_index=True))
        mse_loss_diff_test = calc_mse_loss(df_test) - calc_mse_loss(df_test.append(pd.DataFrame(group), ignore_index=True))

        total_records = len(df_train) + len(df_val) + len(df_test)

        len_diff_train = (train_proportion - (len(df_train)/total_records))
        len_diff_val = (val_test_proportion - (len(df_val)/total_records))
        len_diff_test = (val_test_proportion - (len(df_test)/total_records)) 

        len_loss_diff_train = len_diff_train * abs(len_diff_train)
        len_loss_diff_val = len_diff_val * abs(len_diff_val)
        len_loss_diff_test = len_diff_test * abs(len_diff_test)

        loss_train = (hparam_mse_wgt * mse_loss_diff_train) + ((1-hparam_mse_wgt) * len_loss_diff_train)
        loss_val = (hparam_mse_wgt * mse_loss_diff_val) + ((1-hparam_mse_wgt) * len_loss_diff_val)
        loss_test = (hparam_mse_wgt * mse_loss_diff_test) + ((1-hparam_mse_wgt) * len_loss_diff_test)

        if (max(loss_train,loss_val,loss_test) == loss_train):
            df_train = df_train.append(pd.DataFrame(group), ignore_index=True)
        elif (max(loss_train,loss_val,loss_test) == loss_val):
            df_val = df_val.append(pd.DataFrame(group), ignore_index=True)
        else:
            df_test = df_test.append(pd.DataFrame(group), ignore_index=True)

        print ("Group " + str(i) + ". loss_train: " + str(loss_train) + " | " + "loss_val: " + str(loss_val) + " | " + "loss_test: " + str(loss_test) + " | ")
        i += 1

    return df_train, df_val, df_test

df_train, df_val, df_test = StratifiedGroupShuffleSplit(df_main)

I have created some arbitrary loss function based on 2 things:

  1. The average squared difference in the percentage representation of each category compared to the overall dataset
  2. The squared difference between the proportional length of the dataset compared to what it should be according to the ratio supplied (60:20:20)

Weighting these two inputs to the loss function is done by the static hyperparameter hparam_mse_wgt. For my particular dataset, a value of 0.1 worked well but I would encourage you to play around with it if you use this function. Setting it to 0 will prioritise only maintaining the split ratio and ignore the stratification. Setting it to 1 would be vice versa.

Using this loss function, I then iterate through each subject (group) and append it to the appropriate dataset (training, validation or test) according to whichever has the highest loss function.

It's not particularly complicated but it does the job for me. It won't necessarily work for every dataset, but the larger it is, the better the chance. Hopefully someone else will find it useful.