|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- import cv2
- from PIL import Image
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode
-
- class RawVideoExtractor():
- def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True):
- self.centercrop = centercrop
- self.framerate = framerate
- self.to_tensor = to_tensor
- self.transform = self._transform(size)
-
- def _transform(self, n_px):
- if self.to_tensor:
-
- return Compose([
- Resize(n_px, interpolation=InterpolationMode.BICUBIC),
- CenterCrop(n_px), ToTensor(),
- Normalize((0.48145466, 0.4578275, 0.40821073),
- (0.26862954, 0.26130258, 0.27577711))])
-
- else:
- return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)])
-
- def get_video_data(self, video_path, start_time=None, end_time=None):
- if start_time is not None or end_time is not None:
- assert start_time > -1 and end_time > start_time
- assert self.framerate > -1
-
- cap = cv2.VideoCapture(video_path)
-
- frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
- video_fps = cap.get(cv2.CAP_PROP_FPS)
-
- start_frame = int(start_time * video_fps) if start_time else 0
- end_frame = int(end_time * video_fps) if end_time else frame_count - 1
-
- interval = 1
- if self.framerate > 0:
- interval = video_fps / self.framerate
- else:
- self.framerate = video_fps
- if interval == 0:
- interval = 1
-
- images = []
-
- for i in range(frame_count):
- ret, frame = cap.read()
-
- if not ret:
- break
-
- if i >= start_frame and i <= end_frame:
- if len(images) * interval < i - start_frame:
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
- image = Image.fromarray(frame_rgb)
- image = self.transform(image)
- images.append(image)
-
- cap.release()
-
- return images
|