4.1 prompting guide updates: update apply-patch code (#1772)

This commit is contained in:
Noah MacCallum 2025-04-14 12:44:54 -07:00 committed by GitHub
parent 6a47d53c96
commit 9f6939a239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 259 additions and 202 deletions

View File

@ -282,3 +282,8 @@ nm-openai:
name: "Noah MacCallum"
website: "https://x.com/noahmacca"
avatar: "https://avatars.githubusercontent.com/u/171723556"
julian-openai:
name: "Julian Lee"
website: "https://github.com/julian-openai"
avatar: "https://avatars.githubusercontent.com/u/199828632"

View File

@ -591,12 +591,30 @@
"outputs": [],
"source": [
"#!/usr/bin/env python3\n",
"import os\n",
"\n",
"\"\"\"\n",
"A self-contained **pure-Python 3.9+** utility for applying human-readable\n",
"“pseudo-diff” patch files to a collection of text files.\n",
"\"\"\"\n",
"\n",
"from __future__ import annotations\n",
"\n",
"import pathlib\n",
"from dataclasses import dataclass, field\n",
"from enum import Enum\n",
"from typing import Callable, Dict, List, NoReturn, Optional, Tuple, Union\n",
"from typing import (\n",
" Callable,\n",
" Dict,\n",
" List,\n",
" Optional,\n",
" Tuple,\n",
" Union,\n",
")\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# Domain objects\n",
"# --------------------------------------------------------------------------- #\n",
"class ActionType(str, Enum):\n",
" ADD = \"add\"\n",
" DELETE = \"delete\"\n",
@ -616,39 +634,19 @@
" changes: Dict[str, FileChange] = field(default_factory=dict)\n",
"\n",
"\n",
"def assemble_changes(\n",
" orig: Dict[str, Optional[str]],\n",
" updated_files: Dict[str, Optional[str]],\n",
") -> Commit:\n",
" commit = Commit()\n",
" for path, new_content in updated_files.items():\n",
" old_content = orig.get(path)\n",
" if old_content == new_content:\n",
" continue\n",
" if old_content is not None and new_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.UPDATE,\n",
" old_content=old_content,\n",
" new_content=new_content,\n",
" )\n",
" elif new_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.ADD,\n",
" new_content=new_content,\n",
" )\n",
" elif old_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.DELETE,\n",
" old_content=old_content,\n",
" )\n",
" else:\n",
" assert False\n",
" return commit\n",
"# --------------------------------------------------------------------------- #\n",
"# Exceptions\n",
"# --------------------------------------------------------------------------- #\n",
"class DiffError(ValueError):\n",
" \"\"\"Any problem detected while parsing or applying a patch.\"\"\"\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# Helper dataclasses used while parsing patches\n",
"# --------------------------------------------------------------------------- #\n",
"@dataclass\n",
"class Chunk:\n",
" orig_index: int = -1 # line index of the first line in the original file\n",
" orig_index: int = -1\n",
" del_lines: List[str] = field(default_factory=list)\n",
" ins_lines: List[str] = field(default_factory=list)\n",
"\n",
@ -666,79 +664,108 @@
" actions: Dict[str, PatchAction] = field(default_factory=dict)\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# Patch text parser\n",
"# --------------------------------------------------------------------------- #\n",
"@dataclass\n",
"class Parser:\n",
" current_files: Dict[str, str] = field(default_factory=dict)\n",
" lines: List[str] = field(default_factory=list)\n",
" current_files: Dict[str, str]\n",
" lines: List[str]\n",
" index: int = 0\n",
" patch: Patch = field(default_factory=Patch)\n",
" fuzz: int = 0\n",
"\n",
" # ------------- low-level helpers -------------------------------------- #\n",
" def _cur_line(self) -> str:\n",
" if self.index >= len(self.lines):\n",
" raise DiffError(\"Unexpected end of input while parsing patch\")\n",
" return self.lines[self.index]\n",
"\n",
" @staticmethod\n",
" def _norm(line: str) -> str:\n",
" \"\"\"Strip CR so comparisons work for both LF and CRLF input.\"\"\"\n",
" return line.rstrip(\"\\r\")\n",
"\n",
" # ------------- scanning convenience ----------------------------------- #\n",
" def is_done(self, prefixes: Optional[Tuple[str, ...]] = None) -> bool:\n",
" if self.index >= len(self.lines):\n",
" return True\n",
" if prefixes and self.lines[self.index].startswith(prefixes):\n",
" if (\n",
" prefixes\n",
" and len(prefixes) > 0\n",
" and self._norm(self._cur_line()).startswith(prefixes)\n",
" ):\n",
" return True\n",
" return False\n",
"\n",
" def startswith(self, prefix: Union[str, Tuple[str, ...]]) -> bool:\n",
" assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n",
" if self.lines[self.index].startswith(prefix):\n",
" return True\n",
" return False\n",
" return self._norm(self._cur_line()).startswith(prefix)\n",
"\n",
" def read_str(self, prefix: str = \"\", return_everything: bool = False) -> str:\n",
" assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n",
" if self.lines[self.index].startswith(prefix):\n",
" if return_everything:\n",
" text = self.lines[self.index]\n",
" else:\n",
" text = self.lines[self.index][len(prefix) :]\n",
" def read_str(self, prefix: str) -> str:\n",
" \"\"\"\n",
" Consume the current line if it starts with *prefix* and return the text\n",
" **after** the prefix. Raises if prefix is empty.\n",
" \"\"\"\n",
" if prefix == \"\":\n",
" raise ValueError(\"read_str() requires a non-empty prefix\")\n",
" if self._norm(self._cur_line()).startswith(prefix):\n",
" text = self._cur_line()[len(prefix) :]\n",
" self.index += 1\n",
" return text\n",
" return \"\"\n",
"\n",
" def parse(self) -> NoReturn:\n",
" def read_line(self) -> str:\n",
" \"\"\"Return the current raw line and advance.\"\"\"\n",
" line = self._cur_line()\n",
" self.index += 1\n",
" return line\n",
"\n",
" # ------------- public entry point -------------------------------------- #\n",
" def parse(self) -> None:\n",
" while not self.is_done((\"*** End Patch\",)):\n",
" # ---------- UPDATE ---------- #\n",
" path = self.read_str(\"*** Update File: \")\n",
" if path:\n",
" if path in self.patch.actions:\n",
" raise DiffError(f\"Update File Error: Duplicate Path: {path}\")\n",
" raise DiffError(f\"Duplicate update for file: {path}\")\n",
" move_to = self.read_str(\"*** Move to: \")\n",
" if path not in self.current_files:\n",
" raise DiffError(f\"Update File Error: Missing File: {path}\")\n",
" raise DiffError(f\"Update File Error - missing file: {path}\")\n",
" text = self.current_files[path]\n",
" action = self.parse_update_file(text)\n",
" action.move_path = move_to\n",
" action = self._parse_update_file(text)\n",
" action.move_path = move_to or None\n",
" self.patch.actions[path] = action\n",
" continue\n",
"\n",
" # ---------- DELETE ---------- #\n",
" path = self.read_str(\"*** Delete File: \")\n",
" if path:\n",
" if path in self.patch.actions:\n",
" raise DiffError(f\"Delete File Error: Duplicate Path: {path}\")\n",
" raise DiffError(f\"Duplicate delete for file: {path}\")\n",
" if path not in self.current_files:\n",
" raise DiffError(f\"Delete File Error: Missing File: {path}\")\n",
" self.patch.actions[path] = PatchAction(\n",
" type=ActionType.DELETE,\n",
" )\n",
" raise DiffError(f\"Delete File Error - missing file: {path}\")\n",
" self.patch.actions[path] = PatchAction(type=ActionType.DELETE)\n",
" continue\n",
"\n",
" # ---------- ADD ---------- #\n",
" path = self.read_str(\"*** Add File: \")\n",
" if path:\n",
" if path in self.patch.actions:\n",
" raise DiffError(f\"Add File Error: Duplicate Path: {path}\")\n",
" raise DiffError(f\"Duplicate add for file: {path}\")\n",
" if path in self.current_files:\n",
" raise DiffError(f\"Add File Error: File already exists: {path}\")\n",
" self.patch.actions[path] = self.parse_add_file()\n",
" raise DiffError(f\"Add File Error - file already exists: {path}\")\n",
" self.patch.actions[path] = self._parse_add_file()\n",
" continue\n",
" raise DiffError(f\"Unknown Line: {self.lines[self.index]}\")\n",
" if not self.startswith(\"*** End Patch\"):\n",
" raise DiffError(\"Missing End Patch\")\n",
" self.index += 1\n",
"\n",
" def parse_update_file(self, text: str) -> PatchAction:\n",
" action = PatchAction(\n",
" type=ActionType.UPDATE,\n",
" )\n",
" raise DiffError(f\"Unknown line while parsing: {self._cur_line()}\")\n",
"\n",
" if not self.startswith(\"*** End Patch\"):\n",
" raise DiffError(\"Missing *** End Patch sentinel\")\n",
" self.index += 1 # consume sentinel\n",
"\n",
" # ------------- section parsers ---------------------------------------- #\n",
" def _parse_update_file(self, text: str) -> PatchAction:\n",
" action = PatchAction(type=ActionType.UPDATE)\n",
" lines = text.split(\"\\n\")\n",
" index = 0\n",
" while not self.is_done(\n",
@ -752,100 +779,104 @@
" ):\n",
" def_str = self.read_str(\"@@ \")\n",
" section_str = \"\"\n",
" if not def_str:\n",
" if self.lines[self.index] == \"@@\":\n",
" section_str = self.lines[self.index]\n",
" self.index += 1\n",
" if not def_str and self._norm(self._cur_line()) == \"@@\":\n",
" section_str = self.read_line()\n",
"\n",
" if not (def_str or section_str or index == 0):\n",
" raise DiffError(f\"Invalid Line:\\n{self.lines[self.index]}\")\n",
" raise DiffError(f\"Invalid line in update section:\\n{self._cur_line()}\")\n",
"\n",
" if def_str.strip():\n",
" found = False\n",
" if not [s for s in lines[:index] if s == def_str]:\n",
" # def str is a skip ahead operator\n",
" if def_str not in lines[:index]:\n",
" for i, s in enumerate(lines[index:], index):\n",
" if s == def_str:\n",
" index = i + 1\n",
" found = True\n",
" break\n",
" if not found and not [s for s in lines[:index] if s.strip() == def_str.strip()]:\n",
" # def str is a skip ahead operator\n",
" if not found and def_str.strip() not in [\n",
" s.strip() for s in lines[:index]\n",
" ]:\n",
" for i, s in enumerate(lines[index:], index):\n",
" if s.strip() == def_str.strip():\n",
" index = i + 1\n",
" self.fuzz += 1\n",
" found = True\n",
" break\n",
" next_chunk_context, chunks, end_patch_index, eof = peek_next_section(\n",
" self.lines, self.index\n",
" )\n",
" next_chunk_text = \"\\n\".join(next_chunk_context)\n",
" new_index, fuzz = find_context(lines, next_chunk_context, index, eof)\n",
"\n",
" next_ctx, chunks, end_idx, eof = peek_next_section(self.lines, self.index)\n",
" new_index, fuzz = find_context(lines, next_ctx, index, eof)\n",
" if new_index == -1:\n",
" if eof:\n",
" raise DiffError(f\"Invalid EOF Context {index}:\\n{next_chunk_text}\")\n",
" else:\n",
" raise DiffError(f\"Invalid Context {index}:\\n{next_chunk_text}\")\n",
" ctx_txt = \"\\n\".join(next_ctx)\n",
" raise DiffError(\n",
" f\"Invalid {'EOF ' if eof else ''}context at {index}:\\n{ctx_txt}\"\n",
" )\n",
" self.fuzz += fuzz\n",
" for ch in chunks:\n",
" ch.orig_index += new_index\n",
" action.chunks.append(ch)\n",
" index = new_index + len(next_chunk_context)\n",
" self.index = end_patch_index\n",
" continue\n",
" index = new_index + len(next_ctx)\n",
" self.index = end_idx\n",
" return action\n",
"\n",
" def parse_add_file(self) -> PatchAction:\n",
" lines = []\n",
" def _parse_add_file(self) -> PatchAction:\n",
" lines: List[str] = []\n",
" while not self.is_done(\n",
" (\"*** End Patch\", \"*** Update File:\", \"*** Delete File:\", \"*** Add File:\")\n",
" ):\n",
" s = self.read_str()\n",
" s = self.read_line()\n",
" if not s.startswith(\"+\"):\n",
" raise DiffError(f\"Invalid Add File Line: {s}\")\n",
" s = s[1:]\n",
" lines.append(s)\n",
" return PatchAction(\n",
" type=ActionType.ADD,\n",
" new_file=\"\\n\".join(lines),\n",
" )\n",
" raise DiffError(f\"Invalid Add File line (missing '+'): {s}\")\n",
" lines.append(s[1:]) # strip leading '+'\n",
" return PatchAction(type=ActionType.ADD, new_file=\"\\n\".join(lines))\n",
"\n",
"\n",
"def find_context_core(lines: List[str], context: List[str], start: int) -> Tuple[int, int]:\n",
"# --------------------------------------------------------------------------- #\n",
"# Helper functions\n",
"# --------------------------------------------------------------------------- #\n",
"def find_context_core(\n",
" lines: List[str], context: List[str], start: int\n",
") -> Tuple[int, int]:\n",
" if not context:\n",
" return start, 0\n",
"\n",
" # Prefer identical\n",
" for i in range(start, len(lines)):\n",
" if lines[i : i + len(context)] == context:\n",
" return i, 0\n",
" # RStrip is ok\n",
" for i in range(start, len(lines)):\n",
" if [s.rstrip() for s in lines[i : i + len(context)]] == [s.rstrip() for s in context]:\n",
" if [s.rstrip() for s in lines[i : i + len(context)]] == [\n",
" s.rstrip() for s in context\n",
" ]:\n",
" return i, 1\n",
" # Fine, Strip is ok too.\n",
" for i in range(start, len(lines)):\n",
" if [s.strip() for s in lines[i : i + len(context)]] == [s.strip() for s in context]:\n",
" if [s.strip() for s in lines[i : i + len(context)]] == [\n",
" s.strip() for s in context\n",
" ]:\n",
" return i, 100\n",
" return -1, 0\n",
"\n",
"\n",
"def find_context(lines: List[str], context: List[str], start: int, eof: bool) -> Tuple[int, int]:\n",
"def find_context(\n",
" lines: List[str], context: List[str], start: int, eof: bool\n",
") -> Tuple[int, int]:\n",
" if eof:\n",
" new_index, fuzz = find_context_core(lines, context, len(lines) - len(context))\n",
" if new_index != -1:\n",
" return new_index, fuzz\n",
" new_index, fuzz = find_context_core(lines, context, start)\n",
" return new_index, fuzz + 10000\n",
" return new_index, fuzz + 10_000\n",
" return find_context_core(lines, context, start)\n",
"\n",
"\n",
"def peek_next_section(lines: List[str], index: int) -> Tuple[List[str], List[Chunk], int, bool]:\n",
"def peek_next_section(\n",
" lines: List[str], index: int\n",
") -> Tuple[List[str], List[Chunk], int, bool]:\n",
" old: List[str] = []\n",
" del_lines: List[str] = []\n",
" ins_lines: List[str] = []\n",
" chunks: List[Chunk] = []\n",
" mode = \"keep\"\n",
" orig_index = index\n",
"\n",
" while index < len(lines):\n",
" s = lines[index]\n",
" if s.startswith(\n",
@ -861,9 +892,10 @@
" break\n",
" if s == \"***\":\n",
" break\n",
" elif s.startswith(\"***\"):\n",
" if s.startswith(\"***\"):\n",
" raise DiffError(f\"Invalid Line: {s}\")\n",
" index += 1\n",
"\n",
" last_mode = mode\n",
" if s == \"\":\n",
" s = \" \"\n",
@ -876,6 +908,7 @@
" else:\n",
" raise DiffError(f\"Invalid Line: {s}\")\n",
" s = s[1:]\n",
"\n",
" if mode == \"keep\" and last_mode != mode:\n",
" if ins_lines or del_lines:\n",
" chunks.append(\n",
@ -885,8 +918,8 @@
" ins_lines=ins_lines,\n",
" )\n",
" )\n",
" del_lines = []\n",
" ins_lines = []\n",
" del_lines, ins_lines = [], []\n",
"\n",
" if mode == \"delete\":\n",
" del_lines.append(s)\n",
" old.append(s)\n",
@ -894,6 +927,7 @@
" ins_lines.append(s)\n",
" elif mode == \"keep\":\n",
" old.append(s)\n",
"\n",
" if ins_lines or del_lines:\n",
" chunks.append(\n",
" Chunk(\n",
@ -902,96 +936,61 @@
" ins_lines=ins_lines,\n",
" )\n",
" )\n",
" del_lines = []\n",
" ins_lines = []\n",
"\n",
" if index < len(lines) and lines[index] == \"*** End of File\":\n",
" index += 1\n",
" return old, chunks, index, True\n",
"\n",
" if index == orig_index:\n",
" raise DiffError(f\"Nothing in this section - {index=} {lines[index]}\")\n",
" raise DiffError(\"Nothing in this section\")\n",
" return old, chunks, index, False\n",
"\n",
"\n",
"def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n",
" lines = text.strip().split(\"\\n\")\n",
" if len(lines) < 2 or not lines[0].startswith(\"*** Begin Patch\") or lines[-1] != \"*** End Patch\":\n",
" raise DiffError(\"Invalid patch text\")\n",
"\n",
" parser = Parser(\n",
" current_files=orig,\n",
" lines=lines,\n",
" index=1,\n",
" )\n",
" parser.parse()\n",
" return parser.patch, parser.fuzz\n",
"\n",
"\n",
"def identify_files_needed(text: str) -> List[str]:\n",
" lines = text.strip().split(\"\\n\")\n",
" result = set()\n",
" for line in lines:\n",
" if line.startswith(\"*** Update File: \"):\n",
" result.add(line[len(\"*** Update File: \") :])\n",
" if line.startswith(\"*** Delete File: \"):\n",
" result.add(line[len(\"*** Delete File: \") :])\n",
" return list(result)\n",
"\n",
"\n",
"def identify_files_added(text: str) -> List[str]:\n",
" lines = text.strip().split(\"\\n\")\n",
" result = set()\n",
" for line in lines:\n",
" if line.startswith(\"*** Add File: \"):\n",
" result.add(line[len(\"*** Add File: \") :])\n",
" return list(result)\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# Patch → Commit and Commit application\n",
"# --------------------------------------------------------------------------- #\n",
"def _get_updated_file(text: str, action: PatchAction, path: str) -> str:\n",
" assert action.type == ActionType.UPDATE\n",
" if action.type is not ActionType.UPDATE:\n",
" raise DiffError(\"_get_updated_file called with non-update action\")\n",
" orig_lines = text.split(\"\\n\")\n",
" dest_lines = []\n",
" dest_lines: List[str] = []\n",
" orig_index = 0\n",
" dest_index = 0\n",
"\n",
" for chunk in action.chunks:\n",
" # Process the unchanged lines before the chunk\n",
" if chunk.orig_index > len(orig_lines):\n",
" raise DiffError(\n",
" f\"_get_updated_file: {path}: chunk.orig_index {chunk.orig_index} > len(lines) {len(orig_lines)}\"\n",
" f\"{path}: chunk.orig_index {chunk.orig_index} exceeds file length\"\n",
" )\n",
" if orig_index > chunk.orig_index:\n",
" raise DiffError(\n",
" f\"_get_updated_file: {path}: orig_index {orig_index} > chunk.orig_index {chunk.orig_index}\"\n",
" f\"{path}: overlapping chunks at {orig_index} > {chunk.orig_index}\"\n",
" )\n",
" assert orig_index <= chunk.orig_index\n",
"\n",
" dest_lines.extend(orig_lines[orig_index : chunk.orig_index])\n",
" delta = chunk.orig_index - orig_index\n",
" orig_index += delta\n",
" dest_index += delta\n",
" # Process the inserted lines\n",
" if chunk.ins_lines:\n",
" for i in range(len(chunk.ins_lines)):\n",
" dest_lines.append(chunk.ins_lines[i])\n",
" dest_index += len(chunk.ins_lines)\n",
" orig_index = chunk.orig_index\n",
"\n",
" dest_lines.extend(chunk.ins_lines)\n",
" orig_index += len(chunk.del_lines)\n",
" # Final part\n",
"\n",
" dest_lines.extend(orig_lines[orig_index:])\n",
" delta = len(orig_lines) - orig_index\n",
" orig_index += delta\n",
" dest_index += delta\n",
" assert orig_index == len(orig_lines)\n",
" assert dest_index == len(dest_lines)\n",
" return \"\\n\".join(dest_lines)\n",
"\n",
"\n",
"def patch_to_commit(patch: Patch, orig: Dict[str, str]) -> Commit:\n",
" commit = Commit()\n",
" for path, action in patch.actions.items():\n",
" if action.type == ActionType.DELETE:\n",
" commit.changes[path] = FileChange(type=ActionType.DELETE, old_content=orig[path])\n",
" elif action.type == ActionType.ADD:\n",
" commit.changes[path] = FileChange(type=ActionType.ADD, new_content=action.new_file)\n",
" elif action.type == ActionType.UPDATE:\n",
" new_content = _get_updated_file(text=orig[path], action=action, path=path)\n",
" if action.type is ActionType.DELETE:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.DELETE, old_content=orig[path]\n",
" )\n",
" elif action.type is ActionType.ADD:\n",
" if action.new_file is None:\n",
" raise DiffError(\"ADD action without file content\")\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.ADD, new_content=action.new_file\n",
" )\n",
" elif action.type is ActionType.UPDATE:\n",
" new_content = _get_updated_file(orig[path], action, path)\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.UPDATE,\n",
" old_content=orig[path],\n",
@ -1001,69 +1000,122 @@
" return commit\n",
"\n",
"\n",
"class DiffError(ValueError):\n",
" pass\n",
"# --------------------------------------------------------------------------- #\n",
"# User-facing helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n",
" lines = text.splitlines() # preserves blank lines, no strip()\n",
" if (\n",
" len(lines) < 2\n",
" or not Parser._norm(lines[0]).startswith(\"*** Begin Patch\")\n",
" or Parser._norm(lines[-1]) != \"*** End Patch\"\n",
" ):\n",
" raise DiffError(\"Invalid patch text - missing sentinels\")\n",
"\n",
" parser = Parser(current_files=orig, lines=lines, index=1)\n",
" parser.parse()\n",
" return parser.patch, parser.fuzz\n",
"\n",
"\n",
"def load_files(paths: List[str], open_fn: Callable) -> Dict[str, str]:\n",
" orig = {}\n",
" for path in paths:\n",
" orig[path] = open_fn(path)\n",
" return orig\n",
"def identify_files_needed(text: str) -> List[str]:\n",
" lines = text.splitlines()\n",
" return [\n",
" line[len(\"*** Update File: \") :]\n",
" for line in lines\n",
" if line.startswith(\"*** Update File: \")\n",
" ] + [\n",
" line[len(\"*** Delete File: \") :]\n",
" for line in lines\n",
" if line.startswith(\"*** Delete File: \")\n",
" ]\n",
"\n",
"\n",
"def apply_commit(commit: Commit, write_fn: Callable, remove_fn: Callable) -> None:\n",
"def identify_files_added(text: str) -> List[str]:\n",
" lines = text.splitlines()\n",
" return [\n",
" line[len(\"*** Add File: \") :]\n",
" for line in lines\n",
" if line.startswith(\"*** Add File: \")\n",
" ]\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# File-system helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def load_files(paths: List[str], open_fn: Callable[[str], str]) -> Dict[str, str]:\n",
" return {path: open_fn(path) for path in paths}\n",
"\n",
"\n",
"def apply_commit(\n",
" commit: Commit,\n",
" write_fn: Callable[[str, str], None],\n",
" remove_fn: Callable[[str], None],\n",
") -> None:\n",
" for path, change in commit.changes.items():\n",
" if change.type == ActionType.DELETE:\n",
" if change.type is ActionType.DELETE:\n",
" remove_fn(path)\n",
" elif change.type == ActionType.ADD:\n",
" elif change.type is ActionType.ADD:\n",
" if change.new_content is None:\n",
" raise DiffError(f\"ADD change for {path} has no content\")\n",
" write_fn(path, change.new_content)\n",
" elif change.type == ActionType.UPDATE:\n",
" elif change.type is ActionType.UPDATE:\n",
" if change.new_content is None:\n",
" raise DiffError(f\"UPDATE change for {path} has no new content\")\n",
" target = change.move_path or path\n",
" write_fn(target, change.new_content)\n",
" if change.move_path:\n",
" write_fn(change.move_path, change.new_content)\n",
" remove_fn(path)\n",
" else:\n",
" write_fn(path, change.new_content)\n",
"\n",
"\n",
"def process_patch(text: str, open_fn: Callable, write_fn: Callable, remove_fn: Callable) -> str:\n",
" assert text.startswith(\"*** Begin Patch\")\n",
"def process_patch(\n",
" text: str,\n",
" open_fn: Callable[[str], str],\n",
" write_fn: Callable[[str, str], None],\n",
" remove_fn: Callable[[str], None],\n",
") -> str:\n",
" if not text.startswith(\"*** Begin Patch\"):\n",
" raise DiffError(\"Patch text must start with *** Begin Patch\")\n",
" paths = identify_files_needed(text)\n",
" orig = load_files(paths, open_fn)\n",
" patch, fuzz = text_to_patch(text, orig)\n",
" patch, _fuzz = text_to_patch(text, orig)\n",
" commit = patch_to_commit(patch, orig)\n",
" apply_commit(commit, write_fn, remove_fn)\n",
" return \"Done!\"\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# Default FS helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def open_file(path: str) -> str:\n",
" with open(path, \"rt\") as f:\n",
" return f.read()\n",
" with open(path, \"rt\", encoding=\"utf-8\") as fh:\n",
" return fh.read()\n",
"\n",
"\n",
"def write_file(path: str, content: str) -> None:\n",
" if \"/\" in path:\n",
" parent = \"/\".join(path.split(\"/\")[:-1])\n",
" os.makedirs(parent, exist_ok=True)\n",
" with open(path, \"wt\") as f:\n",
" f.write(content)\n",
" target = pathlib.Path(path)\n",
" target.parent.mkdir(parents=True, exist_ok=True)\n",
" with target.open(\"wt\", encoding=\"utf-8\") as fh:\n",
" fh.write(content)\n",
"\n",
"\n",
"def remove_file(path: str) -> None:\n",
" os.remove(path)\n",
" pathlib.Path(path).unlink(missing_ok=True)\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# CLI entry-point\n",
"# --------------------------------------------------------------------------- #\n",
"def main() -> None:\n",
" import sys\n",
"\n",
" patch_text = sys.stdin.read()\n",
" if not patch_text:\n",
" print(\"Please pass patch text through stdin\")\n",
" print(\"Please pass patch text through stdin\", file=sys.stderr)\n",
" return\n",
" try:\n",
" result = process_patch(patch_text, open_file, write_file, remove_file)\n",
" except DiffError as e:\n",
" print(str(e))\n",
" except DiffError as exc:\n",
" print(exc, file=sys.stderr)\n",
" return\n",
" print(result)\n",
"\n",