12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202 |
- {
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import time\n",
- "import argparse\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "import torch.multiprocessing as mp\n",
- "\n",
- "import utils\n",
- "import models.builer as builder\n",
- "import dataloader\n",
- "\n",
- "def get_args():\n",
- " # parse the args\n",
- " print('=> parse the args ...')\n",
- " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
- " parser.add_argument('--arch', default='vgg16', type=str, \n",
- " help='backbone architechture')\n",
- " parser.add_argument('--train_list', type=str)\n",
- " parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', \n",
- " help='number of data loading workers (default: 0)')\n",
- " parser.add_argument('--epochs', default=20, type=int, metavar='N',\n",
- " help='number of total epochs to run')\n",
- " parser.add_argument('--start-epoch', default=0, type=int, metavar='N',\n",
- " help='manual epoch number (useful on restarts)')\n",
- " parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N',\n",
- " help='mini-batch size (default: 256), this is the total '\n",
- " 'batch size of all GPUs on the current node when '\n",
- " 'using Data Parallel or Distributed Data Parallel')\n",
- "\n",
- " parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,\n",
- " metavar='LR', help='initial learning rate', dest='lr')\n",
- " parser.add_argument('--momentum', default=0.9, type=float, metavar='M', \n",
- " help='momentum')\n",
- " parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,\n",
- " metavar='W', help='weight decay (default: 1e-4)',\n",
- " dest='weight_decay')\n",
- "\n",
- " parser.add_argument('-p', '--print-freq', default=100, type=int,\n",
- " metavar='N', help='print frequency (default: 10)')\n",
- "\n",
- " parser.add_argument('--pth-save-fold', default='results/tmp', type=str,\n",
- " help='The folder to save pths')\n",
- " parser.add_argument('--pth-save-epoch', default=1, type=int,\n",
- " help='The epoch to save pth')\n",
- " parser.add_argument('--parallel', type=int, default=1, \n",
- " help='1 for parallel, 0 for non-parallel')\n",
- " parser.add_argument('--dist-url', default='tcp://localhost:10001', type=str,\n",
- " help='url used to set up distributed training') \n",
- "\n",
- " args = parser.parse_args()\n",
- "\n",
- " return args\n",
- "\n",
- "def main(args):\n",
- " print('=> torch version : {}'.format(torch.__version__))\n",
- " ngpus_per_node = torch.cuda.device_count()\n",
- " print('=> ngpus : {}'.format(ngpus_per_node))\n",
- "\n",
- " if args.parallel == 1: \n",
- " # single machine multi card \n",
- " args.gpus = ngpus_per_node\n",
- " args.nodes = 1\n",
- " args.nr = 0\n",
- " args.world_size = args.gpus * args.nodes\n",
- "\n",
- " args.workers = int(args.workers / args.world_size)\n",
- " args.batch_size = int(args.batch_size / args.world_size)\n",
- " mp.spawn(main_worker, nprocs=args.gpus, args=(args,))\n",
- " else:\n",
- " args.world_size = 1\n",
- " main_worker(ngpus_per_node, args)\n",
- " \n",
- "def main_worker(gpu, args):\n",
- " utils.init_seeds(1 + gpu, cuda_deterministic=False)\n",
- " if args.parallel == 1:\n",
- " args.gpu = gpu\n",
- " args.rank = args.nr * args.gpus + args.gpu\n",
- "\n",
- " torch.cuda.set_device(gpu)\n",
- " torch.distributed.init_process_group(backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank) \n",
- " \n",
- " else:\n",
- " # two dummy variable, not real\n",
- " args.rank = 0\n",
- " args.gpus = 1 \n",
- " if args.rank == 0:\n",
- " print('=> modeling the network {} ...'.format(args.arch))\n",
- " model = builder.BuildAutoEncoder(args) \n",
- " if args.rank == 0: \n",
- " total_params = sum(p.numel() for p in model.parameters())\n",
- " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
- " \n",
- " if args.rank == 0:\n",
- " print('=> building the oprimizer ...')\n",
- " # optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), args.lr,)\n",
- " optimizer = torch.optim.SGD(\n",
- " filter(lambda p: p.requires_grad, model.parameters()),\n",
- " args.lr,\n",
- " momentum=args.momentum,\n",
- " weight_decay=args.weight_decay) \n",
- " if args.rank == 0:\n",
- " print('=> building the dataloader ...')\n",
- " train_loader = dataloader.train_loader(args)\n",
- "\n",
- " if args.rank == 0:\n",
- " print('=> building the criterion ...')\n",
- " criterion = nn.MSELoss()\n",
- "\n",
- " global iters\n",
- " iters = 0\n",
- "\n",
- " model.train()\n",
- " if args.rank == 0:\n",
- " print('=> starting training engine ...')\n",
- " for epoch in range(args.start_epoch, args.epochs):\n",
- " \n",
- " global current_lr\n",
- " current_lr = utils.adjust_learning_rate_cosine(optimizer, epoch, args)\n",
- "\n",
- " train_loader.sampler.set_epoch(epoch)\n",
- " \n",
- " # train for one epoch\n",
- " do_train(train_loader, model, criterion, optimizer, epoch, args)\n",
- "\n",
- " # save pth\n",
- " if epoch % args.pth_save_epoch == 0 and args.rank == 0:\n",
- " state_dict = model.state_dict()\n",
- "\n",
- " torch.save(\n",
- " {\n",
- " 'epoch': epoch + 1,\n",
- " 'arch': args.arch,\n",
- " 'state_dict': state_dict,\n",
- " 'optimizer' : optimizer.state_dict(),\n",
- " },\n",
- " os.path.join(args.pth_save_fold, '{}.pth'.format(str(epoch).zfill(3)))\n",
- " )\n",
- " \n",
- " print(' : save pth for epoch {}'.format(epoch + 1))\n",
- "\n",
- "\n",
- "def do_train(train_loader, model, criterion, optimizer, epoch, args):\n",
- " batch_time = utils.AverageMeter('Time', ':6.2f')\n",
- " data_time = utils.AverageMeter('Data', ':2.2f')\n",
- " losses = utils.AverageMeter('Loss', ':.4f')\n",
- " learning_rate = utils.AverageMeter('LR', ':.4f')\n",
- " \n",
- " progress = utils.ProgressMeter(\n",
- " len(train_loader),\n",
- " [batch_time, data_time, losses, learning_rate],\n",
- " prefix=\"Epoch: [{}]\".format(epoch+1))\n",
- " end = time.time()\n",
- "\n",
- " # update lr\n",
- " learning_rate.update(current_lr)\n",
- "\n",
- " for i, (input, target) in enumerate(train_loader):\n",
- " # measure data loading time\n",
- " data_time.update(time.time() - end)\n",
- " global iters\n",
- " iters += 1\n",
- " \n",
- " input = input.cuda(non_blocking=True)\n",
- " target = target.cuda(non_blocking=True)\n",
- "\n",
- " output = model(input)\n",
- "\n",
- " loss = criterion(output, target)\n",
- "\n",
- " # compute gradient and do solver step\n",
- " optimizer.zero_grad()\n",
- " # backward\n",
- " loss.backward()\n",
- " # update weights\n",
- " optimizer.step()\n",
- "\n",
- " # syn for logging\n",
- " torch.cuda.synchronize()\n",
- "\n",
- " # record loss\n",
- " losses.update(loss.item(), input.size(0)) \n",
- "\n",
- " # measure elapsed time\n",
- " if args.rank == 0:\n",
- " batch_time.update(time.time() - end) \n",
- " end = time.time() \n",
- "\n",
- " if i % args.print_freq == 0 and args.rank == 0:\n",
- " progress.display(i)\n",
- "\n",
- "if __name__ == '__main__':\n",
- "\n",
- " args = get_args()\n",
- "\n",
- " main(args)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import sys\n",
- "import math\n",
- "import random\n",
- "import numpy as np\n",
- "\n",
- "from loguru import logger\n",
- "\n",
- "import torch\n",
- "from torch.backends import cudnn\n",
- "\n",
- "def init_seeds(seed=0, cuda_deterministic=True):\n",
- " random.seed(seed)\n",
- " np.random.seed(seed)\n",
- " torch.manual_seed(seed)\n",
- " torch.cuda.manual_seed(seed)\n",
- " os.environ['PYTHONHASHSEED'] = str(seed)\n",
- " # Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html\n",
- " if cuda_deterministic: # slower, more reproducible\n",
- " cudnn.deterministic = True\n",
- " cudnn.benchmark = False\n",
- " else: # faster, less reproducible\n",
- " cudnn.deterministic = False\n",
- " cudnn.benchmark = True\n",
- "\n",
- "class AverageMeter(object):\n",
- " \"\"\"Computes and stores the average and current value\"\"\"\n",
- "\n",
- " def __init__(self, name, fmt=':f'):\n",
- " self.name = name\n",
- " self.fmt = fmt\n",
- " self.reset()\n",
- "\n",
- " def reset(self):\n",
- " self.val = 0\n",
- " self.avg = 0\n",
- " self.sum = 0\n",
- " self.count = 0\n",
- "\n",
- " def update(self, val, n=1):\n",
- " self.val = val\n",
- " self.sum += val * n\n",
- " self.count += n\n",
- " self.avg = self.sum / self.count\n",
- "\n",
- " def __str__(self):\n",
- " fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'\n",
- " return fmtstr.format(**self.__dict__)\n",
- "\n",
- "\n",
- "class ProgressMeter(object):\n",
- " def __init__(self, num_batches, meters, prefix=\"\"):\n",
- " self.batch_fmtstr = self._get_batch_fmtstr(num_batches)\n",
- " self.meters = meters\n",
- " self.prefix = prefix\n",
- "\n",
- " def display(self, batch):\n",
- " entries = [self.prefix + self.batch_fmtstr.format(batch)]\n",
- " entries += [str(meter) for meter in self.meters]\n",
- " logger.info(\"\\t\".join(entries))\n",
- "\n",
- " def _get_batch_fmtstr(self, num_batches):\n",
- " num_digits = len(str(num_batches // 1))\n",
- " fmt = '{:' + str(num_digits) + 'd}'\n",
- " return '[' + fmt + '/' + fmt.format(num_batches) + ']'\n",
- "\n",
- "def adjust_learning_rate(optimizer, epoch, args):\n",
- " decay = args.lr_drop_ratio if epoch in args.lr_drop_epoch else 1.0\n",
- " lr = args.lr * decay\n",
- " global current_lr\n",
- " current_lr = lr\n",
- " for param_group in optimizer.param_groups:\n",
- " param_group['lr'] = lr\n",
- " args.lr = current_lr\n",
- " return current_lr\n",
- "\n",
- "def adjust_learning_rate_cosine(optimizer, epoch, args):\n",
- " \"\"\"cosine learning rate annealing without restart\"\"\"\n",
- " lr = args.lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))\n",
- " global current_lr\n",
- " current_lr = lr\n",
- " for param_group in optimizer.param_groups:\n",
- " param_group['lr'] = lr\n",
- " return current_lr\n",
- "\n",
- "def load_dict(resume_path, model):\n",
- " if os.path.isfile(resume_path):\n",
- " checkpoint = torch.load(resume_path)\n",
- " model_dict = model.state_dict()\n",
- " model_dict.update(checkpoint['state_dict'])\n",
- " model.load_state_dict(model_dict)\n",
- " # delete to release more space\n",
- " del checkpoint\n",
- " else:\n",
- " sys.exit(\"=> No checkpoint found at '{}'\".format(resume_path))\n",
- " return model"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "import os\n",
- "import time\n",
- "import argparse\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "\n",
- "import utils\n",
- "import models.builer as builder\n",
- "import dataloader\n",
- "\n",
- "def get_args():\n",
- " # parse the args\n",
- " print('=> parse the args ...')\n",
- " parser = argparse.ArgumentParser(description='Evaluate for auto encoder')\n",
- " parser.add_argument('--arch', default='vgg16', type=str, \n",
- " help='backbone architechture')\n",
- " parser.add_argument('--val_list', type=str)\n",
- " parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', \n",
- " help='number of data loading workers (default: 0)')\n",
- " parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N',\n",
- " help='mini-batch size (default: 256), this is the total '\n",
- " 'batch size of all GPUs on the current node when '\n",
- " 'using Data Parallel or Distributed Data Parallel')\n",
- "\n",
- " parser.add_argument('-p', '--print-freq', default=20, type=int,\n",
- " metavar='N', help='print frequency (default: 10)')\n",
- "\n",
- " parser.add_argument('--resume', type=str)\n",
- " parser.add_argument('--folder', type=str) \n",
- " parser.add_argument('--start_epoch', default=0, type=int) \n",
- " parser.add_argument('--epochs', default=100, type=int) \n",
- "\n",
- " args = parser.parse_args()\n",
- "\n",
- " args.parallel = 0\n",
- "\n",
- " return args\n",
- "\n",
- "def main(args):\n",
- " print('=> torch version : {}'.format(torch.__version__))\n",
- " ngpus_per_node = torch.cuda.device_count()\n",
- " print('=> ngpus : {}'.format(ngpus_per_node))\n",
- "\n",
- " utils.init_seeds(1, cuda_deterministic=False)\n",
- " \n",
- " print('=> modeling the network ...')\n",
- " model = builder.BuildAutoEncoder(args) \n",
- " total_params = sum(p.numel() for p in model.parameters())\n",
- " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
- " \n",
- " print('=> building the dataloader ...')\n",
- " val_loader = dataloader.val_loader(args)\n",
- "\n",
- " print('=> building the criterion ...')\n",
- " criterion = nn.MSELoss()\n",
- "\n",
- " print('=> starting evaluating engine ...')\n",
- " if args.folder:\n",
- " best_loss = None\n",
- " best_epoch = 1\n",
- " losses = []\n",
- " for epoch in range(args.start_epoch, args.epochs):\n",
- " print()\n",
- " print(\"Epoch {}\".format(epoch+1))\n",
- " resume_path = os.path.join(args.folder, \"%03d.pth\" % epoch)\n",
- " print('=> loading pth from {} ...'.format(resume_path))\n",
- " utils.load_dict(resume_path, model)\n",
- " loss = do_evaluate(val_loader, model, criterion, args)\n",
- " print(\"Evaluate loss : {:.4f}\".format(loss))\n",
- "\n",
- " losses.append(loss)\n",
- " if best_loss:\n",
- " if loss < best_loss:\n",
- " best_loss = loss\n",
- " best_epoch = epoch + 1\n",
- " else:\n",
- " best_loss = loss\n",
- " print()\n",
- " print(\"Best loss : {:.4f} Appears in {}\".format(best_loss, best_epoch))\n",
- "\n",
- " max_loss = max(losses)\n",
- "\n",
- " plt.figure(figsize=(7,7))\n",
- "\n",
- " plt.xlabel(\"epoch\")\n",
- " plt.ylabel(\"loss\")\n",
- " plt.xlim((0,args.epochs+1)) \n",
- " plt.ylim([0, float('%.1g' % (1.22*max_loss))])\n",
- "\n",
- " plt.scatter(range(1, args.epochs+1), losses, s=9)\n",
- "\n",
- " plt.savefig(\"figs/evalall.jpg\")\n",
- "\n",
- " else:\n",
- " print('=> loading pth from {} ...'.format(args.resume))\n",
- " utils.load_dict(args.resume, model)\n",
- " loss = do_evaluate(val_loader, model, criterion, args)\n",
- " print(\"Evaluate loss : {:.4f}\".format(loss))\n",
- "\n",
- "\n",
- "def do_evaluate(val_loader, model, criterion, args):\n",
- " batch_time = utils.AverageMeter('Time', ':6.2f')\n",
- " data_time = utils.AverageMeter('Data', ':2.2f')\n",
- " losses = utils.AverageMeter('Loss', ':.4f')\n",
- " \n",
- " progress = utils.ProgressMeter(\n",
- " len(val_loader),\n",
- " [batch_time, data_time, losses],\n",
- " prefix=\"Evaluate \")\n",
- " end = time.time()\n",
- "\n",
- " model.eval()\n",
- " with torch.no_grad():\n",
- " for i, (input, target) in enumerate(val_loader):\n",
- " # measure data loading time\n",
- " data_time.update(time.time() - end)\n",
- " \n",
- " input = input.cuda(non_blocking=True)\n",
- " target = target.cuda(non_blocking=True)\n",
- "\n",
- " output = model(input)\n",
- "\n",
- " loss = criterion(output, target)\n",
- "\n",
- " # record loss\n",
- " losses.update(loss.item(), input.size(0)) \n",
- " batch_time.update(time.time() - end) \n",
- " end = time.time() \n",
- "\n",
- " if i % args.print_freq == 0:\n",
- " progress.display(i)\n",
- " \n",
- " return losses.avg\n",
- "\n",
- "if __name__ == '__main__':\n",
- "\n",
- " args = get_args()\n",
- "\n",
- " main(args)\n",
- "\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "import random\n",
- "\n",
- "from PIL import Image, ImageFilter\n",
- "\n",
- "import torch\n",
- "import torch.utils.data as data\n",
- "\n",
- "from torchvision.transforms import transforms\n",
- "\n",
- "class ImageDataset(data.Dataset):\n",
- " def __init__(self, ann_file, transform=None):\n",
- " self.ann_file = ann_file\n",
- " self.transform = transform\n",
- " self.init()\n",
- "\n",
- " def init(self):\n",
- " \n",
- " self.im_names = []\n",
- " self.targets = []\n",
- " with open(self.ann_file, 'r') as f:\n",
- " lines = f.readlines()\n",
- " for line in lines:\n",
- " data = line.strip().split(' ')\n",
- " self.im_names.append(data[0])\n",
- " self.targets.append(int(data[1]))\n",
- "\n",
- " def __getitem__(self, index):\n",
- " im_name = self.im_names[index]\n",
- " target = self.targets[index]\n",
- "\n",
- " img = Image.open(im_name).convert('RGB') \n",
- " if img is None:\n",
- " print(im_name)\n",
- " \n",
- " img = self.transform(img)\n",
- "\n",
- " return img, img\n",
- "\n",
- " def __len__(self):\n",
- " return len(self.im_names)\n",
- "\n",
- "def train_loader(args):\n",
- "\n",
- " # [NO] do not use normalize here cause it's very hard to converge\n",
- " # [NO] do not use colorjitter cause it lead to performance drop in both train set and val set\n",
- "\n",
- " # [?] guassian blur will lead to a significantly drop in train loss while val loss remain the same\n",
- "\n",
- " augmentation = [\n",
- " transforms.RandomResizedCrop(224, scale=(0.2, 1.)),\n",
- " #transforms.RandomGrayscale(p=0.2),\n",
- " #transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),\n",
- " transforms.RandomHorizontalFlip(),\n",
- " transforms.ToTensor(),\n",
- " ]\n",
- "\n",
- " train_trans = transforms.Compose(augmentation)\n",
- "\n",
- " train_dataset = ImageDataset(args.train_list, transform=train_trans) \n",
- " \n",
- " if args.parallel == 1:\n",
- " train_sampler = torch.utils.data.distributed.DistributedSampler(\n",
- " train_dataset,\n",
- " rank=args.rank,\n",
- " num_replicas=args.world_size,\n",
- " shuffle=True) \n",
- " else: \n",
- " train_sampler = None \n",
- "\n",
- " train_loader = torch.utils.data.DataLoader(\n",
- " train_dataset,\n",
- " shuffle=(train_sampler is None),\n",
- " batch_size=args.batch_size,\n",
- " num_workers=args.workers,\n",
- " pin_memory=True,\n",
- " sampler=train_sampler,\n",
- " drop_last=(train_sampler is None))\n",
- "\n",
- " return train_loader\n",
- "\n",
- "def val_loader(args):\n",
- "\n",
- " val_trans = transforms.Compose([\n",
- " transforms.Resize(256), \n",
- " transforms.CenterCrop(224),\n",
- " transforms.ToTensor()\n",
- " ])\n",
- "\n",
- " val_dataset = ImageDataset(args.val_list, transform=val_trans) \n",
- "\n",
- " val_loader = torch.utils.data.DataLoader(\n",
- " val_dataset,\n",
- " shuffle=False,\n",
- " batch_size=args.batch_size,\n",
- " num_workers=args.workers,\n",
- " pin_memory=True)\n",
- "\n",
- " return val_loader \n",
- "\n",
- "class GaussianBlur(object):\n",
- " \"\"\"Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709\"\"\"\n",
- "\n",
- " def __init__(self, sigma=[.1, 2.]):\n",
- " self.sigma = sigma\n",
- "\n",
- " def __call__(self, x):\n",
- " sigma = random.uniform(self.sigma[0], self.sigma[1])\n",
- " x = x.filter(ImageFilter.GaussianBlur(radius=sigma))\n",
- " return x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "import argparse\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "import torch\n",
- "\n",
- "from torchvision.transforms import transforms\n",
- "\n",
- "import sys\n",
- "sys.path.append(\"./\")\n",
- "\n",
- "import utils\n",
- "import models.builer as builder\n",
- "import dataloader\n",
- "\n",
- "import os\n",
- "\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
- "\n",
- "def get_args():\n",
- " # parse the args\n",
- " print('=> parse the args ...')\n",
- " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
- " parser.add_argument('--arch', default='vgg16', type=str, \n",
- " help='backbone architechture')\n",
- " parser.add_argument('--resume', type=str)\n",
- " parser.add_argument('--val_list', type=str) \n",
- " \n",
- " args = parser.parse_args()\n",
- "\n",
- " args.parallel = 0\n",
- " args.batch_size = 1\n",
- " args.workers = 0\n",
- "\n",
- " return args\n",
- "\n",
- "def main(args):\n",
- " print('=> torch version : {}'.format(torch.__version__))\n",
- "\n",
- " utils.init_seeds(1, cuda_deterministic=False)\n",
- "\n",
- " print('=> modeling the network ...')\n",
- " model = builder.BuildAutoEncoder(args) \n",
- " total_params = sum(p.numel() for p in model.parameters())\n",
- " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
- "\n",
- " print('=> loading pth from {} ...'.format(args.resume))\n",
- " utils.load_dict(args.resume, model)\n",
- " \n",
- " print('=> building the dataloader ...')\n",
- " train_loader = dataloader.val_loader(args)\n",
- "\n",
- " plt.figure(figsize=(16, 9))\n",
- "\n",
- " model.eval()\n",
- " print('=> reconstructing ...')\n",
- " with torch.no_grad():\n",
- " for i, (input, target) in enumerate(train_loader):\n",
- " \n",
- " input = input.cuda(non_blocking=True)\n",
- " target = target.cuda(non_blocking=True)\n",
- "\n",
- " output = model(input)\n",
- "\n",
- " input = transforms.ToPILImage()(input.squeeze().cpu())\n",
- " output = transforms.ToPILImage()(output.squeeze().cpu())\n",
- "\n",
- " plt.subplot(8,16,2*i+1, xticks=[], yticks=[])\n",
- " plt.imshow(input)\n",
- "\n",
- " plt.subplot(8,16,2*i+2, xticks=[], yticks=[])\n",
- " plt.imshow(output)\n",
- "\n",
- " if i == 63:\n",
- " break\n",
- "\n",
- " plt.savefig('figs/reconstruction.jpg')\n",
- "\n",
- "if __name__ == '__main__':\n",
- "\n",
- " args = get_args()\n",
- "\n",
- " main(args)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import os\n",
- "import argparse\n",
- "\n",
- "parser = argparse.ArgumentParser()\n",
- "parser.add_argument('--name', type=str) \n",
- "parser.add_argument('--path', type=str) \n",
- "args = parser.parse_args()\n",
- "\n",
- "folders = os.listdir(args.path)\n",
- "\n",
- "if not os.path.exists(\"list\"):\n",
- " os.makedirs(\"list\")\n",
- "\n",
- "fl = open('list/' + args.name + '_list.txt', 'w')\n",
- "fn = open('list/' + args.name + '_name.txt', 'w')\n",
- "\n",
- "for i, folder in enumerate(folders):\n",
- "\n",
- " fn.write(str(i) + ' ' + folder + '\\n')\n",
- "\n",
- " folder_path = os.path.join(args.path, folder)\n",
- " files = os.listdir(folder_path)\n",
- "\n",
- " for file in files:\n",
- "\n",
- " fl.write('{} {}\\n'.format(os.path.join(folder_path, file), i))\n",
- "\n",
- "fl.close()\n",
- "fn.close()\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "import argparse\n",
- "\n",
- "from PIL import Image\n",
- "\n",
- "import torch\n",
- "from torchvision.transforms import transforms\n",
- "\n",
- "import sys\n",
- "sys.path.append(\"./\")\n",
- "\n",
- "import utils\n",
- "import models.builer as builder\n",
- "\n",
- "import os\n",
- "\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
- "\n",
- "def get_args():\n",
- " # parse the args\n",
- " print('=> parse the args ...')\n",
- " parser = argparse.ArgumentParser(description='Encoder for auto encoder')\n",
- " parser.add_argument('--arch', default='vgg16', type=str, \n",
- " help='backbone architechture')\n",
- " parser.add_argument('--resume', type=str)\n",
- " parser.add_argument('--img_path',type=str) \n",
- " \n",
- " args = parser.parse_args()\n",
- "\n",
- " args.parallel = 0\n",
- " args.batch_size = 1\n",
- " args.workers = 0\n",
- "\n",
- " return args\n",
- "\n",
- "def encode(model, img):\n",
- "\n",
- " with torch.no_grad():\n",
- "\n",
- " code = model.module.encoder(img).cpu().numpy()\n",
- "\n",
- " return code\n",
- "\n",
- "def main(args):\n",
- " print('=> torch version : {}'.format(torch.__version__))\n",
- "\n",
- " utils.init_seeds(1, cuda_deterministic=False)\n",
- "\n",
- " print('=> modeling the network ...')\n",
- " model = builder.BuildAutoEncoder(args) \n",
- " total_params = sum(p.numel() for p in model.parameters())\n",
- " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
- "\n",
- " print('=> loading pth from {} ...'.format(args.resume))\n",
- " utils.load_dict(args.resume, model)\n",
- " \n",
- " trans = transforms.Compose([\n",
- " transforms.Resize(256), \n",
- " transforms.CenterCrop(224),\n",
- " transforms.ToTensor()\n",
- " ])\n",
- "\n",
- " img = Image.open(args.img_path).convert(\"RGB\")\n",
- "\n",
- " img = trans(img).unsqueeze(0).cuda()\n",
- "\n",
- " model.eval()\n",
- "\n",
- " code = encode(model, img)\n",
- "\n",
- " print(code.shape)\n",
- "\n",
- " # To do : any other postprocessing\n",
- "\n",
- "if __name__ == '__main__':\n",
- "\n",
- " args = get_args()\n",
- "\n",
- " main(args)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "#!/usr/bin/env python\n",
- "\n",
- "import argparse\n",
- "\n",
- "import matplotlib.pyplot as plt\n",
- "\n",
- "import torch\n",
- "\n",
- "from torchvision.transforms import transforms\n",
- "\n",
- "import sys\n",
- "sys.path.append(\"./\")\n",
- "\n",
- "import utils\n",
- "import models.builer as builder\n",
- "\n",
- "import os\n",
- "\n",
- "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
- "\n",
- "def get_args():\n",
- " # parse the args\n",
- " print('=> parse the args ...')\n",
- " parser = argparse.ArgumentParser(description='Trainer for auto encoder')\n",
- " parser.add_argument('--arch', default='vgg16', type=str, \n",
- " help='backbone architechture')\n",
- " parser.add_argument('--resume', type=str) \n",
- " \n",
- " args = parser.parse_args()\n",
- "\n",
- " args.parallel = 0\n",
- " args.batch_size = 1\n",
- " args.workers = 0\n",
- "\n",
- " return args\n",
- "\n",
- "def random_sample(arch):\n",
- "\n",
- " if arch in [\"vgg11\", \"vgg13\", \"vgg16\", \"vgg19\", \"resnet18\", \"resnet34\"]:\n",
- " return torch.randn((1,512,7,7))\n",
- " elif arch in [\"resnet50\", \"resnet101\", \"resnet152\"]:\n",
- " return torch.randn((1,2048,7,7))\n",
- " else:\n",
- " raise NotImplementedError(\"Do not have implemention except VGG and ResNet\")\n",
- "\n",
- "def main(args):\n",
- " print('=> torch version : {}'.format(torch.__version__))\n",
- "\n",
- " utils.init_seeds(1, cuda_deterministic=False)\n",
- "\n",
- " print('=> modeling the network ...')\n",
- " model = builder.BuildAutoEncoder(args) \n",
- " total_params = sum(p.numel() for p in model.parameters())\n",
- " print('=> num of params: {} ({}M)'.format(total_params, int(total_params * 4 / (1024*1024))))\n",
- "\n",
- " print('=> loading pth from {} ...'.format(args.resume))\n",
- " utils.load_dict(args.resume, model)\n",
- "\n",
- " trans = transforms.ToPILImage()\n",
- "\n",
- " plt.figure(figsize=(16, 9))\n",
- "\n",
- " model.eval()\n",
- " print('=> Genarating ...')\n",
- " with torch.no_grad():\n",
- " for i in range(128):\n",
- " \n",
- " input = random_sample(arch=args.arch).cuda()\n",
- "\n",
- " output = model.module.decoder(input)\n",
- "\n",
- " output = trans(output.squeeze().cpu())\n",
- "\n",
- " plt.subplot(8,16,i+1, xticks=[], yticks=[])\n",
- " plt.imshow(output)\n",
- "\n",
- " plt.savefig('figs/generation.jpg')\n",
- "\n",
- "if __name__ == '__main__':\n",
- "\n",
- " args = get_args()\n",
- "\n",
- " main(args)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch\n",
- "import torch.nn as nn\n",
- "\n",
- "def get_configs(arch='vgg16'):\n",
- "\n",
- " if arch == 'vgg11':\n",
- " configs = [1, 1, 2, 2, 2]\n",
- " elif arch == 'vgg13':\n",
- " configs = [2, 2, 2, 2, 2]\n",
- " elif arch == 'vgg16':\n",
- " configs = [2, 2, 3, 3, 3]\n",
- " elif arch == 'vgg19':\n",
- " configs = [2, 2, 4, 4, 4]\n",
- " else:\n",
- " raise ValueError(\"Undefined model\")\n",
- " \n",
- " return configs\n",
- "\n",
- "class VGGAutoEncoder(nn.Module):\n",
- "\n",
- " def __init__(self, configs):\n",
- "\n",
- " super(VGGAutoEncoder, self).__init__()\n",
- "\n",
- " # VGG without Bn as AutoEncoder is hard to train\n",
- " self.encoder = VGGEncoder(configs=configs, enable_bn=True)\n",
- " self.decoder = VGGDecoder(configs=configs[::-1], enable_bn=True)\n",
- " \n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " x = self.encoder(x)\n",
- " x = self.decoder(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class VGG(nn.Module):\n",
- "\n",
- " def __init__(self, configs, num_classes=1000, img_size=224, enable_bn=False):\n",
- " super(VGG, self).__init__()\n",
- "\n",
- " self.encoder = VGGEncoder(configs=configs, enable_bn=enable_bn)\n",
- "\n",
- " self.img_size = img_size / 32\n",
- "\n",
- " self.fc = nn.Sequential(\n",
- " nn.Linear(in_features=int(self.img_size*self.img_size*512), out_features=4096),\n",
- " nn.Dropout(p=0.5),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.Linear(in_features=4096, out_features=4096),\n",
- " nn.Dropout(p=0.5),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.Linear(in_features=4096, out_features=num_classes)\n",
- " )\n",
- " \n",
- " for m in self.modules():\n",
- " if isinstance(m, nn.Conv2d):\n",
- " nn.init.normal_(m.weight, mean=0, std=0.01)\n",
- " if m.bias is not None:\n",
- " nn.init.constant_(m.bias, 0)\n",
- " if isinstance(m, nn.Linear):\n",
- " nn.init.normal_(m.weight, mean=0, std=0.01)\n",
- " nn.init.constant_(m.bias, 0)\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " x = self.encoder(x)\n",
- "\n",
- " x = torch.flatten(x, 1)\n",
- "\n",
- " x = self.fc(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class VGGEncoder(nn.Module):\n",
- "\n",
- " def __init__(self, configs, enable_bn=False):\n",
- "\n",
- " super(VGGEncoder, self).__init__()\n",
- "\n",
- " if len(configs) != 5:\n",
- "\n",
- " raise ValueError(\"There should be 5 stage in VGG\")\n",
- "\n",
- " self.conv1 = EncoderBlock(input_dim=3, output_dim=64, hidden_dim=64, layers=configs[0], enable_bn=enable_bn)\n",
- " self.conv2 = EncoderBlock(input_dim=64, output_dim=128, hidden_dim=128, layers=configs[1], enable_bn=enable_bn)\n",
- " self.conv3 = EncoderBlock(input_dim=128, output_dim=256, hidden_dim=256, layers=configs[2], enable_bn=enable_bn)\n",
- " self.conv4 = EncoderBlock(input_dim=256, output_dim=512, hidden_dim=512, layers=configs[3], enable_bn=enable_bn)\n",
- " self.conv5 = EncoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[4], enable_bn=enable_bn)\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " x = self.conv1(x)\n",
- " x = self.conv2(x)\n",
- " x = self.conv3(x)\n",
- " x = self.conv4(x)\n",
- " x = self.conv5(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class VGGDecoder(nn.Module):\n",
- "\n",
- " def __init__(self, configs, enable_bn=False):\n",
- "\n",
- " super(VGGDecoder, self).__init__()\n",
- "\n",
- " if len(configs) != 5:\n",
- "\n",
- " raise ValueError(\"There should be 5 stage in VGG\")\n",
- "\n",
- " self.conv1 = DecoderBlock(input_dim=512, output_dim=512, hidden_dim=512, layers=configs[0], enable_bn=enable_bn)\n",
- " self.conv2 = DecoderBlock(input_dim=512, output_dim=256, hidden_dim=512, layers=configs[1], enable_bn=enable_bn)\n",
- " self.conv3 = DecoderBlock(input_dim=256, output_dim=128, hidden_dim=256, layers=configs[2], enable_bn=enable_bn)\n",
- " self.conv4 = DecoderBlock(input_dim=128, output_dim=64, hidden_dim=128, layers=configs[3], enable_bn=enable_bn)\n",
- " self.conv5 = DecoderBlock(input_dim=64, output_dim=3, hidden_dim=64, layers=configs[4], enable_bn=enable_bn)\n",
- " self.gate = nn.Sigmoid()\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " x = self.conv1(x)\n",
- " x = self.conv2(x)\n",
- " x = self.conv3(x)\n",
- " x = self.conv4(x)\n",
- " x = self.conv5(x)\n",
- " x = self.gate(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class EncoderBlock(nn.Module):\n",
- "\n",
- " def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False):\n",
- "\n",
- " super(EncoderBlock, self).__init__()\n",
- "\n",
- " if layers == 1:\n",
- "\n",
- " layer = EncoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
- "\n",
- " self.add_module('0 EncoderLayer', layer)\n",
- "\n",
- " else:\n",
- "\n",
- " for i in range(layers):\n",
- "\n",
- " if i == 0:\n",
- " layer = EncoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
- " elif i == (layers - 1):\n",
- " layer = EncoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
- " else:\n",
- " layer = EncoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
- " \n",
- " self.add_module('%d EncoderLayer' % i, layer)\n",
- " \n",
- " maxpool = nn.MaxPool2d(kernel_size=2, stride=2)\n",
- "\n",
- " self.add_module('%d MaxPooling' % layers, maxpool)\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " for name, layer in self.named_children():\n",
- "\n",
- " x = layer(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class DecoderBlock(nn.Module):\n",
- "\n",
- " def __init__(self, input_dim, hidden_dim, output_dim, layers, enable_bn=False):\n",
- "\n",
- " super(DecoderBlock, self).__init__()\n",
- "\n",
- " upsample = nn.ConvTranspose2d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=2, stride=2)\n",
- "\n",
- " self.add_module('0 UpSampling', upsample)\n",
- "\n",
- " if layers == 1:\n",
- "\n",
- " layer = DecoderLayer(input_dim=input_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
- "\n",
- " self.add_module('1 DecoderLayer', layer)\n",
- "\n",
- " else:\n",
- "\n",
- " for i in range(layers):\n",
- "\n",
- " if i == 0:\n",
- " layer = DecoderLayer(input_dim=input_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
- " elif i == (layers - 1):\n",
- " layer = DecoderLayer(input_dim=hidden_dim, output_dim=output_dim, enable_bn=enable_bn)\n",
- " else:\n",
- " layer = DecoderLayer(input_dim=hidden_dim, output_dim=hidden_dim, enable_bn=enable_bn)\n",
- " \n",
- " self.add_module('%d DecoderLayer' % (i+1), layer)\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " for name, layer in self.named_children():\n",
- "\n",
- " x = layer(x)\n",
- "\n",
- " return x\n",
- "\n",
- "class EncoderLayer(nn.Module):\n",
- "\n",
- " def __init__(self, input_dim, output_dim, enable_bn):\n",
- " super(EncoderLayer, self).__init__()\n",
- "\n",
- " if enable_bn:\n",
- " self.layer = nn.Sequential(\n",
- " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
- " nn.BatchNorm2d(output_dim),\n",
- " nn.ReLU(inplace=True),\n",
- " )\n",
- " else:\n",
- " self.layer = nn.Sequential(\n",
- " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
- " nn.ReLU(inplace=True),\n",
- " )\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " return self.layer(x)\n",
- "\n",
- "class DecoderLayer(nn.Module):\n",
- "\n",
- " def __init__(self, input_dim, output_dim, enable_bn):\n",
- " super(DecoderLayer, self).__init__()\n",
- "\n",
- " if enable_bn:\n",
- " self.layer = nn.Sequential(\n",
- " nn.BatchNorm2d(input_dim),\n",
- " nn.ReLU(inplace=True),\n",
- " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
- " )\n",
- " else:\n",
- " self.layer = nn.Sequential(\n",
- " nn.ReLU(inplace=True),\n",
- " nn.Conv2d(in_channels=input_dim, out_channels=output_dim, kernel_size=3, stride=1, padding=1),\n",
- " )\n",
- " \n",
- " def forward(self, x):\n",
- "\n",
- " return self.layer(x)\n",
- "\n",
- "if __name__ == \"__main__\":\n",
- "\n",
- " input = torch.randn((5,3,224,224))\n",
- "\n",
- " configs = get_configs()\n",
- "\n",
- " model = VGGAutoEncoder(configs)\n",
- "\n",
- " output = model(input)\n",
- "\n",
- " print(output.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch.nn as nn\n",
- "import torch.nn.parallel as parallel\n",
- "\n",
- "from . import vgg, resnet\n",
- "\n",
- "def BuildAutoEncoder(args):\n",
- "\n",
- " if args.arch in [\"vgg11\", \"vgg13\", \"vgg16\", \"vgg19\"]:\n",
- " configs = vgg.get_configs(args.arch)\n",
- " model = vgg.VGGAutoEncoder(configs)\n",
- "\n",
- " elif args.arch in [\"resnet18\", \"resnet34\", \"resnet50\", \"resnet101\", \"resnet152\"]:\n",
- " configs, bottleneck = resnet.get_configs(args.arch)\n",
- " model = resnet.ResNetAutoEncoder(configs, bottleneck)\n",
- " \n",
- " else:\n",
- " return None\n",
- " \n",
- " if args.parallel == 1:\n",
- " model = nn.SyncBatchNorm.convert_sync_batchnorm(model)\n",
- " model = parallel.DistributedDataParallel(\n",
- " model.to(args.gpu),\n",
- " device_ids=[args.gpu],\n",
- " output_device=args.gpu\n",
- " ) \n",
- " \n",
- " else:\n",
- " model = nn.DataParallel(model).cuda()\n",
- "\n",
- " return model"
- ]
- }
- ],
- "metadata": {
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 2
- }
|