123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101 |
- import numpy as np
- import os
- from os.path import join, split, isdir, isfile, abspath
- import torch
- from PIL import Image
- import random
- import collections
- from torchvision import transforms
- from torch.utils.data import Dataset, DataLoader
- class SemanLineDataset(Dataset):
- def __init__(self, root_dir, label_file, split='train', transform=None, t_transform=None):
- lines = [line.rstrip('\n') for line in open(label_file)]
- self.image_path = [join(root_dir, i+".jpg") for i in lines]
- self.data_path = [join(root_dir, i+".npy") for i in lines]
- self.split = split
- self.transform = transform
- self.t_transform = t_transform
-
- def __getitem__(self, item):
- assert isfile(self.image_path[item]), self.image_path[item]
- image = Image.open(self.image_path[item]).convert('RGB')
- data = np.load(self.data_path[item], allow_pickle=True).item()
- hough_space_label8 = data["hough_space_label8"].astype(np.float32)
- if self.transform is not None:
- image = self.transform(image)
-
- hough_space_label8 = torch.from_numpy(hough_space_label8).unsqueeze(0)
- gt_coords = data["coords"]
-
- if self.split == 'val':
- return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
- elif self.split == 'train':
- return image, hough_space_label8, gt_coords, self.image_path[item].split('/')[-1]
- def __len__(self):
- return len(self.image_path)
- def collate_fn(self, batch):
- images, hough_space_label8, gt_coords, names = list(zip(*batch))
- images = torch.stack([image for image in images])
- hough_space_label8 = torch.stack([hough_space_label for hough_space_label in hough_space_label8])
- return images, hough_space_label8, gt_coords, names
- class SemanLineDatasetTest(Dataset):
- def __init__(self, root_dir, label_file, transform=None, t_transform=None):
- lines = [line.rstrip('\n') for line in open(label_file)]
- self.image_path = [join(root_dir, i+".jpg") for i in lines]
- self.transform = transform
- self.t_transform = t_transform
-
- def __getitem__(self, item):
- assert isfile(self.image_path[item]), self.image_path[item]
- image = Image.open(self.image_path[item]).convert('RGB')
- w, h = image.size
- if self.transform is not None:
- image = self.transform(image)
-
- return image, self.image_path[item].split('/')[-1], (h, w)
- def __len__(self):
- return len(self.image_path)
- def collate_fn(self, batch):
- images, names, sizes = list(zip(*batch))
- images = torch.stack([image for image in images])
-
- return images, names, sizes
- def get_loader(root_dir, label_file, batch_size, img_size=0, num_thread=4, pin=True, test=False, split='train'):
- if test is False:
- transform = transforms.Compose([
- # transforms.Resize((400, 400)),# Not used for current version.
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- dataset = SemanLineDataset(root_dir, label_file, transform=transform, t_transform=None, split=split)
- else:
- transform = transforms.Compose([
- transforms.Resize((400, 400)),
- transforms.ToTensor(),
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- dataset = SemanLineDatasetTest(root_dir, label_file, transform=transform, t_transform=None)
- if test is False:
- data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=num_thread,
- pin_memory=pin, collate_fn=dataset.collate_fn)
- else:
- data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=num_thread,
- pin_memory=pin, collate_fn=dataset.collate_fn)
- return data_loader
-
|