mirror of
https://github.com/james-m-jordan/morphik-core.git
synced 2025-05-09 19:32:38 +00:00
393 lines
13 KiB
Python
393 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from enum import Enum
|
|
from typing import Any, Dict, List, Optional, Type, Union
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class PatchOp(str, Enum):
|
|
"""JSON Patch operation types as defined in RFC 6902."""
|
|
|
|
ADD = "add"
|
|
REMOVE = "remove"
|
|
REPLACE = "replace"
|
|
MOVE = "move"
|
|
COPY = "copy"
|
|
TEST = "test"
|
|
|
|
|
|
def gather_paths(schema: Dict[str, Any], base_pointer: str = "") -> Dict[str, Dict[str, Any]]:
|
|
"""
|
|
Walk a JSON Schema and map every reachable JSON Pointer -> subschema.
|
|
For objects: recurse into 'properties'.
|
|
For arrays: treat the items schema as the path '/*'.
|
|
"""
|
|
paths = {}
|
|
|
|
# Add current path
|
|
if base_pointer:
|
|
paths[base_pointer] = schema
|
|
|
|
# Handle base schema type
|
|
if "type" not in schema and base_pointer:
|
|
# Ensure there's always a type
|
|
schema["type"] = "object"
|
|
|
|
# Handle object types
|
|
if schema.get("type") == "object":
|
|
# Add properties
|
|
for prop, prop_schema in schema.get("properties", {}).items():
|
|
child_ptr = f"{base_pointer}/{prop}"
|
|
paths[child_ptr] = prop_schema
|
|
# Recursively process nested objects
|
|
if prop_schema.get("type") == "object":
|
|
paths.update(gather_paths(prop_schema, child_ptr))
|
|
# Handle array items
|
|
elif prop_schema.get("type") == "array" and "items" in prop_schema:
|
|
array_path = f"{child_ptr}/-" # '-' represents the end of an array
|
|
paths[array_path] = prop_schema["items"]
|
|
# Process array item indexing
|
|
paths.update(gather_paths(prop_schema["items"], f"{child_ptr}/0"))
|
|
|
|
# Handle additionalProperties if present
|
|
if schema.get("additionalProperties"):
|
|
add_props = schema["additionalProperties"]
|
|
if isinstance(add_props, dict):
|
|
wild_path = f"{base_pointer}/*"
|
|
paths[wild_path] = add_props
|
|
|
|
# Handle array types
|
|
elif schema.get("type") == "array" and "items" in schema:
|
|
array_path = f"{base_pointer}/-"
|
|
paths[array_path] = schema["items"]
|
|
# Process array item indexing
|
|
paths.update(gather_paths(schema["items"], f"{base_pointer}/0"))
|
|
|
|
return paths
|
|
|
|
|
|
def op_schema(op: str, pointer: str, value_schema: dict | None) -> dict:
|
|
"""
|
|
Build a schema fragment for one operation / one pointer.
|
|
"""
|
|
# Ensure value_schema always has a type
|
|
if value_schema and "type" not in value_schema:
|
|
value_schema = {"type": "object", **value_schema}
|
|
|
|
# Base schema for the operation
|
|
base = {
|
|
"type": "object",
|
|
"required": ["op", "path"],
|
|
"properties": {"op": {"type": "string", "enum": [op]}, "path": {"type": "string", "const": pointer}},
|
|
"additionalProperties": False,
|
|
}
|
|
|
|
# Add value property for operations that require it
|
|
if op in {"add", "replace", "test"}:
|
|
base["required"].append("value")
|
|
if value_schema:
|
|
base["properties"]["value"] = value_schema
|
|
else:
|
|
base["properties"]["value"] = {"type": "null"}
|
|
|
|
# Add from property for operations that require it
|
|
if op in {"move", "copy"}:
|
|
base["required"].append("from")
|
|
base["properties"]["from"] = {"type": "string"}
|
|
|
|
return base
|
|
|
|
|
|
def build_patch_schema(model_schema: dict) -> dict:
|
|
"""
|
|
Build a JSON Schema that validates JSON Patch operations for the given model schema.
|
|
"""
|
|
# Gather all paths in the schema
|
|
paths = gather_paths(model_schema)
|
|
|
|
# Generate operation schemas for each path
|
|
variants = []
|
|
for ptr, subschema in paths.items():
|
|
# Handle empty or invalid schemas
|
|
if not subschema or not isinstance(subschema, dict):
|
|
subschema = {"type": "object"}
|
|
|
|
# Ensure the subschema has a type
|
|
if "type" not in subschema:
|
|
# Try to infer type from other properties
|
|
if "properties" in subschema:
|
|
subschema["type"] = "object"
|
|
elif "items" in subschema:
|
|
subschema["type"] = "array"
|
|
else:
|
|
# Default to string if can't determine
|
|
subschema["type"] = "string"
|
|
|
|
# Generate variants for different operations
|
|
variants.append(op_schema("add", ptr, subschema))
|
|
variants.append(op_schema("replace", ptr, subschema))
|
|
variants.append(op_schema("remove", ptr, None))
|
|
variants.append(op_schema("test", ptr, subschema))
|
|
|
|
# For move and copy operations, we'd need to do cross-path validation
|
|
# For simplicity, we'll skip detailed validation of those here
|
|
|
|
# Final patch schema - wrap the array in an object to work with OpenAI's API
|
|
patch_schema = {
|
|
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
|
"type": "object",
|
|
"properties": {"operations": {"type": "array", "items": {"anyOf": variants}}},
|
|
"required": ["operations"],
|
|
}
|
|
|
|
return patch_schema
|
|
|
|
|
|
def MakePatchSchema(model_class: Type[BaseModel]) -> Type[BaseModel]:
|
|
"""
|
|
Create a Pydantic model that represents valid JSON Patch operations for the given model.
|
|
|
|
Args:
|
|
model_class: The Pydantic model to create a patch schema for
|
|
|
|
Returns:
|
|
A new Pydantic model class that validates JSON Patch operations
|
|
"""
|
|
# Get the JSON schema for the model
|
|
model_schema = model_class.model_json_schema()
|
|
|
|
# Build the patch schema
|
|
patch_schema = build_patch_schema(model_schema)
|
|
|
|
# Create a Pydantic model for a single patch operation
|
|
class PatchOperation(BaseModel):
|
|
op: PatchOp
|
|
path: str
|
|
value: Optional[Any] = None
|
|
from_: Optional[str] = Field(None, alias="from")
|
|
|
|
model_config = {"populate_by_name": True, "extra": "forbid"}
|
|
|
|
# Create a model for the patch (array of operations wrapped in an object)
|
|
class JsonPatchDocument(BaseModel):
|
|
operations: List[PatchOperation] = Field(default_factory=list)
|
|
|
|
model_config = {
|
|
"populate_by_name": True,
|
|
"extra": "forbid",
|
|
"json_schema_extra": lambda schema: schema.update(patch_schema),
|
|
}
|
|
|
|
def __iter__(self):
|
|
return iter(self.operations)
|
|
|
|
def __getitem__(self, item):
|
|
return self.operations[item]
|
|
|
|
@classmethod
|
|
def parse_obj(cls, obj):
|
|
"""Parse a JSON object into a patch document."""
|
|
if isinstance(obj, (str, bytes)):
|
|
obj = json.loads(obj)
|
|
|
|
# Expect an object with an "operations" key
|
|
if isinstance(obj, dict) and "operations" in obj:
|
|
return cls(operations=obj["operations"])
|
|
else:
|
|
raise ValueError(
|
|
"Input must be an object with an 'operations' field containing an array of patch operations"
|
|
)
|
|
|
|
def model_json_schema(cls, *args, **kwargs):
|
|
"""Return the JSON Schema as a dict."""
|
|
return patch_schema
|
|
|
|
def apply(self, target: Union[str, dict, BaseModel]) -> dict:
|
|
"""
|
|
Apply the patch to a target, which can be a JSON string, dict, or Pydantic model.
|
|
|
|
Args:
|
|
target: The target to apply the patch to
|
|
|
|
Returns:
|
|
The patched data as a dict
|
|
"""
|
|
# Convert the target to a dict
|
|
if isinstance(target, str):
|
|
target_dict = json.loads(target)
|
|
elif isinstance(target, BaseModel):
|
|
target_dict = target.model_dump(exclude_unset=True)
|
|
else:
|
|
target_dict = dict(target)
|
|
|
|
# Apply operations in sequence
|
|
for op in self.operations:
|
|
path_parts = op.path.strip("/").split("/") if op.path != "/" else []
|
|
|
|
if op.op == PatchOp.REPLACE:
|
|
target_dict = apply_replace(target_dict, path_parts, op.value)
|
|
elif op.op == PatchOp.ADD:
|
|
target_dict = apply_add(target_dict, path_parts, op.value)
|
|
elif op.op == PatchOp.REMOVE:
|
|
target_dict = apply_remove(target_dict, path_parts)
|
|
elif op.op == PatchOp.MOVE:
|
|
from_parts = op.from_.strip("/").split("/") if op.from_ else []
|
|
target_dict = apply_move(target_dict, from_parts, path_parts)
|
|
elif op.op == PatchOp.COPY:
|
|
from_parts = op.from_.strip("/").split("/") if op.from_ else []
|
|
target_dict = apply_copy(target_dict, from_parts, path_parts)
|
|
elif op.op == PatchOp.TEST:
|
|
target_dict = apply_test(target_dict, path_parts, op.value)
|
|
|
|
return target_dict
|
|
|
|
# Update the model name to reflect the original model
|
|
patch_model_name = f"{model_class.__name__}PatchSchema"
|
|
JsonPatchDocument.__name__ = patch_model_name
|
|
|
|
# Override the model_json_schema method to return our custom schema
|
|
JsonPatchDocument.model_json_schema = classmethod(lambda cls, *args, **kwargs: patch_schema)
|
|
|
|
return JsonPatchDocument
|
|
|
|
|
|
def apply_replace(target: dict, path_parts: List[str], value: Any) -> dict:
|
|
"""Apply a 'replace' operation."""
|
|
target = target.copy() # Create a copy to avoid modifying the original
|
|
|
|
if not path_parts:
|
|
# Replace entire document
|
|
return value
|
|
|
|
current = target
|
|
for i, part in enumerate(path_parts[:-1]):
|
|
if part.isdigit() and isinstance(current, list):
|
|
part = int(part)
|
|
if part not in current and i < len(path_parts) - 1:
|
|
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
|
current = current[part]
|
|
|
|
last_part = path_parts[-1]
|
|
if last_part.isdigit() and isinstance(current, list):
|
|
last_part = int(last_part)
|
|
if last_part not in current and not (isinstance(current, list) and last_part == "-"):
|
|
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
|
|
|
|
current[last_part] = value
|
|
return target
|
|
|
|
|
|
def apply_add(target: dict, path_parts: List[str], value: Any) -> dict:
|
|
"""Apply an 'add' operation."""
|
|
target = target.copy()
|
|
|
|
if not path_parts:
|
|
# Replace entire document
|
|
return value
|
|
|
|
current = target
|
|
for i, part in enumerate(path_parts[:-1]):
|
|
if part.isdigit() and isinstance(current, list):
|
|
part = int(part)
|
|
if part not in current and i < len(path_parts) - 1:
|
|
if path_parts[i + 1].isdigit() or path_parts[i + 1] == "-":
|
|
current[part] = []
|
|
else:
|
|
current[part] = {}
|
|
current = current[part]
|
|
|
|
last_part = path_parts[-1]
|
|
if last_part == "-" and isinstance(current, list):
|
|
current.append(value)
|
|
elif last_part.isdigit() and isinstance(current, list):
|
|
idx = int(last_part)
|
|
if idx > len(current):
|
|
raise ValueError(f"Index out of bounds: {idx}")
|
|
current.insert(idx, value)
|
|
else:
|
|
current[last_part] = value
|
|
|
|
return target
|
|
|
|
|
|
def apply_remove(target: dict, path_parts: List[str]) -> dict:
|
|
"""Apply a 'remove' operation."""
|
|
target = target.copy()
|
|
|
|
if not path_parts:
|
|
raise ValueError("Cannot remove root document")
|
|
|
|
current = target
|
|
for i, part in enumerate(path_parts[:-1]):
|
|
if part.isdigit() and isinstance(current, list):
|
|
part = int(part)
|
|
if part not in current:
|
|
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
|
current = current[part]
|
|
|
|
last_part = path_parts[-1]
|
|
if last_part.isdigit() and isinstance(current, list):
|
|
idx = int(last_part)
|
|
if idx >= len(current):
|
|
raise ValueError(f"Index out of bounds: {idx}")
|
|
del current[idx]
|
|
elif last_part not in current:
|
|
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
|
|
else:
|
|
del current[last_part]
|
|
|
|
return target
|
|
|
|
|
|
def apply_move(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
|
|
"""Apply a 'move' operation."""
|
|
target = target.copy()
|
|
|
|
# Get the value at the 'from' location
|
|
value = get_value(target, from_parts)
|
|
|
|
# Remove the value from the 'from' location
|
|
target = apply_remove(target, from_parts)
|
|
|
|
# Add the value at the 'path' location
|
|
return apply_add(target, path_parts, value)
|
|
|
|
|
|
def apply_copy(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
|
|
"""Apply a 'copy' operation."""
|
|
target = target.copy()
|
|
|
|
# Get the value at the 'from' location
|
|
value = get_value(target, from_parts)
|
|
|
|
# Add the value at the 'path' location
|
|
return apply_add(target, path_parts, value)
|
|
|
|
|
|
def apply_test(target: dict, path_parts: List[str], value: Any) -> dict:
|
|
"""Apply a 'test' operation."""
|
|
current_value = get_value(target, path_parts)
|
|
|
|
if current_value != value:
|
|
raise ValueError(f"Test failed: {'/'.join(path_parts)} does not match expected value")
|
|
|
|
return target
|
|
|
|
|
|
def get_value(target: dict, path_parts: List[str]) -> Any:
|
|
"""Get a value at a specific path."""
|
|
if not path_parts:
|
|
return target
|
|
|
|
current = target
|
|
for i, part in enumerate(path_parts):
|
|
if part.isdigit() and isinstance(current, list):
|
|
part = int(part)
|
|
if part not in current:
|
|
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
|
|
current = current[part]
|
|
|
|
return current
|