You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

vis_hook.py 1.8KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import matplotlib.pyplot as plt
  2. import pandas as pd
  3. import seaborn as sns
  4. import torch
  5. import umap
  6. from datetime import datetime
  7. from pytorch_adapt.adapters import DANN
  8. from pytorch_adapt.containers import Models, Optimizers, LRSchedulers
  9. from pytorch_adapt.datasets import DataloaderCreator, get_office31
  10. from pytorch_adapt.frameworks.ignite import CheckpointFnCreator, Ignite
  11. from pytorch_adapt.models import Discriminator, office31C, office31G
  12. from pytorch_adapt.validators import AccuracyValidator, IMValidator, ScoreHistory
  13. class VizHook:
  14. def __init__(self, **kwargs):
  15. self.required_data = ["src_val",
  16. "target_val", "target_val_with_labels"]
  17. self.kwargs = kwargs
  18. def __call__(self, epoch, src_val, target_val, target_val_with_labels, **kwargs):
  19. accuracy_validator = AccuracyValidator()
  20. accuracy = accuracy_validator.compute_score(src_val=src_val)
  21. print("src_val accuracy:", accuracy)
  22. accuracy_validator = AccuracyValidator()
  23. accuracy = accuracy_validator.compute_score(src_val=target_val_with_labels)
  24. print("target_val accuracy:", accuracy)
  25. if epoch >= 2 and epoch % kwargs.get("frequency", 5) != 0:
  26. return
  27. features = [src_val["features"], target_val["features"]]
  28. domain = [src_val["domain"], target_val["domain"]]
  29. features = torch.cat(features, dim=0).cpu().numpy()
  30. domain = torch.cat(domain, dim=0).cpu().numpy()
  31. emb = umap.UMAP().fit_transform(features)
  32. df = pd.DataFrame(emb).assign(domain=domain)
  33. df["domain"] = df["domain"].replace({0: "Source", 1: "Target"})
  34. sns.set_theme(style="white", rc={"figure.figsize": (8, 6)})
  35. sns.scatterplot(data=df, x=0, y=1, hue="domain", s=10)
  36. plt.savefig(f"{self.kwargs['output_dir']}/val_{epoch}.png")
  37. plt.close('all')