|
5 | 5 | import logging |
6 | 6 | from itertools import islice |
7 | 7 |
|
8 | | -from .utils import remove_password_from_url |
| 8 | +from .utils import remove_password_from_url, safezip |
9 | 9 |
|
10 | 10 | from .diff_tables import ( |
11 | 11 | TableSegment, |
12 | 12 | TableDiffer, |
13 | 13 | DEFAULT_BISECTION_THRESHOLD, |
14 | 14 | DEFAULT_BISECTION_FACTOR, |
| 15 | + create_schema, |
15 | 16 | ) |
16 | 17 | from .databases.connect import connect |
17 | 18 | from .parse_time import parse_time_before_now, UNITS_STR, ParseError |
@@ -39,6 +40,11 @@ def _remove_passwords_in_dict(d: dict): |
39 | 40 | d[k] = remove_password_from_url(v) |
40 | 41 |
|
41 | 42 |
|
| 43 | +def _get_schema(pair): |
| 44 | + db, table_path = pair |
| 45 | + return db.query_table_schema(table_path) |
| 46 | + |
| 47 | + |
42 | 48 | @click.command() |
43 | 49 | @click.argument("database1", required=False) |
44 | 50 | @click.argument("table1", required=False) |
@@ -67,7 +73,12 @@ def _remove_passwords_in_dict(d: dict): |
67 | 73 | @click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability") |
68 | 74 | @click.option("-v", "--verbose", is_flag=True, help="Print extra info") |
69 | 75 | @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") |
70 | | -@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") |
| 76 | +@click.option( |
| 77 | + "--case-sensitive", |
| 78 | + is_flag=True, |
| 79 | + help="Column names are treated as case-sensitive. Otherwise, correct case according to schema.", |
| 80 | +) |
| 81 | +@click.option("--mutual-columns", is_flag=True, help="XXX") |
71 | 82 | @click.option( |
72 | 83 | "-j", |
73 | 84 | "--threads", |
@@ -111,7 +122,8 @@ def _main( |
111 | 122 | verbose, |
112 | 123 | interactive, |
113 | 124 | threads, |
114 | | - keep_column_case, |
| 125 | + case_sensitive, |
| 126 | + mutual_columns, |
115 | 127 | json_output, |
116 | 128 | where, |
117 | 129 | threads1=None, |
@@ -158,35 +170,53 @@ def _main( |
158 | 170 |
|
159 | 171 | db1 = connect(database1, threads1 or threads) |
160 | 172 | db2 = connect(database2, threads2 or threads) |
| 173 | + dbs = db1, db2 |
161 | 174 |
|
162 | 175 | if interactive: |
163 | | - db1.enable_interactive() |
164 | | - db2.enable_interactive() |
| 176 | + for db in dbs: |
| 177 | + db.enable_interactive() |
165 | 178 |
|
166 | 179 | start = time.time() |
167 | 180 |
|
168 | 181 | try: |
169 | 182 | options = dict( |
170 | 183 | min_update=max_age and parse_time_before_now(max_age), |
171 | 184 | max_update=min_age and parse_time_before_now(min_age), |
172 | | - case_sensitive=keep_column_case, |
| 185 | + case_sensitive=case_sensitive, |
173 | 186 | where=where, |
174 | 187 | ) |
175 | 188 | except ParseError as e: |
176 | 189 | logging.error("Error while parsing age expression: %s" % e) |
177 | 190 | return |
178 | 191 |
|
179 | | - table1_seg = TableSegment(db1, db1.parse_table_name(table1), key_column, update_column, columns, **options) |
180 | | - table2_seg = TableSegment(db2, db2.parse_table_name(table2), key_column, update_column, columns, **options) |
181 | | - |
182 | 192 | differ = TableDiffer( |
183 | 193 | bisection_factor=bisection_factor, |
184 | 194 | bisection_threshold=bisection_threshold, |
185 | 195 | threaded=threaded, |
186 | 196 | max_threadpool_size=threads and threads * 2, |
187 | 197 | debug=debug, |
188 | 198 | ) |
189 | | - diff_iter = differ.diff_tables(table1_seg, table2_seg) |
| 199 | + |
| 200 | + table_names = table1, table2 |
| 201 | + table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)] |
| 202 | + |
| 203 | + schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths))) |
| 204 | + schema1, schema2 = schemas = [ |
| 205 | + create_schema(db, table_path, schema, case_sensitive) |
| 206 | + for db, table_path, schema in safezip(dbs, table_paths, schemas) |
| 207 | + ] |
| 208 | + |
| 209 | + if mutual_columns: |
| 210 | + mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive |
| 211 | + provided_columns = {key_column, update_column} | set(columns) |
| 212 | + columns += tuple(mutual - provided_columns) |
| 213 | + |
| 214 | + segments = [ |
| 215 | + TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema) |
| 216 | + for db, table_path, raw_schema in safezip(dbs, table_paths, schemas) |
| 217 | + ] |
| 218 | + |
| 219 | + diff_iter = differ.diff_tables(*segments) |
190 | 220 |
|
191 | 221 | if limit: |
192 | 222 | diff_iter = islice(diff_iter, int(limit)) |
|
0 commit comments