@@ -41,55 +41,82 @@ def get_args():
4141 return args
4242
4343
44+ def max_elements (data : List [Dict ], keys : List [str | int ]):
45+ """
46+ Get the maximum number of elements in the array identified by the given list of keys that lead to the
47+ nested array.
48+ """
49+ maxn = 0
50+ for record in data :
51+ # access the nested list or ignore the record if it does not exist
52+ val = record
53+ for k in keys :
54+ if isinstance (val , list ):
55+ val = val [k ]
56+ else :
57+ val = val .get (k )
58+ if val is None :
59+ break
60+ if val is not None :
61+ if isinstance (val , list ):
62+ maxn = max (maxn , len (val ))
63+ else :
64+ maxn = max (maxn , 1 )
65+ return maxn
4466
4567def run (config : dict ):
4668 indata = read_input_file (config ["input" ])
4769 logger .info (f"Read { len (indata )} records from { config ['input' ]} " )
48- # indata is a list of nested dictionaries where each of the values in the dictionaries could
49- # be a scalar value, a nested dictionary, a scalar value or a list of scalar values or
50- # nested dictionaries. We need to flatten this structure into a list of flat dictionaries where
51- # each dictionary contains only scalar values. We will use a recursive function to do this.
52- # The names use dots to separate nested fields and underscores to separate array indices.
53- # Example:
54- # indata = [ { "a": 1, "b": [2, 3], "c": { "d": 4, "e": 5 }, "e": [{ "f": 6 },{ "f": 7 }] } ]
55- # outdata = [ { "a": 1, "b_0": 2, "b_1": 3, "c.d": 4, "c.e": 5, "e_0.f": 6, "e_1.f": 7 } ]
56- # First we analyse all the nested dictionaries in the list to find all the field names
57- # and the maximum number of elements in any array.
58- # We also need to make sure that none of the text fields contain any new lines or tabs before
59- # we write the tsv file.
60- def analyse (indata ):
61- fieldnames = set ()
62- maxarraysize = 0
63- for item in indata :
64- for k , v in item .items ():
65- fieldnames .add (k )
66- if isinstance (v , list ):
67- maxarraysize = max (maxarraysize , len (v ))
68- return fieldnames , maxarraysize
69- # now we actually convert the list of nested dictionaries into a list of flat dictionaries
70- def flatten (indata ):
71- fieldnames , maxarraysize = analyse (indata )
72- flatdata = []
73- for item in indata :
74- flatitem = {}
75- for k in fieldnames :
76- v = item .get (k )
77- if isinstance (v , list ):
78- for i , vi in enumerate (v ):
79- flatitem [f"{ k } _{ i } " ] = vi
80- elif isinstance (v , dict ):
81- for k1 , v1 in v .items ():
82- flatitem [f"{ k } .{ k1 } " ] = v1
83- else :
84- flatitem [k ] = v
85- flatdata .append (flatitem )
86- return flatdata
87- flatdata = flatten (indata )
88- # make sure there are no new lines or tabs in the text fields
89- for item in flatdata :
90- for k , v in item .items ():
91- if isinstance (v , str ):
92- item [k ] = v .replace ("\n " , " " ).replace ("\t " , " " )
70+
71+ # For now we only support instances with a single element in the "checks" field and the pid field
72+ maxn_checks = max_elements (indata , ["checks" ])
73+ if maxn_checks > 1 :
74+ logger .error (f"Only one element in the 'checks' field is supported for now, but found { maxn_checks } " )
75+ sys .exit (1 )
76+ maxn_pids = max_elements (indata , ["checks" , 0 , "pids" ])
77+ if maxn_pids > 1 :
78+ logger .error (f"Only one element in the 'pids' field is supported for now, but found { maxn_pids } " )
79+ sys .exit (1 )
80+ TOPFIELDS_SCALAR = ["qid" , "tags" , "query" , "WikiContradict_ID" , "reasoning_required_c1c2" , "response" , "error" , "pid" , "llm" ]
81+ CHECKFIELDS = ["cid" , "query" , "func" , "metrics" , "pid" , "response" , "llm" , "result" , "error" , "check_for" ]
82+ flatdata = []
83+ # find the maximum number for "facts"
84+ maxn_facts = max_elements (indata , ["facts" ])
85+ unknownfields = Counter ()
86+ for record in indata :
87+ flatrecord = {}
88+ for field in TOPFIELDS_SCALAR :
89+ flatrecord [field ] = record .get (field , "" )
90+ # add the facts fields, not all records have the same number of facts
91+ facts = record .get ("facts" )
92+ if facts is None :
93+ facts = []
94+ elif isinstance (facts , str ):
95+ facts = [facts ]
96+ for i in range (maxn_facts ):
97+ if i < len (facts ):
98+ flatrecord [f"facts_{ i } " ] = facts [i ]
99+ else :
100+ flatrecord [f"facts_{ i } " ] = ""
101+ check = record .get ("checks" , [{}])[0 ]
102+ for field in CHECKFIELDS :
103+ val = check .get (field , "" )
104+ if isinstance (val , list ):
105+ val = ", " .join (val )
106+ flatrecord [f"check.{ field } " ] = val
107+ # check if the check dict has any fields not mentioned in CHECKFIELDS, if so, count them using the
108+ # name check.{unknownfieldname}
109+ for k in check :
110+ if k not in CHECKFIELDS :
111+ if k not in ["cost" ]:
112+ unknownfields [f"check.{ k } " ] += 1
113+ # also check if the top level record has any fields not mentioned in TOPFIELDS_SCALAR
114+ for k in record :
115+ if k not in TOPFIELDS_SCALAR :
116+ if k not in ["checks" , "c1xq" , "c2xq" , "cost" , "pids" , "facts" ]:
117+ unknownfields [k ] += 1
118+ flatdata .append (flatrecord )
119+ # convert to a data frame
93120 df = pd .DataFrame (flatdata )
94121 logger .info (f"Converted to dataframe with { df .shape [0 ]} rows and { df .shape [1 ]} columns" )
95122 # Now we have the dataframe, we can write it to the output file
@@ -98,6 +125,11 @@ def flatten(indata):
98125 outputfile = os .path .splitext (config ["input" ])[0 ] + ".tsv"
99126 df .to_csv (outputfile , sep = "\t " , index = False )
100127 logger .info (f"Output written to { outputfile } " )
128+ # print out the unknown fields, if there are any
129+ if unknownfields :
130+ logger .info ("Unknown fields:" )
131+ for k , v in unknownfields .items ():
132+ logger .info (f"{ k } : { v } " )
101133
102134
103135def main ():
0 commit comments