1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253 |
- import numpy as np
- import os
- from os.path import join, split, isdir, isfile, abspath
- from PIL import Image
- import random
- import collections
- import jittor.transform as transforms
- import jittor as jt
- from jittor.dataset import Dataset
- class SemanLineDatasetTest(Dataset):
- def __init__(self, root_dir, label_file, transform=None, t_transform=None):
- super().__init__()
- lines = [line.rstrip('\n') for line in open(label_file)]
- self.image_path = [join(root_dir, i+".jpg") for i in lines]
- self.transform = transform
- self.t_transform = t_transform
- self.set_attrs(total_len=len(self.image_path), keep_numpy_array=True)
-
- def __getitem__(self, item):
- assert isfile(self.image_path[item]), self.image_path[item]
- image = Image.open(self.image_path[item]).convert('RGB')
- w, h = image.size
- if self.transform is not None:
- image = self.transform(image)
- return image, self.image_path[item].split('/')[-1], (h, w)
- def collate_batch(self, batch):
- images, names, sizes = list(zip(*batch))
- images = jt.stack([jt.array(image) for image in images])
- return images, names, sizes
- def get_loader(root_dir, label_file, batch_size, img_size=0, num_thread=4, pin=True, test=False, split='train'):
- if test is False:
- raise NotImplementedError
- else:
- transform = transforms.Compose([
- transforms.Resize((400, 400)),
- transforms.ToTensor(),
- transforms.ImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
- ])
- dataset = SemanLineDatasetTest(root_dir, label_file, transform=transform, t_transform=None)
- if test is False:
- raise NotImplementedError
- else:
- dataset.set_attrs(batch_size=batch_size, shuffle=False)
- print('Get dataset success.')
- return dataset
|