123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- # --------------------------------------------------------
- # Swin Transformer
- # Copyright (c) 2021 Microsoft
- # Licensed under The MIT License [see LICENSE for details]
- # Written by Ze Liu
- # --------------------------------------------------------
- import os
- import zipfile
- import io
- import numpy as np
- from PIL import Image
- from PIL import ImageFile
- ImageFile.LOAD_TRUNCATED_IMAGES = True
- def is_zip_path(img_or_path):
- """judge if this is a zip path"""
- return '.zip@' in img_or_path
- class ZipReader(object):
- """A class to read zipped files"""
- zip_bank = dict()
- def __init__(self):
- super(ZipReader, self).__init__()
- @staticmethod
- def get_zipfile(path):
- zip_bank = ZipReader.zip_bank
- if path not in zip_bank:
- zfile = zipfile.ZipFile(path, 'r')
- zip_bank[path] = zfile
- return zip_bank[path]
- @staticmethod
- def split_zip_style_path(path):
- pos_at = path.index('@')
- assert pos_at != -1, "character '@' is not found from the given path '%s'" % path
- zip_path = path[0: pos_at]
- folder_path = path[pos_at + 1:]
- folder_path = str.strip(folder_path, '/')
- return zip_path, folder_path
- @staticmethod
- def list_folder(path):
- zip_path, folder_path = ZipReader.split_zip_style_path(path)
- zfile = ZipReader.get_zipfile(zip_path)
- folder_list = []
- for file_foler_name in zfile.namelist():
- file_foler_name = str.strip(file_foler_name, '/')
- if file_foler_name.startswith(folder_path) and \
- len(os.path.splitext(file_foler_name)[-1]) == 0 and \
- file_foler_name != folder_path:
- if len(folder_path) == 0:
- folder_list.append(file_foler_name)
- else:
- folder_list.append(file_foler_name[len(folder_path) + 1:])
- return folder_list
- @staticmethod
- def list_files(path, extension=None):
- if extension is None:
- extension = ['.*']
- zip_path, folder_path = ZipReader.split_zip_style_path(path)
- zfile = ZipReader.get_zipfile(zip_path)
- file_lists = []
- for file_foler_name in zfile.namelist():
- file_foler_name = str.strip(file_foler_name, '/')
- if file_foler_name.startswith(folder_path) and \
- str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
- if len(folder_path) == 0:
- file_lists.append(file_foler_name)
- else:
- file_lists.append(file_foler_name[len(folder_path) + 1:])
- return file_lists
- @staticmethod
- def read(path):
- zip_path, path_img = ZipReader.split_zip_style_path(path)
- zfile = ZipReader.get_zipfile(zip_path)
- data = zfile.read(path_img)
- return data
- @staticmethod
- def imread(path):
- zip_path, path_img = ZipReader.split_zip_style_path(path)
- zfile = ZipReader.get_zipfile(zip_path)
- data = zfile.read(path_img)
- try:
- im = Image.open(io.BytesIO(data))
- except:
- print("ERROR IMG LOADED: ", path_img)
- random_img = np.random.rand(224, 224, 3) * 255
- im = Image.fromarray(np.uint8(random_img))
- return im
|