diff --git a/core/completion/ollama_completion.py b/core/completion/ollama_completion.py index 2a79923..84f39a9 100644 --- a/core/completion/ollama_completion.py +++ b/core/completion/ollama_completion.py @@ -5,6 +5,7 @@ from core.completion.base_completion import ( ) from ollama import AsyncClient +BASE_64_PREFIX = "data:image/png;base64," class OllamaCompletionModel(BaseCompletionModel): """Ollama completion model implementation""" @@ -16,18 +17,33 @@ class OllamaCompletionModel(BaseCompletionModel): async def complete(self, request: CompletionRequest) -> CompletionResponse: """Generate completion using Ollama API""" # Construct prompt with context - context = "\n\n".join(request.context_chunks) + images, context = [], [] + for chunk in request.context_chunks: + if chunk.startswith(BASE_64_PREFIX): + image_b64 = chunk.split(',', 1)[1] + images.append(image_b64) + else: + context.append(chunk) + context = "\n\n".join(context) prompt = f"""You are a helpful assistant. Use the provided context to answer questions accurately. -Context: -{context} + +{request.query} + -Question: {request.query}""" + +{context} + +""" # Call Ollama API response = await self.client.chat( model=self.model_name, - messages=[{"role": "user", "content": prompt}], + messages=[{ + "role": "user", + "content": prompt, + "images": [images[0]], + }], options={ "num_predict": request.max_tokens, "temperature": request.temperature, diff --git a/core/models/documents.py b/core/models/documents.py index 39d68f8..58da9b1 100644 --- a/core/models/documents.py +++ b/core/models/documents.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List, Optional, Literal from enum import Enum from datetime import UTC, datetime +from PIL import Image from pydantic import BaseModel, Field, field_validator import uuid import logging @@ -88,7 +89,7 @@ class ChunkResult(BaseModel): filename: Optional[str] = None download_url: Optional[str] = None - def augmented_content(self, doc: DocumentResult) -> str: + def augmented_content(self, doc: DocumentResult) -> str | Image.Image: match self.metadata: case m if "timestamp" in m: # if timestamp present, then must be a video. In that case, @@ -110,5 +111,21 @@ class ChunkResult(BaseModel): for t in timestamps ] return "\n\n".join(augmented_contents) + # case m if m.get("is_image", False): + # try: + # # Handle data URI format "data:image/png;base64,..." + # content = self.content + # if content.startswith('data:'): + # # Extract the base64 part after the comma + # content = content.split(',', 1)[1] + + # # Now decode the base64 string + # image_bytes = base64.b64decode(content) + # content = Image.open(io.BytesIO(image_bytes)) + # return content + # except Exception as e: + # print(f"Error processing image: {str(e)}") + # # Fall back to using the content as text + # return self.content case _: return self.content diff --git a/core/services/document_service.py b/core/services/document_service.py index eea5b99..3bc7979 100644 --- a/core/services/document_service.py +++ b/core/services/document_service.py @@ -27,7 +27,7 @@ from core.services.rules_processor import RulesProcessor from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.vector_store.multi_vector_store import MultiVectorStore import filetype -from filetype.types import IMAGE, archive # , DOCUMENT, document +from filetype.types import IMAGE # , DOCUMENT, document import pdf2image from PIL.Image import Image @@ -363,7 +363,7 @@ class DocumentService: case file_type if file_type in IMAGE: return [Chunk(content=file_content_base64, metadata={"is_image": True})] case "application/pdf": - logger.info(f"Working with PDF file!") + logger.info("Working with PDF file!") images = pdf2image.convert_from_bytes(file_content) images_b64 = [self.img_to_base64_str(image) for image in images] return [ diff --git a/core/tests/integration/test_colpali_integrate_multivector.py b/core/tests/integration/test_colpali_integrate_multivector.py index ccb717e..bcf32c6 100644 --- a/core/tests/integration/test_colpali_integrate_multivector.py +++ b/core/tests/integration/test_colpali_integrate_multivector.py @@ -4,7 +4,6 @@ import os import logging from pathlib import Path from pdf2image import convert_from_path -from PIL import Image from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.vector_store.multi_vector_store import MultiVectorStore diff --git a/core/tests/unit/test_colpali_embedding.py b/core/tests/unit/test_colpali_embedding.py index 9d44e9c..848688a 100644 --- a/core/tests/unit/test_colpali_embedding.py +++ b/core/tests/unit/test_colpali_embedding.py @@ -2,7 +2,6 @@ import pytest import base64 import io import numpy as np -import torch from PIL import Image from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel diff --git a/core/tests/unit/test_multivector.py b/core/tests/unit/test_multivector.py index 30fb914..7f948bf 100644 --- a/core/tests/unit/test_multivector.py +++ b/core/tests/unit/test_multivector.py @@ -2,8 +2,7 @@ import pytest import asyncio import torch import numpy as np -import psycopg -from pgvector.psycopg import Bit, register_vector +from pgvector.psycopg import Bit import logging from core.vector_store.multi_vector_store import MultiVectorStore from core.models.chunk import DocumentChunk diff --git a/databridge.toml b/databridge.toml index 725c1eb..4b67c6f 100644 --- a/databridge.toml +++ b/databridge.toml @@ -12,7 +12,7 @@ dev_permissions = ["read", "write", "admin"] # Default dev permissions [completion] provider = "ollama" -model_name = "llama3.2" +model_name = "llama3.2-vision" default_max_tokens = "1000" default_temperature = 0.7 # base_url = "http://ollama:11434" # Just use the service name diff --git a/sdks/python/databridge/async_.py b/sdks/python/databridge/async_.py index 3740426..3cc5e4a 100644 --- a/sdks/python/databridge/async_.py +++ b/sdks/python/databridge/async_.py @@ -7,7 +7,6 @@ from urllib.parse import urlparse import httpx import jwt from PIL.Image import Image as PILImage -from PIL import Image from .models import ( Document,