add image processing to ollama

This commit is contained in:
Arnav Agrawal 2025-02-26 22:36:25 -05:00
parent 821e9d7e20
commit 07eec6b9e3
8 changed files with 43 additions and 14 deletions

View File

@ -5,6 +5,7 @@ from core.completion.base_completion import (
)
from ollama import AsyncClient
BASE_64_PREFIX = "data:image/png;base64,"
class OllamaCompletionModel(BaseCompletionModel):
"""Ollama completion model implementation"""
@ -16,18 +17,33 @@ class OllamaCompletionModel(BaseCompletionModel):
async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using Ollama API"""
# Construct prompt with context
context = "\n\n".join(request.context_chunks)
images, context = [], []
for chunk in request.context_chunks:
if chunk.startswith(BASE_64_PREFIX):
image_b64 = chunk.split(',', 1)[1]
images.append(image_b64)
else:
context.append(chunk)
context = "\n\n".join(context)
prompt = f"""You are a helpful assistant. Use the provided context to answer questions accurately.
Context:
{context}
<QUESTION>
{request.query}
</QUESTION>
Question: {request.query}"""
<CONTEXT>
{context}
</CONTEXT>
"""
# Call Ollama API
response = await self.client.chat(
model=self.model_name,
messages=[{"role": "user", "content": prompt}],
messages=[{
"role": "user",
"content": prompt,
"images": [images[0]],
}],
options={
"num_predict": request.max_tokens,
"temperature": request.temperature,

View File

@ -1,6 +1,7 @@
from typing import Dict, Any, List, Optional, Literal
from enum import Enum
from datetime import UTC, datetime
from PIL import Image
from pydantic import BaseModel, Field, field_validator
import uuid
import logging
@ -88,7 +89,7 @@ class ChunkResult(BaseModel):
filename: Optional[str] = None
download_url: Optional[str] = None
def augmented_content(self, doc: DocumentResult) -> str:
def augmented_content(self, doc: DocumentResult) -> str | Image.Image:
match self.metadata:
case m if "timestamp" in m:
# if timestamp present, then must be a video. In that case,
@ -110,5 +111,21 @@ class ChunkResult(BaseModel):
for t in timestamps
]
return "\n\n".join(augmented_contents)
# case m if m.get("is_image", False):
# try:
# # Handle data URI format "data:image/png;base64,..."
# content = self.content
# if content.startswith('data:'):
# # Extract the base64 part after the comma
# content = content.split(',', 1)[1]
# # Now decode the base64 string
# image_bytes = base64.b64decode(content)
# content = Image.open(io.BytesIO(image_bytes))
# return content
# except Exception as e:
# print(f"Error processing image: {str(e)}")
# # Fall back to using the content as text
# return self.content
case _:
return self.content

View File

@ -27,7 +27,7 @@ from core.services.rules_processor import RulesProcessor
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore
import filetype
from filetype.types import IMAGE, archive # , DOCUMENT, document
from filetype.types import IMAGE # , DOCUMENT, document
import pdf2image
from PIL.Image import Image
@ -363,7 +363,7 @@ class DocumentService:
case file_type if file_type in IMAGE:
return [Chunk(content=file_content_base64, metadata={"is_image": True})]
case "application/pdf":
logger.info(f"Working with PDF file!")
logger.info("Working with PDF file!")
images = pdf2image.convert_from_bytes(file_content)
images_b64 = [self.img_to_base64_str(image) for image in images]
return [

View File

@ -4,7 +4,6 @@ import os
import logging
from pathlib import Path
from pdf2image import convert_from_path
from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore

View File

@ -2,7 +2,6 @@ import pytest
import base64
import io
import numpy as np
import torch
from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel

View File

@ -2,8 +2,7 @@ import pytest
import asyncio
import torch
import numpy as np
import psycopg
from pgvector.psycopg import Bit, register_vector
from pgvector.psycopg import Bit
import logging
from core.vector_store.multi_vector_store import MultiVectorStore
from core.models.chunk import DocumentChunk

View File

@ -12,7 +12,7 @@ dev_permissions = ["read", "write", "admin"] # Default dev permissions
[completion]
provider = "ollama"
model_name = "llama3.2"
model_name = "llama3.2-vision"
default_max_tokens = "1000"
default_temperature = 0.7
# base_url = "http://ollama:11434" # Just use the service name

View File

@ -7,7 +7,6 @@ from urllib.parse import urlparse
import httpx
import jwt
from PIL.Image import Image as PILImage
from PIL import Image
from .models import (
Document,