dataloader.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import numpy as np
  2. import os
  3. from os.path import join, split, isdir, isfile, abspath
  4. import torch
  5. from PIL import Image
  6. import random
  7. import collections
  8. from torchvision import transforms
  9. from torch.utils.data import Dataset, DataLoader
  10. class SemanLineDataset(Dataset):
  11. def __init__(self, root_dir, label_file, split='train', transform=None, t_transform=None):
  12. lines = [line.rstrip('\n') for line in open(label_file)]
  13. self.image_path = [join(root_dir, i+".jpg") for i in lines]
  14. self.data_path = [join(root_dir, i+".npy") for i in lines]
  15. self.split = split
  16. self.transform = transform
  17. self.t_transform = t_transform
  18. def __getitem__(self, item):
  19. assert isfile(self.image_path[item]), self.image_path[item]
  20. image = Image.open(self.image_path[item]).convert('RGB')
  21. data = np.load(self.data_path[item], allow_pickle=True).item()
  22. hough_space_label8 = data["hough_space_label8"].astype(np.float32)
  23. if self.transform is not None:
  24. image = self.transform(image)
  25. hough_space_label8 = torch.from_numpy(hough_space_label8).unsqueeze(0)
  26. gt_coords = data["coords"]
  27. if self.split == 'val':
  28. return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
  29. elif self.split == 'train':
  30. return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
  31. def __len__(self):
  32. return len(self.image_path)
  33. def collate_fn(self, batch):
  34. images, hough_space_label8, gt_coords, names = list(zip(*batch))
  35. images = torch.stack([image for image in images])
  36. hough_space_label8 = torch.stack([hough_space_label for hough_space_label in hough_space_label8])
  37. return images, hough_space_label8, gt_coords, names
  38. class SemanLineDatasetTest(Dataset):
  39. def __init__(self, root_dir, label_file, transform=None, t_transform=None):
  40. lines = [line.rstrip('\n') for line in open(label_file)]
  41. self.image_path = [join(root_dir, i+".jpg") for i in lines]
  42. self.transform = transform
  43. self.t_transform = t_transform
  44. def __getitem__(self, item):
  45. assert isfile(self.image_path[item]), self.image_path[item]
  46. image = Image.open(self.image_path[item]).convert('RGB')
  47. w, h = image.size
  48. if self.transform is not None:
  49. image = self.transform(image)
  50. return image, self.image_path[item].split('/')[-1], (h, w)
  51. def __len__(self):
  52. return len(self.image_path)
  53. def collate_fn(self, batch):
  54. images, names, sizes = list(zip(*batch))
  55. images = torch.stack([image for image in images])
  56. return images, names, sizes
  57. def get_loader(root_dir, label_file, batch_size, img_size=0, num_thread=4, pin=True, test=False, split='train'):
  58. if test is False:
  59. transform = transforms.Compose([
  60. # transforms.Resize((400, 400)),# Not used for current version.
  61. transforms.ToTensor(),
  62. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  63. ])
  64. dataset = SemanLineDataset(root_dir, label_file, transform=transform, t_transform=None, split=split)
  65. else:
  66. transform = transforms.Compose([
  67. transforms.Resize((400, 400)),
  68. transforms.ToTensor(),
  69. transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  70. ])
  71. dataset = SemanLineDatasetTest(root_dir, label_file, transform=transform, t_transform=None)
  72. if test is False:
  73. data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread,
  74. pin_memory=pin, collate_fn=dataset.collate_fn)
  75. else:
  76. data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_thread,
  77. pin_memory=pin, collate_fn=dataset.collate_fn)
  78. return data_loader