You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rawvideo_util.py 2.1KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. import cv2
  2. from PIL import Image
  3. from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode
  4. class RawVideoExtractor():
  5. def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True):
  6. self.centercrop = centercrop
  7. self.framerate = framerate
  8. self.to_tensor = to_tensor
  9. self.transform = self._transform(size)
  10. def _transform(self, n_px):
  11. if self.to_tensor:
  12. return Compose([
  13. Resize(n_px, interpolation=InterpolationMode.BICUBIC),
  14. CenterCrop(n_px), ToTensor(),
  15. Normalize((0.48145466, 0.4578275, 0.40821073),
  16. (0.26862954, 0.26130258, 0.27577711))])
  17. else:
  18. return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)])
  19. def get_video_data(self, video_path, start_time=None, end_time=None):
  20. if start_time is not None or end_time is not None:
  21. assert start_time > -1 and end_time > start_time
  22. assert self.framerate > -1
  23. cap = cv2.VideoCapture(video_path)
  24. frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  25. video_fps = cap.get(cv2.CAP_PROP_FPS)
  26. start_frame = int(start_time * video_fps) if start_time else 0
  27. end_frame = int(end_time * video_fps) if end_time else frame_count - 1
  28. interval = 1
  29. if self.framerate > 0:
  30. interval = video_fps / self.framerate
  31. else:
  32. self.framerate = video_fps
  33. if interval == 0:
  34. interval = 1
  35. images = []
  36. for i in range(frame_count):
  37. ret, frame = cap.read()
  38. if not ret:
  39. break
  40. if i >= start_frame and i <= end_frame:
  41. if len(images) * interval < i - start_frame:
  42. frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
  43. image = Image.fromarray(frame_rgb)
  44. image = self.transform(image)
  45. images.append(image)
  46. cap.release()
  47. return images