mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
reformat files
This commit is contained in:
parent
b54cdb6e0c
commit
0e4a43645a
24
core/api.py
24
core/api.py
@ -62,9 +62,7 @@ match settings.VECTOR_STORE_PROVIDER:
|
|||||||
index_name=settings.VECTOR_INDEX_NAME,
|
index_name=settings.VECTOR_INDEX_NAME,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}")
|
||||||
f"Unsupported vector store provider: {settings.VECTOR_STORE_PROVIDER}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize storage
|
# Initialize storage
|
||||||
match settings.STORAGE_PROVIDER:
|
match settings.STORAGE_PROVIDER:
|
||||||
@ -110,9 +108,7 @@ match settings.EMBEDDING_PROVIDER:
|
|||||||
base_url=settings.OLLAMA_BASE_URL,
|
base_url=settings.OLLAMA_BASE_URL,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}")
|
||||||
f"Unsupported embedding provider: {settings.EMBEDDING_PROVIDER}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize completion model
|
# Initialize completion model
|
||||||
match settings.COMPLETION_PROVIDER:
|
match settings.COMPLETION_PROVIDER:
|
||||||
@ -126,9 +122,7 @@ match settings.COMPLETION_PROVIDER:
|
|||||||
model_name=settings.COMPLETION_MODEL,
|
model_name=settings.COMPLETION_MODEL,
|
||||||
)
|
)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(
|
raise ValueError(f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}")
|
||||||
f"Unsupported completion provider: {settings.COMPLETION_PROVIDER}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize document service with configured components
|
# Initialize document service with configured components
|
||||||
document_service = DocumentService(
|
document_service = DocumentService(
|
||||||
@ -154,9 +148,7 @@ async def verify_token(authorization: str = Header(None)) -> AuthContext:
|
|||||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||||
|
|
||||||
token = authorization[7:] # Remove "Bearer "
|
token = authorization[7:] # Remove "Bearer "
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM])
|
||||||
token, settings.JWT_SECRET_KEY, algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
|
||||||
|
|
||||||
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
if datetime.fromtimestamp(payload["exp"], UTC) < datetime.now(UTC):
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
@ -200,9 +192,7 @@ async def ingest_file(
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
@app.post("/retrieve/chunks", response_model=List[ChunkResult])
|
||||||
async def retrieve_chunks(
|
async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||||
request: RetrieveRequest, auth: AuthContext = Depends(verify_token)
|
|
||||||
):
|
|
||||||
"""Retrieve relevant chunks."""
|
"""Retrieve relevant chunks."""
|
||||||
return await document_service.retrieve_chunks(
|
return await document_service.retrieve_chunks(
|
||||||
request.query, auth, request.filters, request.k, request.min_score
|
request.query, auth, request.filters, request.k, request.min_score
|
||||||
@ -210,9 +200,7 @@ async def retrieve_chunks(
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
@app.post("/retrieve/docs", response_model=List[DocumentResult])
|
||||||
async def retrieve_documents(
|
async def retrieve_documents(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)):
|
||||||
request: RetrieveRequest, auth: AuthContext = Depends(verify_token)
|
|
||||||
):
|
|
||||||
"""Retrieve relevant documents."""
|
"""Retrieve relevant documents."""
|
||||||
return await document_service.retrieve_docs(
|
return await document_service.retrieve_docs(
|
||||||
request.query, auth, request.filters, request.k, request.min_score
|
request.query, auth, request.filters, request.k, request.min_score
|
||||||
|
@ -17,9 +17,7 @@ class BaseDatabase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_document(
|
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
|
||||||
self, document_id: str, auth: AuthContext
|
|
||||||
) -> Optional[Document]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve document metadata by ID if user has access.
|
Retrieve document metadata by ID if user has access.
|
||||||
Returns: Document if found and accessible, None otherwise
|
Returns: Document if found and accessible, None otherwise
|
||||||
|
@ -60,9 +60,7 @@ class MongoDatabase(BaseDatabase):
|
|||||||
logger.error(f"Error storing document metadata: {str(e)}")
|
logger.error(f"Error storing document metadata: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def get_document(
|
async def get_document(self, document_id: str, auth: AuthContext) -> Optional[Document]:
|
||||||
self, document_id: str, auth: AuthContext
|
|
||||||
) -> Optional[Document]:
|
|
||||||
"""Retrieve document metadata by ID if user has access."""
|
"""Retrieve document metadata by ID if user has access."""
|
||||||
try:
|
try:
|
||||||
# Build access filter
|
# Build access filter
|
||||||
@ -92,11 +90,7 @@ class MongoDatabase(BaseDatabase):
|
|||||||
# Build query
|
# Build query
|
||||||
auth_filter = self._build_access_filter(auth)
|
auth_filter = self._build_access_filter(auth)
|
||||||
metadata_filter = self._build_metadata_filter(filters)
|
metadata_filter = self._build_metadata_filter(filters)
|
||||||
query = (
|
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
|
||||||
{"$and": [auth_filter, metadata_filter]}
|
|
||||||
if metadata_filter
|
|
||||||
else auth_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Execute paginated query
|
# Execute paginated query
|
||||||
cursor = self.collection.find(query).skip(skip).limit(limit)
|
cursor = self.collection.find(query).skip(skip).limit(limit)
|
||||||
@ -157,9 +151,7 @@ class MongoDatabase(BaseDatabase):
|
|||||||
# Build query
|
# Build query
|
||||||
auth_filter = self._build_access_filter(auth)
|
auth_filter = self._build_access_filter(auth)
|
||||||
metadata_filter = self._build_metadata_filter(filters)
|
metadata_filter = self._build_metadata_filter(filters)
|
||||||
query = (
|
query = {"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
|
||||||
{"$and": [auth_filter, metadata_filter]} if metadata_filter else auth_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get matching document IDs
|
# Get matching document IDs
|
||||||
cursor = self.collection.find(query, {"external_id": 1})
|
cursor = self.collection.find(query, {"external_id": 1})
|
||||||
@ -183,10 +175,7 @@ class MongoDatabase(BaseDatabase):
|
|||||||
|
|
||||||
# Check owner access
|
# Check owner access
|
||||||
owner = doc.get("owner", {})
|
owner = doc.get("owner", {})
|
||||||
if (
|
if owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id:
|
||||||
owner.get("type") == auth.entity_type
|
|
||||||
and owner.get("id") == auth.entity_id
|
|
||||||
):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check permission-specific access
|
# Check permission-specific access
|
||||||
|
@ -6,9 +6,7 @@ from core.models.documents import Chunk
|
|||||||
|
|
||||||
class BaseEmbeddingModel(ABC):
|
class BaseEmbeddingModel(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def embed_for_ingestion(
|
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
|
||||||
self, chunks: Union[Chunk, List[Chunk]]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
"""Generate embeddings for input text"""
|
"""Generate embeddings for input text"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -9,17 +9,13 @@ class OllamaEmbeddingModel(BaseEmbeddingModel):
|
|||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.client = AsyncClient(host=base_url)
|
self.client = AsyncClient(host=base_url)
|
||||||
|
|
||||||
async def embed_for_ingestion(
|
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
|
||||||
self, chunks: Union[Chunk, List[Chunk]]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
if isinstance(chunks, Chunk):
|
if isinstance(chunks, Chunk):
|
||||||
chunks = [chunks]
|
chunks = [chunks]
|
||||||
|
|
||||||
embeddings: List[List[float]] = []
|
embeddings: List[List[float]] = []
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
response = await self.client.embeddings(
|
response = await self.client.embeddings(model=self.model_name, prompt=c.content)
|
||||||
model=self.model_name, prompt=c.content
|
|
||||||
)
|
|
||||||
embedding = list(response["embedding"])
|
embedding = list(response["embedding"])
|
||||||
embeddings.append(embedding)
|
embeddings.append(embedding)
|
||||||
|
|
||||||
|
@ -10,9 +10,7 @@ class OpenAIEmbeddingModel(BaseEmbeddingModel):
|
|||||||
self.client = OpenAI(api_key=api_key)
|
self.client = OpenAI(api_key=api_key)
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
async def embed_for_ingestion(
|
async def embed_for_ingestion(self, chunks: Union[Chunk, List[Chunk]]) -> List[List[float]]:
|
||||||
self, chunks: Union[Chunk, List[Chunk]]
|
|
||||||
) -> List[List[float]]:
|
|
||||||
chunks = [chunks] if isinstance(chunks, Chunk) else chunks
|
chunks = [chunks] if isinstance(chunks, Chunk) else chunks
|
||||||
text = [c.content for c in chunks]
|
text = [c.content for c in chunks]
|
||||||
response = self.client.embeddings.create(model=self.model_name, input=text)
|
response = self.client.embeddings.create(model=self.model_name, input=text)
|
||||||
|
@ -118,12 +118,8 @@ class ChunkResult(BaseModel):
|
|||||||
# frame/transcript information as well.
|
# frame/transcript information as well.
|
||||||
frame_description = doc.additional_metadata.get("frame_description")
|
frame_description = doc.additional_metadata.get("frame_description")
|
||||||
transcript = doc.additional_metadata.get("transcript")
|
transcript = doc.additional_metadata.get("transcript")
|
||||||
if not isinstance(frame_description, dict) or not isinstance(
|
if not isinstance(frame_description, dict) or not isinstance(transcript, dict):
|
||||||
transcript, dict
|
logger.warning("Invalid frame description or transcript - not a dictionary")
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Invalid frame description or transcript - not a dictionary"
|
|
||||||
)
|
|
||||||
return self.content
|
return self.content
|
||||||
ts_frame = TimeSeriesData(frame_description)
|
ts_frame = TimeSeriesData(frame_description)
|
||||||
ts_transcript = TimeSeriesData(transcript)
|
ts_transcript = TimeSeriesData(transcript)
|
||||||
|
@ -77,13 +77,8 @@ class TimeSeriesData:
|
|||||||
start_idx = max(0, start_idx)
|
start_idx = max(0, start_idx)
|
||||||
end_idx = min(len(self.timestamps) - 1, end_idx)
|
end_idx = min(len(self.timestamps) - 1, end_idx)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Retrieving content between {start_time:.2f}s and {end_time:.2f}s")
|
||||||
f"Retrieving content between {start_time:.2f}s and {end_time:.2f}s"
|
return [(self.timestamps[i], self.contents[i]) for i in range(start_idx, end_idx + 1)]
|
||||||
)
|
|
||||||
return [
|
|
||||||
(self.timestamps[i], self.contents[i])
|
|
||||||
for i in range(start_idx, end_idx + 1)
|
|
||||||
]
|
|
||||||
|
|
||||||
def times_for_content(self, content: str) -> List[float]:
|
def times_for_content(self, content: str) -> List[float]:
|
||||||
"""Get all timestamps where this content appears"""
|
"""Get all timestamps where this content appears"""
|
||||||
|
@ -27,10 +27,7 @@ class UnstructuredAPIParser(BaseParser):
|
|||||||
|
|
||||||
async def split_text(self, text: str) -> List[Chunk]:
|
async def split_text(self, text: str) -> List[Chunk]:
|
||||||
"""Split plain text into chunks"""
|
"""Split plain text into chunks"""
|
||||||
return [
|
return [Chunk(content=chunk, metadata={}) for chunk in self.text_splitter.split_text(text)]
|
||||||
Chunk(content=chunk, metadata={})
|
|
||||||
for chunk in self.text_splitter.split_text(text)
|
|
||||||
]
|
|
||||||
|
|
||||||
async def parse_file(
|
async def parse_file(
|
||||||
self, file: bytes, content_type: str
|
self, file: bytes, content_type: str
|
||||||
@ -44,6 +41,4 @@ class UnstructuredAPIParser(BaseParser):
|
|||||||
chunking_strategy="by_title",
|
chunking_strategy="by_title",
|
||||||
)
|
)
|
||||||
elements = loader.load()
|
elements = loader.load()
|
||||||
return {}, [
|
return {}, [Chunk(content=element.page_content, metadata={}) for element in elements]
|
||||||
Chunk(content=element.page_content, metadata={}) for element in elements
|
|
||||||
]
|
|
||||||
|
@ -13,9 +13,7 @@ def debug_object(title, obj):
|
|||||||
|
|
||||||
|
|
||||||
class VideoParser:
|
class VideoParser:
|
||||||
def __init__(
|
def __init__(self, video_path: str, assemblyai_api_key: str, frame_sample_rate: int = 120):
|
||||||
self, video_path: str, assemblyai_api_key: str, frame_sample_rate: int = 120
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Initialize the video parser
|
Initialize the video parser
|
||||||
|
|
||||||
@ -82,9 +80,7 @@ class VideoParser:
|
|||||||
transcript = self.get_transcript_object()
|
transcript = self.get_transcript_object()
|
||||||
# divide by 1000 because assemblyai timestamps are in milliseconds
|
# divide by 1000 because assemblyai timestamps are in milliseconds
|
||||||
time_to_text = (
|
time_to_text = (
|
||||||
{u.start / 1000: u.text for u in transcript.utterances}
|
{u.start / 1000: u.text for u in transcript.utterances} if transcript.utterances else {}
|
||||||
if transcript.utterances
|
|
||||||
else {}
|
|
||||||
)
|
)
|
||||||
debug_object("Time to text", time_to_text)
|
debug_object("Time to text", time_to_text)
|
||||||
self.transcript = TimeSeriesData(time_to_text)
|
self.transcript = TimeSeriesData(time_to_text)
|
||||||
@ -135,9 +131,7 @@ class VideoParser:
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
|
||||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
@ -62,9 +62,7 @@ class DocumentService:
|
|||||||
logger.info(f"Found {len(doc_ids)} authorized documents")
|
logger.info(f"Found {len(doc_ids)} authorized documents")
|
||||||
|
|
||||||
# Search chunks with vector similarity
|
# Search chunks with vector similarity
|
||||||
chunks = await self.vector_store.query_similar(
|
chunks = await self.vector_store.query_similar(query_embedding, k=k, doc_ids=doc_ids)
|
||||||
query_embedding, k=k, doc_ids=doc_ids
|
|
||||||
)
|
|
||||||
logger.info(f"Found {len(chunks)} similar chunks")
|
logger.info(f"Found {len(chunks)} similar chunks")
|
||||||
|
|
||||||
# Create and return chunk results
|
# Create and return chunk results
|
||||||
@ -104,9 +102,7 @@ class DocumentService:
|
|||||||
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
chunks = await self.retrieve_chunks(query, auth, filters, k, min_score)
|
||||||
documents = await self._create_document_results(auth, chunks)
|
documents = await self._create_document_results(auth, chunks)
|
||||||
|
|
||||||
chunk_contents = [
|
chunk_contents = [chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks]
|
||||||
chunk.augmented_content(documents[chunk.document_id]) for chunk in chunks
|
|
||||||
]
|
|
||||||
|
|
||||||
# Generate completion
|
# Generate completion
|
||||||
request = CompletionRequest(
|
request = CompletionRequest(
|
||||||
@ -119,9 +115,7 @@ class DocumentService:
|
|||||||
response = await self.completion_model.complete(request)
|
response = await self.completion_model.complete(request)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def ingest_text(
|
async def ingest_text(self, request: IngestTextRequest, auth: AuthContext) -> Document:
|
||||||
self, request: IngestTextRequest, auth: AuthContext
|
|
||||||
) -> Document:
|
|
||||||
"""Ingest a text document."""
|
"""Ingest a text document."""
|
||||||
if "write" not in auth.permissions:
|
if "write" not in auth.permissions:
|
||||||
logger.error(f"User {auth.entity_id} does not have write permission")
|
logger.error(f"User {auth.entity_id} does not have write permission")
|
||||||
@ -190,9 +184,7 @@ class DocumentService:
|
|||||||
base64.b64encode(file_content).decode(), doc.external_id, file.content_type
|
base64.b64encode(file_content).decode(), doc.external_id, file.content_type
|
||||||
)
|
)
|
||||||
doc.storage_info = {"bucket": storage_info[0], "key": storage_info[1]}
|
doc.storage_info = {"bucket": storage_info[0], "key": storage_info[1]}
|
||||||
logger.info(
|
logger.info(f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`")
|
||||||
f"Stored file in bucket `{storage_info[0]}` with key `{storage_info[1]}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
raise ValueError("No content chunks extracted from file")
|
raise ValueError("No content chunks extracted from file")
|
||||||
@ -304,18 +296,14 @@ class DocumentService:
|
|||||||
|
|
||||||
# Create DocumentContent based on content type
|
# Create DocumentContent based on content type
|
||||||
if doc.content_type == "text/plain":
|
if doc.content_type == "text/plain":
|
||||||
content = DocumentContent(
|
content = DocumentContent(type="string", value=chunk.content, filename=None)
|
||||||
type="string", value=chunk.content, filename=None
|
|
||||||
)
|
|
||||||
logger.debug(f"Created text content for document {doc_id}")
|
logger.debug(f"Created text content for document {doc_id}")
|
||||||
else:
|
else:
|
||||||
# Generate download URL for file types
|
# Generate download URL for file types
|
||||||
download_url = await self.storage.get_download_url(
|
download_url = await self.storage.get_download_url(
|
||||||
doc.storage_info["bucket"], doc.storage_info["key"]
|
doc.storage_info["bucket"], doc.storage_info["key"]
|
||||||
)
|
)
|
||||||
content = DocumentContent(
|
content = DocumentContent(type="url", value=download_url, filename=doc.filename)
|
||||||
type="url", value=download_url, filename=doc.filename
|
|
||||||
)
|
|
||||||
logger.debug(f"Created URL content for document {doc_id}")
|
logger.debug(f"Created URL content for document {doc_id}")
|
||||||
results[doc_id] = DocumentResult(
|
results[doc_id] = DocumentResult(
|
||||||
score=chunk.score,
|
score=chunk.score,
|
||||||
|
@ -57,9 +57,7 @@ class BaseStorage(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_download_url(
|
async def get_download_url(self, bucket: str, key: str, expires_in: int = 3600) -> str:
|
||||||
self, bucket: str, key: str, expires_in: int = 3600
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Get temporary download URL.
|
Get temporary download URL.
|
||||||
|
|
||||||
|
@ -61,9 +61,7 @@ class S3Storage(BaseStorage):
|
|||||||
Path(temp_file_path).unlink()
|
Path(temp_file_path).unlink()
|
||||||
else:
|
else:
|
||||||
# File object
|
# File object
|
||||||
self.s3_client.upload_fileobj(
|
self.s3_client.upload_fileobj(file, self.default_bucket, key, ExtraArgs=extra_args)
|
||||||
file, self.default_bucket, key, ExtraArgs=extra_args
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.default_bucket, key
|
return self.default_bucket, key
|
||||||
|
|
||||||
@ -80,9 +78,7 @@ class S3Storage(BaseStorage):
|
|||||||
extension = detect_file_type(content)
|
extension = detect_file_type(content)
|
||||||
key = f"{key}{extension}"
|
key = f"{key}{extension}"
|
||||||
|
|
||||||
return await self.upload_file(
|
return await self.upload_file(file=decoded_content, key=key, content_type=content_type)
|
||||||
file=decoded_content, key=key, content_type=content_type
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error uploading base64 content to S3: {e}")
|
logger.error(f"Error uploading base64 content to S3: {e}")
|
||||||
@ -97,9 +93,7 @@ class S3Storage(BaseStorage):
|
|||||||
logger.error(f"Error downloading from S3: {e}")
|
logger.error(f"Error downloading from S3: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_download_url(
|
async def get_download_url(self, bucket: str, key: str, expires_in: int = 3600) -> str:
|
||||||
self, bucket: str, key: str, expires_in: int = 3600
|
|
||||||
) -> str:
|
|
||||||
"""Generate presigned download URL."""
|
"""Generate presigned download URL."""
|
||||||
if not key or not bucket:
|
if not key or not bucket:
|
||||||
return ""
|
return ""
|
||||||
|
@ -5,9 +5,7 @@ from core.models.documents import DocumentChunk
|
|||||||
|
|
||||||
class BaseVectorStore(ABC):
|
class BaseVectorStore(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def store_embeddings(
|
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||||
self, chunks: List[DocumentChunk]
|
|
||||||
) -> Tuple[bool, List[str]]:
|
|
||||||
"""Store document chunks and their embeddings"""
|
"""Store document chunks and their embeddings"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -41,9 +41,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
|||||||
logger.error(f"Error initializing vector store indexes: {str(e)}")
|
logger.error(f"Error initializing vector store indexes: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def store_embeddings(
|
async def store_embeddings(self, chunks: List[DocumentChunk]) -> Tuple[bool, List[str]]:
|
||||||
self, chunks: List[DocumentChunk]
|
|
||||||
) -> Tuple[bool, List[str]]:
|
|
||||||
"""Store document chunks with their embeddings."""
|
"""Store document chunks with their embeddings."""
|
||||||
try:
|
try:
|
||||||
if not chunks:
|
if not chunks:
|
||||||
@ -56,8 +54,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
|||||||
# Ensure we have required fields
|
# Ensure we have required fields
|
||||||
if not doc.get("embedding"):
|
if not doc.get("embedding"):
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Missing embedding for chunk "
|
f"Missing embedding for chunk " f"{chunk.document_id}-{chunk.chunk_number}"
|
||||||
f"{chunk.document_id}-{chunk.chunk_number}"
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
documents.append(doc)
|
documents.append(doc)
|
||||||
@ -65,9 +62,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
|||||||
if documents:
|
if documents:
|
||||||
# Use ordered=False to continue even if some inserts fail
|
# Use ordered=False to continue even if some inserts fail
|
||||||
result = await self.collection.insert_many(documents, ordered=False)
|
result = await self.collection.insert_many(documents, ordered=False)
|
||||||
return len(result.inserted_ids) > 0, [
|
return len(result.inserted_ids) > 0, [str(id) for id in result.inserted_ids]
|
||||||
str(id) for id in result.inserted_ids
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"No documents to store - here is the input: {chunks}")
|
logger.error(f"No documents to store - here is the input: {chunks}")
|
||||||
return False, []
|
return False, []
|
||||||
@ -85,8 +80,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
|
|||||||
"""Find similar chunks using MongoDB Atlas Vector Search."""
|
"""Find similar chunks using MongoDB Atlas Vector Search."""
|
||||||
try:
|
try:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Searching in database {self.db.name} "
|
f"Searching in database {self.db.name} " f"collection {self.collection.name}"
|
||||||
f"collection {self.collection.name}"
|
|
||||||
)
|
)
|
||||||
logger.debug(f"Query vector looks like: {query_embedding}")
|
logger.debug(f"Query vector looks like: {query_embedding}")
|
||||||
logger.debug(f"Doc IDs: {doc_ids}")
|
logger.debug(f"Doc IDs: {doc_ids}")
|
||||||
|
@ -179,18 +179,14 @@ Root Directory: {root_dir}
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="Aggregate Python files with directory structure")
|
||||||
description="Aggregate Python files with directory structure"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mode",
|
"--mode",
|
||||||
choices=["all", "core", "sdk", "test"],
|
choices=["all", "core", "sdk", "test"],
|
||||||
default="all",
|
default="all",
|
||||||
help="Which directories to process",
|
help="Which directories to process",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--output", default="aggregated_code.txt", help="Output file name")
|
||||||
"--output", default="aggregated_code.txt", help="Output file name"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
script_name = os.path.basename(__file__)
|
script_name = os.path.basename(__file__)
|
||||||
|
@ -16,9 +16,7 @@ load_dotenv(find_dotenv(), override=True)
|
|||||||
# Set up argument parser
|
# Set up argument parser
|
||||||
parser = argparse.ArgumentParser(description="Setup S3 bucket and MongoDB collections")
|
parser = argparse.ArgumentParser(description="Setup S3 bucket and MongoDB collections")
|
||||||
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
parser.add_argument(
|
parser.add_argument("--quiet", action="store_true", help="Only show warning and error logs")
|
||||||
"--quiet", action="store_true", help="Only show warning and error logs"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Configure logging based on command line arguments
|
# Configure logging based on command line arguments
|
||||||
@ -155,14 +153,10 @@ def setup_mongodb():
|
|||||||
type="vectorSearch",
|
type="vectorSearch",
|
||||||
)
|
)
|
||||||
db[CHUNKS_COLLECTION].create_search_index(model=vector_index)
|
db[CHUNKS_COLLECTION].create_search_index(model=vector_index)
|
||||||
LOGGER.info(
|
LOGGER.info("Vector index 'vector_index' created on 'documents_chunk' collection.")
|
||||||
"Vector index 'vector_index' created on 'documents_chunk' collection."
|
|
||||||
)
|
|
||||||
|
|
||||||
except ConnectionFailure:
|
except ConnectionFailure:
|
||||||
LOGGER.error(
|
LOGGER.error("Failed to connect to MongoDB. Check your MongoDB URI and network connection.")
|
||||||
"Failed to connect to MongoDB. Check your MongoDB URI and network connection."
|
|
||||||
)
|
|
||||||
except OperationFailure as e:
|
except OperationFailure as e:
|
||||||
LOGGER.error(f"MongoDB operation failed: {e}")
|
LOGGER.error(f"MongoDB operation failed: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -171,16 +171,12 @@ class AsyncDataBridge:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare multipart form data
|
# Prepare multipart form data
|
||||||
files = {
|
files = {"file": (filename, file_obj, content_type or "application/octet-stream")}
|
||||||
"file": (filename, file_obj, content_type or "application/octet-stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add metadata
|
# Add metadata
|
||||||
data = {"metadata": json.dumps(metadata or {})}
|
data = {"metadata": json.dumps(metadata or {})}
|
||||||
|
|
||||||
response = await self._request(
|
response = await self._request("POST", "ingest/file", data=data, files=files)
|
||||||
"POST", "ingest/file", data=data, files=files
|
|
||||||
)
|
|
||||||
return Document(**response)
|
return Document(**response)
|
||||||
finally:
|
finally:
|
||||||
# Close file if we opened it
|
# Close file if we opened it
|
||||||
|
@ -8,9 +8,7 @@ class Document(BaseModel):
|
|||||||
external_id: str = Field(..., description="Unique document identifier")
|
external_id: str = Field(..., description="Unique document identifier")
|
||||||
content_type: str = Field(..., description="Content type of the document")
|
content_type: str = Field(..., description="Content type of the document")
|
||||||
filename: Optional[str] = Field(None, description="Original filename if available")
|
filename: Optional[str] = Field(None, description="Original filename if available")
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="User-defined metadata")
|
||||||
default_factory=dict, description="User-defined metadata"
|
|
||||||
)
|
|
||||||
storage_info: Dict[str, str] = Field(
|
storage_info: Dict[str, str] = Field(
|
||||||
default_factory=dict, description="Storage-related information"
|
default_factory=dict, description="Storage-related information"
|
||||||
)
|
)
|
||||||
@ -20,18 +18,14 @@ class Document(BaseModel):
|
|||||||
access_control: Dict[str, Any] = Field(
|
access_control: Dict[str, Any] = Field(
|
||||||
default_factory=dict, description="Access control information"
|
default_factory=dict, description="Access control information"
|
||||||
)
|
)
|
||||||
chunk_ids: List[str] = Field(
|
chunk_ids: List[str] = Field(default_factory=list, description="IDs of document chunks")
|
||||||
default_factory=list, description="IDs of document chunks"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class IngestTextRequest(BaseModel):
|
class IngestTextRequest(BaseModel):
|
||||||
"""Request model for text ingestion"""
|
"""Request model for text ingestion"""
|
||||||
|
|
||||||
content: str = Field(..., description="Text content to ingest")
|
content: str = Field(..., description="Text content to ingest")
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Optional metadata")
|
||||||
default_factory=dict, description="Optional metadata"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkResult(BaseModel):
|
class ChunkResult(BaseModel):
|
||||||
@ -41,22 +35,16 @@ class ChunkResult(BaseModel):
|
|||||||
score: float = Field(..., description="Relevance score")
|
score: float = Field(..., description="Relevance score")
|
||||||
document_id: str = Field(..., description="Parent document ID")
|
document_id: str = Field(..., description="Parent document ID")
|
||||||
chunk_number: int = Field(..., description="Chunk sequence number")
|
chunk_number: int = Field(..., description="Chunk sequence number")
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
||||||
default_factory=dict, description="Document metadata"
|
|
||||||
)
|
|
||||||
content_type: str = Field(..., description="Content type")
|
content_type: str = Field(..., description="Content type")
|
||||||
filename: Optional[str] = Field(None, description="Original filename")
|
filename: Optional[str] = Field(None, description="Original filename")
|
||||||
download_url: Optional[str] = Field(
|
download_url: Optional[str] = Field(None, description="URL to download full document")
|
||||||
None, description="URL to download full document"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DocumentContent(BaseModel):
|
class DocumentContent(BaseModel):
|
||||||
"""Represents either a URL or content string"""
|
"""Represents either a URL or content string"""
|
||||||
|
|
||||||
type: Literal["url", "string"] = Field(
|
type: Literal["url", "string"] = Field(..., description="Content type (url or string)")
|
||||||
..., description="Content type (url or string)"
|
|
||||||
)
|
|
||||||
value: str = Field(..., description="The actual content or URL")
|
value: str = Field(..., description="The actual content or URL")
|
||||||
filename: Optional[str] = Field(None, description="Filename when type is url")
|
filename: Optional[str] = Field(None, description="Filename when type is url")
|
||||||
|
|
||||||
@ -74,9 +62,7 @@ class DocumentResult(BaseModel):
|
|||||||
|
|
||||||
score: float = Field(..., description="Relevance score")
|
score: float = Field(..., description="Relevance score")
|
||||||
document_id: str = Field(..., description="Document ID")
|
document_id: str = Field(..., description="Document ID")
|
||||||
metadata: Dict[str, Any] = Field(
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="Document metadata")
|
||||||
default_factory=dict, description="Document metadata"
|
|
||||||
)
|
|
||||||
content: DocumentContent = Field(..., description="Document content or URL")
|
content: DocumentContent = Field(..., description="Document content or URL")
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,9 +88,7 @@ class DataBridge:
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def ingest_text(
|
def ingest_text(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> Document:
|
||||||
self, content: str, metadata: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Document:
|
|
||||||
"""
|
"""
|
||||||
Ingest a text document into DataBridge.
|
Ingest a text document into DataBridge.
|
||||||
|
|
||||||
@ -166,9 +164,7 @@ class DataBridge:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare multipart form data
|
# Prepare multipart form data
|
||||||
files = {
|
files = {"file": (filename, file_obj, content_type or "application/octet-stream")}
|
||||||
"file": (filename, file_obj, content_type or "application/octet-stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add metadata
|
# Add metadata
|
||||||
data = {"metadata": json.dumps(metadata or {})}
|
data = {"metadata": json.dumps(metadata or {})}
|
||||||
@ -312,9 +308,7 @@ class DataBridge:
|
|||||||
next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"})
|
next_page = db.list_documents(skip=10, limit=10, filters={"department": "research"})
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
response = self._request(
|
response = self._request("GET", f"documents?skip={skip}&limit={limit}&filters={filters}")
|
||||||
"GET", f"documents?skip={skip}&limit={limit}&filters={filters}"
|
|
||||||
)
|
|
||||||
return [Document(**doc) for doc in response]
|
return [Document(**doc) for doc in response]
|
||||||
|
|
||||||
def get_document(self, document_id: str) -> Document:
|
def get_document(self, document_id: str) -> Document:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user