4.1 prompting guide updates: update apply-patch code (#1772)

This commit is contained in:
Noah MacCallum 2025-04-14 12:44:54 -07:00 committed by GitHub
parent 6a47d53c96
commit 9f6939a239
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 259 additions and 202 deletions

View File

@ -281,4 +281,9 @@ vishnu-oai:
nm-openai: nm-openai:
name: "Noah MacCallum" name: "Noah MacCallum"
website: "https://x.com/noahmacca" website: "https://x.com/noahmacca"
avatar: "https://avatars.githubusercontent.com/u/171723556" avatar: "https://avatars.githubusercontent.com/u/171723556"
julian-openai:
name: "Julian Lee"
website: "https://github.com/julian-openai"
avatar: "https://avatars.githubusercontent.com/u/199828632"

View File

@ -591,12 +591,30 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#!/usr/bin/env python3\n", "#!/usr/bin/env python3\n",
"import os\n", "\n",
"\"\"\"\n",
"A self-contained **pure-Python 3.9+** utility for applying human-readable\n",
"“pseudo-diff” patch files to a collection of text files.\n",
"\"\"\"\n",
"\n",
"from __future__ import annotations\n",
"\n",
"import pathlib\n",
"from dataclasses import dataclass, field\n", "from dataclasses import dataclass, field\n",
"from enum import Enum\n", "from enum import Enum\n",
"from typing import Callable, Dict, List, NoReturn, Optional, Tuple, Union\n", "from typing import (\n",
" Callable,\n",
" Dict,\n",
" List,\n",
" Optional,\n",
" Tuple,\n",
" Union,\n",
")\n",
"\n", "\n",
"\n", "\n",
"# --------------------------------------------------------------------------- #\n",
"# Domain objects\n",
"# --------------------------------------------------------------------------- #\n",
"class ActionType(str, Enum):\n", "class ActionType(str, Enum):\n",
" ADD = \"add\"\n", " ADD = \"add\"\n",
" DELETE = \"delete\"\n", " DELETE = \"delete\"\n",
@ -616,39 +634,19 @@
" changes: Dict[str, FileChange] = field(default_factory=dict)\n", " changes: Dict[str, FileChange] = field(default_factory=dict)\n",
"\n", "\n",
"\n", "\n",
"def assemble_changes(\n", "# --------------------------------------------------------------------------- #\n",
" orig: Dict[str, Optional[str]],\n", "# Exceptions\n",
" updated_files: Dict[str, Optional[str]],\n", "# --------------------------------------------------------------------------- #\n",
") -> Commit:\n", "class DiffError(ValueError):\n",
" commit = Commit()\n", " \"\"\"Any problem detected while parsing or applying a patch.\"\"\"\n",
" for path, new_content in updated_files.items():\n",
" old_content = orig.get(path)\n",
" if old_content == new_content:\n",
" continue\n",
" if old_content is not None and new_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.UPDATE,\n",
" old_content=old_content,\n",
" new_content=new_content,\n",
" )\n",
" elif new_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.ADD,\n",
" new_content=new_content,\n",
" )\n",
" elif old_content is not None:\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.DELETE,\n",
" old_content=old_content,\n",
" )\n",
" else:\n",
" assert False\n",
" return commit\n",
"\n", "\n",
"\n", "\n",
"# --------------------------------------------------------------------------- #\n",
"# Helper dataclasses used while parsing patches\n",
"# --------------------------------------------------------------------------- #\n",
"@dataclass\n", "@dataclass\n",
"class Chunk:\n", "class Chunk:\n",
" orig_index: int = -1 # line index of the first line in the original file\n", " orig_index: int = -1\n",
" del_lines: List[str] = field(default_factory=list)\n", " del_lines: List[str] = field(default_factory=list)\n",
" ins_lines: List[str] = field(default_factory=list)\n", " ins_lines: List[str] = field(default_factory=list)\n",
"\n", "\n",
@ -666,79 +664,108 @@
" actions: Dict[str, PatchAction] = field(default_factory=dict)\n", " actions: Dict[str, PatchAction] = field(default_factory=dict)\n",
"\n", "\n",
"\n", "\n",
"# --------------------------------------------------------------------------- #\n",
"# Patch text parser\n",
"# --------------------------------------------------------------------------- #\n",
"@dataclass\n", "@dataclass\n",
"class Parser:\n", "class Parser:\n",
" current_files: Dict[str, str] = field(default_factory=dict)\n", " current_files: Dict[str, str]\n",
" lines: List[str] = field(default_factory=list)\n", " lines: List[str]\n",
" index: int = 0\n", " index: int = 0\n",
" patch: Patch = field(default_factory=Patch)\n", " patch: Patch = field(default_factory=Patch)\n",
" fuzz: int = 0\n", " fuzz: int = 0\n",
"\n", "\n",
" # ------------- low-level helpers -------------------------------------- #\n",
" def _cur_line(self) -> str:\n",
" if self.index >= len(self.lines):\n",
" raise DiffError(\"Unexpected end of input while parsing patch\")\n",
" return self.lines[self.index]\n",
"\n",
" @staticmethod\n",
" def _norm(line: str) -> str:\n",
" \"\"\"Strip CR so comparisons work for both LF and CRLF input.\"\"\"\n",
" return line.rstrip(\"\\r\")\n",
"\n",
" # ------------- scanning convenience ----------------------------------- #\n",
" def is_done(self, prefixes: Optional[Tuple[str, ...]] = None) -> bool:\n", " def is_done(self, prefixes: Optional[Tuple[str, ...]] = None) -> bool:\n",
" if self.index >= len(self.lines):\n", " if self.index >= len(self.lines):\n",
" return True\n", " return True\n",
" if prefixes and self.lines[self.index].startswith(prefixes):\n", " if (\n",
" prefixes\n",
" and len(prefixes) > 0\n",
" and self._norm(self._cur_line()).startswith(prefixes)\n",
" ):\n",
" return True\n", " return True\n",
" return False\n", " return False\n",
"\n", "\n",
" def startswith(self, prefix: Union[str, Tuple[str, ...]]) -> bool:\n", " def startswith(self, prefix: Union[str, Tuple[str, ...]]) -> bool:\n",
" assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n", " return self._norm(self._cur_line()).startswith(prefix)\n",
" if self.lines[self.index].startswith(prefix):\n",
" return True\n",
" return False\n",
"\n", "\n",
" def read_str(self, prefix: str = \"\", return_everything: bool = False) -> str:\n", " def read_str(self, prefix: str) -> str:\n",
" assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n", " \"\"\"\n",
" if self.lines[self.index].startswith(prefix):\n", " Consume the current line if it starts with *prefix* and return the text\n",
" if return_everything:\n", " **after** the prefix. Raises if prefix is empty.\n",
" text = self.lines[self.index]\n", " \"\"\"\n",
" else:\n", " if prefix == \"\":\n",
" text = self.lines[self.index][len(prefix) :]\n", " raise ValueError(\"read_str() requires a non-empty prefix\")\n",
" if self._norm(self._cur_line()).startswith(prefix):\n",
" text = self._cur_line()[len(prefix) :]\n",
" self.index += 1\n", " self.index += 1\n",
" return text\n", " return text\n",
" return \"\"\n", " return \"\"\n",
"\n", "\n",
" def parse(self) -> NoReturn:\n", " def read_line(self) -> str:\n",
" \"\"\"Return the current raw line and advance.\"\"\"\n",
" line = self._cur_line()\n",
" self.index += 1\n",
" return line\n",
"\n",
" # ------------- public entry point -------------------------------------- #\n",
" def parse(self) -> None:\n",
" while not self.is_done((\"*** End Patch\",)):\n", " while not self.is_done((\"*** End Patch\",)):\n",
" # ---------- UPDATE ---------- #\n",
" path = self.read_str(\"*** Update File: \")\n", " path = self.read_str(\"*** Update File: \")\n",
" if path:\n", " if path:\n",
" if path in self.patch.actions:\n", " if path in self.patch.actions:\n",
" raise DiffError(f\"Update File Error: Duplicate Path: {path}\")\n", " raise DiffError(f\"Duplicate update for file: {path}\")\n",
" move_to = self.read_str(\"*** Move to: \")\n", " move_to = self.read_str(\"*** Move to: \")\n",
" if path not in self.current_files:\n", " if path not in self.current_files:\n",
" raise DiffError(f\"Update File Error: Missing File: {path}\")\n", " raise DiffError(f\"Update File Error - missing file: {path}\")\n",
" text = self.current_files[path]\n", " text = self.current_files[path]\n",
" action = self.parse_update_file(text)\n", " action = self._parse_update_file(text)\n",
" action.move_path = move_to\n", " action.move_path = move_to or None\n",
" self.patch.actions[path] = action\n", " self.patch.actions[path] = action\n",
" continue\n", " continue\n",
"\n",
" # ---------- DELETE ---------- #\n",
" path = self.read_str(\"*** Delete File: \")\n", " path = self.read_str(\"*** Delete File: \")\n",
" if path:\n", " if path:\n",
" if path in self.patch.actions:\n", " if path in self.patch.actions:\n",
" raise DiffError(f\"Delete File Error: Duplicate Path: {path}\")\n", " raise DiffError(f\"Duplicate delete for file: {path}\")\n",
" if path not in self.current_files:\n", " if path not in self.current_files:\n",
" raise DiffError(f\"Delete File Error: Missing File: {path}\")\n", " raise DiffError(f\"Delete File Error - missing file: {path}\")\n",
" self.patch.actions[path] = PatchAction(\n", " self.patch.actions[path] = PatchAction(type=ActionType.DELETE)\n",
" type=ActionType.DELETE,\n",
" )\n",
" continue\n", " continue\n",
"\n",
" # ---------- ADD ---------- #\n",
" path = self.read_str(\"*** Add File: \")\n", " path = self.read_str(\"*** Add File: \")\n",
" if path:\n", " if path:\n",
" if path in self.patch.actions:\n", " if path in self.patch.actions:\n",
" raise DiffError(f\"Add File Error: Duplicate Path: {path}\")\n", " raise DiffError(f\"Duplicate add for file: {path}\")\n",
" if path in self.current_files:\n", " if path in self.current_files:\n",
" raise DiffError(f\"Add File Error: File already exists: {path}\")\n", " raise DiffError(f\"Add File Error - file already exists: {path}\")\n",
" self.patch.actions[path] = self.parse_add_file()\n", " self.patch.actions[path] = self._parse_add_file()\n",
" continue\n", " continue\n",
" raise DiffError(f\"Unknown Line: {self.lines[self.index]}\")\n",
" if not self.startswith(\"*** End Patch\"):\n",
" raise DiffError(\"Missing End Patch\")\n",
" self.index += 1\n",
"\n", "\n",
" def parse_update_file(self, text: str) -> PatchAction:\n", " raise DiffError(f\"Unknown line while parsing: {self._cur_line()}\")\n",
" action = PatchAction(\n", "\n",
" type=ActionType.UPDATE,\n", " if not self.startswith(\"*** End Patch\"):\n",
" )\n", " raise DiffError(\"Missing *** End Patch sentinel\")\n",
" self.index += 1 # consume sentinel\n",
"\n",
" # ------------- section parsers ---------------------------------------- #\n",
" def _parse_update_file(self, text: str) -> PatchAction:\n",
" action = PatchAction(type=ActionType.UPDATE)\n",
" lines = text.split(\"\\n\")\n", " lines = text.split(\"\\n\")\n",
" index = 0\n", " index = 0\n",
" while not self.is_done(\n", " while not self.is_done(\n",
@ -752,100 +779,104 @@
" ):\n", " ):\n",
" def_str = self.read_str(\"@@ \")\n", " def_str = self.read_str(\"@@ \")\n",
" section_str = \"\"\n", " section_str = \"\"\n",
" if not def_str:\n", " if not def_str and self._norm(self._cur_line()) == \"@@\":\n",
" if self.lines[self.index] == \"@@\":\n", " section_str = self.read_line()\n",
" section_str = self.lines[self.index]\n", "\n",
" self.index += 1\n",
" if not (def_str or section_str or index == 0):\n", " if not (def_str or section_str or index == 0):\n",
" raise DiffError(f\"Invalid Line:\\n{self.lines[self.index]}\")\n", " raise DiffError(f\"Invalid line in update section:\\n{self._cur_line()}\")\n",
"\n",
" if def_str.strip():\n", " if def_str.strip():\n",
" found = False\n", " found = False\n",
" if not [s for s in lines[:index] if s == def_str]:\n", " if def_str not in lines[:index]:\n",
" # def str is a skip ahead operator\n",
" for i, s in enumerate(lines[index:], index):\n", " for i, s in enumerate(lines[index:], index):\n",
" if s == def_str:\n", " if s == def_str:\n",
" index = i + 1\n", " index = i + 1\n",
" found = True\n", " found = True\n",
" break\n", " break\n",
" if not found and not [s for s in lines[:index] if s.strip() == def_str.strip()]:\n", " if not found and def_str.strip() not in [\n",
" # def str is a skip ahead operator\n", " s.strip() for s in lines[:index]\n",
" ]:\n",
" for i, s in enumerate(lines[index:], index):\n", " for i, s in enumerate(lines[index:], index):\n",
" if s.strip() == def_str.strip():\n", " if s.strip() == def_str.strip():\n",
" index = i + 1\n", " index = i + 1\n",
" self.fuzz += 1\n", " self.fuzz += 1\n",
" found = True\n", " found = True\n",
" break\n", " break\n",
" next_chunk_context, chunks, end_patch_index, eof = peek_next_section(\n", "\n",
" self.lines, self.index\n", " next_ctx, chunks, end_idx, eof = peek_next_section(self.lines, self.index)\n",
" )\n", " new_index, fuzz = find_context(lines, next_ctx, index, eof)\n",
" next_chunk_text = \"\\n\".join(next_chunk_context)\n",
" new_index, fuzz = find_context(lines, next_chunk_context, index, eof)\n",
" if new_index == -1:\n", " if new_index == -1:\n",
" if eof:\n", " ctx_txt = \"\\n\".join(next_ctx)\n",
" raise DiffError(f\"Invalid EOF Context {index}:\\n{next_chunk_text}\")\n", " raise DiffError(\n",
" else:\n", " f\"Invalid {'EOF ' if eof else ''}context at {index}:\\n{ctx_txt}\"\n",
" raise DiffError(f\"Invalid Context {index}:\\n{next_chunk_text}\")\n", " )\n",
" self.fuzz += fuzz\n", " self.fuzz += fuzz\n",
" for ch in chunks:\n", " for ch in chunks:\n",
" ch.orig_index += new_index\n", " ch.orig_index += new_index\n",
" action.chunks.append(ch)\n", " action.chunks.append(ch)\n",
" index = new_index + len(next_chunk_context)\n", " index = new_index + len(next_ctx)\n",
" self.index = end_patch_index\n", " self.index = end_idx\n",
" continue\n",
" return action\n", " return action\n",
"\n", "\n",
" def parse_add_file(self) -> PatchAction:\n", " def _parse_add_file(self) -> PatchAction:\n",
" lines = []\n", " lines: List[str] = []\n",
" while not self.is_done(\n", " while not self.is_done(\n",
" (\"*** End Patch\", \"*** Update File:\", \"*** Delete File:\", \"*** Add File:\")\n", " (\"*** End Patch\", \"*** Update File:\", \"*** Delete File:\", \"*** Add File:\")\n",
" ):\n", " ):\n",
" s = self.read_str()\n", " s = self.read_line()\n",
" if not s.startswith(\"+\"):\n", " if not s.startswith(\"+\"):\n",
" raise DiffError(f\"Invalid Add File Line: {s}\")\n", " raise DiffError(f\"Invalid Add File line (missing '+'): {s}\")\n",
" s = s[1:]\n", " lines.append(s[1:]) # strip leading '+'\n",
" lines.append(s)\n", " return PatchAction(type=ActionType.ADD, new_file=\"\\n\".join(lines))\n",
" return PatchAction(\n",
" type=ActionType.ADD,\n",
" new_file=\"\\n\".join(lines),\n",
" )\n",
"\n", "\n",
"\n", "\n",
"def find_context_core(lines: List[str], context: List[str], start: int) -> Tuple[int, int]:\n", "# --------------------------------------------------------------------------- #\n",
"# Helper functions\n",
"# --------------------------------------------------------------------------- #\n",
"def find_context_core(\n",
" lines: List[str], context: List[str], start: int\n",
") -> Tuple[int, int]:\n",
" if not context:\n", " if not context:\n",
" return start, 0\n", " return start, 0\n",
"\n", "\n",
" # Prefer identical\n",
" for i in range(start, len(lines)):\n", " for i in range(start, len(lines)):\n",
" if lines[i : i + len(context)] == context:\n", " if lines[i : i + len(context)] == context:\n",
" return i, 0\n", " return i, 0\n",
" # RStrip is ok\n",
" for i in range(start, len(lines)):\n", " for i in range(start, len(lines)):\n",
" if [s.rstrip() for s in lines[i : i + len(context)]] == [s.rstrip() for s in context]:\n", " if [s.rstrip() for s in lines[i : i + len(context)]] == [\n",
" s.rstrip() for s in context\n",
" ]:\n",
" return i, 1\n", " return i, 1\n",
" # Fine, Strip is ok too.\n",
" for i in range(start, len(lines)):\n", " for i in range(start, len(lines)):\n",
" if [s.strip() for s in lines[i : i + len(context)]] == [s.strip() for s in context]:\n", " if [s.strip() for s in lines[i : i + len(context)]] == [\n",
" s.strip() for s in context\n",
" ]:\n",
" return i, 100\n", " return i, 100\n",
" return -1, 0\n", " return -1, 0\n",
"\n", "\n",
"\n", "\n",
"def find_context(lines: List[str], context: List[str], start: int, eof: bool) -> Tuple[int, int]:\n", "def find_context(\n",
" lines: List[str], context: List[str], start: int, eof: bool\n",
") -> Tuple[int, int]:\n",
" if eof:\n", " if eof:\n",
" new_index, fuzz = find_context_core(lines, context, len(lines) - len(context))\n", " new_index, fuzz = find_context_core(lines, context, len(lines) - len(context))\n",
" if new_index != -1:\n", " if new_index != -1:\n",
" return new_index, fuzz\n", " return new_index, fuzz\n",
" new_index, fuzz = find_context_core(lines, context, start)\n", " new_index, fuzz = find_context_core(lines, context, start)\n",
" return new_index, fuzz + 10000\n", " return new_index, fuzz + 10_000\n",
" return find_context_core(lines, context, start)\n", " return find_context_core(lines, context, start)\n",
"\n", "\n",
"\n", "\n",
"def peek_next_section(lines: List[str], index: int) -> Tuple[List[str], List[Chunk], int, bool]:\n", "def peek_next_section(\n",
" lines: List[str], index: int\n",
") -> Tuple[List[str], List[Chunk], int, bool]:\n",
" old: List[str] = []\n", " old: List[str] = []\n",
" del_lines: List[str] = []\n", " del_lines: List[str] = []\n",
" ins_lines: List[str] = []\n", " ins_lines: List[str] = []\n",
" chunks: List[Chunk] = []\n", " chunks: List[Chunk] = []\n",
" mode = \"keep\"\n", " mode = \"keep\"\n",
" orig_index = index\n", " orig_index = index\n",
"\n",
" while index < len(lines):\n", " while index < len(lines):\n",
" s = lines[index]\n", " s = lines[index]\n",
" if s.startswith(\n", " if s.startswith(\n",
@ -861,9 +892,10 @@
" break\n", " break\n",
" if s == \"***\":\n", " if s == \"***\":\n",
" break\n", " break\n",
" elif s.startswith(\"***\"):\n", " if s.startswith(\"***\"):\n",
" raise DiffError(f\"Invalid Line: {s}\")\n", " raise DiffError(f\"Invalid Line: {s}\")\n",
" index += 1\n", " index += 1\n",
"\n",
" last_mode = mode\n", " last_mode = mode\n",
" if s == \"\":\n", " if s == \"\":\n",
" s = \" \"\n", " s = \" \"\n",
@ -876,6 +908,7 @@
" else:\n", " else:\n",
" raise DiffError(f\"Invalid Line: {s}\")\n", " raise DiffError(f\"Invalid Line: {s}\")\n",
" s = s[1:]\n", " s = s[1:]\n",
"\n",
" if mode == \"keep\" and last_mode != mode:\n", " if mode == \"keep\" and last_mode != mode:\n",
" if ins_lines or del_lines:\n", " if ins_lines or del_lines:\n",
" chunks.append(\n", " chunks.append(\n",
@ -885,8 +918,8 @@
" ins_lines=ins_lines,\n", " ins_lines=ins_lines,\n",
" )\n", " )\n",
" )\n", " )\n",
" del_lines = []\n", " del_lines, ins_lines = [], []\n",
" ins_lines = []\n", "\n",
" if mode == \"delete\":\n", " if mode == \"delete\":\n",
" del_lines.append(s)\n", " del_lines.append(s)\n",
" old.append(s)\n", " old.append(s)\n",
@ -894,6 +927,7 @@
" ins_lines.append(s)\n", " ins_lines.append(s)\n",
" elif mode == \"keep\":\n", " elif mode == \"keep\":\n",
" old.append(s)\n", " old.append(s)\n",
"\n",
" if ins_lines or del_lines:\n", " if ins_lines or del_lines:\n",
" chunks.append(\n", " chunks.append(\n",
" Chunk(\n", " Chunk(\n",
@ -902,96 +936,61 @@
" ins_lines=ins_lines,\n", " ins_lines=ins_lines,\n",
" )\n", " )\n",
" )\n", " )\n",
" del_lines = []\n", "\n",
" ins_lines = []\n",
" if index < len(lines) and lines[index] == \"*** End of File\":\n", " if index < len(lines) and lines[index] == \"*** End of File\":\n",
" index += 1\n", " index += 1\n",
" return old, chunks, index, True\n", " return old, chunks, index, True\n",
"\n",
" if index == orig_index:\n", " if index == orig_index:\n",
" raise DiffError(f\"Nothing in this section - {index=} {lines[index]}\")\n", " raise DiffError(\"Nothing in this section\")\n",
" return old, chunks, index, False\n", " return old, chunks, index, False\n",
"\n", "\n",
"\n", "\n",
"def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n", "# --------------------------------------------------------------------------- #\n",
" lines = text.strip().split(\"\\n\")\n", "# Patch → Commit and Commit application\n",
" if len(lines) < 2 or not lines[0].startswith(\"*** Begin Patch\") or lines[-1] != \"*** End Patch\":\n", "# --------------------------------------------------------------------------- #\n",
" raise DiffError(\"Invalid patch text\")\n",
"\n",
" parser = Parser(\n",
" current_files=orig,\n",
" lines=lines,\n",
" index=1,\n",
" )\n",
" parser.parse()\n",
" return parser.patch, parser.fuzz\n",
"\n",
"\n",
"def identify_files_needed(text: str) -> List[str]:\n",
" lines = text.strip().split(\"\\n\")\n",
" result = set()\n",
" for line in lines:\n",
" if line.startswith(\"*** Update File: \"):\n",
" result.add(line[len(\"*** Update File: \") :])\n",
" if line.startswith(\"*** Delete File: \"):\n",
" result.add(line[len(\"*** Delete File: \") :])\n",
" return list(result)\n",
"\n",
"\n",
"def identify_files_added(text: str) -> List[str]:\n",
" lines = text.strip().split(\"\\n\")\n",
" result = set()\n",
" for line in lines:\n",
" if line.startswith(\"*** Add File: \"):\n",
" result.add(line[len(\"*** Add File: \") :])\n",
" return list(result)\n",
"\n",
"\n",
"def _get_updated_file(text: str, action: PatchAction, path: str) -> str:\n", "def _get_updated_file(text: str, action: PatchAction, path: str) -> str:\n",
" assert action.type == ActionType.UPDATE\n", " if action.type is not ActionType.UPDATE:\n",
" raise DiffError(\"_get_updated_file called with non-update action\")\n",
" orig_lines = text.split(\"\\n\")\n", " orig_lines = text.split(\"\\n\")\n",
" dest_lines = []\n", " dest_lines: List[str] = []\n",
" orig_index = 0\n", " orig_index = 0\n",
" dest_index = 0\n", "\n",
" for chunk in action.chunks:\n", " for chunk in action.chunks:\n",
" # Process the unchanged lines before the chunk\n",
" if chunk.orig_index > len(orig_lines):\n", " if chunk.orig_index > len(orig_lines):\n",
" raise DiffError(\n", " raise DiffError(\n",
" f\"_get_updated_file: {path}: chunk.orig_index {chunk.orig_index} > len(lines) {len(orig_lines)}\"\n", " f\"{path}: chunk.orig_index {chunk.orig_index} exceeds file length\"\n",
" )\n", " )\n",
" if orig_index > chunk.orig_index:\n", " if orig_index > chunk.orig_index:\n",
" raise DiffError(\n", " raise DiffError(\n",
" f\"_get_updated_file: {path}: orig_index {orig_index} > chunk.orig_index {chunk.orig_index}\"\n", " f\"{path}: overlapping chunks at {orig_index} > {chunk.orig_index}\"\n",
" )\n", " )\n",
" assert orig_index <= chunk.orig_index\n", "\n",
" dest_lines.extend(orig_lines[orig_index : chunk.orig_index])\n", " dest_lines.extend(orig_lines[orig_index : chunk.orig_index])\n",
" delta = chunk.orig_index - orig_index\n", " orig_index = chunk.orig_index\n",
" orig_index += delta\n", "\n",
" dest_index += delta\n", " dest_lines.extend(chunk.ins_lines)\n",
" # Process the inserted lines\n",
" if chunk.ins_lines:\n",
" for i in range(len(chunk.ins_lines)):\n",
" dest_lines.append(chunk.ins_lines[i])\n",
" dest_index += len(chunk.ins_lines)\n",
" orig_index += len(chunk.del_lines)\n", " orig_index += len(chunk.del_lines)\n",
" # Final part\n", "\n",
" dest_lines.extend(orig_lines[orig_index:])\n", " dest_lines.extend(orig_lines[orig_index:])\n",
" delta = len(orig_lines) - orig_index\n",
" orig_index += delta\n",
" dest_index += delta\n",
" assert orig_index == len(orig_lines)\n",
" assert dest_index == len(dest_lines)\n",
" return \"\\n\".join(dest_lines)\n", " return \"\\n\".join(dest_lines)\n",
"\n", "\n",
"\n", "\n",
"def patch_to_commit(patch: Patch, orig: Dict[str, str]) -> Commit:\n", "def patch_to_commit(patch: Patch, orig: Dict[str, str]) -> Commit:\n",
" commit = Commit()\n", " commit = Commit()\n",
" for path, action in patch.actions.items():\n", " for path, action in patch.actions.items():\n",
" if action.type == ActionType.DELETE:\n", " if action.type is ActionType.DELETE:\n",
" commit.changes[path] = FileChange(type=ActionType.DELETE, old_content=orig[path])\n", " commit.changes[path] = FileChange(\n",
" elif action.type == ActionType.ADD:\n", " type=ActionType.DELETE, old_content=orig[path]\n",
" commit.changes[path] = FileChange(type=ActionType.ADD, new_content=action.new_file)\n", " )\n",
" elif action.type == ActionType.UPDATE:\n", " elif action.type is ActionType.ADD:\n",
" new_content = _get_updated_file(text=orig[path], action=action, path=path)\n", " if action.new_file is None:\n",
" raise DiffError(\"ADD action without file content\")\n",
" commit.changes[path] = FileChange(\n",
" type=ActionType.ADD, new_content=action.new_file\n",
" )\n",
" elif action.type is ActionType.UPDATE:\n",
" new_content = _get_updated_file(orig[path], action, path)\n",
" commit.changes[path] = FileChange(\n", " commit.changes[path] = FileChange(\n",
" type=ActionType.UPDATE,\n", " type=ActionType.UPDATE,\n",
" old_content=orig[path],\n", " old_content=orig[path],\n",
@ -1001,69 +1000,122 @@
" return commit\n", " return commit\n",
"\n", "\n",
"\n", "\n",
"class DiffError(ValueError):\n", "# --------------------------------------------------------------------------- #\n",
" pass\n", "# User-facing helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n",
" lines = text.splitlines() # preserves blank lines, no strip()\n",
" if (\n",
" len(lines) < 2\n",
" or not Parser._norm(lines[0]).startswith(\"*** Begin Patch\")\n",
" or Parser._norm(lines[-1]) != \"*** End Patch\"\n",
" ):\n",
" raise DiffError(\"Invalid patch text - missing sentinels\")\n",
"\n",
" parser = Parser(current_files=orig, lines=lines, index=1)\n",
" parser.parse()\n",
" return parser.patch, parser.fuzz\n",
"\n", "\n",
"\n", "\n",
"def load_files(paths: List[str], open_fn: Callable) -> Dict[str, str]:\n", "def identify_files_needed(text: str) -> List[str]:\n",
" orig = {}\n", " lines = text.splitlines()\n",
" for path in paths:\n", " return [\n",
" orig[path] = open_fn(path)\n", " line[len(\"*** Update File: \") :]\n",
" return orig\n", " for line in lines\n",
" if line.startswith(\"*** Update File: \")\n",
" ] + [\n",
" line[len(\"*** Delete File: \") :]\n",
" for line in lines\n",
" if line.startswith(\"*** Delete File: \")\n",
" ]\n",
"\n", "\n",
"\n", "\n",
"def apply_commit(commit: Commit, write_fn: Callable, remove_fn: Callable) -> None:\n", "def identify_files_added(text: str) -> List[str]:\n",
" lines = text.splitlines()\n",
" return [\n",
" line[len(\"*** Add File: \") :]\n",
" for line in lines\n",
" if line.startswith(\"*** Add File: \")\n",
" ]\n",
"\n",
"\n",
"# --------------------------------------------------------------------------- #\n",
"# File-system helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def load_files(paths: List[str], open_fn: Callable[[str], str]) -> Dict[str, str]:\n",
" return {path: open_fn(path) for path in paths}\n",
"\n",
"\n",
"def apply_commit(\n",
" commit: Commit,\n",
" write_fn: Callable[[str, str], None],\n",
" remove_fn: Callable[[str], None],\n",
") -> None:\n",
" for path, change in commit.changes.items():\n", " for path, change in commit.changes.items():\n",
" if change.type == ActionType.DELETE:\n", " if change.type is ActionType.DELETE:\n",
" remove_fn(path)\n", " remove_fn(path)\n",
" elif change.type == ActionType.ADD:\n", " elif change.type is ActionType.ADD:\n",
" if change.new_content is None:\n",
" raise DiffError(f\"ADD change for {path} has no content\")\n",
" write_fn(path, change.new_content)\n", " write_fn(path, change.new_content)\n",
" elif change.type == ActionType.UPDATE:\n", " elif change.type is ActionType.UPDATE:\n",
" if change.new_content is None:\n",
" raise DiffError(f\"UPDATE change for {path} has no new content\")\n",
" target = change.move_path or path\n",
" write_fn(target, change.new_content)\n",
" if change.move_path:\n", " if change.move_path:\n",
" write_fn(change.move_path, change.new_content)\n",
" remove_fn(path)\n", " remove_fn(path)\n",
" else:\n",
" write_fn(path, change.new_content)\n",
"\n", "\n",
"\n", "\n",
"def process_patch(text: str, open_fn: Callable, write_fn: Callable, remove_fn: Callable) -> str:\n", "def process_patch(\n",
" assert text.startswith(\"*** Begin Patch\")\n", " text: str,\n",
" open_fn: Callable[[str], str],\n",
" write_fn: Callable[[str, str], None],\n",
" remove_fn: Callable[[str], None],\n",
") -> str:\n",
" if not text.startswith(\"*** Begin Patch\"):\n",
" raise DiffError(\"Patch text must start with *** Begin Patch\")\n",
" paths = identify_files_needed(text)\n", " paths = identify_files_needed(text)\n",
" orig = load_files(paths, open_fn)\n", " orig = load_files(paths, open_fn)\n",
" patch, fuzz = text_to_patch(text, orig)\n", " patch, _fuzz = text_to_patch(text, orig)\n",
" commit = patch_to_commit(patch, orig)\n", " commit = patch_to_commit(patch, orig)\n",
" apply_commit(commit, write_fn, remove_fn)\n", " apply_commit(commit, write_fn, remove_fn)\n",
" return \"Done!\"\n", " return \"Done!\"\n",
"\n", "\n",
"\n", "\n",
"# --------------------------------------------------------------------------- #\n",
"# Default FS helpers\n",
"# --------------------------------------------------------------------------- #\n",
"def open_file(path: str) -> str:\n", "def open_file(path: str) -> str:\n",
" with open(path, \"rt\") as f:\n", " with open(path, \"rt\", encoding=\"utf-8\") as fh:\n",
" return f.read()\n", " return fh.read()\n",
"\n", "\n",
"\n", "\n",
"def write_file(path: str, content: str) -> None:\n", "def write_file(path: str, content: str) -> None:\n",
" if \"/\" in path:\n", " target = pathlib.Path(path)\n",
" parent = \"/\".join(path.split(\"/\")[:-1])\n", " target.parent.mkdir(parents=True, exist_ok=True)\n",
" os.makedirs(parent, exist_ok=True)\n", " with target.open(\"wt\", encoding=\"utf-8\") as fh:\n",
" with open(path, \"wt\") as f:\n", " fh.write(content)\n",
" f.write(content)\n",
"\n", "\n",
"\n", "\n",
"def remove_file(path: str) -> None:\n", "def remove_file(path: str) -> None:\n",
" os.remove(path)\n", " pathlib.Path(path).unlink(missing_ok=True)\n",
"\n", "\n",
"\n", "\n",
"# --------------------------------------------------------------------------- #\n",
"# CLI entry-point\n",
"# --------------------------------------------------------------------------- #\n",
"def main() -> None:\n", "def main() -> None:\n",
" import sys\n", " import sys\n",
"\n", "\n",
" patch_text = sys.stdin.read()\n", " patch_text = sys.stdin.read()\n",
" if not patch_text:\n", " if not patch_text:\n",
" print(\"Please pass patch text through stdin\")\n", " print(\"Please pass patch text through stdin\", file=sys.stderr)\n",
" return\n", " return\n",
" try:\n", " try:\n",
" result = process_patch(patch_text, open_file, write_file, remove_file)\n", " result = process_patch(patch_text, open_file, write_file, remove_file)\n",
" except DiffError as e:\n", " except DiffError as exc:\n",
" print(str(e))\n", " print(exc, file=sys.stderr)\n",
" return\n", " return\n",
" print(result)\n", " print(result)\n",
"\n", "\n",