import cv2 from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode class RawVideoExtractor(): def __init__(self, centercrop=False, framerate=-1, size=224, to_tensor=True): self.centercrop = centercrop self.framerate = framerate self.to_tensor = to_tensor self.transform = self._transform(size) def _transform(self, n_px): if self.to_tensor: return Compose([ Resize(n_px, interpolation=InterpolationMode.BICUBIC), CenterCrop(n_px), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) else: return Compose([Resize(n_px, interpolation=InterpolationMode.BICUBIC),CenterCrop(n_px)]) def get_video_data(self, video_path, start_time=None, end_time=None): if start_time is not None or end_time is not None: assert start_time > -1 and end_time > start_time assert self.framerate > -1 cap = cv2.VideoCapture(video_path) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) video_fps = cap.get(cv2.CAP_PROP_FPS) start_frame = int(start_time * video_fps) if start_time else 0 end_frame = int(end_time * video_fps) if end_time else frame_count - 1 interval = 1 if self.framerate > 0: interval = video_fps / self.framerate else: self.framerate = video_fps if interval == 0: interval = 1 images = [] for i in range(frame_count): ret, frame = cap.read() if not ret: break if i >= start_frame and i <= end_frame: if len(images) * interval < i - start_frame: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) image = Image.fromarray(frame_rgb) image = self.transform(image) images.append(image) cap.release() return images