3
0

create_tf_record.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Convert the Oxford pet dataset to TFRecord for object_detection.
  16. See: O. M. Parkhi, A. Vedaldi, A. Zisserman, C. V. Jawahar
  17. Cats and Dogs
  18. IEEE Conference on Computer Vision and Pattern Recognition, 2012
  19. http://www.robots.ox.ac.uk/~vgg/data/pets/
  20. Example usage:
  21. ./create_pet_tf_record --data_dir=/home/user/pet \
  22. --output_dir=/home/user/pet/output
  23. """
  24. import hashlib
  25. import io
  26. import logging
  27. import os
  28. import random
  29. import re
  30. from lxml import etree
  31. import PIL.Image
  32. import tensorflow as tf
  33. from object_detection.utils import dataset_util
  34. from object_detection.utils import label_map_util
  35. def dict_to_tf_example(data,
  36. label_map_dict,
  37. image_subdirectory,
  38. ignore_difficult_instances=False):
  39. """Convert XML derived dict to tf.Example proto.
  40. Notice that this function normalizes the bounding box coordinates provided
  41. by the raw data.
  42. Args:
  43. data: dict holding PASCAL XML fields for a single image (obtained by
  44. running dataset_util.recursive_parse_xml_to_dict)
  45. label_map_dict: A map from string label names to integers ids.
  46. image_subdirectory: String specifying subdirectory within the
  47. Pascal dataset directory holding the actual image data.
  48. ignore_difficult_instances: Whether to skip difficult instances in the
  49. dataset (default: False).
  50. Returns:
  51. example: The converted tf.Example.
  52. Raises:
  53. ValueError: if the image pointed to by data['filename'] is not a valid JPEG
  54. """
  55. img_path = os.path.join(image_subdirectory, data['filename'])
  56. with tf.gfile.GFile(img_path, 'rb') as fid:
  57. encoded_jpg = fid.read()
  58. encoded_jpg_io = io.BytesIO(encoded_jpg)
  59. image = PIL.Image.open(encoded_jpg_io)
  60. if image.format != 'JPEG':
  61. raise ValueError('Image format not JPEG')
  62. key = hashlib.sha256(encoded_jpg).hexdigest()
  63. width = int(data['size']['width'])
  64. height = int(data['size']['height'])
  65. xmin = []
  66. ymin = []
  67. xmax = []
  68. ymax = []
  69. classes = []
  70. classes_text = []
  71. truncated = []
  72. poses = []
  73. difficult_obj = []
  74. for obj in data['object']:
  75. difficult_obj.append(int(0))
  76. xmin.append(float(obj['bndbox']['xmin']) / width)
  77. ymin.append(float(obj['bndbox']['ymin']) / height)
  78. xmax.append(float(obj['bndbox']['xmax']) / width)
  79. ymax.append(float(obj['bndbox']['ymax']) / height)
  80. class_name = obj['name']
  81. classes_text.append(class_name.encode('utf8'))
  82. classes.append(label_map_dict[class_name])
  83. truncated.append(int(0))
  84. poses.append('Unspecified'.encode('utf8'))
  85. example = tf.train.Example(features=tf.train.Features(feature={
  86. 'image/height': dataset_util.int64_feature(height),
  87. 'image/width': dataset_util.int64_feature(width),
  88. 'image/filename': dataset_util.bytes_feature(
  89. data['filename'].encode('utf8')),
  90. 'image/source_id': dataset_util.bytes_feature(
  91. data['filename'].encode('utf8')),
  92. 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
  93. 'image/encoded': dataset_util.bytes_feature(encoded_jpg),
  94. 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
  95. 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
  96. 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
  97. 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
  98. 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
  99. 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  100. 'image/object/class/label': dataset_util.int64_list_feature(classes),
  101. 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
  102. 'image/object/truncated': dataset_util.int64_list_feature(truncated),
  103. 'image/object/view': dataset_util.bytes_list_feature(poses),
  104. }))
  105. return example
  106. def create_tf_record(output_filename,
  107. label_map_dict,
  108. annotations_dir,
  109. image_dir,
  110. examples):
  111. """Creates a TFRecord file from examples.
  112. Args:
  113. output_filename: Path to where output file is saved.
  114. label_map_dict: The label map dictionary.
  115. annotations_dir: Directory where annotation files are stored.
  116. image_dir: Directory where image files are stored.
  117. examples: Examples to parse and save to tf record.
  118. """
  119. writer = tf.python_io.TFRecordWriter(output_filename)
  120. for idx, example in enumerate(examples):
  121. if idx % 100 == 0:
  122. logging.info('On image %d of %d', idx, len(examples))
  123. print('On image %d of %d', idx, len(examples))
  124. path = os.path.join(annotations_dir, 'xmls', example + '.xml')
  125. if not os.path.exists(path):
  126. logging.warning('Could not find %s, ignoring example.', path)
  127. continue
  128. with tf.gfile.GFile(path, 'r') as fid:
  129. xml_str = fid.read()
  130. xml = etree.fromstring(xml_str)
  131. data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
  132. tf_example = dict_to_tf_example(data, label_map_dict, image_dir)
  133. writer.write(tf_example.SerializeToString())
  134. writer.close()
  135. def main(_):
  136. label_map_dict = label_map_util.get_label_map_dict('/tensorflow/input/annotations/label_map.pbtxt')
  137. logging.info('Reading from Pet dataset.')
  138. image_dir = '/tensorflow/input/images'
  139. annotations_dir = '/tensorflow/input/annotations'
  140. examples_path = os.path.join(annotations_dir, 'trainval.txt')
  141. examples_list = dataset_util.read_examples_list(examples_path)
  142. # Test images are not included in the downloaded data set, so we shall perform
  143. # our own split.
  144. random.seed(42)
  145. random.shuffle(examples_list)
  146. num_examples = len(examples_list)
  147. num_train = int(0.7 * num_examples)
  148. train_examples = examples_list[:num_train]
  149. val_examples = examples_list[num_train:]
  150. print('%d training and %d validation examples.',
  151. len(train_examples), len(val_examples))
  152. train_output_path = '/tensorflow/input/train.record'
  153. val_output_path = '/tensorflow/input/val.record'
  154. create_tf_record(train_output_path, label_map_dict, annotations_dir,
  155. image_dir, train_examples)
  156. create_tf_record(val_output_path, label_map_dict, annotations_dir,
  157. image_dir, val_examples)
  158. if __name__ == '__main__':
  159. tf.app.run()