Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optimizer): Fix qualify for SEMI/ANTI joins #4622

Merged
merged 3 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,6 +2453,10 @@ def hint(self) -> str:
def alias_or_name(self) -> str:
return self.this.alias_or_name

@property
def is_semi_or_anti_join(self) -> bool:
return self.kind in ("SEMI", "ANTI")

def on(
self,
*expressions: t.Optional[ExpOrStr],
Expand Down
13 changes: 9 additions & 4 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def _update_source_columns(source_name: str) -> None:
join_columns = resolver.get_source_columns(join_table)
conditions = []
using_identifier_count = len(using)
is_semi_or_anti_join = join.is_semi_or_anti_join

for identifier in using:
identifier = identifier.name
Expand Down Expand Up @@ -208,10 +209,14 @@ def _update_source_columns(source_name: str) -> None:

# Set all values in the dict to None, because we only care about the key ordering
tables = column_tables.setdefault(identifier, {})
if table not in tables:
tables[table] = None
if join_table not in tables:
tables[join_table] = None

# Do not update the dict if this was a SEMI/ANTI join in
# order to avoid generating COALESCE columns for this join pair
if not is_semi_or_anti_join:
if table not in tables:
tables[table] = None
if join_table not in tables:
tables[join_table] = None

join.args.pop("using")
join.set("on", exp.and_(*conditions, copy=False))
Expand Down
20 changes: 19 additions & 1 deletion sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def clear_cache(self):
self._join_hints = None
self._pivots = None
self._references = None
self._semi_anti_join_tables = None

def branch(
self, expression, scope_type, sources=None, cte_sources=None, lateral_sources=None, **kwargs
Expand All @@ -126,6 +127,7 @@ def _collect(self):
self._raw_columns = []
self._stars = []
self._join_hints = []
self._semi_anti_join_tables = set()

for node in self.walk(bfs=False):
if node is self.expression:
Expand All @@ -139,6 +141,10 @@ def _collect(self):
else:
self._raw_columns.append(node)
elif isinstance(node, exp.Table) and not isinstance(node.parent, exp.JoinHint):
parent = node.parent
if isinstance(parent, exp.Join) and parent.is_semi_or_anti_join:
self._semi_anti_join_tables.add(node.alias_or_name)

self._tables.append(node)
elif isinstance(node, exp.JoinHint):
self._join_hints.append(node)
Expand Down Expand Up @@ -311,6 +317,11 @@ def selected_sources(self):
result = {}

for name, node in self.references:
if name in self._semi_anti_join_tables:
# The RHS table of SEMI/ANTI joins shouldn't be collected as a
# selected source
continue

if name in result:
raise OptimizeError(f"Alias already used: {name}")
if name in self.sources:
Expand Down Expand Up @@ -351,7 +362,10 @@ def external_columns(self):
self._external_columns = left.external_columns + right.external_columns
else:
self._external_columns = [
c for c in self.columns if c.table not in self.selected_sources
c
for c in self.columns
if c.table not in self.selected_sources
and c.table not in self.semi_or_anti_join_tables
]

return self._external_columns
Expand Down Expand Up @@ -387,6 +401,10 @@ def pivots(self):

return self._pivots

@property
def semi_or_anti_join_tables(self):
return self._semi_anti_join_tables or set()

def source_columns(self, source_name):
"""
Get all columns in the current scope for a particular source.
Expand Down
23 changes: 22 additions & 1 deletion tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -785,4 +785,25 @@ SELECT X.A AS FOO FROM X AS X GROUP BY X.A = 1;
# dialect: snowflake
# execute: false
SELECT x.a AS foo FROM x WHERE foo = 1;
SELECT X.A AS FOO FROM X AS X WHERE X.A = 1;
SELECT X.A AS FOO FROM X AS X WHERE X.A = 1;


--------------------------------------
-- SEMI / ANTI Joins
--------------------------------------

# title: SEMI JOIN table is excluded from the scope
SELECT * FROM x SEMI JOIN y USING (b);
SELECT x.a AS a, x.b AS b FROM x AS x SEMI JOIN y AS y ON x.b = y.b;

# title: ANTI JOIN table is excluded from the scope
SELECT * FROM x ANTI JOIN y USING (b);
SELECT x.a AS a, x.b AS b FROM x AS x ANTI JOIN y AS y ON x.b = y.b;

# title: SEMI + normal joins reinclude the table on scope
SELECT * FROM x SEMI JOIN y USING (b) JOIN y USING (b);
SELECT x.a AS a, COALESCE(x.b, y_2.b) AS b, y_2.c AS c FROM x AS x SEMI JOIN y AS y ON x.b = y.b JOIN y AS y_2 ON x.b = y_2.b;

# title: ANTI + normal joins reinclude the table on scope
SELECT * FROM x ANTI JOIN y USING (b) JOIN y USING (b);
SELECT x.a AS a, COALESCE(x.b, y_2.b) AS b, y_2.c AS c FROM x AS x ANTI JOIN y AS y ON x.b = y.b JOIN y AS y_2 ON x.b = y_2.b;