|
| 1 | + |
| 2 | +""" |
| 3 | +trafilatura extractor implementation. |
| 4 | +""" |
| 5 | +from typing import Dict, Any, Optional, List |
| 6 | +from dataclasses import dataclass |
| 7 | +from .base import BaseExtractor, ExtractionResult |
| 8 | +from .factory import extractor |
| 9 | +from trafilatura import extract,html2txt,baseline |
| 10 | +import re |
| 11 | + |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class TrafilaturaInferenceConfig: |
| 15 | + """Configuration for Trafilatura extractor.""" |
| 16 | + favor_precision: bool = True # 优先精度:只提取最核心的内容,过滤更多冗余(如侧边栏、广告),默认开启 |
| 17 | + favor_recall: bool = True # 优先召回:尽可能提取所有潜在有效内容,减少遗漏,默认开启 |
| 18 | + include_comments: bool = False # 是否保留评论,默认关闭 |
| 19 | + include_tables: bool = True # 是否保留提取html表格,默认开启 |
| 20 | + include_images: bool = False # 是否保留提取图片信息,默认开启 |
| 21 | + include_links: bool = False # 是否保留链接,默认关闭 |
| 22 | + with_metadata: bool = False # 是否保留元信息,默认关闭 |
| 23 | + skip_elements: bool = False # 是否保留CSS隐藏元素,默认关闭 |
| 24 | + output_format: str = "markdown" # 支持多种格式输出:"csv", "json", "html", "markdown", "txt", "xml"等 |
| 25 | + |
| 26 | + |
| 27 | +@extractor("trafilatura_txt") |
| 28 | +class TrafilaturaExtractor(BaseExtractor): |
| 29 | + """Extractor using Trafilatura.""" |
| 30 | + |
| 31 | + version = "2.0.0" |
| 32 | + description = "Trafilatura based content extractor" |
| 33 | + |
| 34 | + def __init__(self, name: str, config: Optional[Dict[str, Any]] = None): |
| 35 | + super().__init__(name, config) |
| 36 | + self.inference_config = TrafilaturaInferenceConfig() |
| 37 | + |
| 38 | + # 应用用户配置 |
| 39 | + if config: |
| 40 | + for key, value in config.items(): |
| 41 | + if hasattr(self.inference_config, key): |
| 42 | + setattr(self.inference_config, key, value) |
| 43 | + |
| 44 | + def _setup(self) -> None: |
| 45 | + """Set up the Trafilatura extractor.""" |
| 46 | + # 初始化操作 |
| 47 | + pass |
| 48 | + |
| 49 | + def _extract_content(self, html: str, url: str = None) -> ExtractionResult: |
| 50 | + """ |
| 51 | + Extract content using Trafilatura. |
| 52 | +
|
| 53 | + Args: |
| 54 | + html: HTML content to extract from |
| 55 | + url: Optional URL of the page |
| 56 | +
|
| 57 | + Returns: |
| 58 | + ExtractionResult instance |
| 59 | + """ |
| 60 | + try: |
| 61 | + # 使用配置参数进行内容抽取 |
| 62 | + # content = extract( |
| 63 | + # html, |
| 64 | + # url=url, |
| 65 | + # favor_precision=self.inference_config.favor_precision, |
| 66 | + # favor_recall=self.inference_config.favor_recall, |
| 67 | + # include_comments=self.inference_config.include_comments, |
| 68 | + # include_tables=self.inference_config.include_tables, |
| 69 | + # include_images=self.inference_config.include_images, |
| 70 | + # include_links=self.inference_config.include_links, |
| 71 | + # with_metadata=self.inference_config.with_metadata, |
| 72 | + # output_format=self.inference_config.output_format # 传入输出格式 |
| 73 | + # |
| 74 | + # ) |
| 75 | + |
| 76 | + # 使用最大化召回率进行内容抽取成txt |
| 77 | + # content = html2txt(html) |
| 78 | + |
| 79 | + # 使用输出更准确的抽取txt结果 |
| 80 | + postbody, content, len_text = baseline(html) |
| 81 | + |
| 82 | + # 创建 content_list(简单分割段落) |
| 83 | + content_list = [] |
| 84 | + if content: |
| 85 | + paragraphs = content.split('\n\n') |
| 86 | + for i, para in enumerate(paragraphs): |
| 87 | + if para.strip(): |
| 88 | + content_list.append({ |
| 89 | + "type": "paragraph", |
| 90 | + "content": para.strip(), |
| 91 | + "index": i |
| 92 | + }) |
| 93 | + |
| 94 | + return ExtractionResult( |
| 95 | + content=content, |
| 96 | + # content_list=content_list, |
| 97 | + title=self._extract_title(html), |
| 98 | + language=self._detect_language(content), |
| 99 | + success=True |
| 100 | + ) |
| 101 | + |
| 102 | + except Exception as e: |
| 103 | + return ExtractionResult.create_error_result( |
| 104 | + f"Trafilatura extraction failed: {str(e)}" |
| 105 | + ) |
| 106 | + |
| 107 | + def _extract_title(self, html: str) -> Optional[str]: |
| 108 | + """提取页面标题.""" |
| 109 | + try: |
| 110 | + import re |
| 111 | + title_match = re.search(r'<title[^>]*>(.*?)</title>', html, re.IGNORECASE | re.DOTALL) |
| 112 | + if title_match: |
| 113 | + return title_match.group(1).strip() |
| 114 | + except: |
| 115 | + pass |
| 116 | + return None |
| 117 | + |
| 118 | + def _detect_language(self, content: str) -> Optional[str]: |
| 119 | + """检测内容语言.""" |
| 120 | + if not content: |
| 121 | + return None |
| 122 | + |
| 123 | + # 简单的语言检测逻辑 |
| 124 | + chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', content)) |
| 125 | + english_chars = len(re.findall(r'[a-zA-Z]', content)) |
| 126 | + |
| 127 | + if chinese_chars > english_chars: |
| 128 | + return "zh" |
| 129 | + elif english_chars > 0: |
| 130 | + return "en" |
| 131 | + else: |
| 132 | + return None |
0 commit comments