add image processing to ollama

This commit is contained in:
Arnav Agrawal 2025-02-26 22:36:25 -05:00
parent 821e9d7e20
commit 07eec6b9e3
8 changed files with 43 additions and 14 deletions

View File

@ -5,6 +5,7 @@ from core.completion.base_completion import (
) )
from ollama import AsyncClient from ollama import AsyncClient
BASE_64_PREFIX = "data:image/png;base64,"
class OllamaCompletionModel(BaseCompletionModel): class OllamaCompletionModel(BaseCompletionModel):
"""Ollama completion model implementation""" """Ollama completion model implementation"""
@ -16,18 +17,33 @@ class OllamaCompletionModel(BaseCompletionModel):
async def complete(self, request: CompletionRequest) -> CompletionResponse: async def complete(self, request: CompletionRequest) -> CompletionResponse:
"""Generate completion using Ollama API""" """Generate completion using Ollama API"""
# Construct prompt with context # Construct prompt with context
context = "\n\n".join(request.context_chunks) images, context = [], []
for chunk in request.context_chunks:
if chunk.startswith(BASE_64_PREFIX):
image_b64 = chunk.split(',', 1)[1]
images.append(image_b64)
else:
context.append(chunk)
context = "\n\n".join(context)
prompt = f"""You are a helpful assistant. Use the provided context to answer questions accurately. prompt = f"""You are a helpful assistant. Use the provided context to answer questions accurately.
Context: <QUESTION>
{context} {request.query}
</QUESTION>
Question: {request.query}""" <CONTEXT>
{context}
</CONTEXT>
"""
# Call Ollama API # Call Ollama API
response = await self.client.chat( response = await self.client.chat(
model=self.model_name, model=self.model_name,
messages=[{"role": "user", "content": prompt}], messages=[{
"role": "user",
"content": prompt,
"images": [images[0]],
}],
options={ options={
"num_predict": request.max_tokens, "num_predict": request.max_tokens,
"temperature": request.temperature, "temperature": request.temperature,

View File

@ -1,6 +1,7 @@
from typing import Dict, Any, List, Optional, Literal from typing import Dict, Any, List, Optional, Literal
from enum import Enum from enum import Enum
from datetime import UTC, datetime from datetime import UTC, datetime
from PIL import Image
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
import uuid import uuid
import logging import logging
@ -88,7 +89,7 @@ class ChunkResult(BaseModel):
filename: Optional[str] = None filename: Optional[str] = None
download_url: Optional[str] = None download_url: Optional[str] = None
def augmented_content(self, doc: DocumentResult) -> str: def augmented_content(self, doc: DocumentResult) -> str | Image.Image:
match self.metadata: match self.metadata:
case m if "timestamp" in m: case m if "timestamp" in m:
# if timestamp present, then must be a video. In that case, # if timestamp present, then must be a video. In that case,
@ -110,5 +111,21 @@ class ChunkResult(BaseModel):
for t in timestamps for t in timestamps
] ]
return "\n\n".join(augmented_contents) return "\n\n".join(augmented_contents)
# case m if m.get("is_image", False):
# try:
# # Handle data URI format "data:image/png;base64,..."
# content = self.content
# if content.startswith('data:'):
# # Extract the base64 part after the comma
# content = content.split(',', 1)[1]
# # Now decode the base64 string
# image_bytes = base64.b64decode(content)
# content = Image.open(io.BytesIO(image_bytes))
# return content
# except Exception as e:
# print(f"Error processing image: {str(e)}")
# # Fall back to using the content as text
# return self.content
case _: case _:
return self.content return self.content

View File

@ -27,7 +27,7 @@ from core.services.rules_processor import RulesProcessor
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore from core.vector_store.multi_vector_store import MultiVectorStore
import filetype import filetype
from filetype.types import IMAGE, archive # , DOCUMENT, document from filetype.types import IMAGE # , DOCUMENT, document
import pdf2image import pdf2image
from PIL.Image import Image from PIL.Image import Image
@ -363,7 +363,7 @@ class DocumentService:
case file_type if file_type in IMAGE: case file_type if file_type in IMAGE:
return [Chunk(content=file_content_base64, metadata={"is_image": True})] return [Chunk(content=file_content_base64, metadata={"is_image": True})]
case "application/pdf": case "application/pdf":
logger.info(f"Working with PDF file!") logger.info("Working with PDF file!")
images = pdf2image.convert_from_bytes(file_content) images = pdf2image.convert_from_bytes(file_content)
images_b64 = [self.img_to_base64_str(image) for image in images] images_b64 = [self.img_to_base64_str(image) for image in images]
return [ return [

View File

@ -4,7 +4,6 @@ import os
import logging import logging
from pathlib import Path from pathlib import Path
from pdf2image import convert_from_path from pdf2image import convert_from_path
from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel
from core.vector_store.multi_vector_store import MultiVectorStore from core.vector_store.multi_vector_store import MultiVectorStore

View File

@ -2,7 +2,6 @@ import pytest
import base64 import base64
import io import io
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel from core.embedding.colpali_embedding_model import ColpaliEmbeddingModel

View File

@ -2,8 +2,7 @@ import pytest
import asyncio import asyncio
import torch import torch
import numpy as np import numpy as np
import psycopg from pgvector.psycopg import Bit
from pgvector.psycopg import Bit, register_vector
import logging import logging
from core.vector_store.multi_vector_store import MultiVectorStore from core.vector_store.multi_vector_store import MultiVectorStore
from core.models.chunk import DocumentChunk from core.models.chunk import DocumentChunk

View File

@ -12,7 +12,7 @@ dev_permissions = ["read", "write", "admin"] # Default dev permissions
[completion] [completion]
provider = "ollama" provider = "ollama"
model_name = "llama3.2" model_name = "llama3.2-vision"
default_max_tokens = "1000" default_max_tokens = "1000"
default_temperature = 0.7 default_temperature = 0.7
# base_url = "http://ollama:11434" # Just use the service name # base_url = "http://ollama:11434" # Just use the service name

View File

@ -7,7 +7,6 @@ from urllib.parse import urlparse
import httpx import httpx
import jwt import jwt
from PIL.Image import Image as PILImage from PIL.Image import Image as PILImage
from PIL import Image
from .models import ( from .models import (
Document, Document,