You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_time.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import itertools
  2. import numpy as np
  3. from pprint import pprint
  4. domains = ["w", "d", "a"]
  5. pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains]
  6. pairs = list(itertools.chain.from_iterable(pairs))
  7. pairs.sort()
  8. # print(pairs)
  9. ROUND_FACTOR = 2
  10. all_accs_list = {}
  11. files = [
  12. # "all.txt",
  13. # "all2.txt",
  14. # "all3.txt",
  15. # "all_step_scheduler.txt",
  16. # "all_source_trained_1.txt",
  17. "all_source_trained_2_specific_hp.txt",
  18. ]
  19. for file in files:
  20. with open(file) as f:
  21. name = None
  22. for line in f:
  23. splitted = line.split(" ")
  24. if splitted[0] == "##": # e.g. ## DANN
  25. name = splitted[1].strip()
  26. splitted = line.split(",") # a2w, acc1, acc2,
  27. if splitted[0] in pairs:
  28. pair = splitted[0]
  29. name = name.lower()
  30. acc = float(splitted[5].strip()) #/ 60 / 60
  31. if name not in all_accs_list:
  32. all_accs_list[name] = {p:[] for p in pairs}
  33. all_accs_list[name][pair].append(acc)
  34. # all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}}
  35. acc_means_by_model_name = {}
  36. vars_by_model_name = {}
  37. for name, pair_list in all_accs_list.items():
  38. accs = {p:[] for p in pairs}
  39. vars = {p:[] for p in pairs}
  40. for pair, acc_list in pair_list.items():
  41. if len(acc_list) > 0:
  42. ## Calculate average and round
  43. accs[pair] = round(sum(acc_list) / len(acc_list), ROUND_FACTOR)
  44. vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR)
  45. print(vars[pair], "|||", acc_list)
  46. acc_means_by_model_name[name] = accs
  47. vars_by_model_name[name] = vars
  48. # for name, acc_list in acc_means_by_model_name.items():
  49. # pprint(name)
  50. # pprint(all_accs_list)
  51. # pprint(acc_means_by_model_name)
  52. # pprint(vars_by_model_name)
  53. print()
  54. latex_table = ""
  55. header = [pair for pair in pairs if pair[0] != pair[-1]]
  56. table = []
  57. var_table = []
  58. for name, acc_list in acc_means_by_model_name.items():
  59. if "target" in name or "source" in name:
  60. continue
  61. print("~~~~%%%%~~~", name)
  62. var_list = vars_by_model_name[name]
  63. valid_accs = []
  64. table_row = []
  65. var_table_row = []
  66. for pair in pairs:
  67. acc = acc_list[pair]
  68. var = var_list[pair]
  69. if pair[0] != pair[-1]: # exclude w2w, ...
  70. table_row.append(acc)
  71. var_table_row.append(var)
  72. if acc != None:
  73. valid_accs.append(acc)
  74. acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR)
  75. table_row.append(acc_average)
  76. table.append(table_row)
  77. var =round(np.var(valid_accs), ROUND_FACTOR)
  78. print(var, ">>>", valid_accs)
  79. var_table_row.append(var)
  80. var_table.append(var_table_row)
  81. t = np.array(table)
  82. t[t==None] = np.nan
  83. # pprint(t)
  84. col_max = t.min(axis=0)
  85. pprint(table)
  86. latex_table = ""
  87. header = [pair for pair in pairs if pair[0] != pair[-1]]
  88. name_map = {"base_source": "Source-Only"}
  89. j = 0
  90. for name, acc_list in acc_means_by_model_name.items():
  91. if "target" in name or "source" in name:
  92. continue
  93. latex_name = name
  94. if name in name_map:
  95. latex_name= name_map[name]
  96. latex_row = f"{latex_name.replace('_','-').upper()} &"
  97. acc_sum = 0
  98. for i, acc in enumerate(table[j]):
  99. if i == len(table[j]) - 1:
  100. acc_str = f"${acc}$"
  101. else:
  102. acc_str = f"${acc}$"
  103. if acc == col_max[i]:
  104. latex_row += f" \\underline{{{acc_str}}} &"
  105. else:
  106. latex_row += f" {acc_str} &"
  107. latex_row = f"{latex_row[:-1]} \\\\ \hline"
  108. latex_table += f"{latex_row}\n"
  109. j += 1
  110. print(*header, sep=" & ")
  111. print(latex_table)
  112. data = np.array(table)
  113. legend = [key for key in acc_means_by_model_name.keys() if "source" not in key]
  114. labels = [*header, "avg"]
  115. # data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]])
  116. # legend = ["CDAN", "Source-only"]
  117. # labels = [*header, "avg"]
  118. import matplotlib.pyplot as plt
  119. # Assume your matrix is called 'data'
  120. n, m = data.shape
  121. # Create an array of x-coordinates for the bars
  122. x = np.arange(m)
  123. # Plot the bars for each row side by side
  124. for i in range(n):
  125. row = data[i, :]
  126. plt.bar(x + (i-n/2)*0.1, row, width=0.08, align='center')
  127. # Set x-axis tick labels and labels
  128. plt.xticks(x, labels=labels)
  129. # plt.xlabel("Task")
  130. plt.ylabel("Time (s)")
  131. # Add a legend
  132. plt.legend(legend)
  133. plt.show()