yolov5hub.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. """
  2. # File : yolov5hub.py
  3. # Time :22.5.16 16:47
  4. # Author :FEANGYANG
  5. # version :python 3.7
  6. # Contact :1071082183@qq.com
  7. # Description:
  8. """
  9. import cv2
  10. import torch
  11. import glob
  12. class Yolov5Detect:
  13. def __init__(self, weight_path, weight, class_f='\W+'):
  14. self.class_f = class_f
  15. self.model = torch.hub.load(weight_path, weight, source='local')
  16. def yolov5_detect(self, img):
  17. results = self.model(img) # includes NMS
  18. all_results = results.pandas().xyxy[0]
  19. filter_cla = self._filter_class(all_results, self.class_f)
  20. # print(filter_cla)
  21. return filter_cla
  22. def _filter_class(self, all_results , class_f):
  23. filter_cla = all_results[all_results['name'].str.contains(class_f)]
  24. return filter_cla
  25. if __name__ == '__main__':
  26. weight_path = './'
  27. weight = 'yolov5l'
  28. detect = Yolov5Detect(weight_path, weight, class_f='person')
  29. for img_path in glob.glob('./datasets/coco128/images/train20/*.jpg'):
  30. img = cv2.imread(img_path)[:,:,::-1]
  31. # img = cv2.imread(img_path)
  32. # img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  33. result = detect.yolov5_detect(img)
  34. print(result)
  35. print('-----------------------------------------------------')