You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_vis.py 1.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import torchvision
  4. from pytorch_adapt.datasets import get_office31
  5. # root="datasets/pytorch-adapt/"
  6. mean = [0.485, 0.456, 0.406]
  7. std = [0.229, 0.224, 0.225]
  8. inv_normalize = torchvision.transforms.Normalize(
  9. mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
  10. )
  11. idx = 0
  12. def imshow(img, domain, figsize=(10, 6)):
  13. img = inv_normalize(img)
  14. npimg = img.numpy()
  15. plt.figure(figsize=figsize)
  16. plt.imshow(np.transpose(npimg, (1, 2, 0)))
  17. plt.axis('off')
  18. plt.savefig(f"office31-{idx}")
  19. plt.show()
  20. plt.close("all")
  21. idx += 1
  22. def imshow_many(datasets, src, target):
  23. d = datasets["train"]
  24. for name in ["src_imgs", "target_imgs"]:
  25. domains = src if name == "src_imgs" else target
  26. if len(domains) == 0:
  27. continue
  28. print(domains)
  29. imgs = [d[i][name] for i in np.random.choice(len(d), size=16, replace=False)]
  30. imshow(torchvision.utils.make_grid(imgs))
  31. for src, target in [(["amazon"], ["dslr"]), (["webcam"], [])]:
  32. datasets = get_office31(src, target,folder=root)
  33. imshow_many(datasets, src, target)