morphik-core/utils/patching.py
2025-04-24 03:23:11 -04:00

393 lines
13 KiB
Python

from __future__ import annotations
import json
from enum import Enum
from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel, Field
class PatchOp(str, Enum):
"""JSON Patch operation types as defined in RFC 6902."""
ADD = "add"
REMOVE = "remove"
REPLACE = "replace"
MOVE = "move"
COPY = "copy"
TEST = "test"
def gather_paths(schema: Dict[str, Any], base_pointer: str = "") -> Dict[str, Dict[str, Any]]:
"""
Walk a JSON Schema and map every reachable JSON Pointer -> subschema.
For objects: recurse into 'properties'.
For arrays: treat the items schema as the path '/*'.
"""
paths = {}
# Add current path
if base_pointer:
paths[base_pointer] = schema
# Handle base schema type
if "type" not in schema and base_pointer:
# Ensure there's always a type
schema["type"] = "object"
# Handle object types
if schema.get("type") == "object":
# Add properties
for prop, prop_schema in schema.get("properties", {}).items():
child_ptr = f"{base_pointer}/{prop}"
paths[child_ptr] = prop_schema
# Recursively process nested objects
if prop_schema.get("type") == "object":
paths.update(gather_paths(prop_schema, child_ptr))
# Handle array items
elif prop_schema.get("type") == "array" and "items" in prop_schema:
array_path = f"{child_ptr}/-" # '-' represents the end of an array
paths[array_path] = prop_schema["items"]
# Process array item indexing
paths.update(gather_paths(prop_schema["items"], f"{child_ptr}/0"))
# Handle additionalProperties if present
if schema.get("additionalProperties"):
add_props = schema["additionalProperties"]
if isinstance(add_props, dict):
wild_path = f"{base_pointer}/*"
paths[wild_path] = add_props
# Handle array types
elif schema.get("type") == "array" and "items" in schema:
array_path = f"{base_pointer}/-"
paths[array_path] = schema["items"]
# Process array item indexing
paths.update(gather_paths(schema["items"], f"{base_pointer}/0"))
return paths
def op_schema(op: str, pointer: str, value_schema: dict | None) -> dict:
"""
Build a schema fragment for one operation / one pointer.
"""
# Ensure value_schema always has a type
if value_schema and "type" not in value_schema:
value_schema = {"type": "object", **value_schema}
# Base schema for the operation
base = {
"type": "object",
"required": ["op", "path"],
"properties": {"op": {"type": "string", "enum": [op]}, "path": {"type": "string", "const": pointer}},
"additionalProperties": False,
}
# Add value property for operations that require it
if op in {"add", "replace", "test"}:
base["required"].append("value")
if value_schema:
base["properties"]["value"] = value_schema
else:
base["properties"]["value"] = {"type": "null"}
# Add from property for operations that require it
if op in {"move", "copy"}:
base["required"].append("from")
base["properties"]["from"] = {"type": "string"}
return base
def build_patch_schema(model_schema: dict) -> dict:
"""
Build a JSON Schema that validates JSON Patch operations for the given model schema.
"""
# Gather all paths in the schema
paths = gather_paths(model_schema)
# Generate operation schemas for each path
variants = []
for ptr, subschema in paths.items():
# Handle empty or invalid schemas
if not subschema or not isinstance(subschema, dict):
subschema = {"type": "object"}
# Ensure the subschema has a type
if "type" not in subschema:
# Try to infer type from other properties
if "properties" in subschema:
subschema["type"] = "object"
elif "items" in subschema:
subschema["type"] = "array"
else:
# Default to string if can't determine
subschema["type"] = "string"
# Generate variants for different operations
variants.append(op_schema("add", ptr, subschema))
variants.append(op_schema("replace", ptr, subschema))
variants.append(op_schema("remove", ptr, None))
variants.append(op_schema("test", ptr, subschema))
# For move and copy operations, we'd need to do cross-path validation
# For simplicity, we'll skip detailed validation of those here
# Final patch schema - wrap the array in an object to work with OpenAI's API
patch_schema = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {"operations": {"type": "array", "items": {"anyOf": variants}}},
"required": ["operations"],
}
return patch_schema
def MakePatchSchema(model_class: Type[BaseModel]) -> Type[BaseModel]:
"""
Create a Pydantic model that represents valid JSON Patch operations for the given model.
Args:
model_class: The Pydantic model to create a patch schema for
Returns:
A new Pydantic model class that validates JSON Patch operations
"""
# Get the JSON schema for the model
model_schema = model_class.model_json_schema()
# Build the patch schema
patch_schema = build_patch_schema(model_schema)
# Create a Pydantic model for a single patch operation
class PatchOperation(BaseModel):
op: PatchOp
path: str
value: Optional[Any] = None
from_: Optional[str] = Field(None, alias="from")
model_config = {"populate_by_name": True, "extra": "forbid"}
# Create a model for the patch (array of operations wrapped in an object)
class JsonPatchDocument(BaseModel):
operations: List[PatchOperation] = Field(default_factory=list)
model_config = {
"populate_by_name": True,
"extra": "forbid",
"json_schema_extra": lambda schema: schema.update(patch_schema),
}
def __iter__(self):
return iter(self.operations)
def __getitem__(self, item):
return self.operations[item]
@classmethod
def parse_obj(cls, obj):
"""Parse a JSON object into a patch document."""
if isinstance(obj, (str, bytes)):
obj = json.loads(obj)
# Expect an object with an "operations" key
if isinstance(obj, dict) and "operations" in obj:
return cls(operations=obj["operations"])
else:
raise ValueError(
"Input must be an object with an 'operations' field containing an array of patch operations"
)
def model_json_schema(cls, *args, **kwargs):
"""Return the JSON Schema as a dict."""
return patch_schema
def apply(self, target: Union[str, dict, BaseModel]) -> dict:
"""
Apply the patch to a target, which can be a JSON string, dict, or Pydantic model.
Args:
target: The target to apply the patch to
Returns:
The patched data as a dict
"""
# Convert the target to a dict
if isinstance(target, str):
target_dict = json.loads(target)
elif isinstance(target, BaseModel):
target_dict = target.model_dump(exclude_unset=True)
else:
target_dict = dict(target)
# Apply operations in sequence
for op in self.operations:
path_parts = op.path.strip("/").split("/") if op.path != "/" else []
if op.op == PatchOp.REPLACE:
target_dict = apply_replace(target_dict, path_parts, op.value)
elif op.op == PatchOp.ADD:
target_dict = apply_add(target_dict, path_parts, op.value)
elif op.op == PatchOp.REMOVE:
target_dict = apply_remove(target_dict, path_parts)
elif op.op == PatchOp.MOVE:
from_parts = op.from_.strip("/").split("/") if op.from_ else []
target_dict = apply_move(target_dict, from_parts, path_parts)
elif op.op == PatchOp.COPY:
from_parts = op.from_.strip("/").split("/") if op.from_ else []
target_dict = apply_copy(target_dict, from_parts, path_parts)
elif op.op == PatchOp.TEST:
target_dict = apply_test(target_dict, path_parts, op.value)
return target_dict
# Update the model name to reflect the original model
patch_model_name = f"{model_class.__name__}PatchSchema"
JsonPatchDocument.__name__ = patch_model_name
# Override the model_json_schema method to return our custom schema
JsonPatchDocument.model_json_schema = classmethod(lambda cls, *args, **kwargs: patch_schema)
return JsonPatchDocument
def apply_replace(target: dict, path_parts: List[str], value: Any) -> dict:
"""Apply a 'replace' operation."""
target = target.copy() # Create a copy to avoid modifying the original
if not path_parts:
# Replace entire document
return value
current = target
for i, part in enumerate(path_parts[:-1]):
if part.isdigit() and isinstance(current, list):
part = int(part)
if part not in current and i < len(path_parts) - 1:
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
current = current[part]
last_part = path_parts[-1]
if last_part.isdigit() and isinstance(current, list):
last_part = int(last_part)
if last_part not in current and not (isinstance(current, list) and last_part == "-"):
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
current[last_part] = value
return target
def apply_add(target: dict, path_parts: List[str], value: Any) -> dict:
"""Apply an 'add' operation."""
target = target.copy()
if not path_parts:
# Replace entire document
return value
current = target
for i, part in enumerate(path_parts[:-1]):
if part.isdigit() and isinstance(current, list):
part = int(part)
if part not in current and i < len(path_parts) - 1:
if path_parts[i + 1].isdigit() or path_parts[i + 1] == "-":
current[part] = []
else:
current[part] = {}
current = current[part]
last_part = path_parts[-1]
if last_part == "-" and isinstance(current, list):
current.append(value)
elif last_part.isdigit() and isinstance(current, list):
idx = int(last_part)
if idx > len(current):
raise ValueError(f"Index out of bounds: {idx}")
current.insert(idx, value)
else:
current[last_part] = value
return target
def apply_remove(target: dict, path_parts: List[str]) -> dict:
"""Apply a 'remove' operation."""
target = target.copy()
if not path_parts:
raise ValueError("Cannot remove root document")
current = target
for i, part in enumerate(path_parts[:-1]):
if part.isdigit() and isinstance(current, list):
part = int(part)
if part not in current:
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
current = current[part]
last_part = path_parts[-1]
if last_part.isdigit() and isinstance(current, list):
idx = int(last_part)
if idx >= len(current):
raise ValueError(f"Index out of bounds: {idx}")
del current[idx]
elif last_part not in current:
raise ValueError(f"Path not found: {'/'.join(path_parts)}")
else:
del current[last_part]
return target
def apply_move(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
"""Apply a 'move' operation."""
target = target.copy()
# Get the value at the 'from' location
value = get_value(target, from_parts)
# Remove the value from the 'from' location
target = apply_remove(target, from_parts)
# Add the value at the 'path' location
return apply_add(target, path_parts, value)
def apply_copy(target: dict, from_parts: List[str], path_parts: List[str]) -> dict:
"""Apply a 'copy' operation."""
target = target.copy()
# Get the value at the 'from' location
value = get_value(target, from_parts)
# Add the value at the 'path' location
return apply_add(target, path_parts, value)
def apply_test(target: dict, path_parts: List[str], value: Any) -> dict:
"""Apply a 'test' operation."""
current_value = get_value(target, path_parts)
if current_value != value:
raise ValueError(f"Test failed: {'/'.join(path_parts)} does not match expected value")
return target
def get_value(target: dict, path_parts: List[str]) -> Any:
"""Get a value at a specific path."""
if not path_parts:
return target
current = target
for i, part in enumerate(path_parts):
if part.isdigit() and isinstance(current, list):
part = int(part)
if part not in current:
raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}")
current = current[part]
return current