You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 870B

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import numpy as np
  2. import torch
  3. def tril_indices(rows, cols, offset=0):
  4. return torch.ones(rows, cols, dtype=torch.uint8).tril(offset).nonzero()
  5. # x = torch.tensor([1., 2., 3., 4., 5., 6.])
  6. # m = torch.zeros((3, 3))
  7. # rows=3
  8. # cols=3
  9. # offset=0
  10. # tril_indices = torch.ones(rows, cols, dtype=torch.uint8).tril(offset).nonzero()
  11. # m[tril_indices[0], tril_indices[1]] = x
  12. # print(m)
  13. def sym(A):
  14. for i in range(A.shape[0]):
  15. for j in range(A.shape[1]):
  16. A[j, i] = A[i, j]
  17. return A
  18. # dm = np.random.rand(6)
  19. # tri = np.zeros((3, 3))
  20. # print(tri)
  21. # print(np.triu_indices(3))
  22. # print(dm)
  23. # tri[np.triu_indices(3)] = dm
  24. # print(tri)
  25. # A = sym(tri)
  26. # print(A)
  27. a = np.zeros((3,3))
  28. a[0,1] = 5
  29. a[2,2] = 6
  30. print(a)
  31. print(np.where(~a.any(axis=1))[0])
  32. missing_node_index = np.where(~a.any(axis=1))[0][0]
  33. print(a.shape)
  34. print(a[missing_node_index, :])