zipreader.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import zipfile
  9. import io
  10. import numpy as np
  11. from PIL import Image
  12. from PIL import ImageFile
  13. ImageFile.LOAD_TRUNCATED_IMAGES = True
  14. def is_zip_path(img_or_path):
  15. """judge if this is a zip path"""
  16. return '.zip@' in img_or_path
  17. class ZipReader(object):
  18. """A class to read zipped files"""
  19. zip_bank = dict()
  20. def __init__(self):
  21. super(ZipReader, self).__init__()
  22. @staticmethod
  23. def get_zipfile(path):
  24. zip_bank = ZipReader.zip_bank
  25. if path not in zip_bank:
  26. zfile = zipfile.ZipFile(path, 'r')
  27. zip_bank[path] = zfile
  28. return zip_bank[path]
  29. @staticmethod
  30. def split_zip_style_path(path):
  31. pos_at = path.index('@')
  32. assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
  33. zip_path = path[0: pos_at]
  34. folder_path = path[pos_at + 1:]
  35. folder_path = str.strip(folder_path, '/')
  36. return zip_path, folder_path
  37. @staticmethod
  38. def list_folder(path):
  39. zip_path, folder_path = ZipReader.split_zip_style_path(path)
  40. zfile = ZipReader.get_zipfile(zip_path)
  41. folder_list = []
  42. for file_foler_name in zfile.namelist():
  43. file_foler_name = str.strip(file_foler_name, '/')
  44. if file_foler_name.startswith(folder_path) and \
  45. len(os.path.splitext(file_foler_name)[-1]) == 0 and \
  46. file_foler_name != folder_path:
  47. if len(folder_path) == 0:
  48. folder_list.append(file_foler_name)
  49. else:
  50. folder_list.append(file_foler_name[len(folder_path) + 1:])
  51. return folder_list
  52. @staticmethod
  53. def list_files(path, extension=None):
  54. if extension is None:
  55. extension = ['.*']
  56. zip_path, folder_path = ZipReader.split_zip_style_path(path)
  57. zfile = ZipReader.get_zipfile(zip_path)
  58. file_lists = []
  59. for file_foler_name in zfile.namelist():
  60. file_foler_name = str.strip(file_foler_name, '/')
  61. if file_foler_name.startswith(folder_path) and \
  62. str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
  63. if len(folder_path) == 0:
  64. file_lists.append(file_foler_name)
  65. else:
  66. file_lists.append(file_foler_name[len(folder_path) + 1:])
  67. return file_lists
  68. @staticmethod
  69. def read(path):
  70. zip_path, path_img = ZipReader.split_zip_style_path(path)
  71. zfile = ZipReader.get_zipfile(zip_path)
  72. data = zfile.read(path_img)
  73. return data
  74. @staticmethod
  75. def imread(path):
  76. zip_path, path_img = ZipReader.split_zip_style_path(path)
  77. zfile = ZipReader.get_zipfile(zip_path)
  78. data = zfile.read(path_img)
  79. try:
  80. im = Image.open(io.BytesIO(data))
  81. except:
  82. print("ERROR IMG LOADED: ", path_img)
  83. random_img = np.random.rand(224, 224, 3) * 255
  84. im = Image.fromarray(np.uint8(random_img))
  85. return im