dataloader.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import numpy as np
  2. import os
  3. from os.path import join, split, isdir, isfile, abspath
  4. from PIL import Image
  5. import random
  6. import collections
  7. import jittor.transform as transforms
  8. import jittor as jt
  9. from jittor.dataset import Dataset
  10. class SemanLineDatasetTest(Dataset):
  11. def __init__(self, root_dir, label_file, transform=None, t_transform=None):
  12. super().__init__()
  13. lines = [line.rstrip('\n') for line in open(label_file)]
  14. self.image_path = [join(root_dir, i+".jpg") for i in lines]
  15. self.transform = transform
  16. self.t_transform = t_transform
  17. self.set_attrs(total_len=len(self.image_path), keep_numpy_array=True)
  18. def __getitem__(self, item):
  19. assert isfile(self.image_path[item]), self.image_path[item]
  20. image = Image.open(self.image_path[item]).convert('RGB')
  21. w, h = image.size
  22. if self.transform is not None:
  23. image = self.transform(image)
  24. return image, self.image_path[item].split('/')[-1], (h, w)
  25. def collate_batch(self, batch):
  26. images, names, sizes = list(zip(*batch))
  27. images = jt.stack([jt.array(image) for image in images])
  28. return images, names, sizes
  29. def get_loader(root_dir, label_file, batch_size, img_size=0, num_thread=4, pin=True, test=False, split='train'):
  30. if test is False:
  31. raise NotImplementedError
  32. else:
  33. transform = transforms.Compose([
  34. transforms.Resize((400, 400)),
  35. transforms.ToTensor(),
  36. transforms.ImageNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  37. ])
  38. dataset = SemanLineDatasetTest(root_dir, label_file, transform=transform, t_transform=None)
  39. if test is False:
  40. raise NotImplementedError
  41. else:
  42. dataset.set_attrs(batch_size=batch_size, shuffle=False)
  43. print('Get dataset success.')
  44. return dataset