reformat files

This commit is contained in:
Arnav Agrawal 2024-12-29 12:48:41 +05:30
parent b54cdb6e0c
commit 0e4a43645a
20 changed files with 56 additions and 171 deletions

View File

@ -62,9 +62,7 @@ match settings.VECTOR_STORE_PROVIDER:
index_name=settings.VECTOR_INDEX_NAME,
)
case _:
raise ValueError(
f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}"
)
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
# Initialize storage
match settings.STORAGE_PROVIDER:
@ -110,9 +108,7 @@ match settings.EMBEDDING_PROVIDER:
base_url=settings.OLLAMA_BASE_URL,
)
case _:
raise ValueError(
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
)
raise ValueError(f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}")
# Initialize completion model
match settings.COMPLETION_PROVIDER:
@ -126,9 +122,7 @@ match settings.COMPLETION_PROVIDER:
model_name=settings.COMPLETION_MODEL,
)
case _:
raise ValueError(
f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}"
)
raise ValueError(f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}")
# Initialize document service with configured components
document_service = DocumentService(
@ -154,9 +148,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization[7:] # Remove "Bearer "
payload = jwt.decode(
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
)
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
raise HTTPException(status_code=401, detail="Token expired")
@ -200,9 +192,7 @@ async def ingest_file(
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
async def retrieve_chunks(
request: RetrieveRequest, auth: AuthContext = Depends(verify_token)
):
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant chunks."""
return await document_service.retrieve_chunks(
request.query, auth, request.filters, request.k, request.min_score
@ -210,9 +200,7 @@ async def retrieve_chunks(
@app.post("/retrieve/docs", response_model=List[DocumentResult])
async def retrieve_documents(
request: RetrieveRequest, auth: AuthContext = Depends(verify_token)
):
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
"""Retrieve relevant documents."""
return await document_service.retrieve_docs(
request.query, auth, request.filters, request.k, request.min_score

View File

@ -17,9 +17,7 @@ class BaseDatabase(ABC):
pass
@abstractmethod
async def get_document(
self, document_id: str, auth: AuthContext
) -> Optional[Document]:
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""
Retrieve document metadata by ID if user has access.
Returns: Document if found and accessible, None otherwise

View File

@ -60,9 +60,7 @@ class MongoDatabase(BaseDatabase):
logger.error(f"Error storing document metadata: {str(e)}")
return False
async def get_document(
self, document_id: str, auth: AuthContext
) -> Optional[Document]:
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
"""Retrieve document metadata by ID if user has access."""
try:
# Build access filter
@ -92,11 +90,7 @@ class MongoDatabase(BaseDatabase):
# Build query
auth_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = (
{"$and": [auth_filter, metadata_filter]}
if metadata_filter
else auth_filter
)
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
# Execute paginated query
cursor = self.collection.find(query).skip(skip).limit(limit)
@ -157,9 +151,7 @@ class MongoDatabase(BaseDatabase):
# Build query
auth_filter = self._build_access_filter(auth)
metadata_filter = self._build_metadata_filter(filters)
query = (
{"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
)
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
# Get matching document IDs
cursor = self.collection.find(query, {"external_id": 1})
@ -183,10 +175,7 @@ class MongoDatabase(BaseDatabase):
# Check owner access
owner = doc.get("owner", {})
if (
owner.get("type") == auth.entity_type
and owner.get("id") == auth.entity_id
):
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
return True
# Check permission-specific access

View File

@ -6,9 +6,7 @@ from core.models.documents import Chunk
class BaseEmbeddingModel(ABC):
@abstractmethod
async def embed_for_ingestion(
self, chunks: Union[Chunk, List[Chunk]]
) -> List[List[float]]:
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
"""Generate embeddings for input text"""
pass

View File

@ -9,17 +9,13 @@ class OllamaEmbeddingModel(BaseEmbeddingModel):
self.model_name = model_name
self.client = AsyncClient(host=base_url)
async def embed_for_ingestion(
self, chunks: Union[Chunk, List[Chunk]]
) -> List[List[float]]:
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
if isinstance(chunks, Chunk):
chunks = [chunks]
embeddings: List[List[float]] = []
for c in chunks:
response = await self.client.embeddings(
model=self.model_name, prompt=c.content
)
response = await self.client.embeddings(model=self.model_name, prompt=c.content)
embedding = list(response["embedding"])
embeddings.append(embedding)

View File

@ -10,9 +10,7 @@ class OpenAIEmbeddingModel(BaseEmbeddingModel):
self.client = OpenAI(api_key=api_key)
self.model_name = model_name
async def embed_for_ingestion(
self, chunks: Union[Chunk, List[Chunk]]
) -> List[List[float]]:
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
chunks = [chunks] if isinstance(chunks, Chunk) else chunks
text = [c.content for c in chunks]
response = self.client.embeddings.create(model=self.model_name, input=text)

View File

@ -118,12 +118,8 @@ class ChunkResult(BaseModel):
# frame/transcript information as well.
frame_description = doc.additional_metadata.get("frame_description")
transcript = doc.additional_metadata.get("transcript")
if not isinstance(frame_description, dict) or not isinstance(
transcript, dict
):
logger.warning(
"Invalid frame description or transcript - not a dictionary"
)
if not isinstance(frame_description, dict) or not isinstance(transcript, dict):
logger.warning("Invalid frame description or transcript - not a dictionary")
return self.content
ts_frame = TimeSeriesData(frame_description)
ts_transcript = TimeSeriesData(transcript)

View File

@ -77,13 +77,8 @@ class TimeSeriesData:
start_idx = max(0, start_idx)
end_idx = min(len(self.timestamps) - 1, end_idx)
logger.debug(
f"Retrieving content between {start_time:.2f}s and {end_time:.2f}s"
)
return [
(self.timestamps[i], self.contents[i])
for i in range(start_idx, end_idx + 1)
]
logger.debug(f"Retrieving content between {start_time:.2f}s and {end_time:.2f}s")
return [(self.timestamps[i], self.contents[i]) for i in range(start_idx, end_idx + 1)]
def times_for_content(self, content: str) -> List[float]:
"""Get all timestamps where this content appears"""

View File

@ -27,10 +27,7 @@ class UnstructuredAPIParser(BaseParser):
async def split_text(self, text: str) -> List[Chunk]:
"""Split plain text into chunks"""
return [
Chunk(content=chunk, metadata={})
for chunk in self.text_splitter.split_text(text)
]
return [Chunk(content=chunk, metadata={}) for chunk in self.text_splitter.split_text(text)]
async def parse_file(
self, file: bytes, content_type: str
@ -44,6 +41,4 @@ class UnstructuredAPIParser(BaseParser):
chunking_strategy="by_title",
)
elements = loader.load()
return {}, [
Chunk(content=element.page_content, metadata={}) for element in elements
]
return {}, [Chunk(content=element.page_content, metadata={}) for element in elements]

View File

@ -13,9 +13,7 @@ def debug_object(title, obj):
class VideoParser:
def __init__(
self, video_path: str, assemblyai_api_key: str, frame_sample_rate: int = 120
):
def __init__(self, video_path: str, assemblyai_api_key: str, frame_sample_rate: int = 120):
"""
Initialize the video parser
@ -82,9 +80,7 @@ class VideoParser:
transcript = self.get_transcript_object()
# divide by 1000 because assemblyai timestamps are in milliseconds
time_to_text = (
{u.start / 1000: u.text for u in transcript.utterances}
if transcript.utterances
else {}
{u.start / 1000: u.text for u in transcript.utterances} if transcript.utterances else {}
)
debug_object("Time to text", time_to_text)
self.transcript = TimeSeriesData(time_to_text)
@ -135,9 +131,7 @@ class VideoParser:
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
},
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
},
],
}

View File

@ -62,9 +62,7 @@ class DocumentService:
logger.info(f"Found {len(doc_ids)} authorized documents")
# Search chunks with vector similarity
chunks = await self.vector_store.query_similar(
query_embedding, k=k, doc_ids=doc_ids
)
chunks = await self.vector_store.query_similar(query_embedding, k=k, doc_ids=doc_ids)
logger.info(f"Found {len(chunks)} similar chunks")
# Create and return chunk results
@ -104,9 +102,7 @@ class DocumentService:
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
documents = await self._create_document_results(auth, chunks)
chunk_contents = [
chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks
]
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
# Generate completion
request = CompletionRequest(
@ -119,9 +115,7 @@ class DocumentService:
response = await self.completion_model.complete(request)
return response
async def ingest_text(
self, request: IngestTextRequest, auth: AuthContext
) -> Document:
async def ingest_text(self, request: IngestTextRequest, auth: AuthContext) -> Document:
"""Ingest a text document."""
if "write" not in auth.permissions:
logger.error(f"User {auth.entity_id} does not have write permission")
@ -190,9 +184,7 @@ class DocumentService:
base64.b64encode(file_content).decode(), doc.external_id, file.content_type
)
doc.storage_info = {"bucket": storage_info[0], "key": storage_info[1]}
logger.info(
f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`"
)
logger.info(f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`")
if not chunks:
raise ValueError("No content chunks extracted from file")
@ -304,18 +296,14 @@ class DocumentService:
# Create DocumentContent based on content type
if doc.content_type == "text/plain":
content = DocumentContent(
type="string", value=chunk.content, filename=None
)
content = DocumentContent(type="string", value=chunk.content, filename=None)
logger.debug(f"Created text content for document {doc_id}")
else:
# Generate download URL for file types
download_url = await self.storage.get_download_url(
doc.storage_info["bucket"], doc.storage_info["key"]
)
content = DocumentContent(
type="url", value=download_url, filename=doc.filename
)
content = DocumentContent(type="url", value=download_url, filename=doc.filename)
logger.debug(f"Created URL content for document {doc_id}")
results[doc_id] = DocumentResult(
score=chunk.score,

View File

@ -57,9 +57,7 @@ class BaseStorage(ABC):
pass
@abstractmethod
async def get_download_url(
self, bucket: str, key: str, expires_in: int = 3600
) -> str:
async def get_download_url(self, bucket: str, key: str, expires_in: int = 3600) -> str:
"""
Get temporary download URL.

View File

@ -61,9 +61,7 @@ class S3Storage(BaseStorage):
Path(temp_file_path).unlink()
else:
# File object
self.s3_client.upload_fileobj(
file, self.default_bucket, key, ExtraArgs=extra_args
)
self.s3_client.upload_fileobj(file, self.default_bucket, key, ExtraArgs=extra_args)
return self.default_bucket, key
@ -80,9 +78,7 @@ class S3Storage(BaseStorage):
extension = detect_file_type(content)
key = f"{key}{extension}"
return await self.upload_file(
file=decoded_content, key=key, content_type=content_type
)
return await self.upload_file(file=decoded_content, key=key, content_type=content_type)
except Exception as e:
logger.error(f"Error uploading base64 content to S3: {e}")
@ -97,9 +93,7 @@ class S3Storage(BaseStorage):
logger.error(f"Error downloading from S3: {e}")
raise
async def get_download_url(
self, bucket: str, key: str, expires_in: int = 3600
) -> str:
async def get_download_url(self, bucket: str, key: str, expires_in: int = 3600) -> str:
"""Generate presigned download URL."""
if not key or not bucket:
return ""

View File

@ -5,9 +5,7 @@ from core.models.documents import DocumentChunk
class BaseVectorStore(ABC):
@abstractmethod
async def store_embeddings(
self, chunks: List[DocumentChunk]
) -> Tuple[bool, List[str]]:
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
"""Store document chunks and their embeddings"""
pass

View File

@ -41,9 +41,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
logger.error(f"Error initializing vector store indexes: {str(e)}")
return False
async def store_embeddings(
self, chunks: List[DocumentChunk]
) -> Tuple[bool, List[str]]:
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
"""Store document chunks with their embeddings."""
try:
if not chunks:
@ -56,8 +54,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
# Ensure we have required fields
if not doc.get("embedding"):
logger.error(
f"Missing embedding for chunk "
f"{chunk.document_id}-{chunk.chunk_number}"
f"Missing embedding for chunk " f"{chunk.document_id}-{chunk.chunk_number}"
)
continue
documents.append(doc)
@ -65,9 +62,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
if documents:
# Use ordered=False to continue even if some inserts fail
result = await self.collection.insert_many(documents, ordered=False)
return len(result.inserted_ids) > 0, [
str(id) for id in result.inserted_ids
]
return len(result.inserted_ids) > 0, [str(id) for id in result.inserted_ids]
else:
logger.error(f"No documents to store - here is the input: {chunks}")
return False, []
@ -85,8 +80,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"""Find similar chunks using MongoDB Atlas Vector Search."""
try:
logger.debug(
f"Searching in database {self.db.name} "
f"collection {self.collection.name}"
f"Searching in database {self.db.name} " f"collection {self.collection.name}"
)
logger.debug(f"Query vector looks like: {query_embedding}")
logger.debug(f"Doc IDs: {doc_ids}")

View File

@ -179,18 +179,14 @@ Root Directory: {root_dir}
def main():
parser = argparse.ArgumentParser(
description="Aggregate Python files with directory structure"
)
parser = argparse.ArgumentParser(description="Aggregate Python files with directory structure")
parser.add_argument(
"--mode",
choices=["all", "core", "sdk", "test"],
default="all",
help="Which directories to process",
)
parser.add_argument(
"--output", default="aggregated_code.txt", help="Output file name"
)
parser.add_argument("--output", default="aggregated_code.txt", help="Output file name")
args = parser.parse_args()
script_name = os.path.basename(__file__)

View File

@ -16,9 +16,7 @@ load_dotenv(find_dotenv(), override=True)
# Set up argument parser
parser = argparse.ArgumentParser(description="Setup S3 bucket and MongoDB collections")
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
parser.add_argument(
"--quiet", action="store_true", help="Only show warning and error logs"
)
parser.add_argument("--quiet", action="store_true", help="Only show warning and error logs")
args = parser.parse_args()
# Configure logging based on command line arguments
@ -155,14 +153,10 @@ def setup_mongodb():
type="vectorSearch",
)
db[CHUNKS_COLLECTION].create_search_index(model=vector_index)
LOGGER.info(
"Vector index 'vector_index' created on 'documents_chunk' collection."
)
LOGGER.info("Vector index 'vector_index' created on 'documents_chunk' collection.")
except ConnectionFailure:
LOGGER.error(
"Failed to connect to MongoDB. Check your MongoDB URI and network connection."
)
LOGGER.error("Failed to connect to MongoDB. Check your MongoDB URI and network connection.")
except OperationFailure as e:
LOGGER.error(f"MongoDB operation failed: {e}")
except Exception as e:

View File

@ -171,16 +171,12 @@ class AsyncDataBridge:
try:
# Prepare multipart form data
files = {
"file": (filename, file_obj, content_type or "application/octet-stream")
}
files = {"file": (filename, file_obj, content_type or "application/octet-stream")}
# Add metadata
data = {"metadata": json.dumps(metadata or {})}
response = await self._request(
"POST", "ingest/file", data=data, files=files
)
response = await self._request("POST", "ingest/file", data=data, files=files)
return Document(**response)
finally:
# Close file if we opened it

View File

@ -8,9 +8,7 @@ class Document(BaseModel):
external_id: str = Field(..., description="Unique document identifier")
content_type: str = Field(..., description="Content type of the document")
filename: Optional[str] = Field(None, description="Original filename if available")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="User-defined metadata"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="User-defined metadata")
storage_info: Dict[str, str] = Field(
default_factory=dict, description="Storage-related information"
)
@ -20,18 +18,14 @@ class Document(BaseModel):
access_control: Dict[str, Any] = Field(
default_factory=dict, description="Access control information"
)
chunk_ids: List[str] = Field(
default_factory=list, description="IDs of document chunks"
)
chunk_ids: List[str] = Field(default_factory=list, description="IDs of document chunks")
class IngestTextRequest(BaseModel):
"""Request model for text ingestion"""
content: str = Field(..., description="Text content to ingest")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Optional metadata"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
class ChunkResult(BaseModel):
@ -41,22 +35,16 @@ class ChunkResult(BaseModel):
score: float = Field(..., description="Relevance score")
document_id: str = Field(..., description="Parent document ID")
chunk_number: int = Field(..., description="Chunk sequence number")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Document metadata"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
content_type: str = Field(..., description="Content type")
filename: Optional[str] = Field(None, description="Original filename")
download_url: Optional[str] = Field(
None, description="URL to download full document"
)
download_url: Optional[str] = Field(None, description="URL to download full document")
class DocumentContent(BaseModel):
"""Represents either a URL or content string"""
type: Literal["url", "string"] = Field(
..., description="Content type (url or string)"
)
type: Literal["url", "string"] = Field(..., description="Content type (url or string)")
value: str = Field(..., description="The actual content or URL")
filename: Optional[str] = Field(None, description="Filename when type is url")
@ -74,9 +62,7 @@ class DocumentResult(BaseModel):
score: float = Field(..., description="Relevance score")
document_id: str = Field(..., description="Document ID")
metadata: Dict[str, Any] = Field(
default_factory=dict, description="Document metadata"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
content: DocumentContent = Field(..., description="Document content or URL")

View File

@ -88,9 +88,7 @@ class DataBridge:
response.raise_for_status()
return response.json()
def ingest_text(
self, content: str, metadata: Optional[Dict[str, Any]] = None
) -> Document:
def ingest_text(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> Document:
"""
Ingest a text document into DataBridge.
@ -166,9 +164,7 @@ class DataBridge:
try:
# Prepare multipart form data
files = {
"file": (filename, file_obj, content_type or "application/octet-stream")
}
files = {"file": (filename, file_obj, content_type or "application/octet-stream")}
# Add metadata
data = {"metadata": json.dumps(metadata or {})}
@ -312,9 +308,7 @@ class DataBridge:
next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"})
```
"""
response = self._request(
"GET", f"documents?skip={skip}&limit={limit}&filters={filters}"
)
response = self._request("GET", f"documents?skip={skip}&limit={limit}&filters={filters}")
return [Document(**doc) for doc in response]
def get_document(self, document_id: str) -> Document: