transforms.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) 2021 megvii-model. All Rights Reserved.
  3. # ------------------------------------------------------------------------
  4. # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
  5. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. """
  11. Transforms and data augmentation for both image + bbox.
  12. """
  13. import copy
  14. import random
  15. import PIL
  16. import torch
  17. import torchvision.transforms as T
  18. import torchvision.transforms.functional as F
  19. from PIL import Image, ImageDraw
  20. from util.box_ops import box_xyxy_to_cxcywh
  21. from util.misc import interpolate
  22. import numpy as np
  23. import os
  24. def crop_mot(image, target, region):
  25. cropped_image = F.crop(image, *region)
  26. target = target.copy()
  27. i, j, h, w = region
  28. # should we do something wrt the original size?
  29. target["size"] = torch.tensor([h, w])
  30. fields = ["labels", "area", "iscrowd"]
  31. if 'obj_ids' in target:
  32. fields.append('obj_ids')
  33. if "boxes" in target:
  34. boxes = target["boxes"]
  35. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  36. cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
  37. for i, box in enumerate(cropped_boxes):
  38. l, t, r, b = box
  39. # if l < 0:
  40. # l = 0
  41. # if r < 0:
  42. # r = 0
  43. # if l > w:
  44. # l = w
  45. # if r > w:
  46. # r = w
  47. # if t < 0:
  48. # t = 0
  49. # if b < 0:
  50. # b = 0
  51. # if t > h:
  52. # t = h
  53. # if b > h:
  54. # b = h
  55. if l < 0 and r < 0:
  56. l = r = 0
  57. if l > w and r > w:
  58. l = r = w
  59. if t < 0 and b < 0:
  60. t = b = 0
  61. if t > h and b > h:
  62. t = b = h
  63. cropped_boxes[i] = torch.tensor([l, t, r, b], dtype=box.dtype)
  64. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  65. cropped_boxes = cropped_boxes.clamp(min=0)
  66. area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  67. target["boxes"] = cropped_boxes.reshape(-1, 4)
  68. target["area"] = area
  69. fields.append("boxes")
  70. if "masks" in target:
  71. # FIXME should we update the area here if there are no boxes?
  72. target['masks'] = target['masks'][:, i:i + h, j:j + w]
  73. fields.append("masks")
  74. # remove elements for which the boxes or masks that have zero area
  75. if "boxes" in target or "masks" in target:
  76. # favor boxes selection when defining which elements to keep
  77. # this is compatible with previous implementation
  78. if "boxes" in target:
  79. cropped_boxes = target['boxes'].reshape(-1, 2, 2)
  80. keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
  81. else:
  82. keep = target['masks'].flatten(1).any(1)
  83. for field in fields:
  84. target[field] = target[field][keep]
  85. return cropped_image, target
  86. def random_shift(image, target, region, sizes):
  87. oh, ow = sizes
  88. # step 1, shift crop and re-scale image firstly
  89. cropped_image = F.crop(image, *region)
  90. cropped_image = F.resize(cropped_image, sizes)
  91. target = target.copy()
  92. i, j, h, w = region
  93. # should we do something wrt the original size?
  94. target["size"] = torch.tensor([h, w])
  95. fields = ["labels", "area", "iscrowd"]
  96. if 'obj_ids' in target:
  97. fields.append('obj_ids')
  98. if "boxes" in target:
  99. boxes = target["boxes"]
  100. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  101. cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
  102. for i, box in enumerate(cropped_boxes):
  103. l, t, r, b = box
  104. if l < 0:
  105. l = 0
  106. if r < 0:
  107. r = 0
  108. if l > w:
  109. l = w
  110. if r > w:
  111. r = w
  112. if t < 0:
  113. t = 0
  114. if b < 0:
  115. b = 0
  116. if t > h:
  117. t = h
  118. if b > h:
  119. b = h
  120. # step 2, re-scale coords secondly
  121. ratio_h = 1.0 * oh / h
  122. ratio_w = 1.0 * ow / w
  123. cropped_boxes[i] = torch.tensor([ratio_w * l, ratio_h * t, ratio_w * r, ratio_h * b], dtype=box.dtype)
  124. cropped_boxes = cropped_boxes.reshape(-1, 2, 2)
  125. area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  126. target["boxes"] = cropped_boxes.reshape(-1, 4)
  127. target["area"] = area
  128. fields.append("boxes")
  129. if "masks" in target:
  130. # FIXME should we update the area here if there are no boxes?
  131. target['masks'] = target['masks'][:, i:i + h, j:j + w]
  132. fields.append("masks")
  133. # remove elements for which the boxes or masks that have zero area
  134. if "boxes" in target or "masks" in target:
  135. # favor boxes selection when defining which elements to keep
  136. # this is compatible with previous implementation
  137. if "boxes" in target:
  138. cropped_boxes = target['boxes'].reshape(-1, 2, 2)
  139. keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
  140. else:
  141. keep = target['masks'].flatten(1).any(1)
  142. for field in fields:
  143. target[field] = target[field][keep]
  144. return cropped_image, target
  145. def crop(image, target, region):
  146. cropped_image = F.crop(image, *region)
  147. target = target.copy()
  148. i, j, h, w = region
  149. # should we do something wrt the original size?
  150. target["size"] = torch.tensor([h, w])
  151. fields = ["labels", "area", "iscrowd"]
  152. if 'obj_ids' in target:
  153. fields.append('obj_ids')
  154. if "boxes" in target:
  155. boxes = target["boxes"]
  156. max_size = torch.as_tensor([w, h], dtype=torch.float32)
  157. cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
  158. cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
  159. cropped_boxes = cropped_boxes.clamp(min=0)
  160. area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
  161. target["boxes"] = cropped_boxes.reshape(-1, 4)
  162. target["area"] = area
  163. fields.append("boxes")
  164. if "masks" in target:
  165. # FIXME should we update the area here if there are no boxes?
  166. target['masks'] = target['masks'][:, i:i + h, j:j + w]
  167. fields.append("masks")
  168. # remove elements for which the boxes or masks that have zero area
  169. if "boxes" in target or "masks" in target:
  170. # favor boxes selection when defining which elements to keep
  171. # this is compatible with previous implementation
  172. if "boxes" in target:
  173. cropped_boxes = target['boxes'].reshape(-1, 2, 2)
  174. keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
  175. else:
  176. keep = target['masks'].flatten(1).any(1)
  177. for field in fields:
  178. target[field] = target[field][keep]
  179. return cropped_image, target
  180. def hflip(image, target):
  181. flipped_image = F.hflip(image)
  182. w, h = image.size
  183. target = target.copy()
  184. if "boxes" in target:
  185. boxes = target["boxes"]
  186. boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
  187. target["boxes"] = boxes
  188. if "masks" in target:
  189. target['masks'] = target['masks'].flip(-1)
  190. return flipped_image, target
  191. def resize(image, target, size, max_size=None):
  192. # size can be min_size (scalar) or (w, h) tuple
  193. def get_size_with_aspect_ratio(image_size, size, max_size=None):
  194. w, h = image_size
  195. if max_size is not None:
  196. min_original_size = float(min((w, h)))
  197. max_original_size = float(max((w, h)))
  198. if max_original_size / min_original_size * size > max_size:
  199. size = int(round(max_size * min_original_size / max_original_size))
  200. if (w <= h and w == size) or (h <= w and h == size):
  201. return (h, w)
  202. if w < h:
  203. ow = size
  204. oh = int(size * h / w)
  205. else:
  206. oh = size
  207. ow = int(size * w / h)
  208. return (oh, ow)
  209. def get_size(image_size, size, max_size=None):
  210. if isinstance(size, (list, tuple)):
  211. return size[::-1]
  212. else:
  213. return get_size_with_aspect_ratio(image_size, size, max_size)
  214. size = get_size(image.size, size, max_size)
  215. rescaled_image = F.resize(image, size)
  216. if target is None:
  217. return rescaled_image, None
  218. ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
  219. ratio_width, ratio_height = ratios
  220. target = target.copy()
  221. if "boxes" in target:
  222. boxes = target["boxes"]
  223. scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
  224. target["boxes"] = scaled_boxes
  225. if "area" in target:
  226. area = target["area"]
  227. scaled_area = area * (ratio_width * ratio_height)
  228. target["area"] = scaled_area
  229. h, w = size
  230. target["size"] = torch.tensor([h, w])
  231. if "masks" in target:
  232. target['masks'] = interpolate(
  233. target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
  234. return rescaled_image, target
  235. def pad(image, target, padding):
  236. # assumes that we only pad on the bottom right corners
  237. padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
  238. if target is None:
  239. return padded_image, None
  240. target = target.copy()
  241. # should we do something wrt the original size?
  242. target["size"] = torch.tensor(padded_image[::-1])
  243. if "masks" in target:
  244. target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
  245. return padded_image, target
  246. class RandomCrop(object):
  247. def __init__(self, size):
  248. self.size = size
  249. def __call__(self, img, target):
  250. region = T.RandomCrop.get_params(img, self.size)
  251. return crop(img, target, region)
  252. class MotRandomCrop(RandomCrop):
  253. def __call__(self, imgs: list, targets: list):
  254. ret_imgs = []
  255. ret_targets = []
  256. region = T.RandomCrop.get_params(imgs[0], self.size)
  257. for img_i, targets_i in zip(imgs, targets):
  258. img_i, targets_i = crop(img_i, targets_i, region)
  259. ret_imgs.append(img_i)
  260. ret_targets.append(targets_i)
  261. return ret_imgs, ret_targets
  262. class FixedMotRandomCrop(object):
  263. def __init__(self, min_size: int, max_size: int):
  264. self.min_size = min_size
  265. self.max_size = max_size
  266. def __call__(self, imgs: list, targets: list):
  267. ret_imgs = []
  268. ret_targets = []
  269. w = random.randint(self.min_size, min(imgs[0].width, self.max_size))
  270. h = random.randint(self.min_size, min(imgs[0].height, self.max_size))
  271. region = T.RandomCrop.get_params(imgs[0], [h, w])
  272. for img_i, targets_i in zip(imgs, targets):
  273. img_i, targets_i = crop_mot(img_i, targets_i, region)
  274. ret_imgs.append(img_i)
  275. ret_targets.append(targets_i)
  276. return ret_imgs, ret_targets
  277. class MotRandomShift(object):
  278. def __init__(self, bs=1):
  279. self.bs = bs
  280. def __call__(self, imgs: list, targets: list):
  281. ret_imgs = copy.deepcopy(imgs)
  282. ret_targets = copy.deepcopy(targets)
  283. n_frames = len(imgs)
  284. select_i = random.choice(list(range(n_frames)))
  285. w, h = imgs[select_i].size
  286. xshift = (100 * torch.rand(self.bs)).int()
  287. xshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1
  288. yshift = (100 * torch.rand(self.bs)).int()
  289. yshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1
  290. ymin = max(0, -yshift[0])
  291. ymax = min(h, h - yshift[0])
  292. xmin = max(0, -xshift[0])
  293. xmax = min(w, w - xshift[0])
  294. region = (int(ymin), int(xmin), int(ymax-ymin), int(xmax-xmin))
  295. ret_imgs[select_i], ret_targets[select_i] = random_shift(imgs[select_i], targets[select_i], region, (h,w))
  296. return ret_imgs, ret_targets
  297. class FixedMotRandomShift(object):
  298. def __init__(self, bs=1, padding=50):
  299. self.bs = bs
  300. self.padding = padding
  301. def __call__(self, imgs: list, targets: list):
  302. ret_imgs = []
  303. ret_targets = []
  304. n_frames = len(imgs)
  305. w, h = imgs[0].size
  306. xshift = (self.padding * torch.rand(self.bs)).int() + 1
  307. xshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1
  308. yshift = (self.padding * torch.rand(self.bs)).int() + 1
  309. yshift *= (torch.randn(self.bs) > 0.0).int() * 2 - 1
  310. ret_imgs.append(imgs[0])
  311. ret_targets.append(targets[0])
  312. for i in range(1, n_frames):
  313. ymin = max(0, -yshift[0])
  314. ymax = min(h, h - yshift[0])
  315. xmin = max(0, -xshift[0])
  316. xmax = min(w, w - xshift[0])
  317. prev_img = ret_imgs[i-1].copy()
  318. prev_target = copy.deepcopy(ret_targets[i-1])
  319. region = (int(ymin), int(xmin), int(ymax - ymin), int(xmax - xmin))
  320. img_i, target_i = random_shift(prev_img, prev_target, region, (h, w))
  321. ret_imgs.append(img_i)
  322. ret_targets.append(target_i)
  323. return ret_imgs, ret_targets
  324. class RandomSizeCrop(object):
  325. def __init__(self, min_size: int, max_size: int):
  326. self.min_size = min_size
  327. self.max_size = max_size
  328. def __call__(self, img: PIL.Image.Image, target: dict):
  329. w = random.randint(self.min_size, min(img.width, self.max_size))
  330. h = random.randint(self.min_size, min(img.height, self.max_size))
  331. region = T.RandomCrop.get_params(img, [h, w])
  332. return crop(img, target, region)
  333. class MotRandomSizeCrop(RandomSizeCrop):
  334. def __call__(self, imgs, targets):
  335. w = random.randint(self.min_size, min(imgs[0].width, self.max_size))
  336. h = random.randint(self.min_size, min(imgs[0].height, self.max_size))
  337. region = T.RandomCrop.get_params(imgs[0], [h, w])
  338. ret_imgs = []
  339. ret_targets = []
  340. for img_i, targets_i in zip(imgs, targets):
  341. img_i, targets_i = crop(img_i, targets_i, region)
  342. ret_imgs.append(img_i)
  343. ret_targets.append(targets_i)
  344. return ret_imgs, ret_targets
  345. class CenterCrop(object):
  346. def __init__(self, size):
  347. self.size = size
  348. def __call__(self, img, target):
  349. image_width, image_height = img.size
  350. crop_height, crop_width = self.size
  351. crop_top = int(round((image_height - crop_height) / 2.))
  352. crop_left = int(round((image_width - crop_width) / 2.))
  353. return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
  354. class MotCenterCrop(CenterCrop):
  355. def __call__(self, imgs, targets):
  356. image_width, image_height = imgs[0].size
  357. crop_height, crop_width = self.size
  358. crop_top = int(round((image_height - crop_height) / 2.))
  359. crop_left = int(round((image_width - crop_width) / 2.))
  360. ret_imgs = []
  361. ret_targets = []
  362. for img_i, targets_i in zip(imgs, targets):
  363. img_i, targets_i = crop(img_i, targets_i, (crop_top, crop_left, crop_height, crop_width))
  364. ret_imgs.append(img_i)
  365. ret_targets.append(targets_i)
  366. return ret_imgs, ret_targets
  367. class RandomHorizontalFlip(object):
  368. def __init__(self, p=0.5):
  369. self.p = p
  370. def __call__(self, img, target):
  371. if random.random() < self.p:
  372. return hflip(img, target)
  373. return img, target
  374. class MotRandomHorizontalFlip(RandomHorizontalFlip):
  375. def __call__(self, imgs, targets):
  376. if random.random() < self.p:
  377. ret_imgs = []
  378. ret_targets = []
  379. for img_i, targets_i in zip(imgs, targets):
  380. img_i, targets_i = hflip(img_i, targets_i)
  381. ret_imgs.append(img_i)
  382. ret_targets.append(targets_i)
  383. return ret_imgs, ret_targets
  384. return imgs, targets
  385. class RandomResize(object):
  386. def __init__(self, sizes, max_size=None):
  387. assert isinstance(sizes, (list, tuple))
  388. self.sizes = sizes
  389. self.max_size = max_size
  390. def __call__(self, img, target=None):
  391. size = random.choice(self.sizes)
  392. return resize(img, target, size, self.max_size)
  393. class MotRandomResize(RandomResize):
  394. def __call__(self, imgs, targets):
  395. size = random.choice(self.sizes)
  396. ret_imgs = []
  397. ret_targets = []
  398. for img_i, targets_i in zip(imgs, targets):
  399. img_i, targets_i = resize(img_i, targets_i, size, self.max_size)
  400. ret_imgs.append(img_i)
  401. ret_targets.append(targets_i)
  402. return ret_imgs, ret_targets
  403. class RandomPad(object):
  404. def __init__(self, max_pad):
  405. self.max_pad = max_pad
  406. def __call__(self, img, target):
  407. pad_x = random.randint(0, self.max_pad)
  408. pad_y = random.randint(0, self.max_pad)
  409. return pad(img, target, (pad_x, pad_y))
  410. class MotRandomPad(RandomPad):
  411. def __call__(self, imgs, targets):
  412. pad_x = random.randint(0, self.max_pad)
  413. pad_y = random.randint(0, self.max_pad)
  414. ret_imgs = []
  415. ret_targets = []
  416. for img_i, targets_i in zip(imgs, targets):
  417. img_i, target_i = pad(img_i, targets_i, (pad_x, pad_y))
  418. ret_imgs.append(img_i)
  419. ret_targets.append(targets_i)
  420. return ret_imgs, ret_targets
  421. class RandomSelect(object):
  422. """
  423. Randomly selects between transforms1 and transforms2,
  424. with probability p for transforms1 and (1 - p) for transforms2
  425. """
  426. def __init__(self, transforms1, transforms2, p=0.5):
  427. self.transforms1 = transforms1
  428. self.transforms2 = transforms2
  429. self.p = p
  430. def __call__(self, img, target):
  431. if random.random() < self.p:
  432. return self.transforms1(img, target)
  433. return self.transforms2(img, target)
  434. class MotRandomSelect(RandomSelect):
  435. """
  436. Randomly selects between transforms1 and transforms2,
  437. with probability p for transforms1 and (1 - p) for transforms2
  438. """
  439. def __call__(self, imgs, targets):
  440. if random.random() < self.p:
  441. return self.transforms1(imgs, targets)
  442. return self.transforms2(imgs, targets)
  443. class ToTensor(object):
  444. def __call__(self, img, target):
  445. return F.to_tensor(img), target
  446. class MotToTensor(ToTensor):
  447. def __call__(self, imgs, targets):
  448. ret_imgs = []
  449. for img in imgs:
  450. ret_imgs.append(F.to_tensor(img))
  451. return ret_imgs, targets
  452. class RandomErasing(object):
  453. def __init__(self, *args, **kwargs):
  454. self.eraser = T.RandomErasing(*args, **kwargs)
  455. def __call__(self, img, target):
  456. return self.eraser(img), target
  457. class MotRandomErasing(RandomErasing):
  458. def __call__(self, imgs, targets):
  459. # TODO: Rewrite this part to ensure the data augmentation is same to each image.
  460. ret_imgs = []
  461. for img_i, targets_i in zip(imgs, targets):
  462. ret_imgs.append(self.eraser(img_i))
  463. return ret_imgs, targets
  464. class MoTColorJitter(T.ColorJitter):
  465. def __call__(self, imgs, targets):
  466. transform = self.get_params(self.brightness, self.contrast,
  467. self.saturation, self.hue)
  468. ret_imgs = []
  469. for img_i, targets_i in zip(imgs, targets):
  470. ret_imgs.append(transform(img_i))
  471. return ret_imgs, targets
  472. class Normalize(object):
  473. def __init__(self, mean, std):
  474. self.mean = mean
  475. self.std = std
  476. def __call__(self, image, target=None):
  477. if target is not None:
  478. target['ori_img'] = image.clone()
  479. image = F.normalize(image, mean=self.mean, std=self.std)
  480. if target is None:
  481. return image, None
  482. target = target.copy()
  483. h, w = image.shape[-2:]
  484. if "boxes" in target:
  485. boxes = target["boxes"]
  486. boxes = box_xyxy_to_cxcywh(boxes)
  487. boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
  488. target["boxes"] = boxes
  489. return image, target
  490. class MotNormalize(Normalize):
  491. def __call__(self, imgs, targets=None):
  492. ret_imgs = []
  493. ret_targets = []
  494. for i in range(len(imgs)):
  495. img_i = imgs[i]
  496. targets_i = targets[i] if targets is not None else None
  497. img_i, targets_i = super().__call__(img_i, targets_i)
  498. ret_imgs.append(img_i)
  499. ret_targets.append(targets_i)
  500. return ret_imgs, ret_targets
  501. class Compose(object):
  502. def __init__(self, transforms):
  503. self.transforms = transforms
  504. def __call__(self, image, target):
  505. for t in self.transforms:
  506. image, target = t(image, target)
  507. return image, target
  508. def __repr__(self):
  509. format_string = self.__class__.__name__ + "("
  510. for t in self.transforms:
  511. format_string += "\n"
  512. format_string += " {0}".format(t)
  513. format_string += "\n)"
  514. return format_string
  515. class MotCompose(Compose):
  516. def __call__(self, imgs, targets):
  517. for t in self.transforms:
  518. imgs, targets = t(imgs, targets)
  519. return imgs, targets