218 lines
7.3 KiB
Python
218 lines
7.3 KiB
Python
from __future__ import annotations
|
|
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import DefaultDict
|
|
|
|
from markdown_it import MarkdownIt
|
|
from markdown_it.token import Token
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _LinePatch:
|
|
old: str
|
|
new: str
|
|
|
|
|
|
class ASTMarkdownEditor:
|
|
def __init__(self) -> None:
|
|
self._md = MarkdownIt("commonmark")
|
|
|
|
def replace_links(self, content: str, replacements: dict[str, str]) -> str:
|
|
if not replacements:
|
|
return content
|
|
|
|
tokens = self._md.parse(content)
|
|
patches_by_line = self._collect_line_patches(tokens=tokens, replacements=replacements)
|
|
if not patches_by_line:
|
|
return content
|
|
|
|
lines = content.splitlines(keepends=True)
|
|
for line_index, patches in patches_by_line.items():
|
|
if line_index < 0 or line_index >= len(lines):
|
|
continue
|
|
lines[line_index] = self._rewrite_markdown_links_in_line(lines[line_index], patches)
|
|
return "".join(lines)
|
|
|
|
def _collect_line_patches(
|
|
self,
|
|
tokens: list[Token],
|
|
replacements: dict[str, str],
|
|
) -> dict[int, list[_LinePatch]]:
|
|
patches_by_line: DefaultDict[int, list[_LinePatch]] = defaultdict(list)
|
|
for token in tokens:
|
|
if token.type != "inline" or not token.children:
|
|
continue
|
|
if not token.map:
|
|
continue
|
|
line_index = int(token.map[0])
|
|
for child in token.children:
|
|
if child.type != "link_open":
|
|
continue
|
|
href = child.attrGet("href")
|
|
if not href:
|
|
continue
|
|
new_href = replacements.get(href)
|
|
if not new_href or new_href == href:
|
|
continue
|
|
patches_by_line[line_index].append(_LinePatch(old=href, new=new_href))
|
|
return dict(patches_by_line)
|
|
|
|
def _rewrite_markdown_links_in_line(self, line: str, patches: list[_LinePatch]) -> str:
|
|
if not patches:
|
|
return line
|
|
|
|
patch_index = 0
|
|
chars = list(line)
|
|
i = 0
|
|
in_code = False
|
|
code_ticks = 0
|
|
|
|
while i < len(chars) and patch_index < len(patches):
|
|
char = chars[i]
|
|
|
|
if char == "`" and not self._is_escaped(chars, i):
|
|
run = self._count_char_run(chars, i, "`")
|
|
if not in_code:
|
|
in_code = True
|
|
code_ticks = run
|
|
elif run == code_ticks:
|
|
in_code = False
|
|
code_ticks = 0
|
|
i += run
|
|
continue
|
|
|
|
if in_code:
|
|
i += 1
|
|
continue
|
|
|
|
if char == "[" and not self._is_escaped(chars, i):
|
|
if i > 0 and chars[i - 1] == "!":
|
|
i += 1
|
|
continue
|
|
parsed = self._parse_inline_link(chars, i)
|
|
if parsed is None:
|
|
i += 1
|
|
continue
|
|
start_url, end_url, parsed_url, close_index = parsed
|
|
patch = patches[patch_index]
|
|
if parsed_url == patch.old:
|
|
replacement = list(patch.new)
|
|
chars[start_url:end_url] = replacement
|
|
delta = len(replacement) - (end_url - start_url)
|
|
close_index += delta
|
|
patch_index += 1
|
|
i = close_index + 1
|
|
continue
|
|
|
|
i += 1
|
|
|
|
return "".join(chars)
|
|
|
|
def _parse_inline_link(self, chars: list[str], open_bracket: int) -> tuple[int, int, str, int] | None:
|
|
close_bracket = self._find_link_text_end(chars, open_bracket)
|
|
if close_bracket is None:
|
|
return None
|
|
|
|
cursor = close_bracket + 1
|
|
while cursor < len(chars) and chars[cursor] in (" ", "\t"):
|
|
cursor += 1
|
|
if cursor >= len(chars) or chars[cursor] != "(":
|
|
return None
|
|
|
|
close_paren = self._find_matching_paren(chars, cursor)
|
|
if close_paren is None:
|
|
return None
|
|
|
|
dest_start = cursor + 1
|
|
while dest_start < close_paren and chars[dest_start] in (" ", "\t"):
|
|
dest_start += 1
|
|
if dest_start >= close_paren:
|
|
return None
|
|
|
|
if chars[dest_start] == "<":
|
|
dest_end = dest_start + 1
|
|
while dest_end < close_paren and chars[dest_end] != ">":
|
|
dest_end += 1
|
|
if dest_end >= close_paren:
|
|
return None
|
|
url_start = dest_start + 1
|
|
url_end = dest_end
|
|
else:
|
|
url_start = dest_start
|
|
url_end = self._scan_destination_end(chars, start=dest_start, stop=close_paren)
|
|
if url_end <= url_start:
|
|
return None
|
|
|
|
parsed_url = "".join(chars[url_start:url_end])
|
|
return url_start, url_end, parsed_url, close_paren
|
|
|
|
def _find_link_text_end(self, chars: list[str], open_bracket: int) -> int | None:
|
|
depth = 1
|
|
index = open_bracket + 1
|
|
while index < len(chars):
|
|
char = chars[index]
|
|
if char == "[" and not self._is_escaped(chars, index):
|
|
depth += 1
|
|
elif char == "]" and not self._is_escaped(chars, index):
|
|
depth -= 1
|
|
if depth == 0:
|
|
return index
|
|
index += 1
|
|
return None
|
|
|
|
def _find_matching_paren(self, chars: list[str], open_paren: int) -> int | None:
|
|
depth = 1
|
|
index = open_paren + 1
|
|
in_quote: str | None = None
|
|
while index < len(chars):
|
|
char = chars[index]
|
|
if in_quote is not None:
|
|
if char == in_quote and not self._is_escaped(chars, index):
|
|
in_quote = None
|
|
index += 1
|
|
continue
|
|
if char in ('"', "'") and not self._is_escaped(chars, index):
|
|
in_quote = char
|
|
index += 1
|
|
continue
|
|
if char == "(" and not self._is_escaped(chars, index):
|
|
depth += 1
|
|
elif char == ")" and not self._is_escaped(chars, index):
|
|
depth -= 1
|
|
if depth == 0:
|
|
return index
|
|
index += 1
|
|
return None
|
|
|
|
def _scan_destination_end(self, chars: list[str], start: int, stop: int) -> int:
|
|
depth = 0
|
|
index = start
|
|
while index < stop:
|
|
char = chars[index]
|
|
if char in (" ", "\t"):
|
|
if depth == 0:
|
|
break
|
|
elif char == "(" and not self._is_escaped(chars, index):
|
|
depth += 1
|
|
elif char == ")" and not self._is_escaped(chars, index):
|
|
if depth == 0:
|
|
break
|
|
depth -= 1
|
|
index += 1
|
|
return index
|
|
|
|
def _is_escaped(self, chars: list[str], pos: int) -> bool:
|
|
backslashes = 0
|
|
index = pos - 1
|
|
while index >= 0 and chars[index] == "\\":
|
|
backslashes += 1
|
|
index -= 1
|
|
return (backslashes % 2) == 1
|
|
|
|
def _count_char_run(self, chars: list[str], start: int, char: str) -> int:
|
|
end = start
|
|
while end < len(chars) and chars[end] == char:
|
|
end += 1
|
|
return end - start
|