From adc0b2dbb8874b1a7ca179a8828627e9de88c5ca Mon Sep 17 00:00:00 2001 From: Adityavardhan Agrawal Date: Tue, 18 Mar 2025 23:27:53 -0400 Subject: [PATCH] Add batch ingestion (#55) --- core/api.py | 117 ++++++++++- core/models/request.py | 10 +- core/services/document_service.py | 44 ++++- core/tests/integration/test_api.py | 299 ++++++++++++++++++++++++++++- sdks/python/databridge/async_.py | 166 ++++++++++++---- sdks/python/databridge/models.py | 3 +- sdks/python/databridge/sync.py | 130 ++++++++++++- shell.py | 114 +++++++++++ 8 files changed, 837 insertions(+), 46 deletions(-) diff --git a/core/api.py b/core/api.py index fb6dd5f..704f724 100644 --- a/core/api.py +++ b/core/api.py @@ -1,16 +1,17 @@ +import asyncio import json from datetime import datetime, UTC, timedelta from pathlib import Path import sys from typing import Any, Dict, List, Optional -from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile +from fastapi import FastAPI, Form, HTTPException, Depends, Header, UploadFile, File from fastapi.middleware.cors import CORSMiddleware import jwt import logging from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from core.completion.openai_completion import OpenAICompletionModel from core.embedding.ollama_embedding_model import OllamaEmbeddingModel -from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest +from core.models.request import RetrieveRequest, CompletionQueryRequest, IngestTextRequest, CreateGraphRequest, BatchIngestResponse from core.models.completion import ChunkSource, CompletionResponse from core.models.documents import Document, DocumentResult, ChunkResult from core.models.graph import Graph @@ -360,6 +361,118 @@ async def ingest_file( raise HTTPException(status_code=403, detail=str(e)) +@app.post("/ingest/files", response_model=BatchIngestResponse) +async def batch_ingest_files( + files: List[UploadFile] = File(...), + metadata: str = Form("{}"), + rules: str = Form("[]"), + use_colpali: Optional[bool] = Form(None), + parallel: bool = Form(True), + auth: AuthContext = Depends(verify_token), +) -> BatchIngestResponse: + """ + Batch ingest multiple files. + + Args: + files: List of files to ingest + metadata: JSON string of metadata (either a single dict or list of dicts) + rules: JSON string of rules list. Can be either: + - A single list of rules to apply to all files + - A list of rule lists, one per file + use_colpali: Whether to use ColPali-style embedding + parallel: Whether to process files in parallel + auth: Authentication context + + Returns: + BatchIngestResponse containing: + - documents: List of successfully ingested documents + - errors: List of errors encountered during ingestion + """ + if not files: + raise HTTPException( + status_code=400, + detail="No files provided for batch ingestion" + ) + + try: + metadata_value = json.loads(metadata) + rules_list = json.loads(rules) + except json.JSONDecodeError as e: + raise HTTPException(status_code=400, detail=f"Invalid JSON: {str(e)}") + + # Validate metadata if it's a list + if isinstance(metadata_value, list) and len(metadata_value) != len(files): + raise HTTPException( + status_code=400, + detail=f"Number of metadata items ({len(metadata_value)}) must match number of files ({len(files)})" + ) + + # Validate rules if it's a list of lists + if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list): + if len(rules_list) != len(files): + raise HTTPException( + status_code=400, + detail=f"Number of rule lists ({len(rules_list)}) must match number of files ({len(files)})" + ) + + documents = [] + errors = [] + + async with telemetry.track_operation( + operation_type="batch_ingest", + user_id=auth.entity_id, + metadata={ + "file_count": len(files), + "metadata_type": "list" if isinstance(metadata_value, list) else "single", + "rules_type": "per_file" if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else "shared", + }, + ): + if parallel: + tasks = [] + for i, file in enumerate(files): + metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value + file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list + task = document_service.ingest_file( + file=file, + metadata=metadata_item, + auth=auth, + rules=file_rules, + use_colpali=use_colpali + ) + tasks.append(task) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, Exception): + errors.append({ + "filename": files[i].filename, + "error": str(result) + }) + else: + documents.append(result) + else: + for i, file in enumerate(files): + try: + metadata_item = metadata_value[i] if isinstance(metadata_value, list) else metadata_value + file_rules = rules_list[i] if isinstance(rules_list, list) and rules_list and isinstance(rules_list[0], list) else rules_list + doc = await document_service.ingest_file( + file=file, + metadata=metadata_item, + auth=auth, + rules=file_rules, + use_colpali=use_colpali + ) + documents.append(doc) + except Exception as e: + errors.append({ + "filename": file.filename, + "error": str(e) + }) + + return BatchIngestResponse(documents=documents, errors=errors) + + @app.post("/retrieve/chunks", response_model=List[ChunkResult]) async def retrieve_chunks(request: RetrieveRequest, auth: AuthContext = Depends(verify_token)): """Retrieve relevant chunks.""" diff --git a/core/models/request.py b/core/models/request.py index c05f1f7..2cde0f9 100644 --- a/core/models/request.py +++ b/core/models/request.py @@ -1,6 +1,8 @@ -from typing import Dict, Any, Optional, List +from typing import Dict, List, Optional, Any from pydantic import BaseModel, Field +from core.models.documents import Document + class RetrieveRequest(BaseModel): """Base retrieve request model""" @@ -49,3 +51,9 @@ class CreateGraphRequest(BaseModel): documents: Optional[List[str]] = Field( None, description="Optional list of specific document IDs to include" ) + + +class BatchIngestResponse(BaseModel): + """Response model for batch ingestion""" + documents: List[Document] + errors: List[Dict[str, str]] diff --git a/core/services/document_service.py b/core/services/document_service.py index 0315be5..5e52e9c 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -405,6 +405,24 @@ class DocumentService: # Read file content file_content = await file.read() file_type = filetype.guess(file_content) + + # Set default mime type for cases where filetype.guess returns None + mime_type = "" + if file_type is not None: + mime_type = file_type.mime + elif file.filename: + # Try to determine by file extension as fallback + import mimetypes + guessed_type = mimetypes.guess_type(file.filename)[0] + if guessed_type: + mime_type = guessed_type + else: + # Default for text files + mime_type = "text/plain" + else: + mime_type = "application/octet-stream" # Generic binary data + + logger.info(f"Determined MIME type: {mime_type} for file {file.filename}") # Parse file to text first additional_metadata, text = await self.parser.parse_file_to_text( @@ -423,7 +441,7 @@ class DocumentService: # Create document record doc = Document( - content_type=file_type.mime or "", + content_type=mime_type, filename=file.filename, metadata=metadata, owner={"type": auth.entity_type, "id": auth.entity_id}, @@ -495,8 +513,19 @@ class DocumentService: def _create_chunks_multivector( self, file_type, file_content_base64: str, file_content: bytes, chunks: List[Chunk] ): - logger.info(f"Creating chunks for multivector embedding for file type {file_type.mime}") - match file_type.mime: + # Handle the case where file_type is None + mime_type = file_type.mime if file_type is not None else "text/plain" + logger.info(f"Creating chunks for multivector embedding for file type {mime_type}") + + # If file_type is None, treat it as a text file + if file_type is None: + logger.info("File type is None, treating as text") + return [ + Chunk(content=chunk.content, metadata=(chunk.metadata | {"is_image": False})) + for chunk in chunks + ] + + match mime_type: case file_type if file_type in IMAGE: return [Chunk(content=file_content_base64, metadata={"is_image": True})] case "application/pdf": @@ -1017,6 +1046,15 @@ class DocumentService: file_type = filetype.guess(file_content) if file_type: doc.content_type = file_type.mime + else: + # If filetype.guess failed, try to determine from filename + import mimetypes + guessed_type = mimetypes.guess_type(file.filename)[0] + if guessed_type: + doc.content_type = guessed_type + else: + # Default fallback + doc.content_type = "text/plain" if file.filename.endswith('.txt') else "application/octet-stream" # Update filename doc.filename = file.filename diff --git a/core/tests/integration/test_api.py b/core/tests/integration/test_api.py index 2830c36..b42ebd7 100644 --- a/core/tests/integration/test_api.py +++ b/core/tests/integration/test_api.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, Dict from httpx import AsyncClient from fastapi import FastAPI from httpx import ASGITransport -from core.api import app, get_settings +from core.api import get_settings import filetype import logging from sqlalchemy.ext.asyncio import create_async_engine @@ -1426,6 +1426,303 @@ async def test_query_with_graph(client: AsyncClient): assert response_no_graph.status_code == 200 +@pytest.mark.asyncio +async def test_batch_ingest_with_shared_metadata( + client: AsyncClient +): + """Test batch ingestion with shared metadata for all files.""" + headers = create_auth_header() + # Create test files + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + # Shared metadata for all files + metadata = {"category": "test", "batch": "shared"} + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps(metadata), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 2 + assert len(result["errors"]) == 0 + + # Verify all documents got the same metadata + for doc in result["documents"]: + assert doc["metadata"]["category"] == "test" + assert doc["metadata"]["batch"] == "shared" + + +@pytest.mark.asyncio +async def test_batch_ingest_with_individual_metadata( + client: AsyncClient +): + """Test batch ingestion with individual metadata per file.""" + headers = create_auth_header() + # Create test files + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + # Individual metadata + metadata = [ + {"category": "test1", "batch": "individual"}, + {"category": "test2", "batch": "individual"}, + ] + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps(metadata), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 2 + assert len(result["errors"]) == 0 + + # Verify each document got its correct metadata + assert result["documents"][0]["metadata"]["category"] == "test1" + assert result["documents"][1]["metadata"]["category"] == "test2" + + +@pytest.mark.asyncio +async def test_batch_ingest_metadata_validation( + client: AsyncClient +): + """Test validation when metadata list length doesn't match files.""" + headers = create_auth_header() + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + # Metadata list with wrong length + metadata = [ + {"category": "test1"}, + {"category": "test2"}, + {"category": "test3"}, # Extra metadata + ] + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps(metadata), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 400 + assert "must match number of files" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_batch_ingest_sequential( + client: AsyncClient +): + """Test sequential batch ingestion.""" + headers = create_auth_header() + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + metadata = {"category": "test"} + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps(metadata), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "false", # Process sequentially + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 2 + assert len(result["errors"]) == 0 + + +@pytest.mark.asyncio +async def test_batch_ingest_with_rules( + client: AsyncClient +): + """Test batch ingestion with rules applied.""" + headers = create_auth_header() + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + # Test shared rules for all files + shared_rules = [{"type": "natural_language", "prompt": "Extract keywords"}] + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": json.dumps(shared_rules), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 2 + assert len(result["errors"]) == 0 + + # Test per-file rules + per_file_rules = [ + [{"type": "natural_language", "prompt": "Extract keywords"}], # Rules for first file + [{"type": "metadata_extraction", "schema": {"title": "string"}}], # Rules for second file + ] + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": json.dumps(per_file_rules), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 2 + assert len(result["errors"]) == 0 + + +@pytest.mark.asyncio +async def test_batch_ingest_rules_validation( + client: AsyncClient +): + """Test validation of rules format and length.""" + headers = create_auth_header() + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ] + + # Test invalid rules format + invalid_rules = "not a list" + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": invalid_rules, + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 400 + assert "Invalid JSON" in response.json()["detail"] + + # Test per-file rules with wrong length + per_file_rules = [ + [{"type": "natural_language", "prompt": "Extract keywords"}], # Only one set of rules + ] + + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": json.dumps(per_file_rules), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 400 + assert "must match number of files" in response.json()["detail"] + + +@pytest.mark.asyncio +async def test_batch_ingest_sequential_vs_parallel( + client: AsyncClient +): + """Test both sequential and parallel batch ingestion.""" + headers = create_auth_header() + files = [ + ("files", ("test1.txt", b"Test content 1")), + ("files", ("test2.txt", b"Test content 2")), + ("files", ("test3.txt", b"Test content 3")), + ] + + # Test parallel processing + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "true", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 3 + assert len(result["errors"]) == 0 + + # Test sequential processing + response = await client.post( + "/ingest/files", + files=files, + data={ + "metadata": json.dumps({}), + "rules": json.dumps([]), + "use_colpali": "true", + "parallel": "false", + }, + headers=headers, + ) + + assert response.status_code == 200 + result = response.json() + assert len(result["documents"]) == 3 + assert len(result["errors"]) == 0 + + @pytest.mark.asyncio async def test_cross_document_query_with_graph(client: AsyncClient): """Test cross-document information retrieval using knowledge graph.""" diff --git a/sdks/python/databridge/async_.py b/sdks/python/databridge/async_.py index 7f43e1c..7f879c4 100644 --- a/sdks/python/databridge/async_.py +++ b/sdks/python/databridge/async_.py @@ -1,5 +1,6 @@ -from io import BytesIO +from io import BytesIO, IOBase import json +import logging from pathlib import Path from typing import Dict, Any, List, Optional, Union, BinaryIO from urllib.parse import urlparse @@ -20,6 +21,8 @@ from .models import ( ) from .rules import Rule +logger = logging.getLogger(__name__) + # Type alias for rules RuleOrDict = Union[Rule, Dict[str, Any]] @@ -222,41 +225,7 @@ class AsyncDataBridge: rules: Optional[List[RuleOrDict]] = None, use_colpali: bool = True, ) -> Document: - """ - Ingest a file document into DataBridge. - - Args: - file: File to ingest (path string, bytes, file object, or Path) - filename: Name of the file - metadata: Optional metadata dictionary - rules: Optional list of rules to apply during ingestion. Can be: - - MetadataExtractionRule: Extract metadata using a schema - - NaturalLanguageRule: Transform content using natural language - use_colpali: Whether to use ColPali-style embedding model to ingest the file (slower, but significantly better retrieval accuracy for text and images) - Returns: - Document: Metadata of the ingested document - - Example: - ```python - from databridge.rules import MetadataExtractionRule, NaturalLanguageRule - from pydantic import BaseModel - - class DocumentInfo(BaseModel): - title: str - author: str - department: str - - doc = await db.ingest_file( - "document.pdf", - filename="document.pdf", - metadata={"category": "research"}, - rules=[ - MetadataExtractionRule(schema=DocumentInfo), - NaturalLanguageRule(prompt="Extract key points only") - ] - ) - ``` - """ + """Ingest a file document into DataBridge.""" # Handle different file input types if isinstance(file, (str, Path)): file_path = Path(file) @@ -290,6 +259,131 @@ class AsyncDataBridge: if isinstance(file, (str, Path)): file_obj.close() + async def ingest_files( + self, + files: List[Union[str, bytes, BinaryIO, Path]], + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + rules: Optional[List[RuleOrDict]] = None, + use_colpali: bool = True, + parallel: bool = True, + ) -> List[Document]: + """ + Ingest multiple files into DataBridge. + + Args: + files: List of files to ingest (path strings, bytes, file objects, or Paths) + metadata: Optional metadata (single dict for all files or list of dicts) + rules: Optional list of rules to apply + use_colpali: Whether to use ColPali-style embedding + parallel: Whether to process files in parallel + + Returns: + List[Document]: List of successfully ingested documents + + Raises: + ValueError: If metadata list length doesn't match files length + """ + # Convert files to format expected by API + file_objects = [] + for file in files: + if isinstance(file, (str, Path)): + path = Path(file) + file_objects.append(("files", (path.name, open(path, "rb")))) + elif isinstance(file, bytes): + file_objects.append(("files", ("file.bin", file))) + else: + file_objects.append(("files", (getattr(file, "name", "file.bin"), file))) + + try: + # Prepare request data + # Convert rules appropriately based on whether it's a flat list or list of lists + if rules: + if all(isinstance(r, list) for r in rules): + # List of lists - per-file rules + converted_rules = [[self._convert_rule(r) for r in rule_list] for rule_list in rules] + else: + # Flat list - shared rules for all files + converted_rules = [self._convert_rule(r) for r in rules] + else: + converted_rules = [] + + data = { + "metadata": json.dumps(metadata or {}), + "rules": json.dumps(converted_rules), + "use_colpali": str(use_colpali).lower() if use_colpali is not None else None, + "parallel": str(parallel).lower(), + } + + response = await self._request("POST", "ingest/files", data=data, files=file_objects) + + if response.get("errors"): + # Log errors but don't raise exception + for error in response["errors"]: + logger.error(f"Failed to ingest {error['filename']}: {error['error']}") + + docs = [Document(**doc) for doc in response["documents"]] + for doc in docs: + doc._client = self + return docs + finally: + # Clean up file objects + for _, (_, file_obj) in file_objects: + if isinstance(file_obj, (IOBase, BytesIO)) and not file_obj.closed: + file_obj.close() + + async def ingest_directory( + self, + directory: Union[str, Path], + recursive: bool = False, + pattern: str = "*", + metadata: Optional[Dict[str, Any]] = None, + rules: Optional[List[RuleOrDict]] = None, + use_colpali: bool = True, + parallel: bool = True, + ) -> List[Document]: + """ + Ingest all files in a directory into DataBridge. + + Args: + directory: Path to directory containing files to ingest + recursive: Whether to recursively process subdirectories + pattern: Optional glob pattern to filter files (e.g. "*.pdf") + metadata: Optional metadata dictionary to apply to all files + rules: Optional list of rules to apply + use_colpali: Whether to use ColPali-style embedding + parallel: Whether to process files in parallel + + Returns: + List[Document]: List of ingested documents + + Raises: + ValueError: If directory not found + """ + directory = Path(directory) + if not directory.is_dir(): + raise ValueError(f"Directory not found: {directory}") + + # Collect all files matching pattern + if recursive: + files = list(directory.rglob(pattern)) + else: + files = list(directory.glob(pattern)) + + # Filter out directories + files = [f for f in files if f.is_file()] + + if not files: + return [] + + # Use ingest_files with collected paths + return await self.ingest_files( + files=files, + metadata=metadata, + rules=rules, + use_colpali=use_colpali, + parallel=parallel + ) + async def retrieve_chunks( self, query: str, diff --git a/sdks/python/databridge/models.py b/sdks/python/databridge/models.py index 156d348..bc2c732 100644 --- a/sdks/python/databridge/models.py +++ b/sdks/python/databridge/models.py @@ -1,5 +1,4 @@ -from typing import Dict, Any, List, Literal, Optional, Union -from io import BinaryIO +from typing import Dict, Any, List, Literal, Optional, Union, BinaryIO from pathlib import Path from datetime import datetime from pydantic import BaseModel, Field, field_validator diff --git a/sdks/python/databridge/sync.py b/sdks/python/databridge/sync.py index 01ae77b..2f757b3 100644 --- a/sdks/python/databridge/sync.py +++ b/sdks/python/databridge/sync.py @@ -1,9 +1,10 @@ import base64 -from io import BytesIO +from io import BytesIO, IOBase import io from PIL.Image import Image as PILImage from PIL import Image import json +import logging from pathlib import Path from typing import Dict, Any, List, Optional, Union, BinaryIO from urllib.parse import urlparse @@ -23,6 +24,8 @@ from .models import ( ) from .rules import Rule +logger = logging.getLogger(__name__) + # Type alias for rules RuleOrDict = Union[Rule, Dict[str, Any]] @@ -294,6 +297,131 @@ class DataBridge: if isinstance(file, (str, Path)): file_obj.close() + def ingest_files( + self, + files: List[Union[str, bytes, BinaryIO, Path]], + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + rules: Optional[List[RuleOrDict]] = None, + use_colpali: bool = True, + parallel: bool = True, + ) -> List[Document]: + """ + Ingest multiple files into DataBridge. + + Args: + files: List of files to ingest (path strings, bytes, file objects, or Paths) + metadata: Optional metadata (single dict for all files or list of dicts) + rules: Optional list of rules to apply + use_colpali: Whether to use ColPali-style embedding + parallel: Whether to process files in parallel + + Returns: + List[Document]: List of successfully ingested documents + + Raises: + ValueError: If metadata list length doesn't match files length + """ + # Convert files to format expected by API + file_objects = [] + for file in files: + if isinstance(file, (str, Path)): + path = Path(file) + file_objects.append(("files", (path.name, open(path, "rb")))) + elif isinstance(file, bytes): + file_objects.append(("files", ("file.bin", file))) + else: + file_objects.append(("files", (getattr(file, "name", "file.bin"), file))) + + try: + # Prepare request data + # Convert rules appropriately based on whether it's a flat list or list of lists + if rules: + if all(isinstance(r, list) for r in rules): + # List of lists - per-file rules + converted_rules = [[self._convert_rule(r) for r in rule_list] for rule_list in rules] + else: + # Flat list - shared rules for all files + converted_rules = [self._convert_rule(r) for r in rules] + else: + converted_rules = [] + + data = { + "metadata": json.dumps(metadata or {}), + "rules": json.dumps(converted_rules), + "use_colpali": str(use_colpali).lower() if use_colpali is not None else None, + "parallel": str(parallel).lower(), + } + + response = self._request("POST", "ingest/files", data=data, files=file_objects) + + if response.get("errors"): + # Log errors but don't raise exception + for error in response["errors"]: + logger.error(f"Failed to ingest {error['filename']}: {error['error']}") + + docs = [Document(**doc) for doc in response["documents"]] + for doc in docs: + doc._client = self + return docs + finally: + # Clean up file objects + for _, (_, file_obj) in file_objects: + if isinstance(file_obj, (IOBase, BytesIO)) and not file_obj.closed: + file_obj.close() + + def ingest_directory( + self, + directory: Union[str, Path], + recursive: bool = False, + pattern: str = "*", + metadata: Optional[Dict[str, Any]] = None, + rules: Optional[List[RuleOrDict]] = None, + use_colpali: bool = True, + parallel: bool = True, + ) -> List[Document]: + """ + Ingest all files in a directory into DataBridge. + + Args: + directory: Path to directory containing files to ingest + recursive: Whether to recursively process subdirectories + pattern: Optional glob pattern to filter files (e.g. "*.pdf") + metadata: Optional metadata dictionary to apply to all files + rules: Optional list of rules to apply + use_colpali: Whether to use ColPali-style embedding + parallel: Whether to process files in parallel + + Returns: + List[Document]: List of ingested documents + + Raises: + ValueError: If directory not found + """ + directory = Path(directory) + if not directory.is_dir(): + raise ValueError(f"Directory not found: {directory}") + + # Collect all files matching pattern + if recursive: + files = list(directory.rglob(pattern)) + else: + files = list(directory.glob(pattern)) + + # Filter out directories + files = [f for f in files if f.is_file()] + + if not files: + return [] + + # Use ingest_files with collected paths + return self.ingest_files( + files=files, + metadata=metadata, + rules=rules, + use_colpali=use_colpali, + parallel=parallel + ) + def retrieve_chunks( self, query: str, diff --git a/shell.py b/shell.py index 271d2f5..b7961b7 100644 --- a/shell.py +++ b/shell.py @@ -137,6 +137,120 @@ class DB: ) return doc if as_object else doc.model_dump() + def ingest_files( + self, + files: List[str], + metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, + rules: Optional[List[Dict[str, Any]]] = None, + use_colpali: bool = True, + parallel: bool = True, + as_objects: bool = False, + ) -> List[Union[dict, 'Document']]: + """ + Batch ingest multiple files into DataBridge. + + Args: + files: List of file paths to ingest + metadata: Optional metadata (single dict for all files or list of dicts) + rules: Optional list of rules. Can be either: + - A single list of rules to apply to all files + - A list of rule lists, one per file + use_colpali: Whether to use ColPali-style embedding model + parallel: Whether to process files in parallel + as_objects: If True, returns Document objects with update methods, otherwise returns dicts + + Returns: + List of document metadata (dicts or Document objects) + + Example: + ```python + # Ingest multiple files with shared metadata + docs = db.ingest_files( + ["doc1.pdf", "doc2.pdf"], + metadata={"category": "research"}, + parallel=True + ) + + # Ingest files with individual metadata + docs = db.ingest_files( + ["doc1.pdf", "doc2.pdf"], + metadata=[ + {"category": "research", "author": "Alice"}, + {"category": "reports", "author": "Bob"} + ] + ) + ``` + """ + # Convert file paths to Path objects + file_paths = [Path(f) for f in files] + + # Ingest files using the client + docs = self._client.ingest_files( + files=file_paths, + metadata=metadata, + rules=rules, + use_colpali=use_colpali, + parallel=parallel, + ) + + return docs if as_objects else [doc.model_dump() for doc in docs] + + def ingest_directory( + self, + directory: str, + recursive: bool = False, + pattern: str = "*", + metadata: Optional[Dict[str, Any]] = None, + rules: Optional[List[Dict[str, Any]]] = None, + use_colpali: bool = True, + parallel: bool = True, + as_objects: bool = False, + ) -> List[Union[dict, 'Document']]: + """ + Ingest all files in a directory into DataBridge. + + Args: + directory: Path to directory containing files to ingest + recursive: Whether to recursively process subdirectories + pattern: Optional glob pattern to filter files (e.g. "*.pdf") + metadata: Optional metadata dictionary to apply to all files + rules: Optional list of rules. Can be either: + - A single list of rules to apply to all files + - A list of rule lists, one per file + use_colpali: Whether to use ColPali-style embedding model + parallel: Whether to process files in parallel + as_objects: If True, returns Document objects with update methods, otherwise returns dicts + + Returns: + List of document metadata (dicts or Document objects) + + Example: + ```python + # Ingest all PDFs in a directory and its subdirectories + docs = db.ingest_directory( + "data/documents", + recursive=True, + metadata={"category": "research"}, + pattern="*.pdf" + ) + ``` + """ + # Convert directory to Path + dir_path = Path(directory) + + # Ingest directory using the client + docs = self._client.ingest_directory( + directory=dir_path, + recursive=recursive, + pattern=pattern, + metadata=metadata, + rules=rules, + use_colpali=use_colpali, + parallel=parallel, + ) + + return docs if as_objects else [doc.model_dump() for doc in docs] + def retrieve_chunks( self, query: str,