model_zoo.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os.path as osp
  15. import pkg_resources
  16. try:
  17. from collections.abc import Sequence
  18. except:
  19. from collections import Sequence
  20. from ppdet.core.workspace import load_config, create
  21. from ppdet.utils.checkpoint import load_weight
  22. from ppdet.utils.download import get_config_path
  23. from ppdet.utils.logger import setup_logger
  24. logger = setup_logger(__name__)
  25. __all__ = [
  26. 'list_model', 'get_config_file', 'get_weights_url', 'get_model',
  27. 'MODEL_ZOO_FILENAME'
  28. ]
  29. MODEL_ZOO_FILENAME = 'MODEL_ZOO'
  30. def list_model(filters=[]):
  31. model_zoo_file = pkg_resources.resource_filename('ppdet.model_zoo',
  32. MODEL_ZOO_FILENAME)
  33. with open(model_zoo_file) as f:
  34. model_names = f.read().splitlines()
  35. # filter model_name
  36. def filt(name):
  37. for f in filters:
  38. if name.find(f) < 0:
  39. return False
  40. return True
  41. if isinstance(filters, str) or not isinstance(filters, Sequence):
  42. filters = [filters]
  43. model_names = [name for name in model_names if filt(name)]
  44. if len(model_names) == 0 and len(filters) > 0:
  45. raise ValueError("no model found, please check filters seeting, "
  46. "filters can be set as following kinds:\n"
  47. "\tDataset: coco, voc ...\n"
  48. "\tArchitecture: yolo, rcnn, ssd ...\n"
  49. "\tBackbone: resnet, vgg, darknet ...\n")
  50. model_str = "Available Models:\n"
  51. for model_name in model_names:
  52. model_str += "\t{}\n".format(model_name)
  53. logger.info(model_str)
  54. # models and configs save on bcebos under dygraph directory
  55. def get_config_file(model_name):
  56. return get_config_path("ppdet://configs/{}.yml".format(model_name))
  57. def get_weights_url(model_name):
  58. return "ppdet://models/{}.pdparams".format(osp.split(model_name)[-1])
  59. def get_model(model_name, pretrained=True):
  60. cfg_file = get_config_file(model_name)
  61. cfg = load_config(cfg_file)
  62. model = create(cfg.architecture)
  63. if pretrained:
  64. load_weight(model, get_weights_url(model_name))
  65. return model