Files
2026-04-10 15:24:42 +02:00

218 lines
7.3 KiB
Python

from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass
from typing import DefaultDict
from markdown_it import MarkdownIt
from markdown_it.token import Token
@dataclass(frozen=True)
class _LinePatch:
old: str
new: str
class ASTMarkdownEditor:
def __init__(self) -> None:
self._md = MarkdownIt("commonmark")
def replace_links(self, content: str, replacements: dict[str, str]) -> str:
if not replacements:
return content
tokens = self._md.parse(content)
patches_by_line = self._collect_line_patches(tokens=tokens, replacements=replacements)
if not patches_by_line:
return content
lines = content.splitlines(keepends=True)
for line_index, patches in patches_by_line.items():
if line_index < 0 or line_index >= len(lines):
continue
lines[line_index] = self._rewrite_markdown_links_in_line(lines[line_index], patches)
return "".join(lines)
def _collect_line_patches(
self,
tokens: list[Token],
replacements: dict[str, str],
) -> dict[int, list[_LinePatch]]:
patches_by_line: DefaultDict[int, list[_LinePatch]] = defaultdict(list)
for token in tokens:
if token.type != "inline" or not token.children:
continue
if not token.map:
continue
line_index = int(token.map[0])
for child in token.children:
if child.type != "link_open":
continue
href = child.attrGet("href")
if not href:
continue
new_href = replacements.get(href)
if not new_href or new_href == href:
continue
patches_by_line[line_index].append(_LinePatch(old=href, new=new_href))
return dict(patches_by_line)
def _rewrite_markdown_links_in_line(self, line: str, patches: list[_LinePatch]) -> str:
if not patches:
return line
patch_index = 0
chars = list(line)
i = 0
in_code = False
code_ticks = 0
while i < len(chars) and patch_index < len(patches):
char = chars[i]
if char == "`" and not self._is_escaped(chars, i):
run = self._count_char_run(chars, i, "`")
if not in_code:
in_code = True
code_ticks = run
elif run == code_ticks:
in_code = False
code_ticks = 0
i += run
continue
if in_code:
i += 1
continue
if char == "[" and not self._is_escaped(chars, i):
if i > 0 and chars[i - 1] == "!":
i += 1
continue
parsed = self._parse_inline_link(chars, i)
if parsed is None:
i += 1
continue
start_url, end_url, parsed_url, close_index = parsed
patch = patches[patch_index]
if parsed_url == patch.old:
replacement = list(patch.new)
chars[start_url:end_url] = replacement
delta = len(replacement) - (end_url - start_url)
close_index += delta
patch_index += 1
i = close_index + 1
continue
i += 1
return "".join(chars)
def _parse_inline_link(self, chars: list[str], open_bracket: int) -> tuple[int, int, str, int] | None:
close_bracket = self._find_link_text_end(chars, open_bracket)
if close_bracket is None:
return None
cursor = close_bracket + 1
while cursor < len(chars) and chars[cursor] in (" ", "\t"):
cursor += 1
if cursor >= len(chars) or chars[cursor] != "(":
return None
close_paren = self._find_matching_paren(chars, cursor)
if close_paren is None:
return None
dest_start = cursor + 1
while dest_start < close_paren and chars[dest_start] in (" ", "\t"):
dest_start += 1
if dest_start >= close_paren:
return None
if chars[dest_start] == "<":
dest_end = dest_start + 1
while dest_end < close_paren and chars[dest_end] != ">":
dest_end += 1
if dest_end >= close_paren:
return None
url_start = dest_start + 1
url_end = dest_end
else:
url_start = dest_start
url_end = self._scan_destination_end(chars, start=dest_start, stop=close_paren)
if url_end <= url_start:
return None
parsed_url = "".join(chars[url_start:url_end])
return url_start, url_end, parsed_url, close_paren
def _find_link_text_end(self, chars: list[str], open_bracket: int) -> int | None:
depth = 1
index = open_bracket + 1
while index < len(chars):
char = chars[index]
if char == "[" and not self._is_escaped(chars, index):
depth += 1
elif char == "]" and not self._is_escaped(chars, index):
depth -= 1
if depth == 0:
return index
index += 1
return None
def _find_matching_paren(self, chars: list[str], open_paren: int) -> int | None:
depth = 1
index = open_paren + 1
in_quote: str | None = None
while index < len(chars):
char = chars[index]
if in_quote is not None:
if char == in_quote and not self._is_escaped(chars, index):
in_quote = None
index += 1
continue
if char in ('"', "'") and not self._is_escaped(chars, index):
in_quote = char
index += 1
continue
if char == "(" and not self._is_escaped(chars, index):
depth += 1
elif char == ")" and not self._is_escaped(chars, index):
depth -= 1
if depth == 0:
return index
index += 1
return None
def _scan_destination_end(self, chars: list[str], start: int, stop: int) -> int:
depth = 0
index = start
while index < stop:
char = chars[index]
if char in (" ", "\t"):
if depth == 0:
break
elif char == "(" and not self._is_escaped(chars, index):
depth += 1
elif char == ")" and not self._is_escaped(chars, index):
if depth == 0:
break
depth -= 1
index += 1
return index
def _is_escaped(self, chars: list[str], pos: int) -> bool:
backslashes = 0
index = pos - 1
while index >= 0 and chars[index] == "\\":
backslashes += 1
index -= 1
return (backslashes % 2) == 1
def _count_char_run(self, chars: list[str], start: int, char: str) -> int:
end = start
while end < len(chars) and chars[end] == char:
end += 1
return end - start