123456789101112131415161718192021222324252627282930313233343536373839404142 |
- import matplotlib.pyplot as plt
- import numpy as np
- import torchvision
- from pytorch_adapt.datasets import get_office31
-
- # root="datasets/pytorch-adapt/"
-
-
- mean = [0.485, 0.456, 0.406]
- std = [0.229, 0.224, 0.225]
-
- inv_normalize = torchvision.transforms.Normalize(
- mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]
- )
-
- idx = 0
-
- def imshow(img, domain, figsize=(10, 6)):
- img = inv_normalize(img)
- npimg = img.numpy()
- plt.figure(figsize=figsize)
- plt.imshow(np.transpose(npimg, (1, 2, 0)))
- plt.axis('off')
-
- plt.savefig(f"office31-{idx}")
- plt.show()
- plt.close("all")
- idx += 1
-
- def imshow_many(datasets, src, target):
- d = datasets["train"]
- for name in ["src_imgs", "target_imgs"]:
- domains = src if name == "src_imgs" else target
- if len(domains) == 0:
- continue
- print(domains)
- imgs = [d[i][name] for i in np.random.choice(len(d), size=16, replace=False)]
- imshow(torchvision.utils.make_grid(imgs))
-
- for src, target in [(["amazon"], ["dslr"]), (["webcam"], [])]:
- datasets = get_office31(src, target,folder=root)
- imshow_many(datasets, src, target)
|