added frame/transcript augmentation for video retrieval

This commit is contained in:
Arnav Agrawal 2024-12-29 12:45:12 +05:30
parent 7830b42c6b
commit 80db083471
4 changed files with 61 additions and 30 deletions

View File

@ -5,6 +5,8 @@ from pydantic import BaseModel, Field, field_validator
import uuid import uuid
import logging import logging
from core.models.video import TimeSeriesData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -69,19 +71,6 @@ class Chunk(BaseModel):
) )
class ChunkResult(BaseModel):
"""Query result at chunk level"""
content: str
score: float
document_id: str # external_id
chunk_number: int
metadata: Dict[str, Any]
content_type: str
filename: Optional[str] = None
download_url: Optional[str] = None
class DocumentContent(BaseModel): class DocumentContent(BaseModel):
"""Represents either a URL or content string""" """Represents either a URL or content string"""
@ -106,3 +95,46 @@ class DocumentResult(BaseModel):
document_id: str # external_id document_id: str # external_id
metadata: Dict[str, Any] metadata: Dict[str, Any]
content: DocumentContent content: DocumentContent
additional_metadata: Dict[str, Any]
class ChunkResult(BaseModel):
"""Query result at chunk level"""
content: str
score: float
document_id: str # external_id
chunk_number: int
metadata: Dict[str, Any]
content_type: str
filename: Optional[str] = None
download_url: Optional[str] = None
def augmented_content(self, doc: DocumentResult) -> str:
match self.metadata:
case m if "timestamp" in m:
# if timestamp present, then must be a video. In that case,
# obtain the original document and augment the content with
# frame/transcript information as well.
frame_description = doc.additional_metadata.get("frame_description")
transcript = doc.additional_metadata.get("transcript")
if not isinstance(frame_description, dict) or not isinstance(
transcript, dict
):
logger.warning(
"Invalid frame description or transcript - not a dictionary"
)
return self.content
ts_frame = TimeSeriesData(frame_description)
ts_transcript = TimeSeriesData(transcript)
timestamps = (
ts_frame.content_to_times[self.content]
+ ts_transcript.content_to_times[self.content]
)
augmented_contents = [
f"Frame description: {ts_frame.at_time(t)} \n \n Transcript: {ts_transcript.at_time(t)}"
for t in timestamps
]
return "\n\n".join(augmented_contents)
case _:
return self.content

View File

@ -5,7 +5,6 @@ import tempfile
import magic import magic
from core.models.documents import Chunk from core.models.documents import Chunk
from core.models.video import TimeSeriesData
from core.parser.base_parser import BaseParser from core.parser.base_parser import BaseParser
from core.parser.unstructured_parser import UnstructuredAPIParser from core.parser.unstructured_parser import UnstructuredAPIParser
from core.parser.video.parse_video import VideoParser from core.parser.video.parse_video import VideoParser

View File

@ -1,6 +1,4 @@
from numbers import Number
import cv2 import cv2
from typing import Dict, Union
import base64 import base64
from openai import OpenAI from openai import OpenAI
import assemblyai as aai import assemblyai as aai

View File

@ -83,11 +83,11 @@ class DocumentService:
"""Retrieve relevant documents.""" """Retrieve relevant documents."""
# Get chunks first # Get chunks first
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score) chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
# Convert to document results # Convert to document results
results = await self._create_document_results(auth, chunks) results = await self._create_document_results(auth, chunks)
logger.info(f"Returning {len(results)} document results") documents = list(results.values())
return results logger.info(f"Returning {len(documents)} document results")
return documents
async def query( async def query(
self, self,
@ -102,7 +102,11 @@ class DocumentService:
"""Generate completion using relevant chunks as context.""" """Generate completion using relevant chunks as context."""
# Get relevant chunks # Get relevant chunks
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score) chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
chunk_contents = [chunk.content for chunk in chunks] documents = await self._create_document_results(auth, chunks)
chunk_contents = [
chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks
]
# Generate completion # Generate completion
request = CompletionRequest( request = CompletionRequest(
@ -277,7 +281,7 @@ class DocumentService:
async def _create_document_results( async def _create_document_results(
self, auth: AuthContext, chunks: List[ChunkResult] self, auth: AuthContext, chunks: List[ChunkResult]
) -> List[DocumentResult]: ) -> Dict[str, DocumentResult]:
"""Group chunks by document and create DocumentResult objects.""" """Group chunks by document and create DocumentResult objects."""
# Group chunks by document and get highest scoring chunk per doc # Group chunks by document and get highest scoring chunk per doc
doc_chunks: Dict[str, ChunkResult] = {} doc_chunks: Dict[str, ChunkResult] = {}
@ -289,7 +293,7 @@ class DocumentService:
doc_chunks[chunk.document_id] = chunk doc_chunks[chunk.document_id] = chunk
logger.info(f"Grouped chunks into {len(doc_chunks)} documents") logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
logger.info(f"Document chunks: {doc_chunks}") logger.info(f"Document chunks: {doc_chunks}")
results = [] results = {}
for doc_id, chunk in doc_chunks.items(): for doc_id, chunk in doc_chunks.items():
# Get document metadata # Get document metadata
doc = await self.db.get_document(doc_id, auth) doc = await self.db.get_document(doc_id, auth)
@ -313,14 +317,12 @@ class DocumentService:
type="url", value=download_url, filename=doc.filename type="url", value=download_url, filename=doc.filename
) )
logger.debug(f"Created URL content for document {doc_id}") logger.debug(f"Created URL content for document {doc_id}")
results[doc_id] = DocumentResult(
results.append( score=chunk.score,
DocumentResult( document_id=doc_id,
score=chunk.score, metadata=doc.metadata,
document_id=doc_id, content=content,
metadata=doc.metadata, additional_metadata=doc.additional_metadata,
content=content,
)
) )
logger.info(f"Created {len(results)} document results") logger.info(f"Created {len(results)} document results")