You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

cluster_frames.py 3.1KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from sklearn.cluster import SpectralClustering, AgglomerativeClustering, KMeans
  2. import numpy as np
  3. from sklearn.metrics.pairwise import cosine_similarity
  4. class VideoClusterer:
  5. def __init__(self, clustering_method='uniform', n_clusters=2, similarity_threshold=0.8):
  6. self.n_clusters = n_clusters
  7. self.similarity_threshold = similarity_threshold
  8. self.clustering_method = clustering_method
  9. # Decide on the clustering method to use
  10. if clustering_method == 'uniform':
  11. self.clusterer = self.uniform_clustering
  12. elif clustering_method == 'spectral':
  13. self.clusterer = SpectralClustering(n_clusters=n_clusters, affinity='precomputed')
  14. elif clustering_method == 'agglomerative':
  15. self.clusterer = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
  16. elif clustering_method == 'kmeans':
  17. self.clusterer = KMeans(n_clusters=n_clusters, n_init=1)
  18. else:
  19. raise ValueError(f"Invalid clustering method: {clustering_method}")
  20. def uniform_clustering(self, features):
  21. n = len(features)
  22. clusters = []
  23. cluster_size = n // self.n_clusters
  24. remainder = n % self.n_clusters
  25. start = 0
  26. for i in range(self.n_clusters):
  27. if i < remainder:
  28. end = start + cluster_size + 1
  29. else:
  30. end = start + cluster_size
  31. clusters.append(list(range(start, end)))
  32. start = end
  33. return clusters
  34. def detect_outliers(self, features):
  35. dot_product_matrix = features.dot(features.T)
  36. average_similarities = np.mean(dot_product_matrix, axis=0)
  37. # Adding a small constant epsilon to the standard deviation to prevent division by zero
  38. epsilon = 1e-8
  39. normal = (average_similarities - np.mean(average_similarities)) / (np.std(average_similarities) + epsilon)
  40. outlier_mask = np.logical_or(normal > 1.5, normal < -1.5)
  41. return outlier_mask
  42. def get_clusters(self, features):
  43. features = features.cpu().numpy()
  44. if self.clustering_method == 'uniform':
  45. return self.uniform_clustering(features)
  46. else:
  47. # For non-uniform methods, follow the original procedure
  48. outlier_mask = self.detect_outliers(features)
  49. if np.sum(~outlier_mask) > self.n_clusters:
  50. features = features[~outlier_mask]
  51. # Compute cosine similarity matrix for spectral clustering
  52. if self.clustering_method == 'spectral':
  53. similarity_matrix = cosine_similarity(features)
  54. labels = self.clusterer.fit_predict(similarity_matrix)
  55. else:
  56. # For agglomerative, k-means, and other clustering methods that don't require a precomputed matrix
  57. labels = self.clusterer.fit_predict(features)
  58. # Organize frames into clusters based on labels
  59. clusters = [[] for _ in range(self.n_clusters)]
  60. for idx, label in enumerate(labels):
  61. clusters[label].append(idx)
  62. return clusters