You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_MMD.py 1.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch
  2. import numpy as np
  3. import time
  4. def compute_kernel(x,y):
  5. x_size = x.size(0)
  6. y_size = y.size(0)
  7. dim = x.size(1)
  8. x_tile = x.view(x_size,1,dim)
  9. x_tile = x_tile.repeat(1,y_size,1)
  10. y_tile = y.view(1,y_size,dim)
  11. y_tile = y_tile.repeat(x_size,1,1)
  12. return torch.exp(-torch.mean((x_tile-y_tile)**2,dim = 2)/float(dim))
  13. def compute_mmd(x,y):
  14. x_kernel = compute_kernel(x,x)
  15. # print(x_kernel)
  16. y_kernel = compute_kernel(y,y)
  17. # print(y_kernel)
  18. xy_kernel = compute_kernel(x,y)
  19. # print(xy_kernel)
  20. return torch.mean(x_kernel)+torch.mean(y_kernel)-2*torch.mean(xy_kernel)
  21. # start = time.time()
  22. # x = torch.randn(4000,1).cuda()
  23. # y = torch.randn(4000,1).cuda()
  24. # print(compute_mmd(x,y))
  25. # end = time.time()
  26. # print('GPU time:', end-start)
  27. start = time.time()
  28. torch.manual_seed(123)
  29. batch = 1000
  30. x = torch.randn(batch,1)
  31. y_baseline = torch.randn(batch,1)
  32. y_pred = torch.zeros(batch,1)
  33. print('MMD baseline', compute_mmd(x,y_baseline))
  34. print('MMD prediction', compute_mmd(x,y_pred))
  35. #
  36. # print('before',x)
  37. # print('MMD', compute_mmd(x,y))
  38. # x_idx = np.random.permutation(x.size(0))
  39. # x = x[x_idx,:]
  40. # print('after permutation',x)
  41. # print('MMD', compute_mmd(x,y))
  42. #
  43. #
  44. # end = time.time()
  45. # print('CPU time:', end-start)