From a4f4730cf5b88f196ba24f8860dc3e742e71ffc1 Mon Sep 17 00:00:00 2001 From: Shannon Shen <22512825+lolipopshock@users.noreply.github.com> Date: Wed, 8 Sep 2021 16:40:54 -0400 Subject: [PATCH] re-org the ocr utils --- src/layoutparser/ocr/__init__.py | 2 + src/layoutparser/ocr/base.py | 67 +++++ src/layoutparser/{ocr.py => ocr/gcv_agent.py} | 240 +----------------- src/layoutparser/ocr/tesseract_agent.py | 169 ++++++++++++ 4 files changed, 241 insertions(+), 237 deletions(-) create mode 100644 src/layoutparser/ocr/__init__.py create mode 100644 src/layoutparser/ocr/base.py rename src/layoutparser/{ocr.py => ocr/gcv_agent.py} (55%) create mode 100644 src/layoutparser/ocr/tesseract_agent.py diff --git a/src/layoutparser/ocr/__init__.py b/src/layoutparser/ocr/__init__.py new file mode 100644 index 0000000..4a45839 --- /dev/null +++ b/src/layoutparser/ocr/__init__.py @@ -0,0 +1,2 @@ +from .gcv_agent import GCVAgent, GCVFeatureType +from .tesseract_agent import TesseractAgent, TesseractFeatureType \ No newline at end of file diff --git a/src/layoutparser/ocr/base.py b/src/layoutparser/ocr/base.py new file mode 100644 index 0000000..e4e9569 --- /dev/null +++ b/src/layoutparser/ocr/base.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from enum import IntEnum +import importlib + + +class BaseOCRElementType(IntEnum): + @property + @abstractmethod + def attr_name(self): + pass + + +class BaseOCRAgent(ABC): + @property + @abstractmethod + def DEPENDENCIES(self): + """DEPENDENCIES lists all necessary dependencies for the class.""" + pass + + @property + @abstractmethod + def MODULES(self): + """MODULES instructs how to import these necessary libraries. + + Note: + Sometimes a python module have different installation name and module name (e.g., + `pip install tensorflow-gpu` when installing and `import tensorflow` when using + ). And sometimes we only need to import a submodule but not whole module. MODULES + is designed for this purpose. + + Returns: + :obj: list(dict): A list of dict indicate how the model is imported. + + Example:: + + [{ + "import_name": "_vision", + "module_path": "google.cloud.vision" + }] + + is equivalent to self._vision = importlib.import_module("google.cloud.vision") + """ + pass + + @classmethod + def _import_module(cls): + for m in cls.MODULES: + if importlib.util.find_spec(m["module_path"]): + setattr( + cls, m["import_name"], importlib.import_module(m["module_path"]) + ) + else: + raise ModuleNotFoundError( + f"\n " + f"\nPlease install the following libraries to support the class {cls.__name__}:" + f"\n pip install {' '.join(cls.DEPENDENCIES)}" + f"\n " + ) + + def __new__(cls, *args, **kwargs): + + cls._import_module() + return super().__new__(cls) + + @abstractmethod + def detect(self, image): + pass diff --git a/src/layoutparser/ocr.py b/src/layoutparser/ocr/gcv_agent.py similarity index 55% rename from src/layoutparser/ocr.py rename to src/layoutparser/ocr/gcv_agent.py index 65f5ccc..cc0b016 100644 --- a/src/layoutparser/ocr.py +++ b/src/layoutparser/ocr/gcv_agent.py @@ -1,91 +1,19 @@ -from abc import ABC, abstractmethod -from enum import IntEnum -import importlib import io import os import json -import csv import warnings -import pickle import numpy as np -import pandas as pd from cv2 import imencode -from .elements import * -from .io import load_dataframe - -__all__ = ["GCVFeatureType", "GCVAgent", "TesseractFeatureType", "TesseractAgent"] +from .base import BaseOCRAgent, BaseOCRElementType +from ..elements import Layout, TextBlock, Quadrilateral, TextBlock def _cvt_GCV_vertices_to_points(vertices): return np.array([[vertex.x, vertex.y] for vertex in vertices]) -class BaseOCRElementType(IntEnum): - @property - @abstractmethod - def attr_name(self): - pass - - -class BaseOCRAgent(ABC): - @property - @abstractmethod - def DEPENDENCIES(self): - """DEPENDENCIES lists all necessary dependencies for the class.""" - pass - - @property - @abstractmethod - def MODULES(self): - """MODULES instructs how to import these necessary libraries. - - Note: - Sometimes a python module have different installation name and module name (e.g., - `pip install tensorflow-gpu` when installing and `import tensorflow` when using - ). And sometimes we only need to import a submodule but not whole module. MODULES - is designed for this purpose. - - Returns: - :obj: list(dict): A list of dict indicate how the model is imported. - - Example:: - - [{ - "import_name": "_vision", - "module_path": "google.cloud.vision" - }] - - is equivalent to self._vision = importlib.import_module("google.cloud.vision") - """ - pass - - @classmethod - def _import_module(cls): - for m in cls.MODULES: - if importlib.util.find_spec(m["module_path"]): - setattr( - cls, m["import_name"], importlib.import_module(m["module_path"]) - ) - else: - raise ModuleNotFoundError( - f"\n " - f"\nPlease install the following libraries to support the class {cls.__name__}:" - f"\n pip install {' '.join(cls.DEPENDENCIES)}" - f"\n " - ) - - def __new__(cls, *args, **kwargs): - - cls._import_module() - return super().__new__(cls) - - @abstractmethod - def detect(self, image): - pass - - class GCVFeatureType(BaseOCRElementType): """ The element types from Google Cloud Vision API @@ -341,166 +269,4 @@ def save_response(self, res, file_name): with open(file_name, "w") as f: json_file = json.loads(res) - json.dump(json_file, f) - - -class TesseractFeatureType(BaseOCRElementType): - """ - The element types for Tesseract Detection API - """ - - PAGE = 0 - BLOCK = 1 - PARA = 2 - LINE = 3 - WORD = 4 - - @property - def attr_name(self): - name_cvt = { - TesseractFeatureType.PAGE: "page_num", - TesseractFeatureType.BLOCK: "block_num", - TesseractFeatureType.PARA: "par_num", - TesseractFeatureType.LINE: "line_num", - TesseractFeatureType.WORD: "word_num", - } - return name_cvt[self] - - @property - def group_levels(self): - levels = ["page_num", "block_num", "par_num", "line_num", "word_num"] - return levels[: self + 1] - - -class TesseractAgent(BaseOCRAgent): - """ - A wrapper for `Tesseract `_ Text - Detection APIs based on `PyTesseract `_. - """ - - DEPENDENCIES = ["pytesseract"] - MODULES = [{"import_name": "_pytesseract", "module_path": "pytesseract"}] - - def __init__(self, languages="eng", **kwargs): - """Create a Tesseract OCR Agent. - - Args: - languages (:obj:`list` or :obj:`str`, optional): - You can specify the language code(s) of the documents to detect to improve - accuracy. The supported language and their code can be found on - `its github repo `_. - It supports two formats: 1) you can pass in the languages code as a string - of format like `"eng+fra"`, or 2) you can pack them as a list of strings - `["eng", "fra"]`. - Defaults to 'eng'. - """ - self.lang = languages if isinstance(languages, str) else "+".join(languages) - self.configs = kwargs - - @classmethod - def with_tesseract_executable(cls, tesseract_cmd_path, **kwargs): - - cls._pytesseract.pytesseract.tesseract_cmd = tesseract_cmd_path - return cls(**kwargs) - - def _detect(self, img_content): - res = {} - res["text"] = self._pytesseract.image_to_string( - img_content, lang=self.lang, **self.configs - ) - _data = self._pytesseract.image_to_data( - img_content, lang=self.lang, **self.configs - ) - res["data"] = pd.read_csv( - io.StringIO(_data), quoting=csv.QUOTE_NONE, encoding="utf-8", sep="\t" - ) - return res - - def detect( - self, image, return_response=False, return_only_text=True, agg_output_level=None - ): - """Send the input image for OCR. - - Args: - image (:obj:`np.ndarray` or :obj:`str`): - The input image array or the name of the image file - return_response (:obj:`bool`, optional): - Whether directly return all output (string and boxes - info) from Tesseract. - Defaults to `False`. - return_only_text (:obj:`bool`, optional): - Whether return only the texts in the OCR results. - Defaults to `False`. - agg_output_level (:obj:`~TesseractFeatureType`, optional): - When set, aggregate the GCV output with respect to the - specified aggregation level. Defaults to `None`. - """ - - res = self._detect(image) - - if return_response: - return res - - if return_only_text: - return res["text"] - - if agg_output_level is not None: - return self.gather_data(res, agg_output_level) - - return res["text"] - - @staticmethod - def gather_data(response, agg_level): - """ - Gather the OCR'ed text, bounding boxes, and confidence - in a given aggeragation level. - """ - assert isinstance( - agg_level, TesseractFeatureType - ), f"Invalid agg_level {agg_level}" - res = response["data"] - df = ( - res[~res.text.isna()] - .groupby(agg_level.group_levels) - .apply( - lambda gp: pd.Series( - [ - gp["left"].min(), - gp["top"].min(), - gp["width"].max(), - gp["height"].max(), - gp["conf"].mean(), - gp["text"].str.cat(sep=" "), - ] - ) - ) - .reset_index(drop=True) - .reset_index() - .rename( - columns={ - 0: "x_1", - 1: "y_1", - 2: "w", - 3: "h", - 4: "score", - 5: "text", - "index": "id", - } - ) - .assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, block_type="rectangle") - .drop(columns=["w", "h"]) - ) - - return load_dataframe(df) - - @staticmethod - def load_response(filename): - with open(filename, "rb") as fp: - res = pickle.load(fp) - return res - - @staticmethod - def save_response(res, file_name): - - with open(file_name, "wb") as fp: - pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL) + json.dump(json_file, f) \ No newline at end of file diff --git a/src/layoutparser/ocr/tesseract_agent.py b/src/layoutparser/ocr/tesseract_agent.py new file mode 100644 index 0000000..5c64357 --- /dev/null +++ b/src/layoutparser/ocr/tesseract_agent.py @@ -0,0 +1,169 @@ +import io +import csv +import pickle + +import pandas as pd + +from .base import BaseOCRAgent, BaseOCRElementType +from ..io import load_dataframe + +class TesseractFeatureType(BaseOCRElementType): + """ + The element types for Tesseract Detection API + """ + + PAGE = 0 + BLOCK = 1 + PARA = 2 + LINE = 3 + WORD = 4 + + @property + def attr_name(self): + name_cvt = { + TesseractFeatureType.PAGE: "page_num", + TesseractFeatureType.BLOCK: "block_num", + TesseractFeatureType.PARA: "par_num", + TesseractFeatureType.LINE: "line_num", + TesseractFeatureType.WORD: "word_num", + } + return name_cvt[self] + + @property + def group_levels(self): + levels = ["page_num", "block_num", "par_num", "line_num", "word_num"] + return levels[: self + 1] + + +class TesseractAgent(BaseOCRAgent): + """ + A wrapper for `Tesseract `_ Text + Detection APIs based on `PyTesseract `_. + """ + + DEPENDENCIES = ["pytesseract"] + MODULES = [{"import_name": "_pytesseract", "module_path": "pytesseract"}] + + def __init__(self, languages="eng", **kwargs): + """Create a Tesseract OCR Agent. + + Args: + languages (:obj:`list` or :obj:`str`, optional): + You can specify the language code(s) of the documents to detect to improve + accuracy. The supported language and their code can be found on + `its github repo `_. + It supports two formats: 1) you can pass in the languages code as a string + of format like `"eng+fra"`, or 2) you can pack them as a list of strings + `["eng", "fra"]`. + Defaults to 'eng'. + """ + self.lang = languages if isinstance(languages, str) else "+".join(languages) + self.configs = kwargs + + @classmethod + def with_tesseract_executable(cls, tesseract_cmd_path, **kwargs): + + cls._pytesseract.pytesseract.tesseract_cmd = tesseract_cmd_path + return cls(**kwargs) + + def _detect(self, img_content): + res = {} + res["text"] = self._pytesseract.image_to_string( + img_content, lang=self.lang, **self.configs + ) + _data = self._pytesseract.image_to_data( + img_content, lang=self.lang, **self.configs + ) + res["data"] = pd.read_csv( + io.StringIO(_data), quoting=csv.QUOTE_NONE, encoding="utf-8", sep="\t" + ) + return res + + def detect( + self, image, return_response=False, return_only_text=True, agg_output_level=None + ): + """Send the input image for OCR. + + Args: + image (:obj:`np.ndarray` or :obj:`str`): + The input image array or the name of the image file + return_response (:obj:`bool`, optional): + Whether directly return all output (string and boxes + info) from Tesseract. + Defaults to `False`. + return_only_text (:obj:`bool`, optional): + Whether return only the texts in the OCR results. + Defaults to `False`. + agg_output_level (:obj:`~TesseractFeatureType`, optional): + When set, aggregate the GCV output with respect to the + specified aggregation level. Defaults to `None`. + """ + + res = self._detect(image) + + if return_response: + return res + + if return_only_text: + return res["text"] + + if agg_output_level is not None: + return self.gather_data(res, agg_output_level) + + return res["text"] + + @staticmethod + def gather_data(response, agg_level): + """ + Gather the OCR'ed text, bounding boxes, and confidence + in a given aggeragation level. + """ + assert isinstance( + agg_level, TesseractFeatureType + ), f"Invalid agg_level {agg_level}" + res = response["data"] + df = ( + res[~res.text.isna()] + .groupby(agg_level.group_levels) + .apply( + lambda gp: pd.Series( + [ + gp["left"].min(), + gp["top"].min(), + gp["width"].max(), + gp["height"].max(), + gp["conf"].mean(), + gp["text"].str.cat(sep=" "), + ] + ) + ) + .reset_index(drop=True) + .reset_index() + .rename( + columns={ + 0: "x_1", + 1: "y_1", + 2: "w", + 3: "h", + 4: "score", + 5: "text", + "index": "id", + } + ) + .assign(x_2=lambda x: x.x_1 + x.w, y_2=lambda x: x.y_1 + x.h, block_type="rectangle") + .drop(columns=["w", "h"]) + ) + + return load_dataframe(df) + + @staticmethod + def load_response(filename): + with open(filename, "rb") as fp: + res = pickle.load(fp) + return res + + @staticmethod + def save_response(res, file_name): + + with open(file_name, "wb") as fp: + pickle.dump(res, fp, protocol=pickle.HIGHEST_PROTOCOL)