diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/patching.py b/utils/patching.py new file mode 100644 index 0000000..e019540 --- /dev/null +++ b/utils/patching.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import json +from enum import Enum +from typing import Any, Dict, List, Optional, Type, Union + +from pydantic import BaseModel, Field + + +class PatchOp(str, Enum): + """JSON Patch operation types as defined in RFC 6902.""" + + ADD = "add" + REMOVE = "remove" + REPLACE = "replace" + MOVE = "move" + COPY = "copy" + TEST = "test" + + +def gather_paths(schema: Dict[str, Any], base_pointer: str = "") -> Dict[str, Dict[str, Any]]: + """ + Walk a JSON Schema and map every reachable JSON Pointer -> subschema. + For objects: recurse into 'properties'. + For arrays: treat the items schema as the path '/*'. + """ + paths = {} + + # Add current path + if base_pointer: + paths[base_pointer] = schema + + # Handle base schema type + if "type" not in schema and base_pointer: + # Ensure there's always a type + schema["type"] = "object" + + # Handle object types + if schema.get("type") == "object": + # Add properties + for prop, prop_schema in schema.get("properties", {}).items(): + child_ptr = f"{base_pointer}/{prop}" + paths[child_ptr] = prop_schema + # Recursively process nested objects + if prop_schema.get("type") == "object": + paths.update(gather_paths(prop_schema, child_ptr)) + # Handle array items + elif prop_schema.get("type") == "array" and "items" in prop_schema: + array_path = f"{child_ptr}/-" # '-' represents the end of an array + paths[array_path] = prop_schema["items"] + # Process array item indexing + paths.update(gather_paths(prop_schema["items"], f"{child_ptr}/0")) + + # Handle additionalProperties if present + if schema.get("additionalProperties"): + add_props = schema["additionalProperties"] + if isinstance(add_props, dict): + wild_path = f"{base_pointer}/*" + paths[wild_path] = add_props + + # Handle array types + elif schema.get("type") == "array" and "items" in schema: + array_path = f"{base_pointer}/-" + paths[array_path] = schema["items"] + # Process array item indexing + paths.update(gather_paths(schema["items"], f"{base_pointer}/0")) + + return paths + + +def op_schema(op: str, pointer: str, value_schema: dict | None) -> dict: + """ + Build a schema fragment for one operation / one pointer. + """ + # Ensure value_schema always has a type + if value_schema and "type" not in value_schema: + value_schema = {"type": "object", **value_schema} + + # Base schema for the operation + base = { + "type": "object", + "required": ["op", "path"], + "properties": {"op": {"type": "string", "enum": [op]}, "path": {"type": "string", "const": pointer}}, + "additionalProperties": False, + } + + # Add value property for operations that require it + if op in {"add", "replace", "test"}: + base["required"].append("value") + if value_schema: + base["properties"]["value"] = value_schema + else: + base["properties"]["value"] = {"type": "null"} + + # Add from property for operations that require it + if op in {"move", "copy"}: + base["required"].append("from") + base["properties"]["from"] = {"type": "string"} + + return base + + +def build_patch_schema(model_schema: dict) -> dict: + """ + Build a JSON Schema that validates JSON Patch operations for the given model schema. + """ + # Gather all paths in the schema + paths = gather_paths(model_schema) + + # Generate operation schemas for each path + variants = [] + for ptr, subschema in paths.items(): + # Handle empty or invalid schemas + if not subschema or not isinstance(subschema, dict): + subschema = {"type": "object"} + + # Ensure the subschema has a type + if "type" not in subschema: + # Try to infer type from other properties + if "properties" in subschema: + subschema["type"] = "object" + elif "items" in subschema: + subschema["type"] = "array" + else: + # Default to string if can't determine + subschema["type"] = "string" + + # Generate variants for different operations + variants.append(op_schema("add", ptr, subschema)) + variants.append(op_schema("replace", ptr, subschema)) + variants.append(op_schema("remove", ptr, None)) + variants.append(op_schema("test", ptr, subschema)) + + # For move and copy operations, we'd need to do cross-path validation + # For simplicity, we'll skip detailed validation of those here + + # Final patch schema - wrap the array in an object to work with OpenAI's API + patch_schema = { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": {"operations": {"type": "array", "items": {"anyOf": variants}}}, + "required": ["operations"], + } + + return patch_schema + + +def MakePatchSchema(model_class: Type[BaseModel]) -> Type[BaseModel]: + """ + Create a Pydantic model that represents valid JSON Patch operations for the given model. + + Args: + model_class: The Pydantic model to create a patch schema for + + Returns: + A new Pydantic model class that validates JSON Patch operations + """ + # Get the JSON schema for the model + model_schema = model_class.model_json_schema() + + # Build the patch schema + patch_schema = build_patch_schema(model_schema) + + # Create a Pydantic model for a single patch operation + class PatchOperation(BaseModel): + op: PatchOp + path: str + value: Optional[Any] = None + from_: Optional[str] = Field(None, alias="from") + + model_config = {"populate_by_name": True, "extra": "forbid"} + + # Create a model for the patch (array of operations wrapped in an object) + class JsonPatchDocument(BaseModel): + operations: List[PatchOperation] = Field(default_factory=list) + + model_config = { + "populate_by_name": True, + "extra": "forbid", + "json_schema_extra": lambda schema: schema.update(patch_schema), + } + + def __iter__(self): + return iter(self.operations) + + def __getitem__(self, item): + return self.operations[item] + + @classmethod + def parse_obj(cls, obj): + """Parse a JSON object into a patch document.""" + if isinstance(obj, (str, bytes)): + obj = json.loads(obj) + + # Expect an object with an "operations" key + if isinstance(obj, dict) and "operations" in obj: + return cls(operations=obj["operations"]) + else: + raise ValueError( + "Input must be an object with an 'operations' field containing an array of patch operations" + ) + + def model_json_schema(cls, *args, **kwargs): + """Return the JSON Schema as a dict.""" + return patch_schema + + def apply(self, target: Union[str, dict, BaseModel]) -> dict: + """ + Apply the patch to a target, which can be a JSON string, dict, or Pydantic model. + + Args: + target: The target to apply the patch to + + Returns: + The patched data as a dict + """ + # Convert the target to a dict + if isinstance(target, str): + target_dict = json.loads(target) + elif isinstance(target, BaseModel): + target_dict = target.model_dump(exclude_unset=True) + else: + target_dict = dict(target) + + # Apply operations in sequence + for op in self.operations: + path_parts = op.path.strip("/").split("/") if op.path != "/" else [] + + if op.op == PatchOp.REPLACE: + target_dict = apply_replace(target_dict, path_parts, op.value) + elif op.op == PatchOp.ADD: + target_dict = apply_add(target_dict, path_parts, op.value) + elif op.op == PatchOp.REMOVE: + target_dict = apply_remove(target_dict, path_parts) + elif op.op == PatchOp.MOVE: + from_parts = op.from_.strip("/").split("/") if op.from_ else [] + target_dict = apply_move(target_dict, from_parts, path_parts) + elif op.op == PatchOp.COPY: + from_parts = op.from_.strip("/").split("/") if op.from_ else [] + target_dict = apply_copy(target_dict, from_parts, path_parts) + elif op.op == PatchOp.TEST: + target_dict = apply_test(target_dict, path_parts, op.value) + + return target_dict + + # Update the model name to reflect the original model + patch_model_name = f"{model_class.__name__}PatchSchema" + JsonPatchDocument.__name__ = patch_model_name + + # Override the model_json_schema method to return our custom schema + JsonPatchDocument.model_json_schema = classmethod(lambda cls, *args, **kwargs: patch_schema) + + return JsonPatchDocument + + +def apply_replace(target: dict, path_parts: List[str], value: Any) -> dict: + """Apply a 'replace' operation.""" + target = target.copy() # Create a copy to avoid modifying the original + + if not path_parts: + # Replace entire document + return value + + current = target + for i, part in enumerate(path_parts[:-1]): + if part.isdigit() and isinstance(current, list): + part = int(part) + if part not in current and i < len(path_parts) - 1: + raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}") + current = current[part] + + last_part = path_parts[-1] + if last_part.isdigit() and isinstance(current, list): + last_part = int(last_part) + if last_part not in current and not (isinstance(current, list) and last_part == "-"): + raise ValueError(f"Path not found: {'/'.join(path_parts)}") + + current[last_part] = value + return target + + +def apply_add(target: dict, path_parts: List[str], value: Any) -> dict: + """Apply an 'add' operation.""" + target = target.copy() + + if not path_parts: + # Replace entire document + return value + + current = target + for i, part in enumerate(path_parts[:-1]): + if part.isdigit() and isinstance(current, list): + part = int(part) + if part not in current and i < len(path_parts) - 1: + if path_parts[i + 1].isdigit() or path_parts[i + 1] == "-": + current[part] = [] + else: + current[part] = {} + current = current[part] + + last_part = path_parts[-1] + if last_part == "-" and isinstance(current, list): + current.append(value) + elif last_part.isdigit() and isinstance(current, list): + idx = int(last_part) + if idx > len(current): + raise ValueError(f"Index out of bounds: {idx}") + current.insert(idx, value) + else: + current[last_part] = value + + return target + + +def apply_remove(target: dict, path_parts: List[str]) -> dict: + """Apply a 'remove' operation.""" + target = target.copy() + + if not path_parts: + raise ValueError("Cannot remove root document") + + current = target + for i, part in enumerate(path_parts[:-1]): + if part.isdigit() and isinstance(current, list): + part = int(part) + if part not in current: + raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}") + current = current[part] + + last_part = path_parts[-1] + if last_part.isdigit() and isinstance(current, list): + idx = int(last_part) + if idx >= len(current): + raise ValueError(f"Index out of bounds: {idx}") + del current[idx] + elif last_part not in current: + raise ValueError(f"Path not found: {'/'.join(path_parts)}") + else: + del current[last_part] + + return target + + +def apply_move(target: dict, from_parts: List[str], path_parts: List[str]) -> dict: + """Apply a 'move' operation.""" + target = target.copy() + + # Get the value at the 'from' location + value = get_value(target, from_parts) + + # Remove the value from the 'from' location + target = apply_remove(target, from_parts) + + # Add the value at the 'path' location + return apply_add(target, path_parts, value) + + +def apply_copy(target: dict, from_parts: List[str], path_parts: List[str]) -> dict: + """Apply a 'copy' operation.""" + target = target.copy() + + # Get the value at the 'from' location + value = get_value(target, from_parts) + + # Add the value at the 'path' location + return apply_add(target, path_parts, value) + + +def apply_test(target: dict, path_parts: List[str], value: Any) -> dict: + """Apply a 'test' operation.""" + current_value = get_value(target, path_parts) + + if current_value != value: + raise ValueError(f"Test failed: {'/'.join(path_parts)} does not match expected value") + + return target + + +def get_value(target: dict, path_parts: List[str]) -> Any: + """Get a value at a specific path.""" + if not path_parts: + return target + + current = target + for i, part in enumerate(path_parts): + if part.isdigit() and isinstance(current, list): + part = int(part) + if part not in current: + raise ValueError(f"Path not found: {'/'.join(path_parts[:i+1])}") + current = current[part] + + return current diff --git a/utils/patching_test.py b/utils/patching_test.py new file mode 100644 index 0000000..8213c66 --- /dev/null +++ b/utils/patching_test.py @@ -0,0 +1,45 @@ +import json +import os +from typing import List as TypeList +from typing import Optional + +import openai +from dotenv import load_dotenv +from patching import MakePatchSchema +from pydantic import BaseModel, Field + +# Example usage +if __name__ == "__main__": + load_dotenv(override=True) + + class Person(BaseModel): + name: str + age: Optional[int] = None + tags: TypeList[str] = Field(default_factory=list) + + # Create an instance of the model + person = Person(name="Arnav", age=30, tags=["developer", "python"]) + + # Create a patch schema for the Person model + PersonPatchSchema = MakePatchSchema(Person) + client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + + # Print the JSON schema for debugging + print(json.dumps(PersonPatchSchema.model_json_schema(), indent=2)) + + response = client.beta.chat.completions.parse( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are a helpful assistant that generates JSON Patch operations."}, + { + "role": "user", + "content": f"""Generate a JSON Patch operation to change the name of the person to 'Ada' + and remove the age field. The original data is: {person.model_dump()}""", + }, + ], + response_format=PersonPatchSchema, + ) + response_json = response.choices[0].message.parsed + print(response_json) + patched_person = response_json.apply(person) + print(patched_person)