|
15 | 15 | limitations under the License. |
16 | 16 | """ |
17 | 17 | from typing import ( |
18 | | - Optional, List |
| 18 | + Optional, List, Set |
19 | 19 | ) |
| 20 | +from itertools import chain |
20 | 21 |
|
21 | 22 | import dbt.exceptions |
| 23 | +from dbt.adapters.base.relation import BaseRelation, InformationSchema |
| 24 | +from dbt.adapters.base.impl import GET_CATALOG_MACRO_NAME |
22 | 25 | from dbt.adapters.sql import SQLAdapter |
23 | 26 | from dbt.adapters.base.meta import available |
24 | 27 | from dbt.adapters.oracle import OracleAdapterConnectionManager |
25 | 28 | from dbt.adapters.oracle.relation import OracleRelation |
| 29 | +from dbt.contracts.graph.manifest import Manifest |
| 30 | + |
26 | 31 |
|
27 | 32 |
|
28 | 33 | import agate |
|
56 | 61 | '''.strip() |
57 | 62 |
|
58 | 63 | LIST_RELATIONS_MACRO_NAME = 'list_relations_without_caching' |
| 64 | +GET_DATABASE_MACRO_NAME = 'get_database_name' |
59 | 65 |
|
60 | 66 |
|
61 | 67 | class OracleAdapter(SQLAdapter): |
@@ -103,7 +109,7 @@ def verify_database(self, database): |
103 | 109 | if database.startswith('"'): |
104 | 110 | database = database.strip('"') |
105 | 111 | expected = self.config.credentials.database |
106 | | - if database.lower() != expected.lower(): |
| 112 | + if expected and database.lower() != expected.lower(): |
107 | 113 | raise dbt.exceptions.NotImplementedException( |
108 | 114 | 'Cross-db references not allowed in {} ({} vs {})' |
109 | 115 | .format(self.type(), database, expected) |
@@ -150,3 +156,59 @@ def timestamp_add_sql( |
150 | 156 | # '+ interval' syntax used in postgres/redshift is relatively common |
151 | 157 | # and might even be the SQL standard's intention. |
152 | 158 | return f"{add_to} + interval '{number}' {interval}" |
| 159 | + |
| 160 | + def get_relation(self, database: str, schema: str, identifier: str) -> Optional[BaseRelation]: |
| 161 | + if database == 'None': |
| 162 | + database = self.config.credentials.database |
| 163 | + return super().get_relation(database, schema, identifier) |
| 164 | + |
| 165 | + def _get_one_catalog( |
| 166 | + self, |
| 167 | + information_schema: InformationSchema, |
| 168 | + schemas: Set[str], |
| 169 | + manifest: Manifest, |
| 170 | + ) -> agate.Table: |
| 171 | + |
| 172 | + kwargs = {"information_schema": information_schema, "schemas": schemas} |
| 173 | + table = self.execute_macro( |
| 174 | + GET_CATALOG_MACRO_NAME, |
| 175 | + kwargs=kwargs, |
| 176 | + # pass in the full manifest so we get any local project |
| 177 | + # overrides |
| 178 | + manifest=manifest, |
| 179 | + ) |
| 180 | + # In case database is not defined, we can use the the configured database which we set as part of credentials |
| 181 | + for node in chain(manifest.nodes.values(), manifest.sources.values()): |
| 182 | + if not node.database or node.database == 'None': |
| 183 | + node.database = self.config.credentials.database |
| 184 | + |
| 185 | + results = self._catalog_filter_table(table, manifest) |
| 186 | + return results |
| 187 | + |
| 188 | + def list_relations_without_caching( |
| 189 | + self, schema_relation: BaseRelation, |
| 190 | + ) -> List[BaseRelation]: |
| 191 | + |
| 192 | + # Set database if not supplied |
| 193 | + if not self.config.credentials.database: |
| 194 | + self.config.credentials.database = self.execute_macro(GET_DATABASE_MACRO_NAME) |
| 195 | + |
| 196 | + kwargs = {'schema_relation': schema_relation} |
| 197 | + results = self.execute_macro( |
| 198 | + LIST_RELATIONS_MACRO_NAME, |
| 199 | + kwargs=kwargs |
| 200 | + ) |
| 201 | + relations = [] |
| 202 | + for _database, name, _schema, _type in results: |
| 203 | + try: |
| 204 | + _type = self.Relation.get_relation_type(_type) |
| 205 | + except ValueError: |
| 206 | + _type = self.Relation.External |
| 207 | + relations.append(self.Relation.create( |
| 208 | + database=_database, |
| 209 | + schema=_schema, |
| 210 | + identifier=name, |
| 211 | + quote_policy=self.config.quoting, |
| 212 | + type=_type |
| 213 | + )) |
| 214 | + return relations |
0 commit comments