Skip to main content
  1. Blog Posts/

Group coordination, automatic clustering, and scikit-learn

·3793 words·18 mins

In a pretty cool experiment run at Tate Modern last year, we put visitors to the gallery into groups of 7, blindfolded them, and asked to to march in time at a given tempo. We’re using this data to explore how people coordinate with each other in groups.

One problem, however, is that people don’t always keep perfect time. Often, they would skip a few beats, step more than once in a single beat, or start late. The challange was to find a way to automatically cluster the data, so we can say this footstep and this footstep belong to this beat, while this other other footstep belongs to a different beat. In this post, I outline how I achieved this using python’s scikit-learn.

The Data #

First, let’s take a look at the raw data.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn import cluster as skcluster

# Plot settinngs
sns.set_style('whitegrid')
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['font.size'] = 18
all_data = pd.read_csv('data/raw.csv')
all_data.head()

groupblockcamerasubjectsteptimeisi
011ARA11.534NaN
111ARA22.0370.503
211ARA32.5190.482
311ARA43.0120.493
411ARA53.4670.455
print('We have %i rows, and %i columns.\n' % all_data.shape)

print('Unique values (of index columns):')
for col in ['group', 'block', 'camera', 'subject']:
    print('- "%s"\n   %s' % (col, str(all_data[col].unique())))
We have 30805 rows, and 7 columns.

Unique values (of index columns):
- "group"
   [ 1  2  3  4  5  6  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
 26 27  7]
- "block"
   [1 2]
- "camera"
   ['A' 'B']
- "subject"
   ['RA' 'S2' 'S4' 'S5' 'S6' 'S1' 'S3']

We have data from 27 groups, who each completed two blocks. Each group consisted of 7 people (6 subjects, and research assistant). For each step by each person, we have coded the time since the start of the recording (in seconds), and the step number (step 1, 2, 3, etc., for that person, in that recording).

Checking Alignment #

Next, let’s extract data from group 1, block 1, and represent it using a wide (participant × step) data frame of step times.

def select_wide_data(all_data: pd.DataFrame, 
                     group: int, block: int) -> pd.DataFrame:
    '''Get wide data for this group, in this block'''
    mask = (all_data['group']==group) & (all_data['block']==block)
    df = all_data[mask] # Long format
    dfx = df.pivot_table(index='subject', columns='step', values='time') # Wide format
    return dfx

group = 1
block = 1
dfx = select_wide_data(all_data, group, block)
dfx

step12345678910...80818283848586878889
subject
RA1.5342.0372.5193.0123.4673.8994.3824.8905.3545.823...40.10240.60741.08841.57742.06142.54943.004NaNNaNNaN
S12.2292.6663.0873.5784.0454.5585.0075.4395.9506.433...40.26940.70641.24741.72142.23142.74443.226NaNNaNNaN
S20.7511.1581.6192.0462.5092.9733.4183.9384.4374.879...38.71339.19739.66040.11840.61941.09741.61042.09542.58343.082
S31.1201.6832.1732.5983.0273.6104.0284.5234.9695.503...39.27339.76740.23640.68841.19841.65742.21942.67143.253NaN
S41.2932.1662.5862.9763.5003.9064.4314.9265.8436.355...40.09340.61341.06241.52442.12042.56543.062NaNNaNNaN
S51.0271.5362.0072.5082.9333.3703.9294.3664.8555.362...39.22739.67540.21440.69241.18141.68442.16842.64543.127NaN
S60.5821.1311.6662.0962.5342.9523.4803.9544.4324.910...38.80739.23039.71640.13040.69641.13741.64142.10342.655NaN

7 rows × 89 columns

Which steps belong to which beat clusters? We’ll start by assuming that everyone was perfectly coordinated, so that step (n) for every participant belongs to cluster (n). This is encoded in the participant × step cluster matrix, below.

n_subjects, n_times = dfx.shape
steps = np.arange(n_times)
subjects = np.arange(n_subjects)

cluster_ids = dfx.columns.values
initial_cluster_matrix = np.repeat(cluster_ids, n_subjects).reshape((n_times, n_subjects)).T
initial_cluster_matrix
array([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89],
       [ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
        33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
        49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
        65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
        81, 82, 83, 84, 85, 86, 87, 88, 89]])

To see how sensible this clustering is, let’s produce a raster plot of all the data from this block. This shows the time of every step for every participant, along with the cluster that step is assigned to (colour-coded, and with numeric labels).

def raster_plot(cluster_matrix: np.ndarray, dfx: pd.DataFrame, text=True):
    '''Produce a raster plot, showing the time of each participants' steps
    and the cluster they've been assigned to.
    
    
    '''
    fig, axes = plt.subplots(4, 1, figsize=(18, 12))
    for i, start in enumerate([0, 12, 24, 36]):
        raster_plot_row(cluster_matrix, dfx, start=start, stop=start+12, ax=axes[i], text=text)
    plt.tight_layout()
    plt.ylabel('Subject')
    plt.xlabel('Time (s)')

def raster_plot_row(cluster_matrix, dfx, start=0, stop=90, text=True, ax=None):
    '''Plot a single row of our raster plot, in the time window provided'''
    if ax is None:
        plt.figure(figsize=(18, 3))
        plt.ylabel('Subject')
        plt.xlabel('Time (s)')
    else:
        plt.sca(ax)
    pal = sns.palettes.color_palette(n_colors=7)
    from itertools import cycle
    pal = cycle(pal)
    X = dfx.values
    subjects = dfx.index.values
    subject_ix = np.arange(len(subjects))
    steps = np.unique(cluster_matrix)
    steps = steps[~np.isnan(steps)]
    for step in steps:
        color = next(pal)
        for sx, subject in enumerate(subjects):
            cm = cluster_matrix[sx]
            t = X[sx, cm==step] # subject cluster times
            t = t[(t >= start) & (t <= stop)]
            n = len(t)
            if n > 0:
                if text:
                    for v in t:
                        plt.text(v, sx, int(step), color=color)
                else:
                    plt.plot(t, np.repeat(sx, n), 'o', color=color)
    plt.ylim(-1, len(subjects)+1)
    plt.yticks(subject_ix+.5, subjects)
    plt.xlim(start, stop)
raster_plot(initial_cluster_matrix, dfx)

png

Clearly, the clusters are wrong. If we look at 6 seconds in, for instance, we see that all 7 people step at around the same time, but this is step 10 for the RA, step 9 for subject 1, step 12 for subject 2, and so on.

Automatic Clustering #

To address this, I use an automatic clustering algorithm from scikit-learn to group the steps into temporal clusters. There are a number of clustering algorithms available, making different assumptions about the nature of the clusters you’re looking for, and with different parameters to be tuned.

Some algorithms such as k-means and Gaussian mixture models require that you specify how many clusters you’re looking for. This is problematic here, since different groups end up with different numbers of beat clusters, and it’s not even clear even to us how many clusters there should be. For example, how many beats occur in the first two seconds of the plot above? It is possible to use model selection techniques to find the optimal number of clusters for some algorithms, such as Guassian mixture modelling, but this is time-consuming and complicated. K-means also assumes that all clusters have the same variance, which doesn’t work here since cluster variance goes up whenever people go out of time.

Other algorithms automatically figure out how many clusters to create, based on some other criterion. I found good results using the Mean Shift algorithm, which requires that you set a bandwidth parameter dictating how broad individual clusters should be. Results were best with bandwidth=0.25, which seems consistent with the fact that particpants were supposed to be stepping at 120 BPM, or once every 0.5 seconds.

# Step times as column vector, excluding NaNs
X = dfx.values
step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
# Fit model
model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
# Cluster labels are not in order
# To reorder them, we generate a dict that maps
# the original labels to new, ordered ones.
labels = model.predict(model.cluster_centers_)
ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
label_dict = dict(zip(ordered_labels, labels))
label_dict[np.nan] = np.nan
print(label_dict)
{88: 0, 86: 1, 87: 2, 82: 3, 81: 4, 80: 5, 79: 6, 78: 7, 77: 8, 76: 9, 85: 10, 75: 11, 74: 12, 73: 13, 72: 14, 71: 15, 70: 16, 69: 17, 68: 18, 67: 19, 66: 20, 65: 21, 64: 22, 63: 23, 62: 24, 61: 25, 60: 26, 59: 27, 58: 28, 57: 29, 56: 30, 55: 31, 54: 32, 53: 33, 52: 34, 51: 35, 50: 36, 49: 37, 48: 38, 47: 39, 46: 40, 45: 41, 44: 42, 43: 43, 42: 44, 41: 45, 40: 46, 39: 47, 38: 48, 37: 49, 36: 50, 35: 51, 34: 52, 33: 53, 32: 54, 31: 55, 30: 56, 29: 57, 28: 58, 27: 59, 26: 60, 25: 61, 24: 62, 84: 63, 23: 64, 22: 65, 21: 66, 20: 67, 19: 68, 18: 69, 17: 70, 16: 71, 15: 72, 14: 73, 13: 74, 12: 75, 11: 76, 10: 77, 9: 78, 8: 79, 7: 80, 6: 81, 5: 82, 4: 83, 3: 84, 2: 85, 1: 86, 0: 87, 83: 88, nan: nan}
# Generate a matrix of ordered cluster labels from the clustering model
cluster_matrix = np.zeros_like(X)
for i in subjects:
    for j in range(n_times):
        x = X[i, j]
        if not np.isnan(x):
            label = model.predict(x.reshape(-1, 1))[0]
            cluster_matrix[i, j] = label_dict[label]
        else:
            cluster_matrix[i, j] = np.nan
print(cluster_matrix)
[[ 2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19.
  20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.
  38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.
  56. 57. 58. 59. 60. 61. 62. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
  75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
 [ 3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.
  21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38.
  39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.
  57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
  75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
 [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
  18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
  36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
  54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71.
  72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88.]
 [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
  19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
  37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
  55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
  73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan]
 [ 1.  3.  4.  5.  6.  7.  8.  9. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.
  21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38.
  39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.
  57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
  75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
 [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
  19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
  37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
  55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
  73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan]
 [ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17.
  18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
  36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
  54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71.
  72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. nan]]
# Here's a function that does the same thing
def fit_meanshift_clusters(dfx: pd.DataFrame, bandwidth=.25) -> np.ndarray:
    '''Cluster the step times using the MeanShift algorithm.
    Returns  participant × step cluster matrix'''
    # Step times as column vector, excluding NaNs
    X = dfx.values
    step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
    # Fit model
    model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
    # Cluster labels are not in order
    # To reorder them, we generate a dict that maps
    # the original labels to new, ordered ones.
    labels = model.predict(model.cluster_centers_)
    ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
    label_dict = dict(zip(ordered_labels, labels))
    label_dict[np.nan] = np.nan
    # Generate a matrix of ordered cluster labels from the clustering model
    cluster_matrix = np.zeros_like(X)
    n_subjects, n_times = X.shape
    for i in range(n_subjects):
        for j in range(n_times):
            x = X[i, j]
            if not np.isnan(x):
                label = model.predict(x.reshape(-1, 1))[0]
                cluster_matrix[i, j] = label_dict[label]
            else:
                cluster_matrix[i, j] = np.nan
    return cluster_matrix
raster_plot(cluster_matrix, dfx)

png

Better. Steps that occur at around the same time are now part of the same cluster.

Finally, we convert the step times and cluster labels back to format, and save them to file.

%mkdir -p data/clustered
step_df = dfx.reset_index().melt(id_vars='subject', value_name='time')
cluster_dfx = pd.DataFrame(cluster_matrix, index=dfx.index, columns=dfx.columns)
cluster_df = (cluster_dfx.reset_index()
              .melt(id_vars='subject')
              .rename({'value':'cluster'}, axis=1)
              .sort_values(['step', 'subject']))
result = pd.merge(step_df, cluster_df, 
                  on=['subject', 'step'], how='left')
result.to_csv('data/clustered/group%i_block%i.csv' % (group, block),
              index=False)
result.head()

Edge Cases #

Although it works very well for the vast majority of our data (27 groups, two blocks each), this approach isn’t perfect. There are still some cases that aren’t clustered as we would expect. Here’s the worst example.

group, block = 2, 2
dfx = select_wide_data(all_data, group, block)
cluster_matrix = fit_meanshift_clusters(dfx)
raster_plot(cluster_matrix, dfx)

png

Cluster 25 contains steps that should really be split across two clusters, clusters 50 and 51 should probably be a single cluster. There are a few other issues, for example around cluster 68. Several of these issues arise because the clustering algorithm isn’t constrained to have only one step per person per cluster. There are a few approaches we could take to deal with these problems.

Exclude them #

For now, we limit our analyses to clusters of 7 steps. This is certainly the most reliable solution, and does not lead to any further issues in our analyses.

Constrain the clustering algorithm #

In principle, we could modify the code for sklearn.cluster.MeanShift to add the constraints we need. This is likely to be difficult though.

Postprocess the result #

Alternatively, we could do the clustering as normal, but add some additional steps at the end to deal with abnormal clusters when they come up.

Manual fix #

Finally, we could just manually adjust the cluster result to deal with these very occasional edge cases. This is probably the most time-effective way to actually fix the problem. Fixing the abberant clusters might be quite time-consuming in a spreadsheet program, but could be made much easier by using matplotlib’s advanced features to make the raster plots above interactive, providing a GUI for modifying the cluster labels. This is quite similar to some of the functionality in the excellent mne package for M/EEG analyses.

Loop It #

Finally, the code below runs the clustering algorithm for every group, produces plots, and merges and saves the output.

%mkdir -p figures
%mkdir -p figures/clustering
# Duplicates of the functions defined above
def select_wide_data(all_data: pd.DataFrame, 
                     group: int, block: int) -> pd.DataFrame:
    '''Get wide data for this group, in this block'''
    mask = (all_data['group']==group) & (all_data['block']==block)
    df = all_data[mask] # Long format
    dfx = df.pivot_table(index='subject', columns='step', values='time') # Wide format
    return dfx

def raster_plot(cluster_matrix: np.ndarray, dfx: pd.DataFrame, text=True):
    '''Produce a raster plot, showing the time of each participants' steps
    and the cluster they've been assigned to.
    
    
    '''
    fig, axes = plt.subplots(4, 1, figsize=(18, 12))
    for i, start in enumerate([0, 12, 24, 36]):
        raster_plot_row(cluster_matrix, dfx, start=start, stop=start+12, ax=axes[i], text=text)
    plt.tight_layout()
    plt.ylabel('Subject')
    plt.xlabel('Time (s)')

def raster_plot_row(cluster_matrix, dfx, start=0, stop=90, text=True, ax=None):
    '''Plot a single row of our raster plot, in the time window provided'''
    if ax is None:
        plt.figure(figsize=(18, 3))
        plt.ylabel('Subject')
        plt.xlabel('Time (s)')
    else:
        plt.sca(ax)
    pal = sns.palettes.color_palette(n_colors=7)
    from itertools import cycle
    pal = cycle(pal)
    X = dfx.values
    subjects = dfx.index.values
    subject_ix = np.arange(len(subjects))
    steps = np.unique(cluster_matrix)
    steps = steps[~np.isnan(steps)]
    for step in steps:
        color = next(pal)
        for sx, subject in enumerate(subjects):
            cm = cluster_matrix[sx]
            t = X[sx, cm==step] # subject cluster times
            t = t[(t >= start) & (t <= stop)]
            n = len(t)
            if n > 0:
                if text:
                    for v in t:
                        plt.text(v, sx, int(step), color=color)
                else:
                    plt.plot(t, np.repeat(sx, n), 'o', color=color)
    plt.ylim(-1, len(subjects)+1)
    plt.yticks(subject_ix+.5, subjects)
    plt.xlim(start, stop)

def fit_meanshift_clusters(dfx: pd.DataFrame, bandwidth=.25) -> np.ndarray:
    '''Cluster the step times using the MeanShift algorithm.
    Returns  participant × step cluster matrix'''
    # Step times as column vector, excluding NaNs
    X = dfx.values
    step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
    # Fit model
    model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
    # Cluster labels are not in order
    # To reorder them, we generate a dict that maps
    # the original labels to new, ordered ones.
    labels = model.predict(model.cluster_centers_)
    ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
    label_dict = dict(zip(ordered_labels, labels))
    label_dict[np.nan] = np.nan
    # Generate a matrix of ordered cluster labels from the clustering model
    cluster_matrix = np.zeros_like(X)
    n_subjects, n_times = X.shape
    for i in range(n_subjects):
        for j in range(n_times):
            x = X[i, j]
            if not np.isnan(x):
                label = model.predict(x.reshape(-1, 1))[0]
                cluster_matrix[i, j] = label_dict[label]
            else:
                cluster_matrix[i, j] = np.nan
    return cluster_matrix
def do_block(group, block, plot=True):
    '''Put the whole pipeline together'''
    print('Group %i - Block %i' % (group, block))
    dfx = select_wide_data(all_data, group, block)
    cluster_matrix = fit_meanshift_clusters(dfx)
    raster_plot(cluster_matrix, dfx)

    # Initial order
    n_subjects, n_times = dfx.shape
    steps = np.arange(n_times)
    subjects = np.arange(n_subjects)
    cluster_ids = dfx.columns.values
    initial_cluster_matrix = np.repeat(cluster_ids, n_subjects).reshape((n_times, n_subjects)).T
    # Plot and save
    if plot:
        raster_plot(initial_cluster_matrix, dfx)
        plt.suptitle('Group %i - Block %i - Initial Labels' % (group, block))
        plt.savefig('figures/clustering/g%i_b%i_initial.png' % (group, block))
        plt.close()

    # Mean Shift Clustering
    cluster_matrix = fit_meanshift_clusters(dfx)
    if plot:
        raster_plot(cluster_matrix, dfx)
        plt.suptitle('Group %i - Block %i - MeanShift Clustered Labels' % (group, block))
        plt.savefig('figures/clustering/g%i_b%i_meanshift.png' % (group, block))
        plt.close()

    ## Export
    step_df = dfx.reset_index().melt(id_vars='subject', value_name='time')
    cluster_dfx = pd.DataFrame(cluster_matrix, index=dfx.index, columns=dfx.columns)
    cluster_df = (cluster_dfx.reset_index()
                  .melt(id_vars='subject')
                  .rename({'value':'cluster'}, axis=1)
                  .sort_values(['step', 'subject']))
    result = pd.merge(step_df, cluster_df, 
                      on=['subject', 'step'], how='left')
    result['group'] = group
    result['block'] = block
    return result
results = []
for (group, block), _ in all_data.groupby(['group', 'block']):
    result = do_block(group, block)
    results.append(result)
Group 1 - Block 1
Group 1 - Block 2
Group 2 - Block 1
...[and so on]...
final_results = pd.concat(results)
final_results.head()

subjectsteptimeclustergroupblock
0RA11.5342.011
1S112.2293.011
2S210.7510.011
3S311.1201.011
4S411.2931.011
final_results.to_csv('data/clustered.csv', index=False)