This commit is contained in:
Adityavardhan Agrawal 2024-12-04 20:26:14 -05:00
parent 46a7c45b4e
commit 251e38828a
10 changed files with 129 additions and 282 deletions

View File

@ -1,7 +0,0 @@
JWT_SECRET_KEY=test-secret
MONGODB_URI=mongodb://localhost:27017/databridge_test
DATABRIDGE_TEST_URI=databridge://test_dev:your_test_token@localhost:8000
DATABRIDGE_HOST=localhost:8000
OPENAI_API_KEY=your_test_key
AWS_ACCESS_KEY=test
AWS_SECRET_ACCESS_KEY=test

View File

@ -1,13 +1,12 @@
import json
from datetime import datetime, UTC
from typing import List, Optional, Union, Dict, Set
from typing import List, Union
from fastapi import (
FastAPI,
Form,
HTTPException,
Depends,
Header,
APIRouter,
UploadFile
)
from fastapi.middleware.cors import CORSMiddleware
@ -28,7 +27,6 @@ from core.vector_store.mongo_vector_store import MongoDBAtlasVectorStore
from core.storage.s3_storage import S3Storage
from core.parser.unstructured_parser import UnstructuredAPIParser
from core.embedding_model.openai_embedding_model import OpenAIEmbeddingModel
from core.services.uri_service import get_uri_service
# Initialize FastAPI app
@ -160,8 +158,6 @@ async def query_documents(
"""Query documents with specified return type."""
try:
return await document_service.query(request, auth)
# except AttributeError as e:
# raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Query failed: {str(e)}")
raise HTTPException(status_code=400, detail=str(e))
@ -196,61 +192,3 @@ async def get_document(
raise e # Return the HTTPException as is
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
auth_router = APIRouter(prefix="/auth", tags=["auth"])
@auth_router.post("/developer-token")
async def create_developer_token(
dev_id: str,
app_id: Optional[str] = None,
expiry_days: int = 30,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a developer access URI."""
# Verify requesting user has admin permissions
if "admin" not in auth.permissions:
raise HTTPException(
status_code=403,
detail="Admin permissions required"
)
uri_service = get_uri_service()
uri = uri_service.create_developer_uri(
dev_id=dev_id,
app_id=app_id,
expiry_days=expiry_days,
permissions=permissions
)
return {"uri": uri}
@auth_router.post("/user-token")
async def create_user_token(
user_id: str,
expiry_days: int = 30,
permissions: Optional[Set[str]] = None,
auth: AuthContext = Depends(verify_token)
) -> Dict[str, str]:
"""Create a user access URI."""
# Verify requesting user has admin permissions
if "admin" not in auth.permissions:
raise HTTPException(
status_code=403,
detail="Admin permissions required"
)
uri_service = get_uri_service()
uri = uri_service.create_user_uri(
user_id=user_id,
expiry_days=expiry_days,
permissions=permissions
)
return {"uri": uri}
# Add to your main FastAPI app
app.include_router(auth_router)

View File

@ -1,4 +1,4 @@
from typing import Optional, Dict, Any
from typing import Dict, Any
from pydantic import Field
from pydantic_settings import BaseSettings
from functools import lru_cache
@ -23,12 +23,15 @@ class Settings(BaseSettings):
UNSTRUCTURED_API_KEY: str = Field(..., env="UNSTRUCTURED_API_KEY")
# Optional API keys for alternative models
ANTHROPIC_API_KEY: Optional[str] = Field(None, env="ANTHROPIC_API_KEY")
COHERE_API_KEY: Optional[str] = Field(None, env="COHERE_API_KEY")
VOYAGE_API_KEY: Optional[str] = Field(None, env="VOYAGE_API_KEY")
ANTHROPIC_API_KEY: str | None = Field(None, env="ANTHROPIC_API_KEY")
COHERE_API_KEY: str | None = Field(None, env="COHERE_API_KEY")
VOYAGE_API_KEY: str | None = Field(None, env="VOYAGE_API_KEY")
# Model settings
EMBEDDING_MODEL: str = Field("text-embedding-3-small", env="EMBEDDING_MODEL")
EMBEDDING_MODEL: str = Field(
"text-embedding-3-small",
env="EMBEDDING_MODEL"
)
# Document processing settings
CHUNK_SIZE: int = Field(1000, env="CHUNK_SIZE")

View File

@ -1,6 +1,6 @@
from typing import List, Optional, Dict, Any
import logging
from datetime import UTC, datetime
import logging
from typing import Dict, List, Optional, Any
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import ReturnDocument
@ -82,7 +82,6 @@ class MongoDatabase(BaseDatabase):
except PyMongoError as e:
logger.error(f"Error retrieving document metadata: {str(e)}")
raise e
# return None
async def get_documents(
self,
@ -194,7 +193,8 @@ class MongoDatabase(BaseDatabase):
# Check owner access
owner = doc.get("owner", {})
if (owner.get("type") == auth.entity_type and owner.get("id") == auth.entity_id):
if (owner.get("type") == auth.entity_type and
owner.get("id") == auth.entity_id):
return True
# Check permission-specific access

View File

@ -1,8 +1,7 @@
import base64
from collections import defaultdict
from typing import Any, Dict, List, Union, Optional
from typing import Any, Dict, List, Union
import logging
from fastapi import UploadFile
import base64
from core.database.base_database import BaseDatabase
from core.embedding_model.base_embedding_model import BaseEmbeddingModel

View File

@ -1,15 +1,14 @@
# core/auth/uri_service.py
from typing import Optional, Set
from datetime import datetime, timedelta, UTC
import jwt
from functools import lru_cache
from ..models.auth import EntityType, AuthContext
from ..config import Settings
# Currently unused. Will be used for uri generation.
class URIService:
"""Service for creating and validating DataBridge URIs with authentication tokens"""
"""Service for creating and validating DataBridge URIs with authentication tokens."""
def __init__(self, settings: Settings):
self.secret_key = settings.JWT_SECRET_KEY
@ -112,12 +111,3 @@ class URIService:
except (jwt.InvalidTokenError, IndexError, ValueError):
return None
@lru_cache()
def get_uri_service(settings: Settings = None) -> URIService:
"""Get cached URIService instance."""
if settings is None:
from ..config import get_settings
settings = get_settings()
return URIService(settings)

View File

@ -1,37 +0,0 @@
from pathlib import Path
import sys
import pytest
from typing import Generator
import os
from dotenv import load_dotenv
root_dir = Path(__file__).parent.parent.parent
sdk_path = str(root_dir / "sdks" / "python")
core_path = str(root_dir)
sys.path.extend([sdk_path, core_path])
from core.config import get_settings
from databridge import DataBridge
# Load test environment variables
load_dotenv(".env.test")
@pytest.fixture(scope="session")
def settings():
"""Get test settings"""
return get_settings()
@pytest.fixture
async def db() -> Generator[DataBridge, None, None]:
"""DataBridge client fixture"""
uri = os.getenv("DATABRIDGE_TEST_URI")
if not uri:
raise ValueError("DATABRIDGE_TEST_URI not set")
client = DataBridge(uri)
try:
yield client
finally:
await client.close()

View File

@ -44,14 +44,7 @@ def setup_test_environment(event_loop):
# Create a small test PDF if it doesn't exist
pdf_file = TEST_DATA_DIR / "test.pdf"
if not pdf_file.exists():
# Create a minimal PDF for testing
try:
from reportlab.pdfgen import canvas
c = canvas.Canvas(str(pdf_file))
c.drawString(100, 750, "Test PDF Document")
c.save()
except ImportError:
pytest.skip("reportlab not installed, skipping PDF tests")
pytest.skip("PDF file not available, skipping PDF tests")
def create_test_token(
@ -98,14 +91,20 @@ async def test_app(event_loop: asyncio.AbstractEventLoop) -> FastAPI:
@pytest.fixture
async def client(test_app: FastAPI, event_loop: asyncio.AbstractEventLoop) -> AsyncGenerator[AsyncClient, None]:
async def client(
test_app: FastAPI,
event_loop: asyncio.AbstractEventLoop
) -> AsyncGenerator[AsyncClient, None]:
"""Create async test client"""
async with AsyncClient(app=test_app, base_url="http://test") as client:
yield client
@pytest.mark.asyncio
async def test_ingest_text_document(client: AsyncClient, content: str = "Test content for document ingestion"):
async def test_ingest_text_document(
client: AsyncClient,
content: str = "Test content for document ingestion"
):
"""Test ingesting a text document"""
headers = create_auth_header()
@ -264,7 +263,10 @@ async def test_auth_insufficient_permissions(client: AsyncClient):
async def test_query_chunks(client: AsyncClient):
"""Test querying document chunks"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client, content="The quick brown fox jumps over the lazy dog")
doc_id = await test_ingest_text_document(
client,
content="The quick brown fox jumps over the lazy dog"
)
headers = create_auth_header()
# Sleep to allow time for document to be indexed
@ -293,7 +295,15 @@ async def test_query_chunks(client: AsyncClient):
async def test_query_documents(client: AsyncClient):
"""Test querying for full documents"""
# First ingest a document to query
doc_id = await test_ingest_text_document(client, content="Headaches can significantly impact daily life and wellbeing. Common triggers include stress, dehydration, and poor sleep habits. While over-the-counter pain relievers may provide temporary relief, it's important to identify and address the root causes. Maintaining good health through proper nutrition, regular exercise, and stress management can help prevent chronic headaches.")
content = (
"Headaches can significantly impact daily life and wellbeing. "
"Common triggers include stress, dehydration, and poor sleep habits. "
"While over-the-counter pain relievers may provide temporary relief, "
"it's important to identify and address the root causes. "
"Maintaining good health through proper nutrition, regular exercise, "
"and stress management can help prevent chronic headaches."
)
doc_id = await test_ingest_text_document(client, content=content)
headers = create_auth_header()
response = await client.post(

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from core.models.auth import AuthContext
from core.models.documents import DocumentChunk

View File

@ -1,12 +1,10 @@
import json
from typing import List, Dict, Any, Optional
from typing import List, Optional
import logging
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo.errors import PyMongoError
from .base_vector_store import BaseVectorStore
from core.models.documents import DocumentChunk
from core.models.auth import AuthContext, EntityType
logger = logging.getLogger(__name__)
@ -55,13 +53,18 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
doc = chunk.model_dump()
# Ensure we have required fields
if not doc.get('embedding'):
logger.error(f"Missing embedding for chunk {chunk.document_id}-{chunk.chunk_number}")
logger.error(
f"Missing embedding for chunk "
f"{chunk.document_id}-{chunk.chunk_number}"
)
continue
documents.append(doc)
if documents:
# Use ordered=False to continue even if some inserts fail
result = await self.collection.insert_many(documents, ordered=False)
result = await self.collection.insert_many(
documents, ordered=False
)
return len(result.inserted_ids) > 0, result
return False, None
@ -77,7 +80,10 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
) -> List[DocumentChunk]:
"""Find similar chunks using MongoDB Atlas Vector Search."""
try:
logger.debug(f"Searching in database {self.db.name} collection {self.collection.name}")
logger.debug(
f"Searching in database {self.db.name} "
f"collection {self.collection.name}"
)
logger.debug(f"Query vector looks like: {query_embedding}")
logger.debug(f"Doc IDs: {doc_ids}")
logger.debug(f"K is: {k}")
@ -90,7 +96,7 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
"index": self.index_name,
"path": "embedding",
"queryVector": query_embedding,
"numCandidates": k*40, # Get more candidates for better results
"numCandidates": k*40, # Get more candidates
"limit": k,
"filter": {"document_id": {"$in": doc_ids}} if doc_ids else {}
}
@ -128,58 +134,3 @@ class MongoDBAtlasVectorStore(BaseVectorStore):
logger.error(f"MongoDB error: {e._message}")
logger.error(f"Error querying similar chunks: {str(e)}")
raise e
def _build_access_filter(self, auth: AuthContext) -> Dict[str, Any]:
"""Build MongoDB filter for access control."""
base_filter = {
"$or": [
{"owner.id": auth.entity_id},
{"access_control.readers": auth.entity_id},
{"access_control.writers": auth.entity_id},
{"access_control.admins": auth.entity_id}
]
}
if auth.entity_type == EntityType.DEVELOPER and auth.app_id:
# Add app-specific access for developers
base_filter["$or"].append(
{"access_control.app_access": auth.app_id}
)
return base_filter
def _build_metadata_filter(self, filters: Dict[str, Any]) -> Dict[str, Any]:
"""Build MongoDB filter for metadata fields."""
if not filters:
return {}
return filters
metadata_filter = {}
for key, value in filters.items():
metadata_key = f"metadata.{key}"
if isinstance(value, (str, int, float, bool)):
metadata_filter[metadata_key] = value
elif isinstance(value, list):
metadata_filter[metadata_key] = {"$in": value}
elif isinstance(value, dict):
valid_ops = {
"gt": "$gt",
"gte": "$gte",
"lt": "$lt",
"lte": "$lte",
"ne": "$ne"
}
mongo_ops = {}
for op, val in value.items():
if op not in valid_ops:
logger.warning(f"Skipping invalid operator: {op}")
continue
mongo_ops[valid_ops[op]] = val
if mongo_ops:
metadata_filter[metadata_key] = mongo_ops
else:
logger.warning(f"Skipping unsupported filter value type for key {key}")
return metadata_filter