Skip to content

Commit 3a1b410

Browse files
committed
Proper implementation of 2tsv
1 parent d840489 commit 3a1b410

File tree

1 file changed

+77
-45
lines changed

1 file changed

+77
-45
lines changed

ragability/ragability_2tsv.py

Lines changed: 77 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -41,55 +41,82 @@ def get_args():
4141
return args
4242

4343

44+
def max_elements(data: List[Dict], keys: List[str|int]):
45+
"""
46+
Get the maximum number of elements in the array identified by the given list of keys that lead to the
47+
nested array.
48+
"""
49+
maxn = 0
50+
for record in data:
51+
# access the nested list or ignore the record if it does not exist
52+
val = record
53+
for k in keys:
54+
if isinstance(val, list):
55+
val = val[k]
56+
else:
57+
val = val.get(k)
58+
if val is None:
59+
break
60+
if val is not None:
61+
if isinstance(val, list):
62+
maxn = max(maxn, len(val))
63+
else:
64+
maxn = max(maxn, 1)
65+
return maxn
4466

4567
def run(config: dict):
4668
indata = read_input_file(config["input"])
4769
logger.info(f"Read {len(indata)} records from {config['input']}")
48-
# indata is a list of nested dictionaries where each of the values in the dictionaries could
49-
# be a scalar value, a nested dictionary, a scalar value or a list of scalar values or
50-
# nested dictionaries. We need to flatten this structure into a list of flat dictionaries where
51-
# each dictionary contains only scalar values. We will use a recursive function to do this.
52-
# The names use dots to separate nested fields and underscores to separate array indices.
53-
# Example:
54-
# indata = [ { "a": 1, "b": [2, 3], "c": { "d": 4, "e": 5 }, "e": [{ "f": 6 },{ "f": 7 }] } ]
55-
# outdata = [ { "a": 1, "b_0": 2, "b_1": 3, "c.d": 4, "c.e": 5, "e_0.f": 6, "e_1.f": 7 } ]
56-
# First we analyse all the nested dictionaries in the list to find all the field names
57-
# and the maximum number of elements in any array.
58-
# We also need to make sure that none of the text fields contain any new lines or tabs before
59-
# we write the tsv file.
60-
def analyse(indata):
61-
fieldnames = set()
62-
maxarraysize = 0
63-
for item in indata:
64-
for k, v in item.items():
65-
fieldnames.add(k)
66-
if isinstance(v, list):
67-
maxarraysize = max(maxarraysize, len(v))
68-
return fieldnames, maxarraysize
69-
# now we actually convert the list of nested dictionaries into a list of flat dictionaries
70-
def flatten(indata):
71-
fieldnames, maxarraysize = analyse(indata)
72-
flatdata = []
73-
for item in indata:
74-
flatitem = {}
75-
for k in fieldnames:
76-
v = item.get(k)
77-
if isinstance(v, list):
78-
for i, vi in enumerate(v):
79-
flatitem[f"{k}_{i}"] = vi
80-
elif isinstance(v, dict):
81-
for k1, v1 in v.items():
82-
flatitem[f"{k}.{k1}"] = v1
83-
else:
84-
flatitem[k] = v
85-
flatdata.append(flatitem)
86-
return flatdata
87-
flatdata = flatten(indata)
88-
# make sure there are no new lines or tabs in the text fields
89-
for item in flatdata:
90-
for k, v in item.items():
91-
if isinstance(v, str):
92-
item[k] = v.replace("\n", " ").replace("\t", " ")
70+
71+
# For now we only support instances with a single element in the "checks" field and the pid field
72+
maxn_checks = max_elements(indata, ["checks"])
73+
if maxn_checks > 1:
74+
logger.error(f"Only one element in the 'checks' field is supported for now, but found {maxn_checks}")
75+
sys.exit(1)
76+
maxn_pids = max_elements(indata, ["checks", 0, "pids"])
77+
if maxn_pids > 1:
78+
logger.error(f"Only one element in the 'pids' field is supported for now, but found {maxn_pids}")
79+
sys.exit(1)
80+
TOPFIELDS_SCALAR = ["qid", "tags", "query", "WikiContradict_ID", "reasoning_required_c1c2", "response", "error", "pid", "llm"]
81+
CHECKFIELDS = ["cid", "query", "func", "metrics", "pid", "response", "llm", "result", "error", "check_for"]
82+
flatdata = []
83+
# find the maximum number for "facts"
84+
maxn_facts = max_elements(indata, ["facts"])
85+
unknownfields = Counter()
86+
for record in indata:
87+
flatrecord = {}
88+
for field in TOPFIELDS_SCALAR:
89+
flatrecord[field] = record.get(field, "")
90+
# add the facts fields, not all records have the same number of facts
91+
facts = record.get("facts")
92+
if facts is None:
93+
facts = []
94+
elif isinstance(facts, str):
95+
facts = [facts]
96+
for i in range(maxn_facts):
97+
if i < len(facts):
98+
flatrecord[f"facts_{i}"] = facts[i]
99+
else:
100+
flatrecord[f"facts_{i}"] = ""
101+
check = record.get("checks", [{}])[0]
102+
for field in CHECKFIELDS:
103+
val = check.get(field, "")
104+
if isinstance(val, list):
105+
val = ", ".join(val)
106+
flatrecord[f"check.{field}"] = val
107+
# check if the check dict has any fields not mentioned in CHECKFIELDS, if so, count them using the
108+
# name check.{unknownfieldname}
109+
for k in check:
110+
if k not in CHECKFIELDS:
111+
if k not in ["cost"]:
112+
unknownfields[f"check.{k}"] += 1
113+
# also check if the top level record has any fields not mentioned in TOPFIELDS_SCALAR
114+
for k in record:
115+
if k not in TOPFIELDS_SCALAR:
116+
if k not in ["checks", "c1xq", "c2xq", "cost", "pids", "facts"]:
117+
unknownfields[k] += 1
118+
flatdata.append(flatrecord)
119+
# convert to a data frame
93120
df = pd.DataFrame(flatdata)
94121
logger.info(f"Converted to dataframe with {df.shape[0]} rows and {df.shape[1]} columns")
95122
# Now we have the dataframe, we can write it to the output file
@@ -98,6 +125,11 @@ def flatten(indata):
98125
outputfile = os.path.splitext(config["input"])[0] + ".tsv"
99126
df.to_csv(outputfile, sep="\t", index=False)
100127
logger.info(f"Output written to {outputfile}")
128+
# print out the unknown fields, if there are any
129+
if unknownfields:
130+
logger.info("Unknown fields:")
131+
for k, v in unknownfields.items():
132+
logger.info(f"{k}: {v}")
101133

102134

103135
def main():

0 commit comments

Comments
 (0)