You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mmd.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import concurrent.futures
  2. from functools import partial
  3. import networkx as nx
  4. import numpy as np
  5. from scipy.linalg import toeplitz
  6. import pyemd
  7. def emd(x, y, distance_scaling=1.0):
  8. support_size = max(len(x), len(y))
  9. d_mat = toeplitz(range(support_size)).astype(np.float)
  10. distance_mat = d_mat / distance_scaling
  11. # convert histogram values x and y to float, and make them equal len
  12. x = x.astype(np.float)
  13. y = y.astype(np.float)
  14. if len(x) < len(y):
  15. x = np.hstack((x, [0.0] * (support_size - len(x))))
  16. elif len(y) < len(x):
  17. y = np.hstack((y, [0.0] * (support_size - len(y))))
  18. emd = pyemd.emd(x, y, distance_mat)
  19. return emd
  20. def l2(x, y):
  21. dist = np.linalg.norm(x - y, 2)
  22. return dist
  23. def gaussian_emd(x, y, sigma=1.0, distance_scaling=1.0):
  24. ''' Gaussian kernel with squared distance in exponential term replaced by EMD
  25. Args:
  26. x, y: 1D pmf of two distributions with the same support
  27. sigma: standard deviation
  28. '''
  29. support_size = max(len(x), len(y))
  30. d_mat = toeplitz(range(support_size)).astype(np.float)
  31. distance_mat = d_mat / distance_scaling
  32. # convert histogram values x and y to float, and make them equal len
  33. x = x.astype(np.float)
  34. y = y.astype(np.float)
  35. if len(x) < len(y):
  36. x = np.hstack((x, [0.0] * (support_size - len(x))))
  37. elif len(y) < len(x):
  38. y = np.hstack((y, [0.0] * (support_size - len(y))))
  39. emd = pyemd.emd(x, y, distance_mat)
  40. return np.exp(-emd * emd / (2 * sigma * sigma))
  41. def gaussian(x, y, sigma=1.0):
  42. dist = np.linalg.norm(x - y, 2)
  43. return np.exp(-dist * dist / (2 * sigma * sigma))
  44. def kernel_parallel_unpacked(x, samples2, kernel):
  45. d = 0
  46. for s2 in samples2:
  47. d += kernel(x, s2)
  48. return d
  49. def kernel_parallel_worker(t):
  50. return kernel_parallel_unpacked(*t)
  51. def disc(samples1, samples2, kernel, is_parallel=True, *args, **kwargs):
  52. ''' Discrepancy between 2 samples
  53. '''
  54. d = 0
  55. if not is_parallel:
  56. for s1 in samples1:
  57. for s2 in samples2:
  58. d += kernel(s1, s2, *args, **kwargs)
  59. else:
  60. with concurrent.futures.ProcessPoolExecutor() as executor:
  61. for dist in executor.map(kernel_parallel_worker,
  62. [(s1, samples2, partial(kernel, *args, **kwargs)) for s1 in samples1]):
  63. d += dist
  64. d /= len(samples1) * len(samples2)
  65. return d
  66. def compute_mmd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
  67. ''' MMD between two samples
  68. '''
  69. # normalize histograms into pmf
  70. if is_hist:
  71. samples1 = [s1 / np.sum(s1) for s1 in samples1]
  72. samples2 = [s2 / np.sum(s2) for s2 in samples2]
  73. # print('===============================')
  74. # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs))
  75. # print('--------------------------')
  76. # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs))
  77. # print('--------------------------')
  78. # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs))
  79. # print('===============================')
  80. return disc(samples1, samples1, kernel, *args, **kwargs) + \
  81. disc(samples2, samples2, kernel, *args, **kwargs) - \
  82. 2 * disc(samples1, samples2, kernel, *args, **kwargs)
  83. def compute_emd(samples1, samples2, kernel, is_hist=True, *args, **kwargs):
  84. ''' EMD between average of two samples
  85. '''
  86. # normalize histograms into pmf
  87. if is_hist:
  88. samples1 = [np.mean(samples1)]
  89. samples2 = [np.mean(samples2)]
  90. # print('===============================')
  91. # print('s1: ', disc(samples1, samples1, kernel, *args, **kwargs))
  92. # print('--------------------------')
  93. # print('s2: ', disc(samples2, samples2, kernel, *args, **kwargs))
  94. # print('--------------------------')
  95. # print('cross: ', disc(samples1, samples2, kernel, *args, **kwargs))
  96. # print('===============================')
  97. return disc(samples1, samples2, kernel, *args, **kwargs),[samples1[0],samples2[0]]
  98. def test():
  99. s1 = np.array([0.2, 0.8])
  100. s2 = np.array([0.3, 0.7])
  101. samples1 = [s1, s2]
  102. s3 = np.array([0.25, 0.75])
  103. s4 = np.array([0.35, 0.65])
  104. samples2 = [s3, s4]
  105. s5 = np.array([0.8, 0.2])
  106. s6 = np.array([0.7, 0.3])
  107. samples3 = [s5, s6]
  108. print('between samples1 and samples2: ', compute_mmd(samples1, samples2, kernel=gaussian_emd,
  109. is_parallel=False, sigma=1.0))
  110. print('between samples1 and samples3: ', compute_mmd(samples1, samples3, kernel=gaussian_emd,
  111. is_parallel=False, sigma=1.0))
  112. if __name__ == '__main__':
  113. test()