From 9f6939a23970f15acbaee7c18f1b569ea3be3f3a Mon Sep 17 00:00:00 2001 From: Noah MacCallum <171723556+nm-openai@users.noreply.github.com> Date: Mon, 14 Apr 2025 12:44:54 -0700 Subject: [PATCH] 4.1 prompting guide updates: update apply-patch code (#1772) --- authors.yaml | 7 +- examples/gpt4-1_prompting_guide.ipynb | 454 ++++++++++++++------------ 2 files changed, 259 insertions(+), 202 deletions(-) diff --git a/authors.yaml b/authors.yaml index 7d87c3e..970b0ae 100644 --- a/authors.yaml +++ b/authors.yaml @@ -281,4 +281,9 @@ vishnu-oai: nm-openai: name: "Noah MacCallum" website: "https://x.com/noahmacca" - avatar: "https://avatars.githubusercontent.com/u/171723556" \ No newline at end of file + avatar: "https://avatars.githubusercontent.com/u/171723556" + +julian-openai: + name: "Julian Lee" + website: "https://github.com/julian-openai" + avatar: "https://avatars.githubusercontent.com/u/199828632" \ No newline at end of file diff --git a/examples/gpt4-1_prompting_guide.ipynb b/examples/gpt4-1_prompting_guide.ipynb index 9dc7852..e2b7cb7 100644 --- a/examples/gpt4-1_prompting_guide.ipynb +++ b/examples/gpt4-1_prompting_guide.ipynb @@ -591,12 +591,30 @@ "outputs": [], "source": [ "#!/usr/bin/env python3\n", - "import os\n", + "\n", + "\"\"\"\n", + "A self-contained **pure-Python 3.9+** utility for applying human-readable\n", + "“pseudo-diff” patch files to a collection of text files.\n", + "\"\"\"\n", + "\n", + "from __future__ import annotations\n", + "\n", + "import pathlib\n", "from dataclasses import dataclass, field\n", "from enum import Enum\n", - "from typing import Callable, Dict, List, NoReturn, Optional, Tuple, Union\n", + "from typing import (\n", + " Callable,\n", + " Dict,\n", + " List,\n", + " Optional,\n", + " Tuple,\n", + " Union,\n", + ")\n", "\n", "\n", + "# --------------------------------------------------------------------------- #\n", + "# Domain objects\n", + "# --------------------------------------------------------------------------- #\n", "class ActionType(str, Enum):\n", " ADD = \"add\"\n", " DELETE = \"delete\"\n", @@ -616,39 +634,19 @@ " changes: Dict[str, FileChange] = field(default_factory=dict)\n", "\n", "\n", - "def assemble_changes(\n", - " orig: Dict[str, Optional[str]],\n", - " updated_files: Dict[str, Optional[str]],\n", - ") -> Commit:\n", - " commit = Commit()\n", - " for path, new_content in updated_files.items():\n", - " old_content = orig.get(path)\n", - " if old_content == new_content:\n", - " continue\n", - " if old_content is not None and new_content is not None:\n", - " commit.changes[path] = FileChange(\n", - " type=ActionType.UPDATE,\n", - " old_content=old_content,\n", - " new_content=new_content,\n", - " )\n", - " elif new_content is not None:\n", - " commit.changes[path] = FileChange(\n", - " type=ActionType.ADD,\n", - " new_content=new_content,\n", - " )\n", - " elif old_content is not None:\n", - " commit.changes[path] = FileChange(\n", - " type=ActionType.DELETE,\n", - " old_content=old_content,\n", - " )\n", - " else:\n", - " assert False\n", - " return commit\n", + "# --------------------------------------------------------------------------- #\n", + "# Exceptions\n", + "# --------------------------------------------------------------------------- #\n", + "class DiffError(ValueError):\n", + " \"\"\"Any problem detected while parsing or applying a patch.\"\"\"\n", "\n", "\n", + "# --------------------------------------------------------------------------- #\n", + "# Helper dataclasses used while parsing patches\n", + "# --------------------------------------------------------------------------- #\n", "@dataclass\n", "class Chunk:\n", - " orig_index: int = -1 # line index of the first line in the original file\n", + " orig_index: int = -1\n", " del_lines: List[str] = field(default_factory=list)\n", " ins_lines: List[str] = field(default_factory=list)\n", "\n", @@ -666,79 +664,108 @@ " actions: Dict[str, PatchAction] = field(default_factory=dict)\n", "\n", "\n", + "# --------------------------------------------------------------------------- #\n", + "# Patch text parser\n", + "# --------------------------------------------------------------------------- #\n", "@dataclass\n", "class Parser:\n", - " current_files: Dict[str, str] = field(default_factory=dict)\n", - " lines: List[str] = field(default_factory=list)\n", + " current_files: Dict[str, str]\n", + " lines: List[str]\n", " index: int = 0\n", " patch: Patch = field(default_factory=Patch)\n", " fuzz: int = 0\n", "\n", + " # ------------- low-level helpers -------------------------------------- #\n", + " def _cur_line(self) -> str:\n", + " if self.index >= len(self.lines):\n", + " raise DiffError(\"Unexpected end of input while parsing patch\")\n", + " return self.lines[self.index]\n", + "\n", + " @staticmethod\n", + " def _norm(line: str) -> str:\n", + " \"\"\"Strip CR so comparisons work for both LF and CRLF input.\"\"\"\n", + " return line.rstrip(\"\\r\")\n", + "\n", + " # ------------- scanning convenience ----------------------------------- #\n", " def is_done(self, prefixes: Optional[Tuple[str, ...]] = None) -> bool:\n", " if self.index >= len(self.lines):\n", " return True\n", - " if prefixes and self.lines[self.index].startswith(prefixes):\n", + " if (\n", + " prefixes\n", + " and len(prefixes) > 0\n", + " and self._norm(self._cur_line()).startswith(prefixes)\n", + " ):\n", " return True\n", " return False\n", "\n", " def startswith(self, prefix: Union[str, Tuple[str, ...]]) -> bool:\n", - " assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n", - " if self.lines[self.index].startswith(prefix):\n", - " return True\n", - " return False\n", + " return self._norm(self._cur_line()).startswith(prefix)\n", "\n", - " def read_str(self, prefix: str = \"\", return_everything: bool = False) -> str:\n", - " assert self.index < len(self.lines), f\"Index: {self.index} >= {len(self.lines)}\"\n", - " if self.lines[self.index].startswith(prefix):\n", - " if return_everything:\n", - " text = self.lines[self.index]\n", - " else:\n", - " text = self.lines[self.index][len(prefix) :]\n", + " def read_str(self, prefix: str) -> str:\n", + " \"\"\"\n", + " Consume the current line if it starts with *prefix* and return the text\n", + " **after** the prefix. Raises if prefix is empty.\n", + " \"\"\"\n", + " if prefix == \"\":\n", + " raise ValueError(\"read_str() requires a non-empty prefix\")\n", + " if self._norm(self._cur_line()).startswith(prefix):\n", + " text = self._cur_line()[len(prefix) :]\n", " self.index += 1\n", " return text\n", " return \"\"\n", "\n", - " def parse(self) -> NoReturn:\n", + " def read_line(self) -> str:\n", + " \"\"\"Return the current raw line and advance.\"\"\"\n", + " line = self._cur_line()\n", + " self.index += 1\n", + " return line\n", + "\n", + " # ------------- public entry point -------------------------------------- #\n", + " def parse(self) -> None:\n", " while not self.is_done((\"*** End Patch\",)):\n", + " # ---------- UPDATE ---------- #\n", " path = self.read_str(\"*** Update File: \")\n", " if path:\n", " if path in self.patch.actions:\n", - " raise DiffError(f\"Update File Error: Duplicate Path: {path}\")\n", + " raise DiffError(f\"Duplicate update for file: {path}\")\n", " move_to = self.read_str(\"*** Move to: \")\n", " if path not in self.current_files:\n", - " raise DiffError(f\"Update File Error: Missing File: {path}\")\n", + " raise DiffError(f\"Update File Error - missing file: {path}\")\n", " text = self.current_files[path]\n", - " action = self.parse_update_file(text)\n", - " action.move_path = move_to\n", + " action = self._parse_update_file(text)\n", + " action.move_path = move_to or None\n", " self.patch.actions[path] = action\n", " continue\n", + "\n", + " # ---------- DELETE ---------- #\n", " path = self.read_str(\"*** Delete File: \")\n", " if path:\n", " if path in self.patch.actions:\n", - " raise DiffError(f\"Delete File Error: Duplicate Path: {path}\")\n", + " raise DiffError(f\"Duplicate delete for file: {path}\")\n", " if path not in self.current_files:\n", - " raise DiffError(f\"Delete File Error: Missing File: {path}\")\n", - " self.patch.actions[path] = PatchAction(\n", - " type=ActionType.DELETE,\n", - " )\n", + " raise DiffError(f\"Delete File Error - missing file: {path}\")\n", + " self.patch.actions[path] = PatchAction(type=ActionType.DELETE)\n", " continue\n", + "\n", + " # ---------- ADD ---------- #\n", " path = self.read_str(\"*** Add File: \")\n", " if path:\n", " if path in self.patch.actions:\n", - " raise DiffError(f\"Add File Error: Duplicate Path: {path}\")\n", + " raise DiffError(f\"Duplicate add for file: {path}\")\n", " if path in self.current_files:\n", - " raise DiffError(f\"Add File Error: File already exists: {path}\")\n", - " self.patch.actions[path] = self.parse_add_file()\n", + " raise DiffError(f\"Add File Error - file already exists: {path}\")\n", + " self.patch.actions[path] = self._parse_add_file()\n", " continue\n", - " raise DiffError(f\"Unknown Line: {self.lines[self.index]}\")\n", - " if not self.startswith(\"*** End Patch\"):\n", - " raise DiffError(\"Missing End Patch\")\n", - " self.index += 1\n", "\n", - " def parse_update_file(self, text: str) -> PatchAction:\n", - " action = PatchAction(\n", - " type=ActionType.UPDATE,\n", - " )\n", + " raise DiffError(f\"Unknown line while parsing: {self._cur_line()}\")\n", + "\n", + " if not self.startswith(\"*** End Patch\"):\n", + " raise DiffError(\"Missing *** End Patch sentinel\")\n", + " self.index += 1 # consume sentinel\n", + "\n", + " # ------------- section parsers ---------------------------------------- #\n", + " def _parse_update_file(self, text: str) -> PatchAction:\n", + " action = PatchAction(type=ActionType.UPDATE)\n", " lines = text.split(\"\\n\")\n", " index = 0\n", " while not self.is_done(\n", @@ -752,100 +779,104 @@ " ):\n", " def_str = self.read_str(\"@@ \")\n", " section_str = \"\"\n", - " if not def_str:\n", - " if self.lines[self.index] == \"@@\":\n", - " section_str = self.lines[self.index]\n", - " self.index += 1\n", + " if not def_str and self._norm(self._cur_line()) == \"@@\":\n", + " section_str = self.read_line()\n", + "\n", " if not (def_str or section_str or index == 0):\n", - " raise DiffError(f\"Invalid Line:\\n{self.lines[self.index]}\")\n", + " raise DiffError(f\"Invalid line in update section:\\n{self._cur_line()}\")\n", + "\n", " if def_str.strip():\n", " found = False\n", - " if not [s for s in lines[:index] if s == def_str]:\n", - " # def str is a skip ahead operator\n", + " if def_str not in lines[:index]:\n", " for i, s in enumerate(lines[index:], index):\n", " if s == def_str:\n", " index = i + 1\n", " found = True\n", " break\n", - " if not found and not [s for s in lines[:index] if s.strip() == def_str.strip()]:\n", - " # def str is a skip ahead operator\n", + " if not found and def_str.strip() not in [\n", + " s.strip() for s in lines[:index]\n", + " ]:\n", " for i, s in enumerate(lines[index:], index):\n", " if s.strip() == def_str.strip():\n", " index = i + 1\n", " self.fuzz += 1\n", " found = True\n", " break\n", - " next_chunk_context, chunks, end_patch_index, eof = peek_next_section(\n", - " self.lines, self.index\n", - " )\n", - " next_chunk_text = \"\\n\".join(next_chunk_context)\n", - " new_index, fuzz = find_context(lines, next_chunk_context, index, eof)\n", + "\n", + " next_ctx, chunks, end_idx, eof = peek_next_section(self.lines, self.index)\n", + " new_index, fuzz = find_context(lines, next_ctx, index, eof)\n", " if new_index == -1:\n", - " if eof:\n", - " raise DiffError(f\"Invalid EOF Context {index}:\\n{next_chunk_text}\")\n", - " else:\n", - " raise DiffError(f\"Invalid Context {index}:\\n{next_chunk_text}\")\n", + " ctx_txt = \"\\n\".join(next_ctx)\n", + " raise DiffError(\n", + " f\"Invalid {'EOF ' if eof else ''}context at {index}:\\n{ctx_txt}\"\n", + " )\n", " self.fuzz += fuzz\n", " for ch in chunks:\n", " ch.orig_index += new_index\n", " action.chunks.append(ch)\n", - " index = new_index + len(next_chunk_context)\n", - " self.index = end_patch_index\n", - " continue\n", + " index = new_index + len(next_ctx)\n", + " self.index = end_idx\n", " return action\n", "\n", - " def parse_add_file(self) -> PatchAction:\n", - " lines = []\n", + " def _parse_add_file(self) -> PatchAction:\n", + " lines: List[str] = []\n", " while not self.is_done(\n", " (\"*** End Patch\", \"*** Update File:\", \"*** Delete File:\", \"*** Add File:\")\n", " ):\n", - " s = self.read_str()\n", + " s = self.read_line()\n", " if not s.startswith(\"+\"):\n", - " raise DiffError(f\"Invalid Add File Line: {s}\")\n", - " s = s[1:]\n", - " lines.append(s)\n", - " return PatchAction(\n", - " type=ActionType.ADD,\n", - " new_file=\"\\n\".join(lines),\n", - " )\n", + " raise DiffError(f\"Invalid Add File line (missing '+'): {s}\")\n", + " lines.append(s[1:]) # strip leading '+'\n", + " return PatchAction(type=ActionType.ADD, new_file=\"\\n\".join(lines))\n", "\n", "\n", - "def find_context_core(lines: List[str], context: List[str], start: int) -> Tuple[int, int]:\n", + "# --------------------------------------------------------------------------- #\n", + "# Helper functions\n", + "# --------------------------------------------------------------------------- #\n", + "def find_context_core(\n", + " lines: List[str], context: List[str], start: int\n", + ") -> Tuple[int, int]:\n", " if not context:\n", " return start, 0\n", "\n", - " # Prefer identical\n", " for i in range(start, len(lines)):\n", " if lines[i : i + len(context)] == context:\n", " return i, 0\n", - " # RStrip is ok\n", " for i in range(start, len(lines)):\n", - " if [s.rstrip() for s in lines[i : i + len(context)]] == [s.rstrip() for s in context]:\n", + " if [s.rstrip() for s in lines[i : i + len(context)]] == [\n", + " s.rstrip() for s in context\n", + " ]:\n", " return i, 1\n", - " # Fine, Strip is ok too.\n", " for i in range(start, len(lines)):\n", - " if [s.strip() for s in lines[i : i + len(context)]] == [s.strip() for s in context]:\n", + " if [s.strip() for s in lines[i : i + len(context)]] == [\n", + " s.strip() for s in context\n", + " ]:\n", " return i, 100\n", " return -1, 0\n", "\n", "\n", - "def find_context(lines: List[str], context: List[str], start: int, eof: bool) -> Tuple[int, int]:\n", + "def find_context(\n", + " lines: List[str], context: List[str], start: int, eof: bool\n", + ") -> Tuple[int, int]:\n", " if eof:\n", " new_index, fuzz = find_context_core(lines, context, len(lines) - len(context))\n", " if new_index != -1:\n", " return new_index, fuzz\n", " new_index, fuzz = find_context_core(lines, context, start)\n", - " return new_index, fuzz + 10000\n", + " return new_index, fuzz + 10_000\n", " return find_context_core(lines, context, start)\n", "\n", "\n", - "def peek_next_section(lines: List[str], index: int) -> Tuple[List[str], List[Chunk], int, bool]:\n", + "def peek_next_section(\n", + " lines: List[str], index: int\n", + ") -> Tuple[List[str], List[Chunk], int, bool]:\n", " old: List[str] = []\n", " del_lines: List[str] = []\n", " ins_lines: List[str] = []\n", " chunks: List[Chunk] = []\n", " mode = \"keep\"\n", " orig_index = index\n", + "\n", " while index < len(lines):\n", " s = lines[index]\n", " if s.startswith(\n", @@ -861,9 +892,10 @@ " break\n", " if s == \"***\":\n", " break\n", - " elif s.startswith(\"***\"):\n", + " if s.startswith(\"***\"):\n", " raise DiffError(f\"Invalid Line: {s}\")\n", " index += 1\n", + "\n", " last_mode = mode\n", " if s == \"\":\n", " s = \" \"\n", @@ -876,6 +908,7 @@ " else:\n", " raise DiffError(f\"Invalid Line: {s}\")\n", " s = s[1:]\n", + "\n", " if mode == \"keep\" and last_mode != mode:\n", " if ins_lines or del_lines:\n", " chunks.append(\n", @@ -885,8 +918,8 @@ " ins_lines=ins_lines,\n", " )\n", " )\n", - " del_lines = []\n", - " ins_lines = []\n", + " del_lines, ins_lines = [], []\n", + "\n", " if mode == \"delete\":\n", " del_lines.append(s)\n", " old.append(s)\n", @@ -894,6 +927,7 @@ " ins_lines.append(s)\n", " elif mode == \"keep\":\n", " old.append(s)\n", + "\n", " if ins_lines or del_lines:\n", " chunks.append(\n", " Chunk(\n", @@ -902,96 +936,61 @@ " ins_lines=ins_lines,\n", " )\n", " )\n", - " del_lines = []\n", - " ins_lines = []\n", + "\n", " if index < len(lines) and lines[index] == \"*** End of File\":\n", " index += 1\n", " return old, chunks, index, True\n", + "\n", " if index == orig_index:\n", - " raise DiffError(f\"Nothing in this section - {index=} {lines[index]}\")\n", + " raise DiffError(\"Nothing in this section\")\n", " return old, chunks, index, False\n", "\n", "\n", - "def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n", - " lines = text.strip().split(\"\\n\")\n", - " if len(lines) < 2 or not lines[0].startswith(\"*** Begin Patch\") or lines[-1] != \"*** End Patch\":\n", - " raise DiffError(\"Invalid patch text\")\n", - "\n", - " parser = Parser(\n", - " current_files=orig,\n", - " lines=lines,\n", - " index=1,\n", - " )\n", - " parser.parse()\n", - " return parser.patch, parser.fuzz\n", - "\n", - "\n", - "def identify_files_needed(text: str) -> List[str]:\n", - " lines = text.strip().split(\"\\n\")\n", - " result = set()\n", - " for line in lines:\n", - " if line.startswith(\"*** Update File: \"):\n", - " result.add(line[len(\"*** Update File: \") :])\n", - " if line.startswith(\"*** Delete File: \"):\n", - " result.add(line[len(\"*** Delete File: \") :])\n", - " return list(result)\n", - "\n", - "\n", - "def identify_files_added(text: str) -> List[str]:\n", - " lines = text.strip().split(\"\\n\")\n", - " result = set()\n", - " for line in lines:\n", - " if line.startswith(\"*** Add File: \"):\n", - " result.add(line[len(\"*** Add File: \") :])\n", - " return list(result)\n", - "\n", - "\n", + "# --------------------------------------------------------------------------- #\n", + "# Patch → Commit and Commit application\n", + "# --------------------------------------------------------------------------- #\n", "def _get_updated_file(text: str, action: PatchAction, path: str) -> str:\n", - " assert action.type == ActionType.UPDATE\n", + " if action.type is not ActionType.UPDATE:\n", + " raise DiffError(\"_get_updated_file called with non-update action\")\n", " orig_lines = text.split(\"\\n\")\n", - " dest_lines = []\n", + " dest_lines: List[str] = []\n", " orig_index = 0\n", - " dest_index = 0\n", + "\n", " for chunk in action.chunks:\n", - " # Process the unchanged lines before the chunk\n", " if chunk.orig_index > len(orig_lines):\n", " raise DiffError(\n", - " f\"_get_updated_file: {path}: chunk.orig_index {chunk.orig_index} > len(lines) {len(orig_lines)}\"\n", + " f\"{path}: chunk.orig_index {chunk.orig_index} exceeds file length\"\n", " )\n", " if orig_index > chunk.orig_index:\n", " raise DiffError(\n", - " f\"_get_updated_file: {path}: orig_index {orig_index} > chunk.orig_index {chunk.orig_index}\"\n", + " f\"{path}: overlapping chunks at {orig_index} > {chunk.orig_index}\"\n", " )\n", - " assert orig_index <= chunk.orig_index\n", + "\n", " dest_lines.extend(orig_lines[orig_index : chunk.orig_index])\n", - " delta = chunk.orig_index - orig_index\n", - " orig_index += delta\n", - " dest_index += delta\n", - " # Process the inserted lines\n", - " if chunk.ins_lines:\n", - " for i in range(len(chunk.ins_lines)):\n", - " dest_lines.append(chunk.ins_lines[i])\n", - " dest_index += len(chunk.ins_lines)\n", + " orig_index = chunk.orig_index\n", + "\n", + " dest_lines.extend(chunk.ins_lines)\n", " orig_index += len(chunk.del_lines)\n", - " # Final part\n", + "\n", " dest_lines.extend(orig_lines[orig_index:])\n", - " delta = len(orig_lines) - orig_index\n", - " orig_index += delta\n", - " dest_index += delta\n", - " assert orig_index == len(orig_lines)\n", - " assert dest_index == len(dest_lines)\n", " return \"\\n\".join(dest_lines)\n", "\n", "\n", "def patch_to_commit(patch: Patch, orig: Dict[str, str]) -> Commit:\n", " commit = Commit()\n", " for path, action in patch.actions.items():\n", - " if action.type == ActionType.DELETE:\n", - " commit.changes[path] = FileChange(type=ActionType.DELETE, old_content=orig[path])\n", - " elif action.type == ActionType.ADD:\n", - " commit.changes[path] = FileChange(type=ActionType.ADD, new_content=action.new_file)\n", - " elif action.type == ActionType.UPDATE:\n", - " new_content = _get_updated_file(text=orig[path], action=action, path=path)\n", + " if action.type is ActionType.DELETE:\n", + " commit.changes[path] = FileChange(\n", + " type=ActionType.DELETE, old_content=orig[path]\n", + " )\n", + " elif action.type is ActionType.ADD:\n", + " if action.new_file is None:\n", + " raise DiffError(\"ADD action without file content\")\n", + " commit.changes[path] = FileChange(\n", + " type=ActionType.ADD, new_content=action.new_file\n", + " )\n", + " elif action.type is ActionType.UPDATE:\n", + " new_content = _get_updated_file(orig[path], action, path)\n", " commit.changes[path] = FileChange(\n", " type=ActionType.UPDATE,\n", " old_content=orig[path],\n", @@ -1001,69 +1000,122 @@ " return commit\n", "\n", "\n", - "class DiffError(ValueError):\n", - " pass\n", + "# --------------------------------------------------------------------------- #\n", + "# User-facing helpers\n", + "# --------------------------------------------------------------------------- #\n", + "def text_to_patch(text: str, orig: Dict[str, str]) -> Tuple[Patch, int]:\n", + " lines = text.splitlines() # preserves blank lines, no strip()\n", + " if (\n", + " len(lines) < 2\n", + " or not Parser._norm(lines[0]).startswith(\"*** Begin Patch\")\n", + " or Parser._norm(lines[-1]) != \"*** End Patch\"\n", + " ):\n", + " raise DiffError(\"Invalid patch text - missing sentinels\")\n", + "\n", + " parser = Parser(current_files=orig, lines=lines, index=1)\n", + " parser.parse()\n", + " return parser.patch, parser.fuzz\n", "\n", "\n", - "def load_files(paths: List[str], open_fn: Callable) -> Dict[str, str]:\n", - " orig = {}\n", - " for path in paths:\n", - " orig[path] = open_fn(path)\n", - " return orig\n", + "def identify_files_needed(text: str) -> List[str]:\n", + " lines = text.splitlines()\n", + " return [\n", + " line[len(\"*** Update File: \") :]\n", + " for line in lines\n", + " if line.startswith(\"*** Update File: \")\n", + " ] + [\n", + " line[len(\"*** Delete File: \") :]\n", + " for line in lines\n", + " if line.startswith(\"*** Delete File: \")\n", + " ]\n", "\n", "\n", - "def apply_commit(commit: Commit, write_fn: Callable, remove_fn: Callable) -> None:\n", + "def identify_files_added(text: str) -> List[str]:\n", + " lines = text.splitlines()\n", + " return [\n", + " line[len(\"*** Add File: \") :]\n", + " for line in lines\n", + " if line.startswith(\"*** Add File: \")\n", + " ]\n", + "\n", + "\n", + "# --------------------------------------------------------------------------- #\n", + "# File-system helpers\n", + "# --------------------------------------------------------------------------- #\n", + "def load_files(paths: List[str], open_fn: Callable[[str], str]) -> Dict[str, str]:\n", + " return {path: open_fn(path) for path in paths}\n", + "\n", + "\n", + "def apply_commit(\n", + " commit: Commit,\n", + " write_fn: Callable[[str, str], None],\n", + " remove_fn: Callable[[str], None],\n", + ") -> None:\n", " for path, change in commit.changes.items():\n", - " if change.type == ActionType.DELETE:\n", + " if change.type is ActionType.DELETE:\n", " remove_fn(path)\n", - " elif change.type == ActionType.ADD:\n", + " elif change.type is ActionType.ADD:\n", + " if change.new_content is None:\n", + " raise DiffError(f\"ADD change for {path} has no content\")\n", " write_fn(path, change.new_content)\n", - " elif change.type == ActionType.UPDATE:\n", + " elif change.type is ActionType.UPDATE:\n", + " if change.new_content is None:\n", + " raise DiffError(f\"UPDATE change for {path} has no new content\")\n", + " target = change.move_path or path\n", + " write_fn(target, change.new_content)\n", " if change.move_path:\n", - " write_fn(change.move_path, change.new_content)\n", " remove_fn(path)\n", - " else:\n", - " write_fn(path, change.new_content)\n", "\n", "\n", - "def process_patch(text: str, open_fn: Callable, write_fn: Callable, remove_fn: Callable) -> str:\n", - " assert text.startswith(\"*** Begin Patch\")\n", + "def process_patch(\n", + " text: str,\n", + " open_fn: Callable[[str], str],\n", + " write_fn: Callable[[str, str], None],\n", + " remove_fn: Callable[[str], None],\n", + ") -> str:\n", + " if not text.startswith(\"*** Begin Patch\"):\n", + " raise DiffError(\"Patch text must start with *** Begin Patch\")\n", " paths = identify_files_needed(text)\n", " orig = load_files(paths, open_fn)\n", - " patch, fuzz = text_to_patch(text, orig)\n", + " patch, _fuzz = text_to_patch(text, orig)\n", " commit = patch_to_commit(patch, orig)\n", " apply_commit(commit, write_fn, remove_fn)\n", " return \"Done!\"\n", "\n", "\n", + "# --------------------------------------------------------------------------- #\n", + "# Default FS helpers\n", + "# --------------------------------------------------------------------------- #\n", "def open_file(path: str) -> str:\n", - " with open(path, \"rt\") as f:\n", - " return f.read()\n", + " with open(path, \"rt\", encoding=\"utf-8\") as fh:\n", + " return fh.read()\n", "\n", "\n", "def write_file(path: str, content: str) -> None:\n", - " if \"/\" in path:\n", - " parent = \"/\".join(path.split(\"/\")[:-1])\n", - " os.makedirs(parent, exist_ok=True)\n", - " with open(path, \"wt\") as f:\n", - " f.write(content)\n", + " target = pathlib.Path(path)\n", + " target.parent.mkdir(parents=True, exist_ok=True)\n", + " with target.open(\"wt\", encoding=\"utf-8\") as fh:\n", + " fh.write(content)\n", "\n", "\n", "def remove_file(path: str) -> None:\n", - " os.remove(path)\n", + " pathlib.Path(path).unlink(missing_ok=True)\n", "\n", "\n", + "# --------------------------------------------------------------------------- #\n", + "# CLI entry-point\n", + "# --------------------------------------------------------------------------- #\n", "def main() -> None:\n", " import sys\n", "\n", " patch_text = sys.stdin.read()\n", " if not patch_text:\n", - " print(\"Please pass patch text through stdin\")\n", + " print(\"Please pass patch text through stdin\", file=sys.stderr)\n", " return\n", " try:\n", " result = process_patch(patch_text, open_file, write_file, remove_file)\n", - " except DiffError as e:\n", - " print(str(e))\n", + " except DiffError as exc:\n", + " print(exc, file=sys.stderr)\n", " return\n", " print(result)\n", "\n",