added frame/transcript augmentation for video retrieval

This commit is contained in:
Arnav Agrawal 2024-12-29 12:45:12 +05:30
parent 7830b42c6b
commit 80db083471
4 changed files with 61 additions and 30 deletions

View File

@ -5,6 +5,8 @@ from pydantic import BaseModel, Field, field_validator
import uuid
import logging
from core.models.video import TimeSeriesData
logger = logging.getLogger(__name__)
@ -69,19 +71,6 @@ class Chunk(BaseModel):
)
class ChunkResult(BaseModel):
"""Query result at chunk level"""
content: str
score: float
document_id: str # external_id
chunk_number: int
metadata: Dict[str, Any]
content_type: str
filename: Optional[str] = None
download_url: Optional[str] = None
class DocumentContent(BaseModel):
"""Represents either a URL or content string"""
@ -106,3 +95,46 @@ class DocumentResult(BaseModel):
document_id: str # external_id
metadata: Dict[str, Any]
content: DocumentContent
additional_metadata: Dict[str, Any]
class ChunkResult(BaseModel):
"""Query result at chunk level"""
content: str
score: float
document_id: str # external_id
chunk_number: int
metadata: Dict[str, Any]
content_type: str
filename: Optional[str] = None
download_url: Optional[str] = None
def augmented_content(self, doc: DocumentResult) -> str:
match self.metadata:
case m if "timestamp" in m:
# if timestamp present, then must be a video. In that case,
# obtain the original document and augment the content with
# frame/transcript information as well.
frame_description = doc.additional_metadata.get("frame_description")
transcript = doc.additional_metadata.get("transcript")
if not isinstance(frame_description, dict) or not isinstance(
transcript, dict
):
logger.warning(
"Invalid frame description or transcript - not a dictionary"
)
return self.content
ts_frame = TimeSeriesData(frame_description)
ts_transcript = TimeSeriesData(transcript)
timestamps = (
ts_frame.content_to_times[self.content]
+ ts_transcript.content_to_times[self.content]
)
augmented_contents = [
f"Frame description: {ts_frame.at_time(t)} \n \n Transcript: {ts_transcript.at_time(t)}"
for t in timestamps
]
return "\n\n".join(augmented_contents)
case _:
return self.content

View File

@ -5,7 +5,6 @@ import tempfile
import magic
from core.models.documents import Chunk
from core.models.video import TimeSeriesData
from core.parser.base_parser import BaseParser
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.parser.video.parse_video import VideoParser

View File

@ -1,6 +1,4 @@
from numbers import Number
import cv2
from typing import Dict, Union
import base64
from openai import OpenAI
import assemblyai as aai

View File

@ -83,11 +83,11 @@ class DocumentService:
"""Retrieve relevant documents."""
# Get chunks first
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
# Convert to document results
results = await self._create_document_results(auth, chunks)
logger.info(f"Returning {len(results)} document results")
return results
documents = list(results.values())
logger.info(f"Returning {len(documents)} document results")
return documents
async def query(
self,
@ -102,7 +102,11 @@ class DocumentService:
"""Generate completion using relevant chunks as context."""
# Get relevant chunks
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
chunk_contents = [chunk.content for chunk in chunks]
documents = await self._create_document_results(auth, chunks)
chunk_contents = [
chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks
]
# Generate completion
request = CompletionRequest(
@ -277,7 +281,7 @@ class DocumentService:
async def _create_document_results(
self, auth: AuthContext, chunks: List[ChunkResult]
) -> List[DocumentResult]:
) -> Dict[str, DocumentResult]:
"""Group chunks by document and create DocumentResult objects."""
# Group chunks by document and get highest scoring chunk per doc
doc_chunks: Dict[str, ChunkResult] = {}
@ -289,7 +293,7 @@ class DocumentService:
doc_chunks[chunk.document_id] = chunk
logger.info(f"Grouped chunks into {len(doc_chunks)} documents")
logger.info(f"Document chunks: {doc_chunks}")
results = []
results = {}
for doc_id, chunk in doc_chunks.items():
# Get document metadata
doc = await self.db.get_document(doc_id, auth)
@ -313,14 +317,12 @@ class DocumentService:
type="url", value=download_url, filename=doc.filename
)
logger.debug(f"Created URL content for document {doc_id}")
results.append(
DocumentResult(
score=chunk.score,
document_id=doc_id,
metadata=doc.metadata,
content=content,
)
results[doc_id] = DocumentResult(
score=chunk.score,
document_id=doc_id,
metadata=doc.metadata,
content=content,
additional_metadata=doc.additional_metadata,
)
logger.info(f"Created {len(results)} document results")