- Eoin Travers | AI & Data Science/
- Blog Posts/
- Group coordination, automatic clustering, and scikit-learn /
Group coordination, automatic clustering, and scikit-learn
Table of Contents
In a pretty cool experiment run at Tate Modern last year, we put visitors to the gallery into groups of 7, blindfolded them, and asked to to march in time at a given tempo. We’re using this data to explore how people coordinate with each other in groups.
One problem, however, is that people don’t always keep perfect time.
Often, they would skip a few beats, step more than once in a single beat, or start late.
The challange was to find a way to automatically cluster the data,
so we can say this footstep and this footstep belong to this beat,
while this other other footstep belongs to a different beat.
In this post, I outline how I achieved this using python’s
scikit-learn
.
The Data #
First, let’s take a look at the raw data.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn import cluster as skcluster
# Plot settinngs
sns.set_style('whitegrid')
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['font.size'] = 18
all_data = pd.read_csv('data/raw.csv')
all_data.head()
group | block | camera | subject | step | time | isi | |
---|---|---|---|---|---|---|---|
0 | 1 | 1 | A | RA | 1 | 1.534 | NaN |
1 | 1 | 1 | A | RA | 2 | 2.037 | 0.503 |
2 | 1 | 1 | A | RA | 3 | 2.519 | 0.482 |
3 | 1 | 1 | A | RA | 4 | 3.012 | 0.493 |
4 | 1 | 1 | A | RA | 5 | 3.467 | 0.455 |
print('We have %i rows, and %i columns.\n' % all_data.shape)
print('Unique values (of index columns):')
for col in ['group', 'block', 'camera', 'subject']:
print('- "%s"\n %s' % (col, str(all_data[col].unique())))
We have 30805 rows, and 7 columns.
Unique values (of index columns):
- "group"
[ 1 2 3 4 5 6 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
26 27 7]
- "block"
[1 2]
- "camera"
['A' 'B']
- "subject"
['RA' 'S2' 'S4' 'S5' 'S6' 'S1' 'S3']
We have data from 27 groups, who each completed two blocks.
Each group consisted of 7 people (6 subjects, and research assistant).
For each step by each person,
we have coded the time
since the start of the recording (in seconds),
and the step
number (step 1, 2, 3, etc., for that person, in that recording).
Checking Alignment #
Next, let’s extract data from group 1, block 1, and represent it using a wide (participant × step) data frame of step times.
def select_wide_data(all_data: pd.DataFrame,
group: int, block: int) -> pd.DataFrame:
'''Get wide data for this group, in this block'''
mask = (all_data['group']==group) & (all_data['block']==block)
df = all_data[mask] # Long format
dfx = df.pivot_table(index='subject', columns='step', values='time') # Wide format
return dfx
group = 1
block = 1
dfx = select_wide_data(all_data, group, block)
dfx
step | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | ... | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
subject | |||||||||||||||||||||
RA | 1.534 | 2.037 | 2.519 | 3.012 | 3.467 | 3.899 | 4.382 | 4.890 | 5.354 | 5.823 | ... | 40.102 | 40.607 | 41.088 | 41.577 | 42.061 | 42.549 | 43.004 | NaN | NaN | NaN |
S1 | 2.229 | 2.666 | 3.087 | 3.578 | 4.045 | 4.558 | 5.007 | 5.439 | 5.950 | 6.433 | ... | 40.269 | 40.706 | 41.247 | 41.721 | 42.231 | 42.744 | 43.226 | NaN | NaN | NaN |
S2 | 0.751 | 1.158 | 1.619 | 2.046 | 2.509 | 2.973 | 3.418 | 3.938 | 4.437 | 4.879 | ... | 38.713 | 39.197 | 39.660 | 40.118 | 40.619 | 41.097 | 41.610 | 42.095 | 42.583 | 43.082 |
S3 | 1.120 | 1.683 | 2.173 | 2.598 | 3.027 | 3.610 | 4.028 | 4.523 | 4.969 | 5.503 | ... | 39.273 | 39.767 | 40.236 | 40.688 | 41.198 | 41.657 | 42.219 | 42.671 | 43.253 | NaN |
S4 | 1.293 | 2.166 | 2.586 | 2.976 | 3.500 | 3.906 | 4.431 | 4.926 | 5.843 | 6.355 | ... | 40.093 | 40.613 | 41.062 | 41.524 | 42.120 | 42.565 | 43.062 | NaN | NaN | NaN |
S5 | 1.027 | 1.536 | 2.007 | 2.508 | 2.933 | 3.370 | 3.929 | 4.366 | 4.855 | 5.362 | ... | 39.227 | 39.675 | 40.214 | 40.692 | 41.181 | 41.684 | 42.168 | 42.645 | 43.127 | NaN |
S6 | 0.582 | 1.131 | 1.666 | 2.096 | 2.534 | 2.952 | 3.480 | 3.954 | 4.432 | 4.910 | ... | 38.807 | 39.230 | 39.716 | 40.130 | 40.696 | 41.137 | 41.641 | 42.103 | 42.655 | NaN |
7 rows × 89 columns
Which steps belong to which beat clusters? We’ll start by assuming that everyone was perfectly coordinated, so that step (n) for every participant belongs to cluster (n). This is encoded in the participant × step cluster matrix, below.
n_subjects, n_times = dfx.shape
steps = np.arange(n_times)
subjects = np.arange(n_subjects)
cluster_ids = dfx.columns.values
initial_cluster_matrix = np.repeat(cluster_ids, n_subjects).reshape((n_times, n_subjects)).T
initial_cluster_matrix
array([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89],
[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64,
65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80,
81, 82, 83, 84, 85, 86, 87, 88, 89]])
To see how sensible this clustering is, let’s produce a raster plot of all the data from this block. This shows the time of every step for every participant, along with the cluster that step is assigned to (colour-coded, and with numeric labels).
def raster_plot(cluster_matrix: np.ndarray, dfx: pd.DataFrame, text=True):
'''Produce a raster plot, showing the time of each participants' steps
and the cluster they've been assigned to.
'''
fig, axes = plt.subplots(4, 1, figsize=(18, 12))
for i, start in enumerate([0, 12, 24, 36]):
raster_plot_row(cluster_matrix, dfx, start=start, stop=start+12, ax=axes[i], text=text)
plt.tight_layout()
plt.ylabel('Subject')
plt.xlabel('Time (s)')
def raster_plot_row(cluster_matrix, dfx, start=0, stop=90, text=True, ax=None):
'''Plot a single row of our raster plot, in the time window provided'''
if ax is None:
plt.figure(figsize=(18, 3))
plt.ylabel('Subject')
plt.xlabel('Time (s)')
else:
plt.sca(ax)
pal = sns.palettes.color_palette(n_colors=7)
from itertools import cycle
pal = cycle(pal)
X = dfx.values
subjects = dfx.index.values
subject_ix = np.arange(len(subjects))
steps = np.unique(cluster_matrix)
steps = steps[~np.isnan(steps)]
for step in steps:
color = next(pal)
for sx, subject in enumerate(subjects):
cm = cluster_matrix[sx]
t = X[sx, cm==step] # subject cluster times
t = t[(t >= start) & (t <= stop)]
n = len(t)
if n > 0:
if text:
for v in t:
plt.text(v, sx, int(step), color=color)
else:
plt.plot(t, np.repeat(sx, n), 'o', color=color)
plt.ylim(-1, len(subjects)+1)
plt.yticks(subject_ix+.5, subjects)
plt.xlim(start, stop)
raster_plot(initial_cluster_matrix, dfx)
Clearly, the clusters are wrong. If we look at 6 seconds in, for instance, we see that all 7 people step at around the same time, but this is step 10 for the RA, step 9 for subject 1, step 12 for subject 2, and so on.
Automatic Clustering #
To address this, I use an automatic clustering algorithm from scikit-learn
to group the steps into temporal clusters.
There are a number of clustering algorithms available,
making different assumptions about the nature of the clusters you’re looking for,
and with different parameters to be tuned.
Some algorithms such as k-means and Gaussian mixture models require that you specify how many clusters you’re looking for. This is problematic here, since different groups end up with different numbers of beat clusters, and it’s not even clear even to us how many clusters there should be. For example, how many beats occur in the first two seconds of the plot above? It is possible to use model selection techniques to find the optimal number of clusters for some algorithms, such as Guassian mixture modelling, but this is time-consuming and complicated. K-means also assumes that all clusters have the same variance, which doesn’t work here since cluster variance goes up whenever people go out of time.
Other algorithms automatically figure out how many clusters to create,
based on some other criterion.
I found good results using the Mean Shift
algorithm,
which requires that you set a bandwidth parameter
dictating how broad individual clusters should be.
Results were best with bandwidth=0.25
,
which seems consistent with the fact that particpants
were supposed to be stepping at 120 BPM, or once every 0.5 seconds.
# Step times as column vector, excluding NaNs
X = dfx.values
step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
# Fit model
model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
# Cluster labels are not in order
# To reorder them, we generate a dict that maps
# the original labels to new, ordered ones.
labels = model.predict(model.cluster_centers_)
ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
label_dict = dict(zip(ordered_labels, labels))
label_dict[np.nan] = np.nan
print(label_dict)
{88: 0, 86: 1, 87: 2, 82: 3, 81: 4, 80: 5, 79: 6, 78: 7, 77: 8, 76: 9, 85: 10, 75: 11, 74: 12, 73: 13, 72: 14, 71: 15, 70: 16, 69: 17, 68: 18, 67: 19, 66: 20, 65: 21, 64: 22, 63: 23, 62: 24, 61: 25, 60: 26, 59: 27, 58: 28, 57: 29, 56: 30, 55: 31, 54: 32, 53: 33, 52: 34, 51: 35, 50: 36, 49: 37, 48: 38, 47: 39, 46: 40, 45: 41, 44: 42, 43: 43, 42: 44, 41: 45, 40: 46, 39: 47, 38: 48, 37: 49, 36: 50, 35: 51, 34: 52, 33: 53, 32: 54, 31: 55, 30: 56, 29: 57, 28: 58, 27: 59, 26: 60, 25: 61, 24: 62, 84: 63, 23: 64, 22: 65, 21: 66, 20: 67, 19: 68, 18: 69, 17: 70, 16: 71, 15: 72, 14: 73, 13: 74, 12: 75, 11: 76, 10: 77, 9: 78, 8: 79, 7: 80, 6: 81, 5: 82, 4: 83, 3: 84, 2: 85, 1: 86, 0: 87, 83: 88, nan: nan}
# Generate a matrix of ordered cluster labels from the clustering model
cluster_matrix = np.zeros_like(X)
for i in subjects:
for j in range(n_times):
x = X[i, j]
if not np.isnan(x):
label = model.predict(x.reshape(-1, 1))[0]
cluster_matrix[i, j] = label_dict[label]
else:
cluster_matrix[i, j] = np.nan
print(cluster_matrix)
[[ 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19.
20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37.
38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55.
56. 57. 58. 59. 60. 61. 62. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
[ 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.
21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38.
39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.
57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.
18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71.
72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88.]
[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan]
[ 1. 3. 4. 5. 6. 7. 8. 9. 11. 12. 13. 14. 15. 16. 17. 18. 19. 20.
21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36. 37. 38.
39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54. 55. 56.
57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74.
75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan nan nan]
[ 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17. 18.
19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35. 36.
37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53. 54.
55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71. 72.
73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. 88. nan]
[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15. 16. 17.
18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31. 32. 33. 34. 35.
36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47. 48. 49. 50. 51. 52. 53.
54. 55. 56. 57. 58. 59. 60. 61. 62. 63. 64. 65. 66. 67. 68. 69. 70. 71.
72. 73. 74. 75. 76. 77. 78. 79. 80. 81. 82. 83. 84. 85. 86. 87. nan]]
# Here's a function that does the same thing
def fit_meanshift_clusters(dfx: pd.DataFrame, bandwidth=.25) -> np.ndarray:
'''Cluster the step times using the MeanShift algorithm.
Returns participant × step cluster matrix'''
# Step times as column vector, excluding NaNs
X = dfx.values
step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
# Fit model
model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
# Cluster labels are not in order
# To reorder them, we generate a dict that maps
# the original labels to new, ordered ones.
labels = model.predict(model.cluster_centers_)
ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
label_dict = dict(zip(ordered_labels, labels))
label_dict[np.nan] = np.nan
# Generate a matrix of ordered cluster labels from the clustering model
cluster_matrix = np.zeros_like(X)
n_subjects, n_times = X.shape
for i in range(n_subjects):
for j in range(n_times):
x = X[i, j]
if not np.isnan(x):
label = model.predict(x.reshape(-1, 1))[0]
cluster_matrix[i, j] = label_dict[label]
else:
cluster_matrix[i, j] = np.nan
return cluster_matrix
raster_plot(cluster_matrix, dfx)
Better. Steps that occur at around the same time are now part of the same cluster.
Finally, we convert the step times and cluster labels back to format, and save them to file.
%mkdir -p data/clustered
step_df = dfx.reset_index().melt(id_vars='subject', value_name='time')
cluster_dfx = pd.DataFrame(cluster_matrix, index=dfx.index, columns=dfx.columns)
cluster_df = (cluster_dfx.reset_index()
.melt(id_vars='subject')
.rename({'value':'cluster'}, axis=1)
.sort_values(['step', 'subject']))
result = pd.merge(step_df, cluster_df,
on=['subject', 'step'], how='left')
result.to_csv('data/clustered/group%i_block%i.csv' % (group, block),
index=False)
result.head()
Edge Cases #
Although it works very well for the vast majority of our data (27 groups, two blocks each), this approach isn’t perfect. There are still some cases that aren’t clustered as we would expect. Here’s the worst example.
group, block = 2, 2
dfx = select_wide_data(all_data, group, block)
cluster_matrix = fit_meanshift_clusters(dfx)
raster_plot(cluster_matrix, dfx)
Cluster 25 contains steps that should really be split across two clusters, clusters 50 and 51 should probably be a single cluster. There are a few other issues, for example around cluster 68. Several of these issues arise because the clustering algorithm isn’t constrained to have only one step per person per cluster. There are a few approaches we could take to deal with these problems.
Exclude them #
For now, we limit our analyses to clusters of 7 steps. This is certainly the most reliable solution, and does not lead to any further issues in our analyses.
Constrain the clustering algorithm #
In principle, we could modify the code for
sklearn.cluster.MeanShift
to add the constraints we need.
This is likely to be difficult though.
Postprocess the result #
Alternatively, we could do the clustering as normal, but add some additional steps at the end to deal with abnormal clusters when they come up.
Manual fix #
Finally, we could just manually adjust the cluster result to deal with these very occasional edge cases.
This is probably the most time-effective way to actually fix the problem.
Fixing the abberant clusters might be quite time-consuming in a spreadsheet program,
but could be made much easier by using
matplotlib
’s advanced features
to make the raster plots above interactive,
providing a GUI for modifying the cluster labels.
This is quite similar to some of the functionality
in the excellent mne
package for M/EEG analyses.
Loop It #
Finally, the code below runs the clustering algorithm for every group, produces plots, and merges and saves the output.
%mkdir -p figures
%mkdir -p figures/clustering
# Duplicates of the functions defined above
def select_wide_data(all_data: pd.DataFrame,
group: int, block: int) -> pd.DataFrame:
'''Get wide data for this group, in this block'''
mask = (all_data['group']==group) & (all_data['block']==block)
df = all_data[mask] # Long format
dfx = df.pivot_table(index='subject', columns='step', values='time') # Wide format
return dfx
def raster_plot(cluster_matrix: np.ndarray, dfx: pd.DataFrame, text=True):
'''Produce a raster plot, showing the time of each participants' steps
and the cluster they've been assigned to.
'''
fig, axes = plt.subplots(4, 1, figsize=(18, 12))
for i, start in enumerate([0, 12, 24, 36]):
raster_plot_row(cluster_matrix, dfx, start=start, stop=start+12, ax=axes[i], text=text)
plt.tight_layout()
plt.ylabel('Subject')
plt.xlabel('Time (s)')
def raster_plot_row(cluster_matrix, dfx, start=0, stop=90, text=True, ax=None):
'''Plot a single row of our raster plot, in the time window provided'''
if ax is None:
plt.figure(figsize=(18, 3))
plt.ylabel('Subject')
plt.xlabel('Time (s)')
else:
plt.sca(ax)
pal = sns.palettes.color_palette(n_colors=7)
from itertools import cycle
pal = cycle(pal)
X = dfx.values
subjects = dfx.index.values
subject_ix = np.arange(len(subjects))
steps = np.unique(cluster_matrix)
steps = steps[~np.isnan(steps)]
for step in steps:
color = next(pal)
for sx, subject in enumerate(subjects):
cm = cluster_matrix[sx]
t = X[sx, cm==step] # subject cluster times
t = t[(t >= start) & (t <= stop)]
n = len(t)
if n > 0:
if text:
for v in t:
plt.text(v, sx, int(step), color=color)
else:
plt.plot(t, np.repeat(sx, n), 'o', color=color)
plt.ylim(-1, len(subjects)+1)
plt.yticks(subject_ix+.5, subjects)
plt.xlim(start, stop)
def fit_meanshift_clusters(dfx: pd.DataFrame, bandwidth=.25) -> np.ndarray:
'''Cluster the step times using the MeanShift algorithm.
Returns participant × step cluster matrix'''
# Step times as column vector, excluding NaNs
X = dfx.values
step_times = X[~np.isnan(X)].flatten().reshape(-1, 1)
# Fit model
model = skcluster.MeanShift(bandwidth=.25).fit(step_times)
# Cluster labels are not in order
# To reorder them, we generate a dict that maps
# the original labels to new, ordered ones.
labels = model.predict(model.cluster_centers_)
ordered_labels = model.predict(np.sort(model.cluster_centers_, 0))
label_dict = dict(zip(ordered_labels, labels))
label_dict[np.nan] = np.nan
# Generate a matrix of ordered cluster labels from the clustering model
cluster_matrix = np.zeros_like(X)
n_subjects, n_times = X.shape
for i in range(n_subjects):
for j in range(n_times):
x = X[i, j]
if not np.isnan(x):
label = model.predict(x.reshape(-1, 1))[0]
cluster_matrix[i, j] = label_dict[label]
else:
cluster_matrix[i, j] = np.nan
return cluster_matrix
def do_block(group, block, plot=True):
'''Put the whole pipeline together'''
print('Group %i - Block %i' % (group, block))
dfx = select_wide_data(all_data, group, block)
cluster_matrix = fit_meanshift_clusters(dfx)
raster_plot(cluster_matrix, dfx)
# Initial order
n_subjects, n_times = dfx.shape
steps = np.arange(n_times)
subjects = np.arange(n_subjects)
cluster_ids = dfx.columns.values
initial_cluster_matrix = np.repeat(cluster_ids, n_subjects).reshape((n_times, n_subjects)).T
# Plot and save
if plot:
raster_plot(initial_cluster_matrix, dfx)
plt.suptitle('Group %i - Block %i - Initial Labels' % (group, block))
plt.savefig('figures/clustering/g%i_b%i_initial.png' % (group, block))
plt.close()
# Mean Shift Clustering
cluster_matrix = fit_meanshift_clusters(dfx)
if plot:
raster_plot(cluster_matrix, dfx)
plt.suptitle('Group %i - Block %i - MeanShift Clustered Labels' % (group, block))
plt.savefig('figures/clustering/g%i_b%i_meanshift.png' % (group, block))
plt.close()
## Export
step_df = dfx.reset_index().melt(id_vars='subject', value_name='time')
cluster_dfx = pd.DataFrame(cluster_matrix, index=dfx.index, columns=dfx.columns)
cluster_df = (cluster_dfx.reset_index()
.melt(id_vars='subject')
.rename({'value':'cluster'}, axis=1)
.sort_values(['step', 'subject']))
result = pd.merge(step_df, cluster_df,
on=['subject', 'step'], how='left')
result['group'] = group
result['block'] = block
return result
results = []
for (group, block), _ in all_data.groupby(['group', 'block']):
result = do_block(group, block)
results.append(result)
Group 1 - Block 1
Group 1 - Block 2
Group 2 - Block 1
...[and so on]...
final_results = pd.concat(results)
final_results.head()
subject | step | time | cluster | group | block | |
---|---|---|---|---|---|---|
0 | RA | 1 | 1.534 | 2.0 | 1 | 1 |
1 | S1 | 1 | 2.229 | 3.0 | 1 | 1 |
2 | S2 | 1 | 0.751 | 0.0 | 1 | 1 |
3 | S3 | 1 | 1.120 | 1.0 | 1 | 1 |
4 | S4 | 1 | 1.293 | 1.0 | 1 | 1 |
final_results.to_csv('data/clustered.csv', index=False)