Skip to content

Commit 3b1f214

Browse files
committed
[fix] invalid link variables
1 parent f808274 commit 3b1f214

File tree

3 files changed

+78
-37
lines changed

3 files changed

+78
-37
lines changed

EduNLP/Formula/Formula.py

Lines changed: 69 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,36 @@
1414

1515

1616
class Formula(object):
17+
"""
18+
Examples
19+
--------
20+
>>> f = Formula("x")
21+
>>> f
22+
<Formula: x>
23+
>>> f.ast
24+
[{'val': {'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, \
25+
'structure': {'bro': [None, None], 'child': None, 'father': None, 'forest': None}}]
26+
>>> f.elements
27+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}]
28+
>>> f.variable_standardization(inplace=True)
29+
<Formula: x>
30+
>>> f.elements
31+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
32+
"""
33+
1734
def __init__(self, formula: (str, List[Dict]), variable_standardization=False, const_mathord=None,
1835
*args, **kwargs):
36+
"""
37+
38+
Parameters
39+
----------
40+
formula: str or List[Dict]
41+
latex formula string or the parsed abstracted syntax tree
42+
variable_standardization
43+
const_mathord
44+
args
45+
kwargs
46+
"""
1947
self._formula = formula
2048
self._ast = None
2149
self.reset_ast(
@@ -43,11 +71,15 @@ def variable_standardization(self, inplace=False, const_mathord=None, variable_c
4371
return Formula(ast_tree, is_str=False)
4472

4573
@property
46-
def element(self):
74+
def ast(self):
4775
return self._ast
4876

4977
@property
50-
def ast(self) -> (nx.Graph, nx.DiGraph):
78+
def elements(self):
79+
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]
80+
81+
@property
82+
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
5183
edges = [(edge[0], edge[1]) for edge in get_edges(self._ast) if edge[2] == 3]
5284
tree = nx.DiGraph()
5385
for node in self._ast:
@@ -67,8 +99,9 @@ def __repr__(self):
6799
else:
68100
return super(Formula, self).__repr__()
69101

70-
def reset_ast(self, formula_ensure_str=True, variable_standardization=False, const_mathord=None, *args, **kwargs):
71-
if formula_ensure_str is True and self.resetable is True:
102+
def reset_ast(self, formula_ensure_str: bool = True, variable_standardization=False, const_mathord=None, *args,
103+
**kwargs):
104+
if formula_ensure_str is True and self.resetable is False:
72105
raise TypeError("formula must be str, now is %s" % type(self._formula))
73106
self._ast = str2ast(self._formula, *args, **kwargs) if isinstance(self._formula, str) else self._formula
74107
if variable_standardization:
@@ -82,34 +115,49 @@ def resetable(self):
82115

83116

84117
class FormulaGroup(object):
118+
"""
119+
Examples
120+
---------
121+
>>> fg = FormulaGroup(["x + y", "y + x", "z + x"])
122+
>>> fg
123+
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
124+
>>> fg = FormulaGroup(["x + y", Formula("y + x"), "z + x"])
125+
>>> fg
126+
<FormulaGroup: <Formula: x + y>;<Formula: y + x>;<Formula: z + x>>
127+
>>> fg = FormulaGroup(["x", Formula("y"), "x"])
128+
>>> fg.elements
129+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None}, {'id': 1, 'type': 'mathord', 'text': 'y', 'role': None},\
130+
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None}]
131+
>>> fg = FormulaGroup(["x", Formula("y"), "x"], variable_standardization=True)
132+
>>> fg.elements
133+
[{'id': 0, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}, \
134+
{'id': 1, 'type': 'mathord', 'text': 'y', 'role': None, 'var': 1}, \
135+
{'id': 2, 'type': 'mathord', 'text': 'x', 'role': None, 'var': 0}]
136+
"""
137+
85138
def __init__(self,
86-
formula_list: (List[(str, Formula, dict)]),
139+
formula_list: (list, List[str], List[Formula]),
87140
variable_standardization=False,
88141
const_mathord=None,
89142
detach=True
90143
):
91-
"""
92-
93-
Parameters
94-
----------
95-
formula_list: List[str]
96-
"""
97144
forest = []
98145
self._formulas = []
99-
for index in range(0, len(formula_list)):
100-
formula = formula_list[index]
146+
for formula in formula_list:
101147
if isinstance(formula, str):
102-
tree = str2ast(
148+
formula = Formula(
103149
formula,
104150
forest_begin=len(forest),
105151
)
106-
self._formulas.append(Formula(tree))
152+
self._formulas.append(formula)
153+
tree = formula.ast
107154
elif isinstance(formula, Formula):
108155
if detach:
109156
formula = deepcopy(formula)
110157
tree = formula.reset_ast(
111158
formula_ensure_str=True,
112159
variable_standardization=False,
160+
forest_begin=len(forest),
113161
)
114162
self._formulas.append(formula)
115163
else:
@@ -149,11 +197,15 @@ def __repr__(self):
149197
return "<FormulaGroup: %s>" % ";".join([repr(_formula) for _formula in self._formulas])
150198

151199
@property
152-
def element(self):
200+
def ast(self):
153201
return self._forest
154202

155203
@property
156-
def ast(self) -> (nx.Graph, nx.DiGraph):
204+
def elements(self):
205+
return [self.ast_graph.nodes[node] for node in self.ast_graph.nodes]
206+
207+
@property
208+
def ast_graph(self) -> (nx.Graph, nx.DiGraph):
157209
edges = [(edge[0], edge[1]) for edge in get_edges(self._forest) if edge[2] == 3]
158210
tree = nx.DiGraph()
159211
for node in self._forest:

EduNLP/SIF/tokenization/formula/ast_token.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,12 @@ def ast_tokenize(formula, ord2token=False, var_numbering=False, return_type="for
7979
<Formula: {x + y}^\\frac{\\pi}{2} + 1 = x>
8080
"""
8181
if return_type == "list":
82-
ast = Formula(formula, variable_standardization=True).ast
82+
ast = Formula(formula, variable_standardization=True).ast_graph
8383
return traversal_formula(ast, ord2token=ord2token, var_numbering=var_numbering)
8484
elif return_type == "formula":
8585
return Formula(formula)
8686
elif return_type == "ast":
87-
return Formula(formula).ast
87+
return Formula(formula).ast_graph
8888
else:
8989
raise ValueError()
9090

EduNLP/SIF/tokenization/tokenization.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# 2021/5/18 @ tongshiwei
33

44
import itertools as it
5-
from EduNLP.Formula import link_variable, Formula
5+
from EduNLP.Formula import link_formulas as _link_formulas, Formula
66
from ..constants import Symbol, TEXT_SYMBOL, FIGURE_SYMBOL, FORMULA_SYMBOL, QUES_MARK_SYMBOL
77
from ..segment import (SegmentList, TextSegment, FigureSegment, LatexFormulaSegment, FigureFormulaSegment,
88
QuesMarkSegment, Figure)
@@ -26,15 +26,9 @@ def __init__(self, segment_list: SegmentList, text_params=None, formula_params=N
2626

2727
def _variable_standardization(self):
2828
if self.formula_tokenize_method == "ast":
29-
ast_formulas = [self._tokens[i].element for i in self._formula_tokens if
30-
isinstance(self._tokens[i], Formula)]
29+
ast_formulas = [self._tokens[i] for i in self._formula_tokens if isinstance(self._tokens[i], Formula)]
3130
if ast_formulas:
32-
link_variable(list(it.chain(*ast_formulas)))
33-
self.variable_standardization()
34-
35-
def variable_standardization(self):
36-
for i in self._formula_tokens:
37-
self._tokens[i].variable_standardization(inplace=True)
31+
_link_formulas(*ast_formulas)
3832

3933
@property
4034
def tokens(self):
@@ -114,9 +108,9 @@ def text_tokens(self):
114108
def __add_token(self, token, tokens):
115109
if isinstance(token, Formula):
116110
if self.formula_params.get("return_type") == "list":
117-
tokens.extend(formula.traversal_formula(token.ast, **self.formula_params))
111+
tokens.extend(formula.traversal_formula(token.ast_graph, **self.formula_params))
118112
elif self.formula_params.get("return_type") == "ast":
119-
tokens.append(token.ast)
113+
tokens.append(token.ast_graph)
120114
else:
121115
tokens.append(token)
122116
elif isinstance(token, Figure):
@@ -161,10 +155,5 @@ def link_formulas(*token_list: TokenList):
161155
ast_formulas = []
162156
for tl in token_list:
163157
if tl.formula_tokenize_method == "ast":
164-
ast_formulas.extend([
165-
token.element for token in tl.inner_formula_tokens
166-
if isinstance(token, Formula)
167-
])
168-
link_variable(list(it.chain(*ast_formulas)))
169-
for tl in token_list:
170-
tl.variable_standardization()
158+
ast_formulas.extend([token for token in tl.inner_formula_tokens if isinstance(token, Formula)])
159+
_link_formulas(*ast_formulas)

0 commit comments

Comments
 (0)