diff --git a/core/config.py b/core/config.py index 8fc5c6e..42cf1c6 100644 --- a/core/config.py +++ b/core/config.py @@ -18,6 +18,7 @@ class Settings(BaseSettings): AWS_SECRET_ACCESS_KEY: Optional[str] = None OPENAI_API_KEY: Optional[str] = None ANTHROPIC_API_KEY: Optional[str] = None + ASSEMBLYAI_API_KEY: Optional[str] = None # API configuration HOST: str diff --git a/core/logging_config.py b/core/logging_config.py new file mode 100644 index 0000000..986df50 --- /dev/null +++ b/core/logging_config.py @@ -0,0 +1,39 @@ +import logging +import sys +from pathlib import Path + + +def setup_logging(): + # Create logs directory if it doesn't exist + log_dir = Path("logs") + log_dir.mkdir(exist_ok=True) + + # Configure root logger + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + + # Create formatters + console_formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) + + # Console handler + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setFormatter(console_formatter) + console_handler.setLevel(logging.INFO) + + # File handler + file_handler = logging.FileHandler(log_dir / "databridge.log") + file_handler.setFormatter(console_formatter) + file_handler.setLevel(logging.INFO) + + # Add handlers to root logger + root_logger.addHandler(console_handler) + root_logger.addHandler(file_handler) + + # Set levels for specific loggers + logging.getLogger("uvicorn").setLevel(logging.INFO) + logging.getLogger("fastapi").setLevel(logging.INFO) + + # Set debug level for our code + logging.getLogger("core").setLevel(logging.DEBUG) diff --git a/core/parser/combined_parser.py b/core/parser/combined_parser.py index b53f683..3537760 100644 --- a/core/parser/combined_parser.py +++ b/core/parser/combined_parser.py @@ -104,7 +104,7 @@ class CombinedParser(BaseParser): assemblyai_api_key=self.assemblyai_api_key, frame_sample_rate=self.frame_sample_rate, ) - results = parser.process_video() + results = await parser.process_video() # Get all frame descriptions frame_descriptions = results.frame_descriptions # Get all transcript text diff --git a/core/parser/video/parse_video.py b/core/parser/video/parse_video.py index 16a2630..f78cec9 100644 --- a/core/parser/video/parse_video.py +++ b/core/parser/video/parse_video.py @@ -4,6 +4,10 @@ from openai import OpenAI import assemblyai as aai import logging from core.models.video import TimeSeriesData, ParseVideoResult +import tomli +import os +from typing import Optional, Dict, Any +from ollama import AsyncClient logger = logging.getLogger(__name__) @@ -12,19 +16,71 @@ def debug_object(title, obj): logger.debug("\n".join(["-" * 100, title, "-" * 100, f"{obj}", "-" * 100])) +def load_config() -> Dict[str, Any]: + config_path = os.path.join(os.path.dirname(__file__), "../../../databridge.toml") + with open(config_path, "rb") as f: + return tomli.load(f) + + +class VisionModelClient: + def __init__(self, config: Dict[str, Any]): + self.config = config["parser"]["vision"] + self.provider = self.config.get("provider", "ollama") + self.model_name = self.config.get("model_name", "llama3.2-vision") + + if self.provider == "openai": + self.client = OpenAI() + elif self.provider == "ollama": + base_url = self.config.get("base_url", "http://localhost:11434") + self.client = AsyncClient(host=base_url) + else: + raise ValueError(f"Unsupported vision model provider: {self.provider}") + + async def get_frame_description(self, image_base64: str, context: str) -> str: + if self.provider == "openai": + response = self.client.chat.completions.create( + model=self.model_name, + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": context}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, + }, + ], + } + ], + max_tokens=300, + ) + return response.choices[0].message.content + else: # ollama + response = await self.client.chat( + model=self.model_name, + messages=[{"role": "user", "content": context, "images": [image_base64]}], + ) + return response["message"]["content"] + + class VideoParser: - def __init__(self, video_path: str, assemblyai_api_key: str, frame_sample_rate: int = 120): + def __init__( + self, video_path: str, assemblyai_api_key: str, frame_sample_rate: Optional[int] = None + ): """ Initialize the video parser Args: video_path: Path to the video file assemblyai_api_key: API key for AssemblyAI - frame_sample_rate: Sample every nth frame for description + frame_sample_rate: Sample every nth frame for description (optional, defaults to config value) """ logger.info(f"Initializing VideoParser for {video_path}") + self.config = load_config() self.video_path = video_path - self.frame_sample_rate = frame_sample_rate + self.frame_sample_rate = frame_sample_rate or self.config["parser"]["vision"].get( + "frame_sample_rate", 120 + ) self.cap = cv2.VideoCapture(video_path) if not self.cap.isOpened(): @@ -34,15 +90,15 @@ class VideoParser: self.fps = self.cap.get(cv2.CAP_PROP_FPS) self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) self.duration = self.total_frames / self.fps + + # Initialize AssemblyAI aai.settings.api_key = assemblyai_api_key - aai_config = aai.TranscriptionConfig( - speaker_labels=True - ) # speech_model=aai.SpeechModel.nano + aai_config = aai.TranscriptionConfig(speaker_labels=True) self.transcriber = aai.Transcriber(config=aai_config) - self.transcript = TimeSeriesData( - time_to_content={} - ) # empty transcript initially - TODO: have this be a lateinit somehow - self.gpt = OpenAI() + self.transcript = TimeSeriesData(time_to_content={}) + + # Initialize vision model client + self.vision_client = VisionModelClient(self.config) logger.info(f"Video loaded: {self.duration:.2f}s duration, {self.fps:.2f} FPS") @@ -83,70 +139,65 @@ class VideoParser: {u.start / 1000: u.text for u in transcript.utterances} if transcript.utterances else {} ) debug_object("Time to text", time_to_text) - self.transcript = TimeSeriesData(time_to_text) + self.transcript = TimeSeriesData(time_to_content=time_to_text) return self.transcript - def get_frame_descriptions(self) -> TimeSeriesData: + async def get_frame_descriptions(self) -> TimeSeriesData: """ - Get descriptions for sampled frames using GPT-4 + Get descriptions for sampled frames using configured vision model Returns: TimeSeriesData object containing frame descriptions """ logger.info("Starting frame description generation") + + # Return empty TimeSeriesData if frame_sample_rate is -1 (captioning disabled) + if self.frame_sample_rate == -1: + logger.info("Frame captioning is disabled (frame_sample_rate = -1)") + return TimeSeriesData(time_to_content={}) + frame_count = 0 time_to_description = {} last_description = None + logger.info("Starting main loop for frame description generation") while True: + logger.info(f"Frame count: {frame_count}") ret, frame = self.cap.read() if not ret: + logger.info("Reached end of video") break if frame_count % self.frame_sample_rate == 0: + logger.info(f"Processing frame at {frame_count / self.fps:.2f}s") timestamp = frame_count / self.fps logger.debug(f"Processing frame at {timestamp:.2f}s") img_base64 = self.frame_to_base64(frame) - response = self.gpt.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "user", - "content": [ - { - "type": "text", - "text": f"""Describe this frame from a video. Focus on the main elements, actions, and any notable details. Here is the transcript around the time of the frame: - --- - {self.transcript.at_time(timestamp, padding=10)} - --- + context = f"""Describe this frame from a video. Focus on the main elements, actions, and any notable details. Here is the transcript around the time of the frame: + --- + {self.transcript.at_time(timestamp, padding=10)} + --- - Here is a description of the previous frame: - --- - {last_description if last_description else 'No previous frame description available, this is the first frame'} - --- + Here is a description of the previous frame: + --- + {last_description if last_description else 'No previous frame description available, this is the first frame'} + --- - In your response, only provide the description of the current frame, using the above information as context. - """, - }, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, - }, - ], - } - ], - max_tokens=300, + In your response, only provide the description of the current frame, using the above information as context. + """ + + last_description = await self.vision_client.get_frame_description( + img_base64, context ) - last_description = response.choices[0].message.content time_to_description[timestamp] = last_description frame_count += 1 logger.info(f"Generated descriptions for {len(time_to_description)} frames") - return TimeSeriesData(time_to_description) + return TimeSeriesData(time_to_content=time_to_description) - def process_video(self) -> ParseVideoResult: + async def process_video(self) -> ParseVideoResult: """ Process the video to get both transcript and frame descriptions @@ -163,7 +214,7 @@ class VideoParser: result = ParseVideoResult( metadata=metadata, transcript=self.get_transcript(), - frame_descriptions=self.get_frame_descriptions(), + frame_descriptions=await self.get_frame_descriptions(), ) logger.info("Video processing completed successfully") return result diff --git a/databridge.toml b/databridge.toml index b320d68..9fc60a4 100644 --- a/databridge.toml +++ b/databridge.toml @@ -58,7 +58,18 @@ use_unstructured_api = false # chunk_size = 1000 # chunk_overlap = 200 # use_unstructured_api = false -# frame_sample_rate = 120 + +[parser.vision] +provider = "ollama" +model_name = "llama3.2-vision" +frame_sample_rate = -1 # Set to -1 to disable frame captioning +# base_url = "http://localhost:11434" # Only used for ollama +base_url = "http://ollama:11434" # Use if using via docker + +# [parser.vision] +# provider = "openai" +# model_name = "gpt-4o-mini" +# frame_sample_rate = -1 # Set to -1 to disable frame captioning [reranker] use_reranker = false diff --git a/shell.py b/shell.py index edb568a..5fb8c48 100644 --- a/shell.py +++ b/shell.py @@ -159,9 +159,7 @@ class Cache: """Add documents to the cache""" return self._client_cache.add_docs(docs) - def query( - self, query: str, max_tokens: int = None, temperature: float = None - ) -> dict: + def query(self, query: str, max_tokens: int = None, temperature: float = None) -> dict: """Query the cache""" response = self._client_cache.query( query=query, diff --git a/start_server.py b/start_server.py index ce0589e..2e75766 100644 --- a/start_server.py +++ b/start_server.py @@ -1,9 +1,13 @@ import uvicorn from dotenv import load_dotenv from core.config import get_settings +from core.logging_config import setup_logging def main(): + # Set up logging first + setup_logging() + # Load environment variables from .env file load_dotenv() @@ -16,6 +20,7 @@ def main(): host=settings.HOST, port=settings.PORT, loop="asyncio", + log_level="info", # reload=settings.RELOAD )