You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

autoencoder.ipynb 44KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": null,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import os\n",
  10. "import time\n",
  11. "import argparse\n",
  12. "\n",
  13. "import torch\n",
  14. "import torch.nn as nn\n",
  15. "import torch.multiprocessing as mp\n",
  16. "\n",
  17. "import utils\n",
  18. "import models.builer as builder\n",
  19. "import dataloader\n",
  20. "\n",
  21. "def get_args():\n",
  22. " # parse the args\n",
  23. " print('=> parse the args ...')\n",
  24. " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
  25. " parser.add_argument('--arch', default='vgg16', type=str, \n",
  26. " help='backbone architechture')\n",
  27. " parser.add_argument('--train_list', type=str)\n",
  28. " parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', \n",
  29. " help='number of data loading workers (default: 0)')\n",
  30. " parser.add_argument('--epochs', default=20, type=int, metavar='N',\n",
  31. " help='number of total epochs to run')\n",
  32. " parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n",
  33. " help='manual epoch number (useful on restarts)')\n",
  34. " parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N',\n",
  35. " help='mini-batch size (default: 256), this is the total '\n",
  36. " 'batch size of all GPUs on the current node when '\n",
  37. " 'using Data Parallel or Distributed Data Parallel')\n",
  38. "\n",
  39. " parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n",
  40. " metavar='LR', help='initial learning rate', dest='lr')\n",
  41. " parser.add_argument('--momentum', default=0.9, type=float, metavar='M', \n",
  42. " help='momentum')\n",
  43. " parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,\n",
  44. " metavar='W', help='weight decay (default: 1e-4)',\n",
  45. " dest='weight_decay')\n",
  46. "\n",
  47. " parser.add_argument('-p', '--print-freq', default=100, type=int,\n",
  48. " metavar='N', help='print frequency (default: 10)')\n",
  49. "\n",
  50. " parser.add_argument('--pth-save-fold', default='results/tmp', type=str,\n",
  51. " help='The folder to save pths')\n",
  52. " parser.add_argument('--pth-save-epoch', default=1, type=int,\n",
  53. " help='The epoch to save pth')\n",
  54. " parser.add_argument('--parallel', type=int, default=1, \n",
  55. " help='1 for parallel, 0 for non-parallel')\n",
  56. " parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,\n",
  57. " help='url used to set up distributed training') \n",
  58. "\n",
  59. " args = parser.parse_args()\n",
  60. "\n",
  61. " return args\n",
  62. "\n",
  63. "def main(args):\n",
  64. " print('=> torch version : {}'.format(torch.__version__))\n",
  65. " ngpus_per_node = torch.cuda.device_count()\n",
  66. " print('=> ngpus : {}'.format(ngpus_per_node))\n",
  67. "\n",
  68. " if args.parallel == 1: \n",
  69. " # single machine multi card \n",
  70. " args.gpus = ngpus_per_node\n",
  71. " args.nodes = 1\n",
  72. " args.nr = 0\n",
  73. " args.world_size = args.gpus * args.nodes\n",
  74. "\n",
  75. " args.workers = int(args.workers / args.world_size)\n",
  76. " args.batch_size = int(args.batch_size / args.world_size)\n",
  77. " mp.spawn(main_worker, nprocs=args.gpus, args=(args,))\n",
  78. " else:\n",
  79. " args.world_size = 1\n",
  80. " main_worker(ngpus_per_node, args)\n",
  81. " \n",
  82. "def main_worker(gpu, args):\n",
  83. " utils.init_seeds(1 + gpu, cuda_deterministic=False)\n",
  84. " if args.parallel == 1:\n",
  85. " args.gpu = gpu\n",
  86. " args.rank = args.nr * args.gpus + args.gpu\n",
  87. "\n",
  88. " torch.cuda.set_device(gpu)\n",
  89. " torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank) \n",
  90. " \n",
  91. " else:\n",
  92. " # two dummy variable, not real\n",
  93. " args.rank = 0\n",
  94. " args.gpus = 1 \n",
  95. " if args.rank == 0:\n",
  96. " print('=> modeling the network {} ...'.format(args.arch))\n",
  97. " model = builder.BuildAutoEncoder(args) \n",
  98. " if args.rank == 0: \n",
  99. " total_params = sum(p.numel() for p in model.parameters())\n",
  100. " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
  101. " \n",
  102. " if args.rank == 0:\n",
  103. " print('=> building the oprimizer ...')\n",
  104. " # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), args.lr,)\n",
  105. " optimizer = torch.optim.SGD(\n",
  106. " filter(lambda p: p.requires_grad, model.parameters()),\n",
  107. " args.lr,\n",
  108. " momentum=args.momentum,\n",
  109. " weight_decay=args.weight_decay) \n",
  110. " if args.rank == 0:\n",
  111. " print('=> building the dataloader ...')\n",
  112. " train_loader = dataloader.train_loader(args)\n",
  113. "\n",
  114. " if args.rank == 0:\n",
  115. " print('=> building the criterion ...')\n",
  116. " criterion = nn.MSELoss()\n",
  117. "\n",
  118. " global iters\n",
  119. " iters = 0\n",
  120. "\n",
  121. " model.train()\n",
  122. " if args.rank == 0:\n",
  123. " print('=> starting training engine ...')\n",
  124. " for epoch in range(args.start_epoch, args.epochs):\n",
  125. " \n",
  126. " global current_lr\n",
  127. " current_lr = utils.adjust_learning_rate_cosine(optimizer, epoch, args)\n",
  128. "\n",
  129. " train_loader.sampler.set_epoch(epoch)\n",
  130. " \n",
  131. " # train for one epoch\n",
  132. " do_train(train_loader, model, criterion, optimizer, epoch, args)\n",
  133. "\n",
  134. " # save pth\n",
  135. " if epoch % args.pth_save_epoch == 0 and args.rank == 0:\n",
  136. " state_dict = model.state_dict()\n",
  137. "\n",
  138. " torch.save(\n",
  139. " {\n",
  140. " 'epoch': epoch + 1,\n",
  141. " 'arch': args.arch,\n",
  142. " 'state_dict': state_dict,\n",
  143. " 'optimizer' : optimizer.state_dict(),\n",
  144. " },\n",
  145. " os.path.join(args.pth_save_fold, '{}.pth'.format(str(epoch).zfill(3)))\n",
  146. " )\n",
  147. " \n",
  148. " print(' : save pth for epoch {}'.format(epoch + 1))\n",
  149. "\n",
  150. "\n",
  151. "def do_train(train_loader, model, criterion, optimizer, epoch, args):\n",
  152. " batch_time = utils.AverageMeter('Time', ':6.2f')\n",
  153. " data_time = utils.AverageMeter('Data', ':2.2f')\n",
  154. " losses = utils.AverageMeter('Loss', ':.4f')\n",
  155. " learning_rate = utils.AverageMeter('LR', ':.4f')\n",
  156. " \n",
  157. " progress = utils.ProgressMeter(\n",
  158. " len(train_loader),\n",
  159. " [batch_time, data_time, losses, learning_rate],\n",
  160. " prefix=\"Epoch: [{}]\".format(epoch+1))\n",
  161. " end = time.time()\n",
  162. "\n",
  163. " # update lr\n",
  164. " learning_rate.update(current_lr)\n",
  165. "\n",
  166. " for i, (input, target) in enumerate(train_loader):\n",
  167. " # measure data loading time\n",
  168. " data_time.update(time.time() - end)\n",
  169. " global iters\n",
  170. " iters += 1\n",
  171. " \n",
  172. " input = input.cuda(non_blocking=True)\n",
  173. " target = target.cuda(non_blocking=True)\n",
  174. "\n",
  175. " output = model(input)\n",
  176. "\n",
  177. " loss = criterion(output, target)\n",
  178. "\n",
  179. " # compute gradient and do solver step\n",
  180. " optimizer.zero_grad()\n",
  181. " # backward\n",
  182. " loss.backward()\n",
  183. " # update weights\n",
  184. " optimizer.step()\n",
  185. "\n",
  186. " # syn for logging\n",
  187. " torch.cuda.synchronize()\n",
  188. "\n",
  189. " # record loss\n",
  190. " losses.update(loss.item(), input.size(0)) \n",
  191. "\n",
  192. " # measure elapsed time\n",
  193. " if args.rank == 0:\n",
  194. " batch_time.update(time.time() - end) \n",
  195. " end = time.time() \n",
  196. "\n",
  197. " if i % args.print_freq == 0 and args.rank == 0:\n",
  198. " progress.display(i)\n",
  199. "\n",
  200. "if __name__ == '__main__':\n",
  201. "\n",
  202. " args = get_args()\n",
  203. "\n",
  204. " main(args)\n"
  205. ]
  206. },
  207. {
  208. "cell_type": "code",
  209. "execution_count": null,
  210. "metadata": {},
  211. "outputs": [],
  212. "source": [
  213. "import os\n",
  214. "import sys\n",
  215. "import math\n",
  216. "import random\n",
  217. "import numpy as np\n",
  218. "\n",
  219. "from loguru import logger\n",
  220. "\n",
  221. "import torch\n",
  222. "from torch.backends import cudnn\n",
  223. "\n",
  224. "def init_seeds(seed=0, cuda_deterministic=True):\n",
  225. " random.seed(seed)\n",
  226. " np.random.seed(seed)\n",
  227. " torch.manual_seed(seed)\n",
  228. " torch.cuda.manual_seed(seed)\n",
  229. " os.environ['PYTHONHASHSEED'] = str(seed)\n",
  230. " # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n",
  231. " if cuda_deterministic: # slower, more reproducible\n",
  232. " cudnn.deterministic = True\n",
  233. " cudnn.benchmark = False\n",
  234. " else: # faster, less reproducible\n",
  235. " cudnn.deterministic = False\n",
  236. " cudnn.benchmark = True\n",
  237. "\n",
  238. "class AverageMeter(object):\n",
  239. " \"\"\"Computes and stores the average and current value\"\"\"\n",
  240. "\n",
  241. " def __init__(self, name, fmt=':f'):\n",
  242. " self.name = name\n",
  243. " self.fmt = fmt\n",
  244. " self.reset()\n",
  245. "\n",
  246. " def reset(self):\n",
  247. " self.val = 0\n",
  248. " self.avg = 0\n",
  249. " self.sum = 0\n",
  250. " self.count = 0\n",
  251. "\n",
  252. " def update(self, val, n=1):\n",
  253. " self.val = val\n",
  254. " self.sum += val * n\n",
  255. " self.count += n\n",
  256. " self.avg = self.sum / self.count\n",
  257. "\n",
  258. " def __str__(self):\n",
  259. " fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
  260. " return fmtstr.format(**self.__dict__)\n",
  261. "\n",
  262. "\n",
  263. "class ProgressMeter(object):\n",
  264. " def __init__(self, num_batches, meters, prefix=\"\"):\n",
  265. " self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n",
  266. " self.meters = meters\n",
  267. " self.prefix = prefix\n",
  268. "\n",
  269. " def display(self, batch):\n",
  270. " entries = [self.prefix + self.batch_fmtstr.format(batch)]\n",
  271. " entries += [str(meter) for meter in self.meters]\n",
  272. " logger.info(\"\\t\".join(entries))\n",
  273. "\n",
  274. " def _get_batch_fmtstr(self, num_batches):\n",
  275. " num_digits = len(str(num_batches // 1))\n",
  276. " fmt = '{:' + str(num_digits) + 'd}'\n",
  277. " return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n",
  278. "\n",
  279. "def adjust_learning_rate(optimizer, epoch, args):\n",
  280. " decay = args.lr_drop_ratio if epoch in args.lr_drop_epoch else 1.0\n",
  281. " lr = args.lr * decay\n",
  282. " global current_lr\n",
  283. " current_lr = lr\n",
  284. " for param_group in optimizer.param_groups:\n",
  285. " param_group['lr'] = lr\n",
  286. " args.lr = current_lr\n",
  287. " return current_lr\n",
  288. "\n",
  289. "def adjust_learning_rate_cosine(optimizer, epoch, args):\n",
  290. " \"\"\"cosine learning rate annealing without restart\"\"\"\n",
  291. " lr = args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))\n",
  292. " global current_lr\n",
  293. " current_lr = lr\n",
  294. " for param_group in optimizer.param_groups:\n",
  295. " param_group['lr'] = lr\n",
  296. " return current_lr\n",
  297. "\n",
  298. "def load_dict(resume_path, model):\n",
  299. " if os.path.isfile(resume_path):\n",
  300. " checkpoint = torch.load(resume_path)\n",
  301. " model_dict = model.state_dict()\n",
  302. " model_dict.update(checkpoint['state_dict'])\n",
  303. " model.load_state_dict(model_dict)\n",
  304. " # delete to release more space\n",
  305. " del checkpoint\n",
  306. " else:\n",
  307. " sys.exit(\"=> No checkpoint found at '{}'\".format(resume_path))\n",
  308. " return model"
  309. ]
  310. },
  311. {
  312. "cell_type": "code",
  313. "execution_count": null,
  314. "metadata": {},
  315. "outputs": [],
  316. "source": [
  317. "#!/usr/bin/env python\n",
  318. "\n",
  319. "import os\n",
  320. "import time\n",
  321. "import argparse\n",
  322. "\n",
  323. "import matplotlib.pyplot as plt\n",
  324. "\n",
  325. "import torch\n",
  326. "import torch.nn as nn\n",
  327. "\n",
  328. "import utils\n",
  329. "import models.builer as builder\n",
  330. "import dataloader\n",
  331. "\n",
  332. "def get_args():\n",
  333. " # parse the args\n",
  334. " print('=> parse the args ...')\n",
  335. " parser = argparse.ArgumentParser(description='Evaluate for auto encoder')\n",
  336. " parser.add_argument('--arch', default='vgg16', type=str, \n",
  337. " help='backbone architechture')\n",
  338. " parser.add_argument('--val_list', type=str)\n",
  339. " parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', \n",
  340. " help='number of data loading workers (default: 0)')\n",
  341. " parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N',\n",
  342. " help='mini-batch size (default: 256), this is the total '\n",
  343. " 'batch size of all GPUs on the current node when '\n",
  344. " 'using Data Parallel or Distributed Data Parallel')\n",
  345. "\n",
  346. " parser.add_argument('-p', '--print-freq', default=20, type=int,\n",
  347. " metavar='N', help='print frequency (default: 10)')\n",
  348. "\n",
  349. " parser.add_argument('--resume', type=str)\n",
  350. " parser.add_argument('--folder', type=str) \n",
  351. " parser.add_argument('--start_epoch', default=0, type=int) \n",
  352. " parser.add_argument('--epochs', default=100, type=int) \n",
  353. "\n",
  354. " args = parser.parse_args()\n",
  355. "\n",
  356. " args.parallel = 0\n",
  357. "\n",
  358. " return args\n",
  359. "\n",
  360. "def main(args):\n",
  361. " print('=> torch version : {}'.format(torch.__version__))\n",
  362. " ngpus_per_node = torch.cuda.device_count()\n",
  363. " print('=> ngpus : {}'.format(ngpus_per_node))\n",
  364. "\n",
  365. " utils.init_seeds(1, cuda_deterministic=False)\n",
  366. " \n",
  367. " print('=> modeling the network ...')\n",
  368. " model = builder.BuildAutoEncoder(args) \n",
  369. " total_params = sum(p.numel() for p in model.parameters())\n",
  370. " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
  371. " \n",
  372. " print('=> building the dataloader ...')\n",
  373. " val_loader = dataloader.val_loader(args)\n",
  374. "\n",
  375. " print('=> building the criterion ...')\n",
  376. " criterion = nn.MSELoss()\n",
  377. "\n",
  378. " print('=> starting evaluating engine ...')\n",
  379. " if args.folder:\n",
  380. " best_loss = None\n",
  381. " best_epoch = 1\n",
  382. " losses = []\n",
  383. " for epoch in range(args.start_epoch, args.epochs):\n",
  384. " print()\n",
  385. " print(\"Epoch {}\".format(epoch+1))\n",
  386. " resume_path = os.path.join(args.folder, \"%03d.pth\" % epoch)\n",
  387. " print('=> loading pth from {} ...'.format(resume_path))\n",
  388. " utils.load_dict(resume_path, model)\n",
  389. " loss = do_evaluate(val_loader, model, criterion, args)\n",
  390. " print(\"Evaluate loss : {:.4f}\".format(loss))\n",
  391. "\n",
  392. " losses.append(loss)\n",
  393. " if best_loss:\n",
  394. " if loss < best_loss:\n",
  395. " best_loss = loss\n",
  396. " best_epoch = epoch + 1\n",
  397. " else:\n",
  398. " best_loss = loss\n",
  399. " print()\n",
  400. " print(\"Best loss : {:.4f} Appears in {}\".format(best_loss, best_epoch))\n",
  401. "\n",
  402. " max_loss = max(losses)\n",
  403. "\n",
  404. " plt.figure(figsize=(7,7))\n",
  405. "\n",
  406. " plt.xlabel(\"epoch\")\n",
  407. " plt.ylabel(\"loss\")\n",
  408. " plt.xlim((0,args.epochs+1)) \n",
  409. " plt.ylim([0, float('%.1g' % (1.22*max_loss))])\n",
  410. "\n",
  411. " plt.scatter(range(1, args.epochs+1), losses, s=9)\n",
  412. "\n",
  413. " plt.savefig(\"figs/evalall.jpg\")\n",
  414. "\n",
  415. " else:\n",
  416. " print('=> loading pth from {} ...'.format(args.resume))\n",
  417. " utils.load_dict(args.resume, model)\n",
  418. " loss = do_evaluate(val_loader, model, criterion, args)\n",
  419. " print(\"Evaluate loss : {:.4f}\".format(loss))\n",
  420. "\n",
  421. "\n",
  422. "def do_evaluate(val_loader, model, criterion, args):\n",
  423. " batch_time = utils.AverageMeter('Time', ':6.2f')\n",
  424. " data_time = utils.AverageMeter('Data', ':2.2f')\n",
  425. " losses = utils.AverageMeter('Loss', ':.4f')\n",
  426. " \n",
  427. " progress = utils.ProgressMeter(\n",
  428. " len(val_loader),\n",
  429. " [batch_time, data_time, losses],\n",
  430. " prefix=\"Evaluate \")\n",
  431. " end = time.time()\n",
  432. "\n",
  433. " model.eval()\n",
  434. " with torch.no_grad():\n",
  435. " for i, (input, target) in enumerate(val_loader):\n",
  436. " # measure data loading time\n",
  437. " data_time.update(time.time() - end)\n",
  438. " \n",
  439. " input = input.cuda(non_blocking=True)\n",
  440. " target = target.cuda(non_blocking=True)\n",
  441. "\n",
  442. " output = model(input)\n",
  443. "\n",
  444. " loss = criterion(output, target)\n",
  445. "\n",
  446. " # record loss\n",
  447. " losses.update(loss.item(), input.size(0)) \n",
  448. " batch_time.update(time.time() - end) \n",
  449. " end = time.time() \n",
  450. "\n",
  451. " if i % args.print_freq == 0:\n",
  452. " progress.display(i)\n",
  453. " \n",
  454. " return losses.avg\n",
  455. "\n",
  456. "if __name__ == '__main__':\n",
  457. "\n",
  458. " args = get_args()\n",
  459. "\n",
  460. " main(args)\n",
  461. "\n",
  462. "\n"
  463. ]
  464. },
  465. {
  466. "cell_type": "code",
  467. "execution_count": null,
  468. "metadata": {},
  469. "outputs": [],
  470. "source": [
  471. "#!/usr/bin/env python\n",
  472. "\n",
  473. "import random\n",
  474. "\n",
  475. "from PIL import Image, ImageFilter\n",
  476. "\n",
  477. "import torch\n",
  478. "import torch.utils.data as data\n",
  479. "\n",
  480. "from torchvision.transforms import transforms\n",
  481. "\n",
  482. "class ImageDataset(data.Dataset):\n",
  483. " def __init__(self, ann_file, transform=None):\n",
  484. " self.ann_file = ann_file\n",
  485. " self.transform = transform\n",
  486. " self.init()\n",
  487. "\n",
  488. " def init(self):\n",
  489. " \n",
  490. " self.im_names = []\n",
  491. " self.targets = []\n",
  492. " with open(self.ann_file, 'r') as f:\n",
  493. " lines = f.readlines()\n",
  494. " for line in lines:\n",
  495. " data = line.strip().split(' ')\n",
  496. " self.im_names.append(data[0])\n",
  497. " self.targets.append(int(data[1]))\n",
  498. "\n",
  499. " def __getitem__(self, index):\n",
  500. " im_name = self.im_names[index]\n",
  501. " target = self.targets[index]\n",
  502. "\n",
  503. " img = Image.open(im_name).convert('RGB') \n",
  504. " if img is None:\n",
  505. " print(im_name)\n",
  506. " \n",
  507. " img = self.transform(img)\n",
  508. "\n",
  509. " return img, img\n",
  510. "\n",
  511. " def __len__(self):\n",
  512. " return len(self.im_names)\n",
  513. "\n",
  514. "def train_loader(args):\n",
  515. "\n",
  516. " # [NO] do not use normalize here cause it's very hard to converge\n",
  517. " # [NO] do not use colorjitter cause it lead to performance drop in both train set and val set\n",
  518. "\n",
  519. " # [?] guassian blur will lead to a significantly drop in train loss while val loss remain the same\n",
  520. "\n",
  521. " augmentation = [\n",
  522. " transforms.RandomResizedCrop(224, scale=(0.2, 1.)),\n",
  523. " #transforms.RandomGrayscale(p=0.2),\n",
  524. " #transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),\n",
  525. " transforms.RandomHorizontalFlip(),\n",
  526. " transforms.ToTensor(),\n",
  527. " ]\n",
  528. "\n",
  529. " train_trans = transforms.Compose(augmentation)\n",
  530. "\n",
  531. " train_dataset = ImageDataset(args.train_list, transform=train_trans) \n",
  532. " \n",
  533. " if args.parallel == 1:\n",
  534. " train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
  535. " train_dataset,\n",
  536. " rank=args.rank,\n",
  537. " num_replicas=args.world_size,\n",
  538. " shuffle=True) \n",
  539. " else: \n",
  540. " train_sampler = None \n",
  541. "\n",
  542. " train_loader = torch.utils.data.DataLoader(\n",
  543. " train_dataset,\n",
  544. " shuffle=(train_sampler is None),\n",
  545. " batch_size=args.batch_size,\n",
  546. " num_workers=args.workers,\n",
  547. " pin_memory=True,\n",
  548. " sampler=train_sampler,\n",
  549. " drop_last=(train_sampler is None))\n",
  550. "\n",
  551. " return train_loader\n",
  552. "\n",
  553. "def val_loader(args):\n",
  554. "\n",
  555. " val_trans = transforms.Compose([\n",
  556. " transforms.Resize(256), \n",
  557. " transforms.CenterCrop(224),\n",
  558. " transforms.ToTensor()\n",
  559. " ])\n",
  560. "\n",
  561. " val_dataset = ImageDataset(args.val_list, transform=val_trans) \n",
  562. "\n",
  563. " val_loader = torch.utils.data.DataLoader(\n",
  564. " val_dataset,\n",
  565. " shuffle=False,\n",
  566. " batch_size=args.batch_size,\n",
  567. " num_workers=args.workers,\n",
  568. " pin_memory=True)\n",
  569. "\n",
  570. " return val_loader \n",
  571. "\n",
  572. "class GaussianBlur(object):\n",
  573. " \"\"\"Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709\"\"\"\n",
  574. "\n",
  575. " def __init__(self, sigma=[.1, 2.]):\n",
  576. " self.sigma = sigma\n",
  577. "\n",
  578. " def __call__(self, x):\n",
  579. " sigma = random.uniform(self.sigma[0], self.sigma[1])\n",
  580. " x = x.filter(ImageFilter.GaussianBlur(radius=sigma))\n",
  581. " return x"
  582. ]
  583. },
  584. {
  585. "cell_type": "code",
  586. "execution_count": null,
  587. "metadata": {},
  588. "outputs": [],
  589. "source": [
  590. "#!/usr/bin/env python\n",
  591. "\n",
  592. "import argparse\n",
  593. "\n",
  594. "import matplotlib.pyplot as plt\n",
  595. "\n",
  596. "import torch\n",
  597. "\n",
  598. "from torchvision.transforms import transforms\n",
  599. "\n",
  600. "import sys\n",
  601. "sys.path.append(\"./\")\n",
  602. "\n",
  603. "import utils\n",
  604. "import models.builer as builder\n",
  605. "import dataloader\n",
  606. "\n",
  607. "import os\n",
  608. "\n",
  609. "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
  610. "\n",
  611. "def get_args():\n",
  612. " # parse the args\n",
  613. " print('=> parse the args ...')\n",
  614. " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
  615. " parser.add_argument('--arch', default='vgg16', type=str, \n",
  616. " help='backbone architechture')\n",
  617. " parser.add_argument('--resume', type=str)\n",
  618. " parser.add_argument('--val_list', type=str) \n",
  619. " \n",
  620. " args = parser.parse_args()\n",
  621. "\n",
  622. " args.parallel = 0\n",
  623. " args.batch_size = 1\n",
  624. " args.workers = 0\n",
  625. "\n",
  626. " return args\n",
  627. "\n",
  628. "def main(args):\n",
  629. " print('=> torch version : {}'.format(torch.__version__))\n",
  630. "\n",
  631. " utils.init_seeds(1, cuda_deterministic=False)\n",
  632. "\n",
  633. " print('=> modeling the network ...')\n",
  634. " model = builder.BuildAutoEncoder(args) \n",
  635. " total_params = sum(p.numel() for p in model.parameters())\n",
  636. " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
  637. "\n",
  638. " print('=> loading pth from {} ...'.format(args.resume))\n",
  639. " utils.load_dict(args.resume, model)\n",
  640. " \n",
  641. " print('=> building the dataloader ...')\n",
  642. " train_loader = dataloader.val_loader(args)\n",
  643. "\n",
  644. " plt.figure(figsize=(16, 9))\n",
  645. "\n",
  646. " model.eval()\n",
  647. " print('=> reconstructing ...')\n",
  648. " with torch.no_grad():\n",
  649. " for i, (input, target) in enumerate(train_loader):\n",
  650. " \n",
  651. " input = input.cuda(non_blocking=True)\n",
  652. " target = target.cuda(non_blocking=True)\n",
  653. "\n",
  654. " output = model(input)\n",
  655. "\n",
  656. " input = transforms.ToPILImage()(input.squeeze().cpu())\n",
  657. " output = transforms.ToPILImage()(output.squeeze().cpu())\n",
  658. "\n",
  659. " plt.subplot(8,16,2*i+1, xticks=[], yticks=[])\n",
  660. " plt.imshow(input)\n",
  661. "\n",
  662. " plt.subplot(8,16,2*i+2, xticks=[], yticks=[])\n",
  663. " plt.imshow(output)\n",
  664. "\n",
  665. " if i == 63:\n",
  666. " break\n",
  667. "\n",
  668. " plt.savefig('figs/reconstruction.jpg')\n",
  669. "\n",
  670. "if __name__ == '__main__':\n",
  671. "\n",
  672. " args = get_args()\n",
  673. "\n",
  674. " main(args)"
  675. ]
  676. },
  677. {
  678. "cell_type": "code",
  679. "execution_count": null,
  680. "metadata": {},
  681. "outputs": [],
  682. "source": [
  683. "import os\n",
  684. "import argparse\n",
  685. "\n",
  686. "parser = argparse.ArgumentParser()\n",
  687. "parser.add_argument('--name', type=str) \n",
  688. "parser.add_argument('--path', type=str) \n",
  689. "args = parser.parse_args()\n",
  690. "\n",
  691. "folders = os.listdir(args.path)\n",
  692. "\n",
  693. "if not os.path.exists(\"list\"):\n",
  694. " os.makedirs(\"list\")\n",
  695. "\n",
  696. "fl = open('list/' + args.name + '_list.txt', 'w')\n",
  697. "fn = open('list/' + args.name + '_name.txt', 'w')\n",
  698. "\n",
  699. "for i, folder in enumerate(folders):\n",
  700. "\n",
  701. " fn.write(str(i) + ' ' + folder + '\\n')\n",
  702. "\n",
  703. " folder_path = os.path.join(args.path, folder)\n",
  704. " files = os.listdir(folder_path)\n",
  705. "\n",
  706. " for file in files:\n",
  707. "\n",
  708. " fl.write('{} {}\\n'.format(os.path.join(folder_path, file), i))\n",
  709. "\n",
  710. "fl.close()\n",
  711. "fn.close()\n"
  712. ]
  713. },
  714. {
  715. "cell_type": "code",
  716. "execution_count": null,
  717. "metadata": {},
  718. "outputs": [],
  719. "source": [
  720. "#!/usr/bin/env python\n",
  721. "\n",
  722. "import argparse\n",
  723. "\n",
  724. "from PIL import Image\n",
  725. "\n",
  726. "import torch\n",
  727. "from torchvision.transforms import transforms\n",
  728. "\n",
  729. "import sys\n",
  730. "sys.path.append(\"./\")\n",
  731. "\n",
  732. "import utils\n",
  733. "import models.builer as builder\n",
  734. "\n",
  735. "import os\n",
  736. "\n",
  737. "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
  738. "\n",
  739. "def get_args():\n",
  740. " # parse the args\n",
  741. " print('=> parse the args ...')\n",
  742. " parser = argparse.ArgumentParser(description='Encoder for auto encoder')\n",
  743. " parser.add_argument('--arch', default='vgg16', type=str, \n",
  744. " help='backbone architechture')\n",
  745. " parser.add_argument('--resume', type=str)\n",
  746. " parser.add_argument('--img_path',type=str) \n",
  747. " \n",
  748. " args = parser.parse_args()\n",
  749. "\n",
  750. " args.parallel = 0\n",
  751. " args.batch_size = 1\n",
  752. " args.workers = 0\n",
  753. "\n",
  754. " return args\n",
  755. "\n",
  756. "def encode(model, img):\n",
  757. "\n",
  758. " with torch.no_grad():\n",
  759. "\n",
  760. " code = model.module.encoder(img).cpu().numpy()\n",
  761. "\n",
  762. " return code\n",
  763. "\n",
  764. "def main(args):\n",
  765. " print('=> torch version : {}'.format(torch.__version__))\n",
  766. "\n",
  767. " utils.init_seeds(1, cuda_deterministic=False)\n",
  768. "\n",
  769. " print('=> modeling the network ...')\n",
  770. " model = builder.BuildAutoEncoder(args) \n",
  771. " total_params = sum(p.numel() for p in model.parameters())\n",
  772. " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
  773. "\n",
  774. " print('=> loading pth from {} ...'.format(args.resume))\n",
  775. " utils.load_dict(args.resume, model)\n",
  776. " \n",
  777. " trans = transforms.Compose([\n",
  778. " transforms.Resize(256), \n",
  779. " transforms.CenterCrop(224),\n",
  780. " transforms.ToTensor()\n",
  781. " ])\n",
  782. "\n",
  783. " img = Image.open(args.img_path).convert(\"RGB\")\n",
  784. "\n",
  785. " img = trans(img).unsqueeze(0).cuda()\n",
  786. "\n",
  787. " model.eval()\n",
  788. "\n",
  789. " code = encode(model, img)\n",
  790. "\n",
  791. " print(code.shape)\n",
  792. "\n",
  793. " # To do : any other postprocessing\n",
  794. "\n",
  795. "if __name__ == '__main__':\n",
  796. "\n",
  797. " args = get_args()\n",
  798. "\n",
  799. " main(args)\n"
  800. ]
  801. },
  802. {
  803. "cell_type": "code",
  804. "execution_count": null,
  805. "metadata": {},
  806. "outputs": [],
  807. "source": [
  808. "#!/usr/bin/env python\n",
  809. "\n",
  810. "import argparse\n",
  811. "\n",
  812. "import matplotlib.pyplot as plt\n",
  813. "\n",
  814. "import torch\n",
  815. "\n",
  816. "from torchvision.transforms import transforms\n",
  817. "\n",
  818. "import sys\n",
  819. "sys.path.append(\"./\")\n",
  820. "\n",
  821. "import utils\n",
  822. "import models.builer as builder\n",
  823. "\n",
  824. "import os\n",
  825. "\n",
  826. "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
  827. "\n",
  828. "def get_args():\n",
  829. " # parse the args\n",
  830. " print('=> parse the args ...')\n",
  831. " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
  832. " parser.add_argument('--arch', default='vgg16', type=str, \n",
  833. " help='backbone architechture')\n",
  834. " parser.add_argument('--resume', type=str) \n",
  835. " \n",
  836. " args = parser.parse_args()\n",
  837. "\n",
  838. " args.parallel = 0\n",
  839. " args.batch_size = 1\n",
  840. " args.workers = 0\n",
  841. "\n",
  842. " return args\n",
  843. "\n",
  844. "def random_sample(arch):\n",
  845. "\n",
  846. " if arch in [\"vgg11\", \"vgg13\", \"vgg16\", \"vgg19\", \"resnet18\", \"resnet34\"]:\n",
  847. " return torch.randn((1,512,7,7))\n",
  848. " elif arch in [\"resnet50\", \"resnet101\", \"resnet152\"]:\n",
  849. " return torch.randn((1,2048,7,7))\n",
  850. " else:\n",
  851. " raise NotImplementedError(\"Do not have implemention except VGG and ResNet\")\n",
  852. "\n",
  853. "def main(args):\n",
  854. " print('=> torch version : {}'.format(torch.__version__))\n",
  855. "\n",
  856. " utils.init_seeds(1, cuda_deterministic=False)\n",
  857. "\n",
  858. " print('=> modeling the network ...')\n",
  859. " model = builder.BuildAutoEncoder(args) \n",
  860. " total_params = sum(p.numel() for p in model.parameters())\n",
  861. " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
  862. "\n",
  863. " print('=> loading pth from {} ...'.format(args.resume))\n",
  864. " utils.load_dict(args.resume, model)\n",
  865. "\n",
  866. " trans = transforms.ToPILImage()\n",
  867. "\n",
  868. " plt.figure(figsize=(16, 9))\n",
  869. "\n",
  870. " model.eval()\n",
  871. " print('=> Genarating ...')\n",
  872. " with torch.no_grad():\n",
  873. " for i in range(128):\n",
  874. " \n",
  875. " input = random_sample(arch=args.arch).cuda()\n",
  876. "\n",
  877. " output = model.module.decoder(input)\n",
  878. "\n",
  879. " output = trans(output.squeeze().cpu())\n",
  880. "\n",
  881. " plt.subplot(8,16,i+1, xticks=[], yticks=[])\n",
  882. " plt.imshow(output)\n",
  883. "\n",
  884. " plt.savefig('figs/generation.jpg')\n",
  885. "\n",
  886. "if __name__ == '__main__':\n",
  887. "\n",
  888. " args = get_args()\n",
  889. "\n",
  890. " main(args)"
  891. ]
  892. },
  893. {
  894. "cell_type": "code",
  895. "execution_count": null,
  896. "metadata": {},
  897. "outputs": [],
  898. "source": [
  899. "import torch\n",
  900. "import torch.nn as nn\n",
  901. "\n",
  902. "def get_configs(arch='vgg16'):\n",
  903. "\n",
  904. " if arch == 'vgg11':\n",
  905. " configs = [1, 1, 2, 2, 2]\n",
  906. " elif arch == 'vgg13':\n",
  907. " configs = [2, 2, 2, 2, 2]\n",
  908. " elif arch == 'vgg16':\n",
  909. " configs = [2, 2, 3, 3, 3]\n",
  910. " elif arch == 'vgg19':\n",
  911. " configs = [2, 2, 4, 4, 4]\n",
  912. " else:\n",
  913. " raise ValueError(\"Undefined model\")\n",
  914. " \n",
  915. " return configs\n",
  916. "\n",
  917. "class VGGAutoEncoder(nn.Module):\n",
  918. "\n",
  919. " def __init__(self, configs):\n",
  920. "\n",
  921. " super(VGGAutoEncoder, self).__init__()\n",
  922. "\n",
  923. " # VGG without Bn as AutoEncoder is hard to train\n",
  924. " self.encoder = VGGEncoder(configs=configs, enable_bn=True)\n",
  925. " self.decoder = VGGDecoder(configs=configs[::-1], enable_bn=True)\n",
  926. " \n",
  927. " \n",
  928. " def forward(self, x):\n",
  929. "\n",
  930. " x = self.encoder(x)\n",
  931. " x = self.decoder(x)\n",
  932. "\n",
  933. " return x\n",
  934. "\n",
  935. "class VGG(nn.Module):\n",
  936. "\n",
  937. " def __init__(self, configs, num_classes=1000, img_size=224, enable_bn=False):\n",
  938. " super(VGG, self).__init__()\n",
  939. "\n",
  940. " self.encoder = VGGEncoder(configs=configs, enable_bn=enable_bn)\n",
  941. "\n",
  942. " self.img_size = img_size / 32\n",
  943. "\n",
  944. " self.fc = nn.Sequential(\n",
  945. " nn.Linear(in_features=int(self.img_size*self.img_size*512), out_features=4096),\n",
  946. " nn.Dropout(p=0.5),\n",
  947. " nn.ReLU(inplace=True),\n",
  948. " nn.Linear(in_features=4096, out_features=4096),\n",
  949. " nn.Dropout(p=0.5),\n",
  950. " nn.ReLU(inplace=True),\n",
  951. " nn.Linear(in_features=4096, out_features=num_classes)\n",
  952. " )\n",
  953. " \n",
  954. " for m in self.modules():\n",
  955. " if isinstance(m, nn.Conv2d):\n",
  956. " nn.init.normal_(m.weight, mean=0, std=0.01)\n",
  957. " if m.bias is not None:\n",
  958. " nn.init.constant_(m.bias, 0)\n",
  959. " if isinstance(m, nn.Linear):\n",
  960. " nn.init.normal_(m.weight, mean=0, std=0.01)\n",
  961. " nn.init.constant_(m.bias, 0)\n",
  962. " \n",
  963. " def forward(self, x):\n",
  964. "\n",
  965. " x = self.encoder(x)\n",
  966. "\n",
  967. " x = torch.flatten(x, 1)\n",
  968. "\n",
  969. " x = self.fc(x)\n",
  970. "\n",
  971. " return x\n",
  972. "\n",
  973. "class VGGEncoder(nn.Module):\n",
  974. "\n",
  975. " def __init__(self, configs, enable_bn=False):\n",
  976. "\n",
  977. " super(VGGEncoder, self).__init__()\n",
  978. "\n",
  979. " if len(configs) != 5:\n",
  980. "\n",
  981. " raise ValueError(\"There should be 5 stage in VGG\")\n",
  982. "\n",
  983. " self.conv1 = EncoderBlock(input_dim=3, output_dim=64, hidden_dim=64, layers=configs[0], enable_bn=enable_bn)\n",
  984. " self.conv2 = EncoderBlock(input_dim=64, output_dim=128, hidden_dim=128, layers=configs[1], enable_bn=enable_bn)\n",
  985. " self.conv3 = EncoderBlock(input_dim=128, output_dim=256, hidden_dim=256, layers=configs[2], enable_bn=enable_bn)\n",
  986. " self.conv4 = EncoderBlock(input_dim=256, output_dim=512, hidden_dim=512, layers=configs[3], enable_bn=enable_bn)\n",
  987. " self.conv5 = EncoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[4], enable_bn=enable_bn)\n",
  988. " \n",
  989. " def forward(self, x):\n",
  990. "\n",
  991. " x = self.conv1(x)\n",
  992. " x = self.conv2(x)\n",
  993. " x = self.conv3(x)\n",
  994. " x = self.conv4(x)\n",
  995. " x = self.conv5(x)\n",
  996. "\n",
  997. " return x\n",
  998. "\n",
  999. "class VGGDecoder(nn.Module):\n",
  1000. "\n",
  1001. " def __init__(self, configs, enable_bn=False):\n",
  1002. "\n",
  1003. " super(VGGDecoder, self).__init__()\n",
  1004. "\n",
  1005. " if len(configs) != 5:\n",
  1006. "\n",
  1007. " raise ValueError(\"There should be 5 stage in VGG\")\n",
  1008. "\n",
  1009. " self.conv1 = DecoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[0], enable_bn=enable_bn)\n",
  1010. " self.conv2 = DecoderBlock(input_dim=512, output_dim=256, hidden_dim=512, layers=configs[1], enable_bn=enable_bn)\n",
  1011. " self.conv3 = DecoderBlock(input_dim=256, output_dim=128, hidden_dim=256, layers=configs[2], enable_bn=enable_bn)\n",
  1012. " self.conv4 = DecoderBlock(input_dim=128, output_dim=64, hidden_dim=128, layers=configs[3], enable_bn=enable_bn)\n",
  1013. " self.conv5 = DecoderBlock(input_dim=64, output_dim=3, hidden_dim=64, layers=configs[4], enable_bn=enable_bn)\n",
  1014. " self.gate = nn.Sigmoid()\n",
  1015. " \n",
  1016. " def forward(self, x):\n",
  1017. "\n",
  1018. " x = self.conv1(x)\n",
  1019. " x = self.conv2(x)\n",
  1020. " x = self.conv3(x)\n",
  1021. " x = self.conv4(x)\n",
  1022. " x = self.conv5(x)\n",
  1023. " x = self.gate(x)\n",
  1024. "\n",
  1025. " return x\n",
  1026. "\n",
  1027. "class EncoderBlock(nn.Module):\n",
  1028. "\n",
  1029. " def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False):\n",
  1030. "\n",
  1031. " super(EncoderBlock, self).__init__()\n",
  1032. "\n",
  1033. " if layers == 1:\n",
  1034. "\n",
  1035. " layer = EncoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
  1036. "\n",
  1037. " self.add_module('0 EncoderLayer', layer)\n",
  1038. "\n",
  1039. " else:\n",
  1040. "\n",
  1041. " for i in range(layers):\n",
  1042. "\n",
  1043. " if i == 0:\n",
  1044. " layer = EncoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
  1045. " elif i == (layers - 1):\n",
  1046. " layer = EncoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
  1047. " else:\n",
  1048. " layer = EncoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
  1049. " \n",
  1050. " self.add_module('%d EncoderLayer' % i, layer)\n",
  1051. " \n",
  1052. " maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
  1053. "\n",
  1054. " self.add_module('%d MaxPooling' % layers, maxpool)\n",
  1055. " \n",
  1056. " def forward(self, x):\n",
  1057. "\n",
  1058. " for name, layer in self.named_children():\n",
  1059. "\n",
  1060. " x = layer(x)\n",
  1061. "\n",
  1062. " return x\n",
  1063. "\n",
  1064. "class DecoderBlock(nn.Module):\n",
  1065. "\n",
  1066. " def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False):\n",
  1067. "\n",
  1068. " super(DecoderBlock, self).__init__()\n",
  1069. "\n",
  1070. " upsample = nn.ConvTranspose2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=2, stride=2)\n",
  1071. "\n",
  1072. " self.add_module('0 UpSampling', upsample)\n",
  1073. "\n",
  1074. " if layers == 1:\n",
  1075. "\n",
  1076. " layer = DecoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
  1077. "\n",
  1078. " self.add_module('1 DecoderLayer', layer)\n",
  1079. "\n",
  1080. " else:\n",
  1081. "\n",
  1082. " for i in range(layers):\n",
  1083. "\n",
  1084. " if i == 0:\n",
  1085. " layer = DecoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
  1086. " elif i == (layers - 1):\n",
  1087. " layer = DecoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
  1088. " else:\n",
  1089. " layer = DecoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
  1090. " \n",
  1091. " self.add_module('%d DecoderLayer' % (i+1), layer)\n",
  1092. " \n",
  1093. " def forward(self, x):\n",
  1094. "\n",
  1095. " for name, layer in self.named_children():\n",
  1096. "\n",
  1097. " x = layer(x)\n",
  1098. "\n",
  1099. " return x\n",
  1100. "\n",
  1101. "class EncoderLayer(nn.Module):\n",
  1102. "\n",
  1103. " def __init__(self, input_dim, output_dim, enable_bn):\n",
  1104. " super(EncoderLayer, self).__init__()\n",
  1105. "\n",
  1106. " if enable_bn:\n",
  1107. " self.layer = nn.Sequential(\n",
  1108. " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
  1109. " nn.BatchNorm2d(output_dim),\n",
  1110. " nn.ReLU(inplace=True),\n",
  1111. " )\n",
  1112. " else:\n",
  1113. " self.layer = nn.Sequential(\n",
  1114. " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
  1115. " nn.ReLU(inplace=True),\n",
  1116. " )\n",
  1117. " \n",
  1118. " def forward(self, x):\n",
  1119. "\n",
  1120. " return self.layer(x)\n",
  1121. "\n",
  1122. "class DecoderLayer(nn.Module):\n",
  1123. "\n",
  1124. " def __init__(self, input_dim, output_dim, enable_bn):\n",
  1125. " super(DecoderLayer, self).__init__()\n",
  1126. "\n",
  1127. " if enable_bn:\n",
  1128. " self.layer = nn.Sequential(\n",
  1129. " nn.BatchNorm2d(input_dim),\n",
  1130. " nn.ReLU(inplace=True),\n",
  1131. " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
  1132. " )\n",
  1133. " else:\n",
  1134. " self.layer = nn.Sequential(\n",
  1135. " nn.ReLU(inplace=True),\n",
  1136. " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
  1137. " )\n",
  1138. " \n",
  1139. " def forward(self, x):\n",
  1140. "\n",
  1141. " return self.layer(x)\n",
  1142. "\n",
  1143. "if __name__ == \"__main__\":\n",
  1144. "\n",
  1145. " input = torch.randn((5,3,224,224))\n",
  1146. "\n",
  1147. " configs = get_configs()\n",
  1148. "\n",
  1149. " model = VGGAutoEncoder(configs)\n",
  1150. "\n",
  1151. " output = model(input)\n",
  1152. "\n",
  1153. " print(output.shape)"
  1154. ]
  1155. },
  1156. {
  1157. "cell_type": "code",
  1158. "execution_count": null,
  1159. "metadata": {},
  1160. "outputs": [],
  1161. "source": [
  1162. "import torch.nn as nn\n",
  1163. "import torch.nn.parallel as parallel\n",
  1164. "\n",
  1165. "from . import vgg, resnet\n",
  1166. "\n",
  1167. "def BuildAutoEncoder(args):\n",
  1168. "\n",
  1169. " if args.arch in [\"vgg11\", \"vgg13\", \"vgg16\", \"vgg19\"]:\n",
  1170. " configs = vgg.get_configs(args.arch)\n",
  1171. " model = vgg.VGGAutoEncoder(configs)\n",
  1172. "\n",
  1173. " elif args.arch in [\"resnet18\", \"resnet34\", \"resnet50\", \"resnet101\", \"resnet152\"]:\n",
  1174. " configs, bottleneck = resnet.get_configs(args.arch)\n",
  1175. " model = resnet.ResNetAutoEncoder(configs, bottleneck)\n",
  1176. " \n",
  1177. " else:\n",
  1178. " return None\n",
  1179. " \n",
  1180. " if args.parallel == 1:\n",
  1181. " model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n",
  1182. " model = parallel.DistributedDataParallel(\n",
  1183. " model.to(args.gpu),\n",
  1184. " device_ids=[args.gpu],\n",
  1185. " output_device=args.gpu\n",
  1186. " ) \n",
  1187. " \n",
  1188. " else:\n",
  1189. " model = nn.DataParallel(model).cuda()\n",
  1190. "\n",
  1191. " return model"
  1192. ]
  1193. }
  1194. ],
  1195. "metadata": {
  1196. "language_info": {
  1197. "name": "python"
  1198. }
  1199. },
  1200. "nbformat": 4,
  1201. "nbformat_minor": 2
  1202. }