mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
add support for patching pydantic (#113)
This commit is contained in:
parent
de1a7d2fd7
commit
3a03112636
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
392
utils/patching.py
Normal file
392
utils/patching.py
Normal file
@ -0,0 +1,392 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PatchOp(str, Enum):
|
||||
"""JSON Patch operation types as defined in RFC 6902."""
|
||||
|
||||
ADD = "add"
|
||||
REMOVE = "remove"
|
||||
REPLACE = "replace"
|
||||
MOVE = "move"
|
||||
COPY = "copy"
|
||||
TEST = "test"
|
||||
|
||||
|
||||
def gather_paths(schema: Dict[str, Any], base_pointer: str = "") -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Walk a JSON Schema and map every reachable JSON Pointer -> subschema.
|
||||
For objects: recurse into 'properties'.
|
||||
For arrays: treat the items schema as the path '/*'.
|
||||
"""
|
||||
paths = {}
|
||||
|
||||
# Add current path
|
||||
if base_pointer:
|
||||
paths[base_pointer] = schema
|
||||
|
||||
# Handle base schema type
|
||||
if "type" not in schema and base_pointer:
|
||||
# Ensure there's always a type
|
||||
schema["type"] = "object"
|
||||
|
||||
# Handle object types
|
||||
if schema.get("type") == "object":
|
||||
# Add properties
|
||||
for prop, prop_schema in schema.get("properties", {}).items():
|
||||
child_ptr = f"{base_pointer}/{prop}"
|
||||
paths[child_ptr] = prop_schema
|
||||
# Recursively process nested objects
|
||||
if prop_schema.get("type") == "object":
|
||||
paths.update(gather_paths(prop_schema, child_ptr))
|
||||
# Handle array items
|
||||
elif prop_schema.get("type") == "array" and "items" in prop_schema:
|
||||
array_path = f"{child_ptr}/-" # '-' represents the end of an array
|
||||
paths[array_path] = prop_schema["items"]
|
||||
# Process array item indexing
|
||||
paths.update(gather_paths(prop_schema["items"], f"{child_ptr}/0"))
|
||||
|
||||
# Handle additionalProperties if present
|
||||
if schema.get("additionalProperties"):
|
||||
add_props = schema["additionalProperties"]
|
||||
if isinstance(add_props, dict):
|
||||
wild_path = f"{base_pointer}/*"
|
||||
paths[wild_path] = add_props
|
||||
|
||||
# Handle array types
|
||||
elif schema.get("type") == "array" and "items" in schema:
|
||||
array_path = f"{base_pointer}/-"
|
||||
paths[array_path] = schema["items"]
|
||||
# Process array item indexing
|
||||
paths.update(gather_paths(schema["items"], f"{base_pointer}/0"))
|
||||
|
||||
return paths
|
||||
|
||||
|
||||
def op_schema(op: str, pointer: str, value_schema: dict | None) -> dict:
|
||||
"""
|
||||
Build a schema fragment for one operation / one pointer.
|
||||
"""
|
||||
# Ensure value_schema always has a type
|
||||
if value_schema and "type" not in value_schema:
|
||||
value_schema = {"type": "object", **value_schema}
|
||||
|
||||
# Base schema for the operation
|
||||
base = {
|
||||
"type": "object",
|
||||
"required": ["op", "path"],
|
||||
"properties": {"op": {"type": "string", "enum": [op]}, "path": {"type": "string", "const": pointer}},
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Add value property for operations that require it
|
||||
if op in {"add", "replace", "test"}:
|
||||
base["required"].append("value")
|
||||
if value_schema:
|
||||
base["properties"]["value"] = value_schema
|
||||
else:
|
||||
base["properties"]["value"] = {"type": "null"}
|
||||
|
||||
# Add from property for operations that require it
|
||||
if op in {"move", "copy"}:
|
||||
base["required"].append("from")
|
||||
base["properties"]["from"] = {"type": "string"}
|
||||
|
||||
return base
|
||||
|
||||
|
||||
def build_patch_schema(model_schema: dict) -> dict:
|
||||
"""
|
||||
Build a JSON Schema that validates JSON Patch operations for the given model schema.
|
||||
"""
|
||||
# Gather all paths in the schema
|
||||
paths = gather_paths(model_schema)
|
||||
|
||||
# Generate operation schemas for each path
|
||||
variants = []
|
||||
for ptr, subschema in paths.items():
|
||||
# Handle empty or invalid schemas
|
||||
if not subschema or not isinstance(subschema, dict):
|
||||
subschema = {"type": "object"}
|
||||
|
||||
# Ensure the subschema has a type
|
||||
if "type" not in subschema:
|
||||
# Try to infer type from other properties
|
||||
if "properties" in subschema:
|
||||
subschema["type"] = "object"
|
||||
elif "items" in subschema:
|
||||
subschema["type"] = "array"
|
||||
else:
|
||||
# Default to string if can't determine
|
||||
subschema["type"] = "string"
|
||||
|
||||
# Generate variants for different operations
|
||||
variants.append(op_schema("add", ptr, subschema))
|
||||
variants.append(op_schema("replace", ptr, subschema))
|
||||
variants.append(op_schema("remove", ptr, None))
|
||||
variants.append(op_schema("test", ptr, subschema))
|
||||
|
||||
# For move and copy operations, we'd need to do cross-path validation
|
||||
# For simplicity, we'll skip detailed validation of those here
|
||||
|
||||
# Final patch schema - wrap the array in an object to work with OpenAI's API
|
||||
patch_schema = {
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"type": "object",
|
||||
"properties": {"operations": {"type": "array", "items": {"anyOf": variants}}},
|
||||
"required": ["operations"],
|
||||
}
|
||||
|
||||
return patch_schema
|
||||
|
||||
|
||||
def MakePatchSchema(model_class: Type[BaseModel]) -> Type[BaseModel]:
|
||||
"""
|
||||
Create a Pydantic model that represents valid JSON Patch operations for the given model.
|
||||
|
||||
Args:
|
||||
model_class: The Pydantic model to create a patch schema for
|
||||
|
||||
Returns:
|
||||
A new Pydantic model class that validates JSON Patch operations
|
||||
"""
|
||||
# Get the JSON schema for the model
|
||||
model_schema = model_class.model_json_schema()
|
||||
|
||||
# Build the patch schema
|
||||
patch_schema = build_patch_schema(model_schema)
|
||||
|
||||
# Create a Pydantic model for a single patch operation
|
||||
class PatchOperation(BaseModel):
|
||||
op: PatchOp
|
||||
path: str
|
||||
value: Optional[Any] = None
|
||||
from_: Optional[str] = Field(None, alias="from")
|
||||
|
||||
model_config = {"populate_by_name": True, "extra": "forbid"}
|
||||
|
||||
# Create a model for the patch (array of operations wrapped in an object)
|
||||
class JsonPatchDocument(BaseModel):
|
||||
operations: List[PatchOperation] = Field(default_factory=list)
|
||||
|
||||
model_config = {
|
||||
"populate_by_name": True,
|
||||
"extra": "forbid",
|
||||
"json_schema_extra": lambda schema: schema.update(patch_schema),
|
||||
}
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.operations)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.operations[item]
|
||||
|
||||
@classmethod
|
||||
def parse_obj(cls, obj):
|
||||
"""Parse a JSON object into a patch document."""
|
||||
if isinstance(obj, (str, bytes)):
|
||||
obj = json.loads(obj)
|
||||
|
||||
# Expect an object with an "operations" key
|
||||
if isinstance(obj, dict) and "operations" in obj:
|
||||
return cls(operations=obj["operations"])
|
||||
else:
|
||||
raise ValueError(
|
||||
"Input must be an object with an 'operations' field containing an array of patch operations"
|
||||
)
|
||||
|
||||
def model_json_schema(cls, *args, **kwargs):
|
||||
"""Return the JSON Schema as a dict."""
|
||||
return patch_schema
|
||||
|
||||
def apply(self, target: Union[str, dict, BaseModel]) -> dict:
|
||||
"""
|
||||
Apply the patch to a target, which can be a JSON string, dict, or Pydantic model.
|
||||
|
||||
Args:
|
||||
target: The target to apply the patch to
|
||||
|
||||
Returns:
|
||||
The patched data as a dict
|
||||
"""
|
||||
# Convert the target to a dict
|
||||
if isinstance(target, str):
|
||||
target_dict = json.loads(target)
|
||||
elif isinstance(target, BaseModel):
|
||||
target_dict = target.model_dump(exclude_unset=True)
|
||||
else:
|
||||
target_dict = dict(target)
|
||||
|
||||
# Apply operations in sequence
|
||||
for op in self.operations:
|
||||
path_parts = op.path.strip("/").split("/") if op.path != "/" else []
|
||||
|
||||
if op.op == PatchOp.REPLACE:
|
||||
target_dict = apply_replace(target_dict, path_parts, op.value)
|
||||
elif op.op == PatchOp.ADD:
|
||||
target_dict = apply_add(target_dict, path_parts, op.value)
|
||||
elif op.op == PatchOp.REMOVE:
|
||||
target_dict = apply_remove(target_dict, path_parts)
|
||||
elif op.op == PatchOp.MOVE:
|
||||
from_parts = op.from_.strip("/").split("/") if op.from_ else []
|
||||
target_dict = apply_move(target_dict, from_parts, path_parts)
|
||||
elif op.op == PatchOp.COPY:
|
||||
from_parts = op.from_.strip("/").split("/") if op.from_ else []
|
||||
target_dict = apply_copy(target_dict, from_parts, path_parts)
|
||||
elif op.op == PatchOp.TEST:
|
||||
target_dict = apply_test(target_dict, path_parts, op.value)
|
||||
|
||||
return target_dict
|
||||
|
||||
# Update the model name to reflect the original model
|
||||
patch_model_name = f"{model_class.__name__}PatchSchema"
|
||||
JsonPatchDocument.__name__ = patch_model_name
|
||||
|
||||
# Override the model_json_schema method to return our custom schema
|
||||
JsonPatchDocument.model_json_schema = classmethod(lambda cls, *args, **kwargs: patch_schema)
|
||||
|
||||
return JsonPatchDocument
|
||||
|
||||
|
||||
def apply_replace(target: dict, path_parts: List[str], value: Any) -> dict:
|
||||
"""Apply a 'replace' operation."""
|
||||
target = target.copy() # Create a copy to avoid modifying the original
|
||||
|
||||
if not path_parts:
|
||||
# Replace entire document
|
||||
return value
|
||||
|
||||
current = target
|
||||
for i, part in enumerate(path_parts[:-1]):
|
||||
if part.isdigit() and isinstance(current, list):
|
||||
part = int(part)
|
||||
if part not in current and i < len(path_parts) - 1:
|
||||
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
||||
current = current[part]
|
||||
|
||||
last_part = path_parts[-1]
|
||||
if last_part.isdigit() and isinstance(current, list):
|
||||
last_part = int(last_part)
|
||||
if last_part not in current and not (isinstance(current, list) and last_part == "-"):
|
||||
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
|
||||
|
||||
current[last_part] = value
|
||||
return target
|
||||
|
||||
|
||||
def apply_add(target: dict, path_parts: List[str], value: Any) -> dict:
|
||||
"""Apply an 'add' operation."""
|
||||
target = target.copy()
|
||||
|
||||
if not path_parts:
|
||||
# Replace entire document
|
||||
return value
|
||||
|
||||
current = target
|
||||
for i, part in enumerate(path_parts[:-1]):
|
||||
if part.isdigit() and isinstance(current, list):
|
||||
part = int(part)
|
||||
if part not in current and i < len(path_parts) - 1:
|
||||
if path_parts[i + 1].isdigit() or path_parts[i + 1] == "-":
|
||||
current[part] = []
|
||||
else:
|
||||
current[part] = {}
|
||||
current = current[part]
|
||||
|
||||
last_part = path_parts[-1]
|
||||
if last_part == "-" and isinstance(current, list):
|
||||
current.append(value)
|
||||
elif last_part.isdigit() and isinstance(current, list):
|
||||
idx = int(last_part)
|
||||
if idx > len(current):
|
||||
raise ValueError(f"Index out of bounds: {idx}")
|
||||
current.insert(idx, value)
|
||||
else:
|
||||
current[last_part] = value
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def apply_remove(target: dict, path_parts: List[str]) -> dict:
|
||||
"""Apply a 'remove' operation."""
|
||||
target = target.copy()
|
||||
|
||||
if not path_parts:
|
||||
raise ValueError("Cannot remove root document")
|
||||
|
||||
current = target
|
||||
for i, part in enumerate(path_parts[:-1]):
|
||||
if part.isdigit() and isinstance(current, list):
|
||||
part = int(part)
|
||||
if part not in current:
|
||||
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
||||
current = current[part]
|
||||
|
||||
last_part = path_parts[-1]
|
||||
if last_part.isdigit() and isinstance(current, list):
|
||||
idx = int(last_part)
|
||||
if idx >= len(current):
|
||||
raise ValueError(f"Index out of bounds: {idx}")
|
||||
del current[idx]
|
||||
elif last_part not in current:
|
||||
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
|
||||
else:
|
||||
del current[last_part]
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def apply_move(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
|
||||
"""Apply a 'move' operation."""
|
||||
target = target.copy()
|
||||
|
||||
# Get the value at the 'from' location
|
||||
value = get_value(target, from_parts)
|
||||
|
||||
# Remove the value from the 'from' location
|
||||
target = apply_remove(target, from_parts)
|
||||
|
||||
# Add the value at the 'path' location
|
||||
return apply_add(target, path_parts, value)
|
||||
|
||||
|
||||
def apply_copy(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
|
||||
"""Apply a 'copy' operation."""
|
||||
target = target.copy()
|
||||
|
||||
# Get the value at the 'from' location
|
||||
value = get_value(target, from_parts)
|
||||
|
||||
# Add the value at the 'path' location
|
||||
return apply_add(target, path_parts, value)
|
||||
|
||||
|
||||
def apply_test(target: dict, path_parts: List[str], value: Any) -> dict:
|
||||
"""Apply a 'test' operation."""
|
||||
current_value = get_value(target, path_parts)
|
||||
|
||||
if current_value != value:
|
||||
raise ValueError(f"Test failed: {'/'.join(path_parts)} does not match expected value")
|
||||
|
||||
return target
|
||||
|
||||
|
||||
def get_value(target: dict, path_parts: List[str]) -> Any:
|
||||
"""Get a value at a specific path."""
|
||||
if not path_parts:
|
||||
return target
|
||||
|
||||
current = target
|
||||
for i, part in enumerate(path_parts):
|
||||
if part.isdigit() and isinstance(current, list):
|
||||
part = int(part)
|
||||
if part not in current:
|
||||
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
||||
current = current[part]
|
||||
|
||||
return current
|
45
utils/patching_test.py
Normal file
45
utils/patching_test.py
Normal file
@ -0,0 +1,45 @@
|
||||
import json
|
||||
import os
|
||||
from typing import List as TypeList
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
from dotenv import load_dotenv
|
||||
from patching import MakePatchSchema
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
load_dotenv(override=True)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: Optional[int] = None
|
||||
tags: TypeList[str] = Field(default_factory=list)
|
||||
|
||||
# Create an instance of the model
|
||||
person = Person(name="Arnav", age=30, tags=["developer", "python"])
|
||||
|
||||
# Create a patch schema for the Person model
|
||||
PersonPatchSchema = MakePatchSchema(Person)
|
||||
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
# Print the JSON schema for debugging
|
||||
print(json.dumps(PersonPatchSchema.model_json_schema(), indent=2))
|
||||
|
||||
response = client.beta.chat.completions.parse(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant that generates JSON Patch operations."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"""Generate a JSON Patch operation to change the name of the person to 'Ada'
|
||||
and remove the age field. The original data is: {person.model_dump()}""",
|
||||
},
|
||||
],
|
||||
response_format=PersonPatchSchema,
|
||||
)
|
||||
response_json = response.choices[0].message.parsed
|
||||
print(response_json)
|
||||
patched_person = response_json.apply(person)
|
||||
print(patched_person)
|
Loading…
x
Reference in New Issue
Block a user