diff --git a/core/completion/ollama_completion.py b/core/completion/ollama_completion.py
index 2a79923..84f39a9 100644
--- a/core/completion/ollama_completion.py
+++ b/core/completion/ollama_completion.py
@@ -5,6 +5,7 @@ from core.completion.base_completion import (
)
from ollama import AsyncClient
+BASE_64_PREFIX = "data:image/png;base64,"
class OllamaCompletionModel(BaseCompletionModel):
"""Ollama completion model implementation"""
@@ -16,18 +17,33 @@ class OllamaCompletionModel(BaseCompletionModel):
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using Ollama API"""
# Construct prompt with context
- context = "\n\n".join(request.context_chunks)
+ images, context = [], []
+ for chunk in request.context_chunks:
+ if chunk.startswith(BASE_64_PREFIX):
+ image_b64 = chunk.split(',', 1)[1]
+ images.append(image_b64)
+ else:
+ context.append(chunk)
+ context = "\n\n".join(context)
prompt = f"""You are a helpful assistant. Use the provided context to answer questions accurately.
-Context:
-{context}
+
+{request.query}
+
-Question: {request.query}"""
+
+{context}
+
+"""
# Call Ollama API
response = await self.client.chat(
model=self.model_name,
- messages=[{"role": "user", "content": prompt}],
+ messages=[{
+ "role": "user",
+ "content": prompt,
+ "images": [images[0]],
+ }],
options={
"num_predict": request.max_tokens,
"temperature": request.temperature,
diff --git a/core/models/documents.py b/core/models/documents.py
index 39d68f8..58da9b1 100644
--- a/core/models/documents.py
+++ b/core/models/documents.py
@@ -1,6 +1,7 @@
from typing import Dict, Any, List, Optional, Literal
from enum import Enum
from datetime import UTC, datetime
+from PIL import Image
from pydantic import BaseModel, Field, field_validator
import uuid
import logging
@@ -88,7 +89,7 @@ class ChunkResult(BaseModel):
filename: Optional[str] = None
download_url: Optional[str] = None
- def augmented_content(self, doc: DocumentResult) -> str:
+ def augmented_content(self, doc: DocumentResult) -> str | Image.Image:
match self.metadata:
case m if "timestamp" in m:
# if timestamp present, then must be a video. In that case,
@@ -110,5 +111,21 @@ class ChunkResult(BaseModel):
for t in timestamps
]
return "\n\n".join(augmented_contents)
+ # case m if m.get("is_image", False):
+ # try:
+ # # Handle data URI format "data:image/png;base64,..."
+ # content = self.content
+ # if content.startswith('data:'):
+ # # Extract the base64 part after the comma
+ # content = content.split(',', 1)[1]
+
+ # # Now decode the base64 string
+ # image_bytes = base64.b64decode(content)
+ # content = Image.open(io.BytesIO(image_bytes))
+ # return content
+ # except Exception as e:
+ # print(f"Error processing image: {str(e)}")
+ # # Fall back to using the content as text
+ # return self.content
case _:
return self.content
diff --git a/core/services/document_service.py b/core/services/document_service.py
index eea5b99..3bc7979 100644
--- a/core/services/document_service.py
+++ b/core/services/document_service.py
@@ -27,7 +27,7 @@ from core.services.rules_processor import RulesProcessor
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore
import filetype
-from filetype.types import IMAGE, archive # , DOCUMENT, document
+from filetype.types import IMAGE # , DOCUMENT, document
import pdf2image
from PIL.Image import Image
@@ -363,7 +363,7 @@ class DocumentService:
case file_type if file_type in IMAGE:
return [Chunk(content=file_content_base64, metadata={"is_image": True})]
case "application/pdf":
- logger.info(f"Working with PDF file!")
+ logger.info("Working with PDF file!")
images = pdf2image.convert_from_bytes(file_content)
images_b64 = [self.img_to_base64_str(image) for image in images]
return [
diff --git a/core/tests/integration/test_colpali_integrate_multivector.py b/core/tests/integration/test_colpali_integrate_multivector.py
index ccb717e..bcf32c6 100644
--- a/core/tests/integration/test_colpali_integrate_multivector.py
+++ b/core/tests/integration/test_colpali_integrate_multivector.py
@@ -4,7 +4,6 @@ import os
import logging
from pathlib import Path
from pdf2image import convert_from_path
-from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore
diff --git a/core/tests/unit/test_colpali_embedding.py b/core/tests/unit/test_colpali_embedding.py
index 9d44e9c..848688a 100644
--- a/core/tests/unit/test_colpali_embedding.py
+++ b/core/tests/unit/test_colpali_embedding.py
@@ -2,7 +2,6 @@ import pytest
import base64
import io
import numpy as np
-import torch
from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
diff --git a/core/tests/unit/test_multivector.py b/core/tests/unit/test_multivector.py
index 30fb914..7f948bf 100644
--- a/core/tests/unit/test_multivector.py
+++ b/core/tests/unit/test_multivector.py
@@ -2,8 +2,7 @@ import pytest
import asyncio
import torch
import numpy as np
-import psycopg
-from pgvector.psycopg import Bit, register_vector
+from pgvector.psycopg import Bit
import logging
from core.vector_store.multi_vector_store import MultiVectorStore
from core.models.chunk import DocumentChunk
diff --git a/databridge.toml b/databridge.toml
index 725c1eb..4b67c6f 100644
--- a/databridge.toml
+++ b/databridge.toml
@@ -12,7 +12,7 @@ dev_permissions = ["read", "write", "admin"] # Default dev permissions
[completion]
provider = "ollama"
-model_name = "llama3.2"
+model_name = "llama3.2-vision"
default_max_tokens = "1000"
default_temperature = 0.7
# base_url = "http://ollama:11434" # Just use the service name
diff --git a/sdks/python/databridge/async_.py b/sdks/python/databridge/async_.py
index 3740426..3cc5e4a 100644
--- a/sdks/python/databridge/async_.py
+++ b/sdks/python/databridge/async_.py
@@ -7,7 +7,6 @@ from urllib.parse import urlparse
import httpx
import jwt
from PIL.Image import Image as PILImage
-from PIL import Image
from .models import (
Document,