You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_average.py 4.3KB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. import itertools
  2. import numpy as np
  3. from pprint import pprint
  4. domains = ["w", "d", "a"]
  5. pairs = [[f"{d1}2{d2}" for d1 in domains] for d2 in domains]
  6. pairs = list(itertools.chain.from_iterable(pairs))
  7. pairs.sort()
  8. # print(pairs)
  9. ROUND_FACTOR = 3
  10. all_accs_list = {}
  11. files = [
  12. "all.txt",
  13. "all2.txt",
  14. "all3.txt",
  15. # "all_step_scheduler.txt",
  16. # "all_source_trained_1.txt",
  17. # "all_source_trained_2_specific_hp.txt",
  18. ]
  19. for file in files:
  20. with open(file) as f:
  21. name = None
  22. for line in f:
  23. splitted = line.split(" ")
  24. if splitted[0] == "##": # e.g. ## DANN
  25. name = splitted[1].strip()
  26. splitted = line.split(",") # a2w, acc1, acc2,
  27. if splitted[0] in pairs:
  28. pair = splitted[0]
  29. name = name.lower()
  30. acc = float(splitted[2].strip())
  31. if name not in all_accs_list:
  32. all_accs_list[name] = {p:[] for p in pairs}
  33. all_accs_list[name][pair].append(acc)
  34. # all_accs_list format: {'dann': {'a2w': [acc1, acc2, acc3]}}
  35. acc_means_by_model_name = {}
  36. vars_by_model_name = {}
  37. for name, pair_list in all_accs_list.items():
  38. accs = {p:[] for p in pairs}
  39. vars = {p:[] for p in pairs}
  40. for pair, acc_list in pair_list.items():
  41. if len(acc_list) > 0:
  42. ## Calculate average and round
  43. accs[pair] = round(100 * sum(acc_list) / len(acc_list), ROUND_FACTOR)
  44. vars[pair] = round(np.var(acc_list) * 100, ROUND_FACTOR)
  45. print(vars[pair], "|||", acc_list)
  46. acc_means_by_model_name[name] = accs
  47. vars_by_model_name[name] = vars
  48. # for name, acc_list in acc_means_by_model_name.items():
  49. # pprint(name)
  50. # pprint(all_accs_list)
  51. # pprint(acc_means_by_model_name)
  52. # pprint(vars_by_model_name)
  53. print()
  54. latex_table = ""
  55. header = [pair for pair in pairs if pair[0] != pair[-1]]
  56. table = []
  57. var_table = []
  58. for name, acc_list in acc_means_by_model_name.items():
  59. if "target" in name:
  60. continue
  61. var_list = vars_by_model_name[name]
  62. valid_accs = []
  63. table_row = []
  64. var_table_row = []
  65. for pair in pairs:
  66. acc = acc_list[pair]
  67. var = var_list[pair]
  68. if pair[0] != pair[-1]: # exclude w2w, ...
  69. table_row.append(acc)
  70. var_table_row.append(var)
  71. if acc != None:
  72. valid_accs.append(acc)
  73. acc_average = round(sum(valid_accs) / len(header), ROUND_FACTOR)
  74. table_row.append(acc_average)
  75. table.append(table_row)
  76. var =round(np.var(valid_accs), ROUND_FACTOR)
  77. print(var, "~~~~~~", valid_accs)
  78. var_table_row.append(var)
  79. var_table.append(var_table_row)
  80. t = np.array(table)
  81. t[t==None] = np.nan
  82. # pprint(t)
  83. col_max = t.max(axis=0)
  84. pprint(table)
  85. latex_table = ""
  86. header = [pair for pair in pairs if pair[0] != pair[-1]]
  87. name_map = {"base_source": "Source-Only"}
  88. j = 0
  89. for name, acc_list in acc_means_by_model_name.items():
  90. if "target" in name:
  91. continue
  92. latex_name = name
  93. if name in name_map:
  94. latex_name= name_map[name]
  95. latex_row = f"{latex_name.replace('_','-').upper()} &"
  96. acc_sum = 0
  97. for i, acc in enumerate(table[j]):
  98. if i == len(table[j]) - 1:
  99. acc_str = f"${acc}$"
  100. else:
  101. acc_str = f"${acc} \pm {var_table[j][i]}$"
  102. if acc == col_max[i]:
  103. latex_row += f" \\underline{{{acc_str}}} &"
  104. else:
  105. latex_row += f" {acc_str} &"
  106. latex_row = f"{latex_row[:-1]} \\\\ \hline"
  107. latex_table += f"{latex_row}\n"
  108. j += 1
  109. print(*header, sep=" & ")
  110. print(latex_table)
  111. data = np.array(table)
  112. legend = [key for key in acc_means_by_model_name.keys()]
  113. labels = [*header, "avg"]
  114. data = np.array([[71.75, 75.94 ,67.38, 90.99, 68.91, 96.67, 78.61], [64.0, 66.67, 37.32, 94.97, 45.74, 98.5, 67.87]])
  115. legend = ["CDAN", "Source-only"]
  116. labels = [*header, "avg"]
  117. import matplotlib.pyplot as plt
  118. # Assume your matrix is called 'data'
  119. n, m = data.shape
  120. # Create an array of x-coordinates for the bars
  121. x = np.arange(m)
  122. # Plot the bars for each row side by side
  123. for i in range(n):
  124. row = data[i, :]
  125. plt.bar(x + (i-n/2)*0.3, row, width=0.25, align='center')
  126. # Set x-axis tick labels and labels
  127. plt.xticks(x, labels=labels)
  128. # plt.xlabel("Task")
  129. plt.ylabel("Accuracy")
  130. # Add a legend
  131. plt.legend(legend)
  132. plt.show()