Skip to content
This repository was archived by the owner on May 17, 2024. It is now read-only.

Commit 6a4c443

Browse files
committed
Mutual columns
1 parent b1bebee commit 6a4c443

File tree

6 files changed

+130
-81
lines changed

6 files changed

+130
-81
lines changed

data_diff/__main__.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import logging
66
from itertools import islice
77

8-
from .utils import remove_password_from_url
8+
from .utils import remove_password_from_url, safezip
99

1010
from .diff_tables import (
1111
TableSegment,
1212
TableDiffer,
1313
DEFAULT_BISECTION_THRESHOLD,
1414
DEFAULT_BISECTION_FACTOR,
15+
create_schema,
1516
)
1617
from .databases.connect import connect
1718
from .parse_time import parse_time_before_now, UNITS_STR, ParseError
@@ -39,6 +40,11 @@ def _remove_passwords_in_dict(d: dict):
3940
d[k] = remove_password_from_url(v)
4041

4142

43+
def _get_schema(pair):
44+
db, table_path = pair
45+
return db.query_table_schema(table_path)
46+
47+
4248
@click.command()
4349
@click.argument("database1", required=False)
4450
@click.argument("table1", required=False)
@@ -67,7 +73,12 @@ def _remove_passwords_in_dict(d: dict):
6773
@click.option("--json", "json_output", is_flag=True, help="Print JSONL output for machine readability")
6874
@click.option("-v", "--verbose", is_flag=True, help="Print extra info")
6975
@click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug")
70-
@click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.")
76+
@click.option(
77+
"--case-sensitive",
78+
is_flag=True,
79+
help="Column names are treated as case-sensitive. Otherwise, correct case according to schema.",
80+
)
81+
@click.option("--mutual-columns", is_flag=True, help="XXX")
7182
@click.option(
7283
"-j",
7384
"--threads",
@@ -111,7 +122,8 @@ def _main(
111122
verbose,
112123
interactive,
113124
threads,
114-
keep_column_case,
125+
case_sensitive,
126+
mutual_columns,
115127
json_output,
116128
where,
117129
threads1=None,
@@ -158,35 +170,53 @@ def _main(
158170

159171
db1 = connect(database1, threads1 or threads)
160172
db2 = connect(database2, threads2 or threads)
173+
dbs = db1, db2
161174

162175
if interactive:
163-
db1.enable_interactive()
164-
db2.enable_interactive()
176+
for db in dbs:
177+
db.enable_interactive()
165178

166179
start = time.time()
167180

168181
try:
169182
options = dict(
170183
min_update=max_age and parse_time_before_now(max_age),
171184
max_update=min_age and parse_time_before_now(min_age),
172-
case_sensitive=keep_column_case,
185+
case_sensitive=case_sensitive,
173186
where=where,
174187
)
175188
except ParseError as e:
176189
logging.error("Error while parsing age expression: %s" % e)
177190
return
178191

179-
table1_seg = TableSegment(db1, db1.parse_table_name(table1), key_column, update_column, columns, **options)
180-
table2_seg = TableSegment(db2, db2.parse_table_name(table2), key_column, update_column, columns, **options)
181-
182192
differ = TableDiffer(
183193
bisection_factor=bisection_factor,
184194
bisection_threshold=bisection_threshold,
185195
threaded=threaded,
186196
max_threadpool_size=threads and threads * 2,
187197
debug=debug,
188198
)
189-
diff_iter = differ.diff_tables(table1_seg, table2_seg)
199+
200+
table_names = table1, table2
201+
table_paths = [db.parse_table_name(t) for db, t in safezip(dbs, table_names)]
202+
203+
schemas = list(differ._thread_map(_get_schema, safezip(dbs, table_paths)))
204+
schema1, schema2 = schemas = [
205+
create_schema(db, table_path, schema, case_sensitive)
206+
for db, table_path, schema in safezip(dbs, table_paths, schemas)
207+
]
208+
209+
if mutual_columns:
210+
mutual = schema1.keys() & schema2.keys() # Case-aware, according to case_sensitive
211+
provided_columns = {key_column, update_column} | set(columns)
212+
columns += tuple(mutual - provided_columns)
213+
214+
segments = [
215+
TableSegment(db, table_path, key_column, update_column, columns, **options)._with_raw_schema(raw_schema)
216+
for db, table_path, raw_schema in safezip(dbs, table_paths, schemas)
217+
]
218+
219+
diff_iter = differ.diff_tables(*segments)
190220

191221
if limit:
192222
diff_iter = islice(diff_iter, int(limit))

data_diff/databases/base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
from abc import abstractmethod
99

10-
from data_diff.utils import is_uuid, safezip
10+
from data_diff.utils import CaseAwareMapping, is_uuid, safezip
1111
from .database_types import (
1212
AbstractDatabase,
1313
ColType,
@@ -180,16 +180,19 @@ def select_table_schema(self, path: DbPath) -> str:
180180
f"WHERE table_name = '{table}' AND table_schema = '{schema}'"
181181
)
182182

183-
def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]:
183+
def query_table_schema(self, path: DbPath) -> Dict[str, ColType]:
184184
rows = self.query(self.select_table_schema(path), list)
185185
if not rows:
186186
raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns")
187187

188-
if filter_columns is not None:
189-
accept = {i.lower() for i in filter_columns}
190-
rows = [r for r in rows if r[0].lower() in accept]
188+
d = {r[0]: r for r in rows}
189+
assert len(d) == len(rows)
190+
return d
191191

192-
col_dict: Dict[str, ColType] = {row[0]: self._parse_type(path, *row) for row in rows}
192+
def _process_table_schema(self, path: DbPath, raw_schema: dict, filter_columns: Sequence[str]):
193+
accept = {i.lower() for i in filter_columns}
194+
195+
col_dict = {name: self._parse_type(path, *row) for name, row in raw_schema.items() if name.lower() in accept}
193196

194197
self._refine_coltypes(path, col_dict)
195198

data_diff/databases/database_types.py

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import decimal
22
from abc import ABC, abstractmethod
3-
from typing import Sequence, Optional, Tuple, Union, Dict, List
3+
from typing import Mapping, Sequence, Optional, Tuple, Union, Dict, List
44
from datetime import datetime
55

66
from runtype import dataclass
77

8-
from data_diff.utils import ArithAlphanumeric, ArithUUID, ArithString
8+
from data_diff.utils import ArithAlphanumeric, ArithUUID, CaseAwareMapping
99

1010

1111
DbPath = Tuple[str, ...]
@@ -254,44 +254,4 @@ def _normalize_table_path(self, path: DbPath) -> DbPath:
254254
...
255255

256256

257-
class Schema(ABC):
258-
@abstractmethod
259-
def get_key(self, key: str) -> str:
260-
...
261-
262-
@abstractmethod
263-
def __getitem__(self, key: str) -> ColType:
264-
...
265-
266-
@abstractmethod
267-
def __setitem__(self, key: str, value):
268-
...
269-
270-
@abstractmethod
271-
def __contains__(self, key: str) -> bool:
272-
...
273-
274-
275-
class Schema_CaseSensitive(dict, Schema):
276-
def get_key(self, key):
277-
return key
278-
279-
280-
class Schema_CaseInsensitive(Schema):
281-
def __init__(self, initial):
282-
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
283-
284-
def get_key(self, key: str) -> str:
285-
return self._dict[key.lower()][0]
286-
287-
def __getitem__(self, key: str) -> ColType:
288-
return self._dict[key.lower()][1]
289-
290-
def __setitem__(self, key: str, value):
291-
k = key.lower()
292-
if k in self._dict:
293-
key = self._dict[k][0]
294-
self._dict[k] = key, value
295-
296-
def __contains__(self, key):
297-
return key.lower() in self._dict
257+
Schema = CaseAwareMapping

data_diff/databases/mysql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def create_connection(self):
4747
elif e.errno == mysql.errorcode.ER_BAD_DB_ERROR:
4848
raise ConnectError("Database does not exist") from e
4949
else:
50-
raise ConnectError(*e._args) from e
50+
raise ConnectError(*e) from e
5151

5252
def quote(self, s: str):
5353
return f"`{s}`"

data_diff/diff_tables.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from runtype import dataclass
1414

1515
from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Value
16-
from .utils import safezip, split_space
16+
from .utils import CaseInsensitiveDict, safezip, split_space, CaseSensitiveDict
1717
from .databases.base import Database
1818
from .databases.database_types import (
1919
ArithString,
@@ -23,8 +23,6 @@
2323
PrecisionType,
2424
StringType,
2525
Schema,
26-
Schema_CaseInsensitive,
27-
Schema_CaseSensitive,
2826
)
2927

3028
logger = logging.getLogger("diff_tables")
@@ -35,6 +33,18 @@
3533
DEFAULT_BISECTION_FACTOR = 32
3634

3735

36+
def create_schema(db: Database, table_path: DbPath, schema: dict, case_sensitive: bool) -> Schema:
37+
logger.debug(f"[{db.name}] Schema = {schema}")
38+
39+
if case_sensitive:
40+
return CaseSensitiveDict(schema)
41+
42+
if len({k.lower() for k in schema}) < len(schema):
43+
logger.warning(f'Ambiguous schema for {db}:{".".join(table_path)} | Columns = {", ".join(list(schema))}')
44+
logger.warning("We recommend to disable case-insensitivity (remove --any-case).")
45+
return CaseInsensitiveDict(schema)
46+
47+
3848
@dataclass(frozen=False)
3949
class TableSegment:
4050
"""Signifies a segment of rows (and selected columns) within a table
@@ -116,26 +126,16 @@ def _normalize_column(self, name: str, template: str = None) -> str:
116126

117127
return self.database.normalize_value_by_type(col, col_type)
118128

129+
def _with_raw_schema(self, raw_schema: dict) -> "TableSegment":
130+
schema = self.database._process_table_schema(self.table_path, raw_schema, self._relevant_columns)
131+
return self.new(_schema=create_schema(self.database, self.table_path, schema, self.case_sensitive))
132+
119133
def with_schema(self) -> "TableSegment":
120134
"Queries the table schema from the database, and returns a new instance of TableSegment, with a schema."
121135
if self._schema:
122136
return self
123137

124-
schema = self.database.query_table_schema(self.table_path, self._relevant_columns)
125-
logger.debug(f"[{self.database.name}] Schema = {schema}")
126-
127-
schema_inst: Schema
128-
if self.case_sensitive:
129-
schema_inst = Schema_CaseSensitive(schema)
130-
else:
131-
if len({k.lower() for k in schema}) < len(schema):
132-
logger.warning(
133-
f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}'
134-
)
135-
logger.warning("We recommend to disable case-insensitivity (remove --any-case).")
136-
schema_inst = Schema_CaseInsensitive(schema)
137-
138-
return self.new(_schema=schema_inst)
138+
return self._with_raw_schema(self.database.query_table_schema(self.table_path))
139139

140140
def _make_key_range(self):
141141
if self.min_key is not None:

data_diff/utils.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
2+
from typing import Iterable, Tuple, Union, Any
3+
from typing import TypeVar, Generic
4+
from abc import ABC, abstractmethod
25
from urllib.parse import urlparse
3-
4-
from typing import Union, Any
56
from uuid import UUID
67
import string
78

@@ -150,9 +151,64 @@ def remove_password_from_url(url: str, replace_with: str = "***") -> str:
150151
return replaced.geturl()
151152

152153

153-
def join_iter(joiner: Any, iterable: iter) -> iter:
154+
def join_iter(joiner: Any, iterable: Iterable) -> Iterable:
154155
it = iter(iterable)
155156
yield next(it)
156157
for i in it:
157158
yield joiner
158159
yield i
160+
161+
162+
V = TypeVar("V")
163+
164+
165+
class CaseAwareMapping(ABC, Generic[V]):
166+
@abstractmethod
167+
def get_key(self, key: str) -> str:
168+
...
169+
170+
@abstractmethod
171+
def __getitem__(self, key: str) -> V:
172+
...
173+
174+
@abstractmethod
175+
def __setitem__(self, key: str, value: V):
176+
...
177+
178+
@abstractmethod
179+
def __contains__(self, key: str) -> bool:
180+
...
181+
182+
183+
class CaseInsensitiveDict(CaseAwareMapping):
184+
def __init__(self, initial):
185+
self._dict = {k.lower(): (k, v) for k, v in dict(initial).items()}
186+
187+
def get_key(self, key: str) -> str:
188+
return self._dict[key.lower()][0]
189+
190+
def __getitem__(self, key: str) -> V:
191+
return self._dict[key.lower()][1]
192+
193+
def __setitem__(self, key: str, value):
194+
k = key.lower()
195+
if k in self._dict:
196+
key = self._dict[k][0]
197+
self._dict[k] = key, value
198+
199+
def __contains__(self, key):
200+
return key.lower() in self._dict
201+
202+
def keys(self) -> Iterable[str]:
203+
return self._dict.keys()
204+
205+
def items(self) -> Iterable[Tuple[str, V]]:
206+
return ((k, v[1]) for k, v in self._dict.items())
207+
208+
209+
class CaseSensitiveDict(dict, CaseAwareMapping):
210+
def get_key(self, key):
211+
return key
212+
213+
def as_insensitive(self):
214+
return CaseInsensitiveDict(self)

0 commit comments

Comments
 (0)