train.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import argparse
  2. import os
  3. import random
  4. import shutil
  5. import time
  6. from os.path import isfile, join, split
  7. import torch
  8. import torchvision
  9. import torch.nn as nn
  10. import torch.backends.cudnn as cudnn
  11. import numpy as np
  12. import torch.optim
  13. import tqdm
  14. import yaml
  15. from torch.optim import lr_scheduler
  16. from logger import Logger
  17. from dataloader import get_loader
  18. from model.network import Net
  19. from skimage.measure import label, regionprops
  20. from tensorboardX import SummaryWriter
  21. from mask_component_utils import reverse_mapping, edge_align
  22. from hungarian_matching import caculate_tp_fp_fn
  23. parser = argparse.ArgumentParser(description='PyTorch Semantic-Line Training')
  24. # arguments from command line
  25. parser.add_argument('--config', default="./config.yml", help="path to config file")
  26. parser.add_argument('--resume', default="", help='path to config file')
  27. parser.add_argument('--tmp', default="", help='tmp')
  28. args = parser.parse_args()
  29. assert os.path.isfile(args.config)
  30. CONFIGS = yaml.load(open(args.config))
  31. # merge configs
  32. if args.tmp != "" and args.tmp != CONFIGS["MISC"]["TMP"]:
  33. CONFIGS["MISC"]["TMP"] = args.tmp
  34. CONFIGS["OPTIMIZER"]["WEIGHT_DECAY"] = float(CONFIGS["OPTIMIZER"]["WEIGHT_DECAY"])
  35. CONFIGS["OPTIMIZER"]["LR"] = float(CONFIGS["OPTIMIZER"]["LR"])
  36. os.makedirs(CONFIGS["MISC"]["TMP"], exist_ok=True)
  37. logger = Logger(os.path.join(CONFIGS["MISC"]["TMP"], "log.txt"))
  38. logger.info(CONFIGS)
  39. def main():
  40. logger.info(args)
  41. assert os.path.isdir(CONFIGS["DATA"]["DIR"])
  42. if CONFIGS['TRAIN']['SEED'] is not None:
  43. random.seed(CONFIGS['TRAIN']['SEED'])
  44. torch.manual_seed(CONFIGS['TRAIN']['SEED'])
  45. cudnn.deterministic = True
  46. model = Net(numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], backbone=CONFIGS["MODEL"]["BACKBONE"])
  47. if CONFIGS["TRAIN"]["DATA_PARALLEL"]:
  48. logger.info("Model Data Parallel")
  49. model = nn.DataParallel(model).cuda()
  50. else:
  51. model = model.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  52. # optimizer
  53. optimizer = torch.optim.Adam(
  54. model.parameters(),
  55. lr=CONFIGS["OPTIMIZER"]["LR"],
  56. weight_decay=CONFIGS["OPTIMIZER"]["WEIGHT_DECAY"]
  57. )
  58. # learning rate scheduler
  59. scheduler = lr_scheduler.MultiStepLR(optimizer,
  60. milestones=CONFIGS["OPTIMIZER"]["STEPS"],
  61. gamma=CONFIGS["OPTIMIZER"]["GAMMA"])
  62. best_acc1 = 0
  63. if args.resume:
  64. if isfile(args.resume):
  65. logger.info("=> loading checkpoint '{}'".format(args.resume))
  66. checkpoint = torch.load(args.resume)
  67. args.start_epoch = checkpoint['epoch']
  68. best_acc1 = checkpoint['best_acc1']
  69. model.load_state_dict(checkpoint['state_dict'])
  70. # optimizer.load_state_dict(checkpoint['optimizer'])
  71. logger.info("=> loaded checkpoint '{}' (epoch {})"
  72. .format(args.resume, checkpoint['epoch']))
  73. else:
  74. logger.info("=> no checkpoint found at '{}'".format(args.resume))
  75. # dataloader
  76. train_loader = get_loader(CONFIGS["DATA"]["DIR"], CONFIGS["DATA"]["LABEL_FILE"],
  77. batch_size=CONFIGS["DATA"]["BATCH_SIZE"], num_thread=CONFIGS["DATA"]["WORKERS"], split='train')
  78. val_loader = get_loader(CONFIGS["DATA"]["VAL_DIR"], CONFIGS["DATA"]["VAL_LABEL_FILE"],
  79. batch_size=1, num_thread=CONFIGS["DATA"]["WORKERS"], split='val')
  80. logger.info("Data loading done.")
  81. # Tensorboard summary
  82. writer = SummaryWriter(log_dir=os.path.join(CONFIGS["MISC"]["TMP"]))
  83. start_epoch = 0
  84. best_acc = best_acc1
  85. is_best = False
  86. start_time = time.time()
  87. if CONFIGS["TRAIN"]["RESUME"] is not None:
  88. raise(NotImplementedError)
  89. if CONFIGS["TRAIN"]["TEST"]:
  90. validate(val_loader, model, 0, writer, args)
  91. return
  92. logger.info("Start training.")
  93. for epoch in range(start_epoch, CONFIGS["TRAIN"]["EPOCHS"]):
  94. train(train_loader, model, optimizer, epoch, writer, args)
  95. acc = validate(val_loader, model, epoch, writer, args)
  96. #return
  97. scheduler.step()
  98. if best_acc < acc:
  99. is_best = True
  100. best_acc = acc
  101. else:
  102. is_best = False
  103. save_checkpoint({
  104. 'epoch': epoch + 1,
  105. 'state_dict': model.state_dict(),
  106. 'best_acc1': best_acc,
  107. 'optimizer' : optimizer.state_dict()
  108. }, is_best, path=CONFIGS["MISC"]["TMP"])
  109. t = time.time() - start_time
  110. elapsed = DayHourMinute(t)
  111. t /= (epoch + 1) - start_epoch # seconds per epoch
  112. t = (CONFIGS["TRAIN"]["EPOCHS"] - epoch - 1) * t
  113. remaining = DayHourMinute(t)
  114. logger.info("Epoch {0}/{1} finishied, auxiliaries saved to {2} .\t"
  115. "Elapsed {elapsed.days:d} days {elapsed.hours:d} hours {elapsed.minutes:d} minutes.\t"
  116. "Remaining {remaining.days:d} days {remaining.hours:d} hours {remaining.minutes:d} minutes.".format(
  117. epoch, CONFIGS["TRAIN"]["EPOCHS"], CONFIGS["MISC"]["TMP"], elapsed=elapsed, remaining=remaining))
  118. logger.info("Optimization done, ALL results saved to %s." % CONFIGS["MISC"]["TMP"])
  119. def train(train_loader, model, optimizer, epoch, writer, args):
  120. # switch to train mode
  121. model.train()
  122. # torch.cuda.empty_cache()
  123. bar = tqdm.tqdm(train_loader)
  124. iter_num = len(train_loader.dataset) // CONFIGS["DATA"]["BATCH_SIZE"]
  125. total_loss_hough = 0
  126. for i, data in enumerate(bar):
  127. images, hough_space_label, _, names = data
  128. if CONFIGS["TRAIN"]["DATA_PARALLEL"]:
  129. images = images.cuda()
  130. hough_space_label = hough_space_label.cuda()
  131. else:
  132. images = images.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  133. hough_space_label = hough_space_label.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  134. keypoint_map = model(images)
  135. hough_space_loss = torch.nn.functional.binary_cross_entropy_with_logits(keypoint_map, hough_space_label)
  136. writer.add_scalar('train/hough_space_loss', hough_space_loss.item(), epoch * iter_num + i)
  137. loss = hough_space_loss
  138. if not torch.isnan(hough_space_loss):
  139. total_loss_hough += hough_space_loss.item()
  140. else:
  141. logger.info("Warnning: loss is Nan.")
  142. #record loss
  143. bar.set_description('Training Loss:{}'.format(loss.item()))
  144. # compute gradient and do SGD step
  145. optimizer.zero_grad()
  146. loss.backward()
  147. optimizer.step()
  148. if i % CONFIGS["TRAIN"]["PRINT_FREQ"] == 0:
  149. visualize_save_path = os.path.join(CONFIGS["MISC"]["TMP"], 'visualize', str(epoch))
  150. os.makedirs(visualize_save_path, exist_ok=True)
  151. # Do visualization.
  152. # torchvision.utils.save_image(torch.sigmoid(keypoint_map), join(visualize_save_path, 'rodon_'+names[0]), normalize=True)
  153. # torchvision.utils.save_image(torch.sum(vis, dim=1, keepdim=True), join(visualize_save_path, 'vis_'+names[0]), normalize=True)
  154. total_loss_hough = total_loss_hough / iter_num
  155. writer.add_scalar('train/total_loss_hough', total_loss_hough, epoch)
  156. def validate(val_loader, model, epoch, writer, args):
  157. # switch to evaluate mode
  158. model.eval()
  159. total_acc = 0.0
  160. total_loss_hough = 0
  161. total_tp = np.zeros(99)
  162. total_fp = np.zeros(99)
  163. total_fn = np.zeros(99)
  164. total_tp_align = np.zeros(99)
  165. total_fp_align = np.zeros(99)
  166. total_fn_align = np.zeros(99)
  167. with torch.no_grad():
  168. bar = tqdm.tqdm(val_loader)
  169. iter_num = len(val_loader.dataset) // 1
  170. for i, data in enumerate(bar):
  171. images, hough_space_label8, gt_coords, names = data
  172. if CONFIGS["TRAIN"]["DATA_PARALLEL"]:
  173. images = images.cuda()
  174. hough_space_label8 = hough_space_label8.cuda()
  175. else:
  176. images = images.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  177. hough_space_label8 = hough_space_label8.cuda(device=CONFIGS["TRAIN"]["GPU_ID"])
  178. keypoint_map = model(images)
  179. hough_space_loss = torch.nn.functional.binary_cross_entropy_with_logits(keypoint_map, hough_space_label8)
  180. writer.add_scalar('val/hough_space_loss', hough_space_loss.item(), epoch * iter_num + i)
  181. acc = 0
  182. total_acc += acc
  183. loss = hough_space_loss
  184. if not torch.isnan(loss):
  185. total_loss_hough += loss.item()
  186. else:
  187. logger.info("Warnning: val loss is Nan.")
  188. key_points = torch.sigmoid(keypoint_map)
  189. binary_kmap = key_points.squeeze().cpu().numpy() > CONFIGS['MODEL']['THRESHOLD']
  190. kmap_label = label(binary_kmap, connectivity=1)
  191. props = regionprops(kmap_label)
  192. plist = []
  193. for prop in props:
  194. plist.append(prop.centroid)
  195. b_points = reverse_mapping(plist, numAngle=CONFIGS["MODEL"]["NUMANGLE"], numRho=CONFIGS["MODEL"]["NUMRHO"], size=(400, 400))
  196. # [[y1, x1, y2, x2], [] ...]
  197. gt_coords = gt_coords[0].tolist()
  198. for i in range(1, 100):
  199. tp, fp, fn = caculate_tp_fp_fn(b_points, gt_coords, thresh=i*0.01)
  200. total_tp[i-1] += tp
  201. total_fp[i-1] += fp
  202. total_fn[i-1] += fn
  203. if CONFIGS["MODEL"]["EDGE_ALIGN"]:
  204. for i in range(len(b_points)):
  205. b_points[i] = edge_align(b_points[i], names[0], division=5)
  206. for i in range(1, 100):
  207. tp, fp, fn = caculate_tp_fp_fn(b_points, gt_coords, thresh=i*0.01)
  208. total_tp_align[i-1] += tp
  209. total_fp_align[i-1] += fp
  210. total_fn_align[i-1] += fn
  211. total_loss_hough = total_loss_hough / iter_num
  212. total_recall = total_tp / (total_tp + total_fn + 1e-8)
  213. total_precision = total_tp / (total_tp + total_fp + 1e-8)
  214. f = 2 * total_recall * total_precision / (total_recall + total_precision + 1e-8)
  215. writer.add_scalar('val/total_loss_hough', total_loss_hough, epoch)
  216. writer.add_scalar('val/total_precison', total_precision.mean(), epoch)
  217. writer.add_scalar('val/total_recall', total_recall.mean(), epoch)
  218. logger.info('Validation result: ==== Precision: %.5f, Recall: %.5f' % (total_precision.mean(), total_recall.mean()))
  219. acc = f.mean()
  220. logger.info('Validation result: ==== F-measure: %.5f' % acc.mean())
  221. logger.info('Validation result: ==== F-measure@0.95: %.5f' % f[95 - 1])
  222. writer.add_scalar('val/f-measure', acc.mean(), epoch)
  223. writer.add_scalar('val/f-measure@0.95', f[95 - 1], epoch)
  224. if CONFIGS["MODEL"]["EDGE_ALIGN"]:
  225. total_recall_align = total_tp_align / (total_tp_align + total_fn_align + 1e-8)
  226. total_precision_align = total_tp_align / (total_tp_align + total_fp_align + 1e-8)
  227. f_align = 2 * total_recall_align * total_precision_align / (total_recall_align + total_precision_align + 1e-8)
  228. writer.add_scalar('val/total_precison_align', total_precision_align.mean(), epoch)
  229. writer.add_scalar('val/total_recall_align', total_recall_align.mean(), epoch)
  230. logger.info('Validation result (Aligned): ==== Precision: %.5f, Recall: %.5f' % (total_precision_align.mean(), total_recall_align.mean()))
  231. acc = f_align.mean()
  232. logger.info('Validation result (Aligned): ==== F-measure: %.5f' % acc.mean())
  233. logger.info('Validation result (Aligned): ==== F-measure@0.95: %.5f' % f_align[95 - 1])
  234. writer.add_scalar('val/f-measure', acc.mean(), epoch)
  235. writer.add_scalar('val/f-measure@0.95', f_align[95 - 1], epoch)
  236. return acc.mean()
  237. def save_checkpoint(state, is_best, path, filename='checkpoint.pth.tar'):
  238. torch.save(state, os.path.join(path, filename))
  239. if is_best:
  240. shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'model_best.pth'))
  241. def get_lr(optimizer):
  242. for param_group in optimizer.param_groups:
  243. return param_group['lr']
  244. class DayHourMinute(object):
  245. def __init__(self, seconds):
  246. self.days = int(seconds // 86400)
  247. self.hours = int((seconds- (self.days * 86400)) // 3600)
  248. self.minutes = int((seconds - self.days * 86400 - self.hours * 3600) // 60)
  249. if __name__ == '__main__':
  250. main()