sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence, seq_get 8 9 10if t.TYPE_CHECKING: 11 from sqlglot._typing import E 12 from sqlglot.generator import Generator 13 14 15class SqlHandler(t.Protocol): 16 def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> str: ... 17 18 19def preprocess( 20 transforms: list[t.Callable[[exp.Expr], exp.Expr]], 21 generator: t.Callable[[Generator, exp.Expr], str] | None = None, 22) -> t.Callable[[Generator, exp.Expr], str]: 23 """ 24 Creates a new transform by chaining a sequence of transformations and converts the resulting 25 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 26 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 27 28 Args: 29 transforms: sequence of transform functions. These will be called in order. 30 31 Returns: 32 Function that can be used as a generator transform. 33 """ 34 35 def _to_sql(self: Generator, expression: exp.Expr) -> str: 36 expression_type = type(expression) 37 38 try: 39 expression = transforms[0](expression) 40 for transform in transforms[1:]: 41 expression = transform(expression) 42 except UnsupportedError as unsupported_error: 43 self.unsupported(str(unsupported_error)) 44 45 if generator: 46 return generator(self, expression) 47 48 _sql_handler: SqlHandler | None = getattr(self, expression.key + "_sql", None) 49 if _sql_handler: 50 return _sql_handler(expression) 51 52 transforms_handler = self.TRANSFORMS.get(type(expression)) 53 if transforms_handler: 54 if expression_type is type(expression): 55 if isinstance(expression, exp.Func): 56 return self.function_fallback_sql(expression) 57 58 # Ensures we don't enter an infinite loop. This can happen when the original expression 59 # has the same type as the final expression and there's no _sql method available for it, 60 # because then it'd re-enter _to_sql. 61 raise ValueError( 62 f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed." 63 ) 64 65 return transforms_handler(self, expression) 66 67 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 68 69 return _to_sql 70 71 72def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr: 73 if isinstance(expression, exp.Select): 74 count = 0 75 recursive_ctes: list[exp.Expr] = [] 76 77 for unnest in expression.find_all(exp.Unnest): 78 if ( 79 not isinstance(unnest.parent, (exp.From, exp.Join)) 80 or len(unnest.expressions) != 1 81 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 82 ): 83 continue 84 85 generate_date_array = unnest.expressions[0] 86 start: exp.Expr | None = generate_date_array.args.get("start") 87 end: exp.Expr | None = generate_date_array.args.get("end") 88 step: exp.Expr | None = generate_date_array.args.get("step") 89 90 if not start or not end or not isinstance(step, exp.Interval): 91 continue 92 93 alias: exp.TableAlias | None = unnest.args.get("alias") 94 column_name: str = ( 95 alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 96 ) 97 98 start = exp.cast(start, "date") 99 date_add = exp.func( 100 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 101 ) 102 cast_date_add = exp.cast(date_add, "date") 103 104 cte_name = "_generated_dates" + (f"_{count}" if count else "") 105 106 base_query = exp.select(start.as_(column_name)) 107 recursive_query = ( 108 exp.select(cast_date_add) 109 .from_(cte_name) 110 .where(cast_date_add <= exp.cast(end, "date")) 111 ) 112 cte_query = base_query.union(recursive_query, distinct=False) 113 114 generate_dates_query = exp.select(column_name).from_(cte_name) 115 unnest.replace(generate_dates_query.subquery(cte_name)) 116 117 recursive_ctes.append( 118 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 119 ) 120 count += 1 121 122 if recursive_ctes: 123 with_expression: exp.With = expression.args.get("with_") or exp.With() 124 with_expression.set("recursive", True) 125 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 126 expression.set("with_", with_expression) 127 128 return expression 129 130 131def unnest_generate_series(expression: exp.Expr) -> exp.Expr: 132 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 133 this = expression.this 134 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 135 unnest = exp.Unnest(expressions=[this]) 136 if expression.alias: 137 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 138 139 return unnest 140 141 return expression 142 143 144def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr: 145 """ 146 Convert SELECT DISTINCT ON statements to a subquery with a window function. 147 148 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 149 150 Args: 151 expression: the expression that will be transformed. 152 153 Returns: 154 The transformed expression. 155 """ 156 if ( 157 isinstance(expression, exp.Select) 158 and expression.args.get("distinct") 159 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 160 ): 161 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 162 163 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 164 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 165 166 order: exp.Order | None = expression.args.get("order") 167 if order: 168 window.set("order", order.pop()) 169 else: 170 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 171 172 expression.select(exp.alias_(window, row_number_window_alias), copy=False) 173 174 # We add aliases to the projections so that we can safely reference them in the outer query 175 new_selects: list[exp.Expr] = [] 176 taken_names = {row_number_window_alias} 177 for select in expression.selects[:-1]: 178 if select.is_star: 179 new_selects = [exp.Star()] 180 break 181 182 if not isinstance(select, exp.Alias): 183 alias = find_new_name(taken_names, select.output_name or "_col") 184 quoted: bool | None = ( 185 select.this.args.get("quoted") if isinstance(select, exp.Column) else None 186 ) 187 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 188 189 taken_names.add(select.output_name) 190 new_selects.append(select.args["alias"]) 191 192 return ( 193 exp.select(*new_selects, copy=False) 194 .from_(expression.subquery("_t", copy=False), copy=False) 195 .where(exp.column(row_number_window_alias).eq(1), copy=False) 196 ) 197 198 return expression 199 200 201def eliminate_qualify(expression: exp.Expr) -> exp.Expr: 202 """ 203 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 204 205 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 206 https://docs.snowflake.com/en/sql-reference/constructs/qualify 207 208 Some dialects don't support window functions in the WHERE clause, so we need to include them as 209 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 210 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 211 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 212 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 213 corresponding expression to avoid creating invalid column references. 214 """ 215 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 216 taken = set(expression.named_selects) 217 for select in expression.selects: 218 if not select.alias_or_name: 219 alias = find_new_name(taken, "_c") 220 select.replace(exp.alias_(select, alias)) 221 taken.add(alias) 222 223 def _select_alias_or_name(select: exp.Expr) -> str | exp.Column: 224 alias_or_name = select.alias_or_name 225 identifier = select.args.get("alias") or select.this 226 if isinstance(identifier, exp.Identifier): 227 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 228 return alias_or_name 229 230 outer_selects = exp.select(*map(_select_alias_or_name, expression.selects)) 231 qualify_filters: exp.Expr = expression.args["qualify"].pop().this 232 expression_by_alias: dict[str, exp.Expr] = { 233 select.alias: select.this 234 for select in expression.selects 235 if isinstance(select, exp.Alias) 236 } 237 238 select_candidates = (exp.Window,) if expression.is_star else (exp.Window, exp.Column) 239 for select_candidate in list(qualify_filters.find_all(*select_candidates)): 240 if isinstance(select_candidate, exp.Window): 241 if expression_by_alias: 242 for column in select_candidate.find_all(exp.Column): 243 expr = expression_by_alias.get(column.name) 244 if expr: 245 column.replace(expr) 246 247 alias = find_new_name(expression.named_selects, "_w") 248 expression.select(exp.alias_(select_candidate, alias), copy=False) 249 column = exp.column(alias) 250 251 if isinstance(select_candidate.parent, exp.Qualify): 252 qualify_filters = column 253 else: 254 select_candidate.replace(column) 255 elif select_candidate.name not in expression.named_selects: 256 expression.select(select_candidate.copy(), copy=False) 257 258 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 259 qualify_filters, copy=False 260 ) 261 262 return expression 263 264 265def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr: 266 """ 267 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 268 other expressions. This transforms removes the precision from parameterized types in expressions. 269 """ 270 for node in expression.find_all(exp.DataType): 271 node.set( 272 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 273 ) 274 275 return expression 276 277 278def unqualify_unnest(expression: exp.Expr) -> exp.Expr: 279 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 280 from sqlglot.optimizer.scope import find_all_in_scope 281 282 if isinstance(expression, exp.Select): 283 unnest_aliases = { 284 unnest.alias 285 for unnest in find_all_in_scope(expression, exp.Unnest) 286 if isinstance(unnest.parent, (exp.From, exp.Join)) 287 } 288 if unnest_aliases: 289 for column in expression.find_all(exp.Column): 290 leftmost_part = column.parts[0] 291 if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases: 292 leftmost_part.pop() 293 294 return expression 295 296 297def unnest_to_explode( 298 expression: exp.Expr, 299 unnest_using_arrays_zip: bool = True, 300) -> exp.Expr: 301 """Convert cross join unnest into lateral view explode.""" 302 303 def _unnest_zip_exprs( 304 u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool 305 ) -> list[exp.Expr]: 306 if has_multi_expr: 307 if not unnest_using_arrays_zip: 308 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 309 310 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 311 zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)] 312 u.set("expressions", zip_exprs) 313 return zip_exprs 314 return unnest_exprs 315 316 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: 317 if u.args.get("offset"): 318 return exp.Posexplode 319 return exp.Inline if has_multi_expr else exp.Explode 320 321 if isinstance(expression, exp.Select): 322 from_ = expression.args.get("from_") 323 324 if from_ and isinstance(from_.this, exp.Unnest): 325 unnest: exp.Unnest = from_.this 326 alias: exp.TableAlias | None = unnest.args.get("alias") 327 exprs: list[exp.Expr] = unnest.expressions 328 has_multi_expr = len(exprs) > 1 329 this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 330 331 columns: list[exp.Identifier] = alias.columns if alias else [] 332 offset: exp.Expr | None = unnest.args.get("offset") 333 if offset: 334 columns.insert( 335 0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos") 336 ) 337 338 unnest.replace( 339 exp.Table( 340 this=_udtf_type(unnest, has_multi_expr)(this=this), 341 alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, 342 ) 343 ) 344 345 joins: list[exp.Join] = expression.args.get("joins") or [] 346 for join in list(joins): 347 join_expr = join.this 348 349 is_lateral = isinstance(join_expr, exp.Lateral) 350 351 unnest = join_expr.this if is_lateral else join_expr 352 353 if isinstance(unnest, exp.Unnest): 354 if is_lateral: 355 alias = join_expr.args.get("alias") 356 else: 357 alias = unnest.args.get("alias") 358 359 if alias is None: 360 raise UnsupportedError( 361 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires an alias" 362 ) 363 364 exprs = unnest.expressions 365 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 366 has_multi_expr = len(exprs) > 1 367 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 368 369 joins.remove(join) 370 371 alias_cols: list[exp.Identifier] = alias.columns 372 373 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 374 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 375 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 376 377 if not has_multi_expr and len(alias_cols) not in (1, 2): 378 raise UnsupportedError( 379 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 380 ) 381 382 offset = unnest.args.get("offset") 383 if offset: 384 alias_cols.insert( 385 0, 386 offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"), 387 ) 388 389 for e, column in zip(exprs, alias_cols): 390 expression.append( 391 "laterals", 392 exp.Lateral( 393 this=_udtf_type(unnest, has_multi_expr)(this=e), 394 view=True, 395 alias=exp.TableAlias(this=alias.this, columns=alias_cols), 396 ), 397 ) 398 399 return expression 400 401 402def explode_projection_to_unnest( 403 index_offset: int = 0, 404) -> t.Callable[[exp.Expr], exp.Expr]: 405 """Convert explode/posexplode projections into unnests.""" 406 407 def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr: 408 if isinstance(expression, exp.Select): 409 from sqlglot.optimizer.scope import Scope 410 411 taken_select_names = set(expression.named_selects) 412 taken_source_names = {name for name, _ in Scope(expression).references} 413 414 def new_name(names: set[str], name: str) -> str: 415 name = find_new_name(names, name) 416 names.add(name) 417 return name 418 419 arrays: list[exp.Condition] = [] 420 series_alias = new_name(taken_select_names, "pos") 421 series = exp.alias_( 422 exp.Unnest( 423 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 424 ), 425 new_name(taken_source_names, "_u"), 426 table=[series_alias], 427 ) 428 429 # we use list here because expression.selects is mutated inside the loop 430 for select in list(expression.selects): 431 explode = select.find(exp.Explode) 432 433 if explode: 434 pos_alias: t.Any = "" 435 explode_alias: t.Any = "" 436 437 if isinstance(select, exp.Alias): 438 explode_alias = select.args["alias"] 439 alias: exp.Expr = select 440 elif isinstance(select, exp.Aliases): 441 pos_alias = select.aliases[0] 442 explode_alias = select.aliases[1] 443 alias = select.replace(exp.alias_(select.this, "", copy=False)) 444 else: 445 alias = select.replace(exp.alias_(select, "")) 446 explode = alias.find(exp.Explode) 447 assert explode 448 449 is_posexplode = isinstance(explode, exp.Posexplode) 450 explode_arg = explode.this 451 452 if isinstance(explode, exp.ExplodeOuter): 453 bracket = explode_arg[0] 454 bracket.set("safe", True) 455 bracket.set("offset", True) 456 explode_arg = exp.func( 457 "IF", 458 exp.func( 459 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 460 ).eq(0), 461 exp.array(bracket, copy=False), 462 explode_arg, 463 ) 464 465 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 466 if isinstance(explode_arg, exp.Column): 467 taken_select_names.add(explode_arg.output_name) 468 469 unnest_source_alias = new_name(taken_source_names, "_u") 470 471 if not explode_alias: 472 explode_alias = new_name(taken_select_names, "col") 473 474 if is_posexplode: 475 pos_alias = new_name(taken_select_names, "pos") 476 477 if not pos_alias: 478 pos_alias = new_name(taken_select_names, "pos") 479 480 alias.set("alias", exp.to_identifier(explode_alias)) 481 482 series_table_alias = series.args["alias"].this 483 column = exp.If( 484 this=exp.column(series_alias, table=series_table_alias).eq( 485 exp.column(pos_alias, table=unnest_source_alias) 486 ), 487 true=exp.column(explode_alias, table=unnest_source_alias), 488 ) 489 490 explode.replace(column) 491 492 if is_posexplode: 493 expressions = expression.expressions 494 expressions.insert( 495 expressions.index(alias) + 1, 496 exp.If( 497 this=exp.column(series_alias, table=series_table_alias).eq( 498 exp.column(pos_alias, table=unnest_source_alias) 499 ), 500 true=exp.column(pos_alias, table=unnest_source_alias), 501 ).as_(pos_alias), 502 ) 503 expression.set("expressions", expressions) 504 505 if not arrays: 506 if expression.args.get("from_"): 507 expression.join(series, copy=False, join_type="CROSS") 508 else: 509 expression.from_(series, copy=False) 510 511 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 512 arrays.append(size) 513 514 # trino doesn't support left join unnest with on conditions 515 # if it did, this would be much simpler 516 expression.join( 517 exp.alias_( 518 exp.Unnest( 519 expressions=[explode_arg.copy()], 520 offset=exp.to_identifier(pos_alias), 521 ), 522 unnest_source_alias, 523 table=[explode_alias], 524 ), 525 join_type="CROSS", 526 copy=False, 527 ) 528 529 if index_offset != 1: 530 size = size - 1 531 532 expression.where( 533 exp.column(series_alias, table=series_table_alias) 534 .eq(exp.column(pos_alias, table=unnest_source_alias)) 535 .or_( 536 (exp.column(series_alias, table=series_table_alias) > size).and_( 537 exp.column(pos_alias, table=unnest_source_alias).eq(size) 538 ) 539 ), 540 copy=False, 541 ) 542 543 if arrays: 544 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 545 546 if index_offset != 1: 547 end = end - (1 - index_offset) 548 series.expressions[0].set("end", end) 549 550 return expression 551 552 return _explode_projection_to_unnest 553 554 555def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr: 556 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 557 if ( 558 isinstance(expression, exp.PERCENTILES) 559 and not isinstance(expression.parent, exp.WithinGroup) 560 and expression.expression 561 ): 562 column = expression.this.pop() 563 expression.set("this", expression.expression.pop()) 564 order = exp.Order(expressions=[exp.Ordered(this=column)]) 565 expression = exp.WithinGroup(this=expression, expression=order) 566 567 return expression 568 569 570def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr: 571 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 572 if ( 573 isinstance(expression, exp.WithinGroup) 574 and isinstance(expression.this, exp.PERCENTILES) 575 and isinstance(expression.expression, exp.Order) 576 ): 577 quantile = expression.this.this 578 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 579 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 580 581 return expression 582 583 584def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr: 585 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 586 if isinstance(expression, exp.With) and expression.recursive: 587 next_name = name_sequence("_c_") 588 589 for cte in expression.expressions: 590 if not cte.args["alias"].columns: 591 query = cte.this 592 if isinstance(query, exp.SetOperation): 593 query = query.this 594 595 cte.args["alias"].set( 596 "columns", 597 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 598 ) 599 600 return expression 601 602 603def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr: 604 """Replace 'epoch' in casts by the equivalent date literal.""" 605 if ( 606 isinstance(expression, (exp.Cast, exp.TryCast)) 607 and expression.name.lower() == "epoch" 608 and expression.to.this in exp.DataType.TEMPORAL_TYPES 609 ): 610 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 611 612 return expression 613 614 615def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr: 616 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 617 if isinstance(expression, exp.Select): 618 for join in list[exp.Join](expression.args.get("joins") or []): 619 on: exp.Expr | None = join.args.get("on") 620 if on and join.kind in ("SEMI", "ANTI"): 621 subquery = exp.select("1").from_(join.this).where(on) 622 exists: exp.Exists | exp.Not = exp.Exists(this=subquery) 623 if join.kind == "ANTI": 624 exists = exists.not_(copy=False) 625 626 join.pop() 627 expression.where(exists, copy=False) 628 629 return expression 630 631 632def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr: 633 """ 634 Converts a query with a FULL OUTER join to a union of identical queries that 635 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 636 for queries that have a single FULL OUTER join. 637 """ 638 if isinstance(expression, exp.Select): 639 full_outer_joins: list[tuple[int, exp.Join]] = [ 640 (index, join) 641 for index, join in enumerate[exp.Join](expression.args.get("joins") or []) 642 if join.side == "FULL" 643 ] 644 645 if len(full_outer_joins) == 1: 646 expression_copy = expression.copy() 647 expression.set("limit", None) 648 index, full_outer_join = full_outer_joins[0] 649 650 tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name) 651 join_conditions = full_outer_join.args.get("on") or exp.and_( 652 *[ 653 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 654 for col in t.cast(list[exp.Identifier], full_outer_join.args.get("using")) 655 ] 656 ) 657 658 full_outer_join.set("side", "left") 659 anti_join_clause = ( 660 exp.select("1").from_(expression.args["from_"]).where(join_conditions) 661 ) 662 expression_copy.args["joins"][index].set("side", "right") 663 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 664 expression_copy.set("with_", None) # remove CTEs from RIGHT side 665 expression.set("order", None) # remove order by from LEFT side 666 667 return exp.union(expression, expression_copy, copy=False, distinct=False) 668 669 return expression 670 671 672def move_ctes_to_top_level(expression: E) -> E: 673 """ 674 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 675 defined at the top-level, so for example queries like: 676 677 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 678 679 are invalid in those dialects. This transformation can be used to ensure all CTEs are 680 moved to the top level so that the final SQL code is valid from a syntax standpoint. 681 682 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 683 """ 684 top_level_with: exp.With | None = expression.args.get("with_") 685 for inner_with in expression.find_all(exp.With): 686 if inner_with.parent is expression: 687 continue 688 689 if not top_level_with: 690 top_level_with = inner_with.pop() 691 expression.set("with_", top_level_with) 692 else: 693 if inner_with.recursive: 694 top_level_with.set("recursive", True) 695 696 parent_cte = inner_with.find_ancestor(exp.CTE) 697 inner_with.pop() 698 699 if parent_cte: 700 i = top_level_with.expressions.index(parent_cte) 701 top_level_with.expressions[i:i] = inner_with.expressions 702 top_level_with.set("expressions", top_level_with.expressions) 703 else: 704 top_level_with.set( 705 "expressions", top_level_with.expressions + inner_with.expressions 706 ) 707 708 return expression 709 710 711def ensure_bools(expression: exp.Expr) -> exp.Expr: 712 """Converts numeric values used in conditions into explicit boolean expressions.""" 713 from sqlglot.optimizer.canonicalize import ensure_bools 714 715 def _ensure_bool(node: exp.Expr) -> None: 716 if ( 717 node.is_number 718 or ( 719 not isinstance(node, exp.SubqueryPredicate) 720 and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 721 ) 722 or (isinstance(node, exp.Column) and not node.type) 723 ): 724 node.replace(node.neq(0)) 725 726 for node in expression.walk(): 727 ensure_bools(node, _ensure_bool) 728 729 return expression 730 731 732def unqualify_columns(expression: exp.Expr) -> exp.Expr: 733 for column in expression.find_all(exp.Column): 734 # We only wanna pop off the table, db, catalog args 735 for part in column.parts[:-1]: 736 part.pop() 737 738 return expression 739 740 741def remove_unique_constraints(expression: exp.Expr) -> exp.Expr: 742 assert isinstance(expression, exp.Create) 743 for constraint in expression.find_all(exp.UniqueColumnConstraint): 744 if constraint.parent: 745 constraint.parent.pop() 746 747 return expression 748 749 750def ctas_with_tmp_tables_to_create_tmp_view( 751 expression: exp.Expr, 752 tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e, 753) -> exp.Expr: 754 assert isinstance(expression, exp.Create) 755 properties: exp.Properties | None = expression.args.get("properties") 756 temporary = any( 757 isinstance(prop, exp.TemporaryProperty) 758 for prop in (properties.expressions if properties is not None else []) 759 ) 760 761 # CTAS with temp tables map to CREATE TEMPORARY VIEW 762 if expression.kind == "TABLE" and temporary: 763 if expression.expression: 764 return exp.Create( 765 kind="TEMPORARY VIEW", 766 this=expression.this, 767 expression=expression.expression, 768 ) 769 return tmp_storage_provider(expression) 770 771 return expression 772 773 774def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: 775 """ 776 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 777 PARTITIONED BY value is an array of column names, they are transformed into a schema. 778 The corresponding columns are removed from the create statement. 779 """ 780 assert isinstance(expression, exp.Create) 781 schema = expression.this 782 is_partitionable = expression.kind in {"TABLE", "VIEW"} 783 784 if isinstance(schema, exp.Schema) and is_partitionable: 785 prop = expression.find(exp.PartitionedByProperty) 786 if prop and prop.this and not isinstance(prop.this, exp.Schema): 787 columns: set[str] = {v.name.upper() for v in prop.this.expressions} 788 schema_exprs: list[exp.Expr] = schema.expressions 789 partitions = [col for col in schema_exprs if col.name.upper() in columns] 790 schema.set("expressions", [e for e in schema_exprs if e not in partitions]) 791 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 792 expression.set("this", schema) 793 794 return expression 795 796 797def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr: 798 """ 799 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 800 801 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 802 """ 803 assert isinstance(expression, exp.Create) 804 prop = expression.find(exp.PartitionedByProperty) 805 if ( 806 prop 807 and prop.this 808 and isinstance(prop.this, exp.Schema) 809 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 810 ): 811 prop_this = exp.Tuple( 812 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 813 ) 814 schema: exp.Schema = expression.this 815 for e in prop.this.expressions: 816 schema.append("expressions", e) 817 prop.set("this", prop_this) 818 819 return expression 820 821 822def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr: 823 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 824 if isinstance(expression, exp.Struct): 825 expression.set( 826 "expressions", 827 [ 828 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 829 for e in expression.expressions 830 ], 831 ) 832 833 return expression 834 835 836def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: 837 """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 838 839 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. 840 841 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. 842 843 The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. 844 845 You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. 846 847 The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. 848 849 A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. 850 851 A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. 852 853 A WHERE condition cannot compare any column marked with the (+) operator with a subquery. 854 855 -- example with WHERE 856 SELECT d.department_name, sum(e.salary) as total_salary 857 FROM departments d, employees e 858 WHERE e.department_id(+) = d.department_id 859 group by department_name 860 861 -- example of left correlation in select 862 SELECT d.department_name, ( 863 SELECT SUM(e.salary) 864 FROM employees e 865 WHERE e.department_id(+) = d.department_id) AS total_salary 866 FROM departments d; 867 868 -- example of left correlation in from 869 SELECT d.department_name, t.total_salary 870 FROM departments d, ( 871 SELECT SUM(e.salary) AS total_salary 872 FROM employees e 873 WHERE e.department_id(+) = d.department_id 874 ) t 875 """ 876 877 from sqlglot.optimizer.scope import traverse_scope 878 from sqlglot.optimizer.normalize import normalize, normalized 879 from collections import defaultdict 880 881 # we go in reverse to check the main query for left correlation 882 for scope in reversed(traverse_scope(expression)): 883 query = scope.expression 884 885 where: exp.Expr | None = query.args.get("where") 886 joins: list[exp.Join] = query.args.get("joins", []) 887 888 if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): 889 continue 890 891 # knockout: we do not support left correlation (see point 2) 892 assert not scope.is_correlated_subquery, "Correlated queries are not supported" 893 894 # make sure we have AND of ORs to have clear join terms 895 where = normalize(where.this) 896 assert normalized(where), "Cannot normalize JOIN predicates" 897 # dict of {name: list of join AND conditions} 898 joins_ons: defaultdict[str, list[exp.Expr]] = defaultdict(list) 899 for cond in [where] if not isinstance(where, exp.And) else where.flatten(): 900 join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] 901 902 left_join_table = set(col.table for col in join_cols) 903 if not left_join_table: 904 continue 905 906 assert not (len(left_join_table) > 1), ( 907 "Cannot combine JOIN predicates from different tables" 908 ) 909 910 for col in join_cols: 911 col.set("join_mark", False) 912 913 joins_ons[left_join_table.pop()].append(cond) 914 915 old_joins = {join.alias_or_name: join for join in joins} 916 new_joins: dict[str, exp.Join] = {} 917 query_from = query.args["from_"] 918 919 for table, predicates in joins_ons.items(): 920 join_what = old_joins.get(table, query_from).this.copy() 921 new_joins[join_what.alias_or_name] = exp.Join( 922 this=join_what, on=exp.and_(*predicates), kind="LEFT" 923 ) 924 925 for p in predicates: 926 while isinstance(p.parent, exp.Paren): 927 p.parent.replace(p) 928 929 parent = p.parent 930 p.pop() 931 if isinstance(parent, exp.Binary): 932 left = parent.args.get("this") 933 parent.replace(parent.right if left is None else left) 934 elif isinstance(parent, exp.Where): 935 parent.pop() 936 937 if query_from.alias_or_name in new_joins: 938 only_old_joins: set[str] = old_joins.keys() - new_joins.keys() 939 assert len(only_old_joins) >= 1, ( 940 "Cannot determine which table to use in the new FROM clause" 941 ) 942 943 new_from_name = list[str](only_old_joins)[0] 944 query.set("from_", exp.From(this=old_joins[new_from_name].this)) 945 946 if new_joins: 947 for n, j in old_joins.items(): # preserve any other joins 948 if n not in new_joins and n != query.args["from_"].name: 949 if not j.kind: 950 j.set("kind", "CROSS") 951 new_joins[n] = j 952 query.set("joins", list(new_joins.values())) 953 954 return expression 955 956 957def any_to_exists(expression: exp.Expr) -> exp.Expr: 958 """ 959 Transform ANY operator to Spark's EXISTS 960 961 For example, 962 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 963 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 964 965 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 966 transformation 967 """ 968 if isinstance(expression, exp.Select): 969 for any_expr in expression.find_all(exp.Any): 970 this: exp.Expr = any_expr.this 971 if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)): 972 continue 973 974 binop = any_expr.parent 975 if isinstance(binop, exp.Binary): 976 lambda_arg = exp.to_identifier("x") 977 any_expr.replace(lambda_arg) 978 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 979 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 980 981 return expression 982 983 984def eliminate_window_clause(expression: exp.Expr) -> exp.Expr: 985 """Eliminates the `WINDOW` query clause by inling each named window.""" 986 windows: list[exp.Expr] | None = expression.args.get("windows") 987 if isinstance(expression, exp.Select) and windows is not None: 988 from sqlglot.optimizer.scope import find_all_in_scope 989 990 expression.set("windows", None) 991 992 window_expression: dict[str, exp.Expr] = {} 993 994 def _inline_inherited_window(window: exp.Expr) -> None: 995 inherited_window = window_expression.get(window.alias.lower()) 996 if not inherited_window: 997 return 998 999 window.set("alias", None) 1000 for key in ("partition_by", "order", "spec"): 1001 arg: exp.Expr | None = inherited_window.args.get(key) 1002 if arg is not None: 1003 window.set(key, arg.copy()) 1004 1005 for window in windows: 1006 _inline_inherited_window(window) 1007 window_expression[window.name.lower()] = window 1008 1009 for window in find_all_in_scope(expression, exp.Window): 1010 _inline_inherited_window(window) 1011 1012 return expression 1013 1014 1015def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr: 1016 """ 1017 Inherit field names from the first struct in an array. 1018 1019 BigQuery supports implicitly inheriting names from the first STRUCT in an array: 1020 1021 Example: 1022 ARRAY[ 1023 STRUCT('Alice' AS name, 85 AS score), -- defines names 1024 STRUCT('Bob', 92), -- inherits names 1025 STRUCT('Diana', 95) -- inherits names 1026 ] 1027 1028 This transformation makes the field names explicit on all structs by adding 1029 PropertyEQ nodes, in order to facilitate transpilation to other dialects. 1030 1031 Args: 1032 expression: The expression tree to transform 1033 1034 Returns: 1035 The modified expression with field names inherited in all structs 1036 """ 1037 if ( 1038 isinstance(expression, exp.Array) 1039 and expression.args.get("struct_name_inheritance") 1040 and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct) 1041 and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions) 1042 ): 1043 field_names: list[exp.Identifier] = [fld.this for fld in first_item.expressions] 1044 1045 # Apply field names to subsequent structs that don't have them 1046 for struct in expression.expressions[1:]: 1047 if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names): 1048 continue 1049 1050 # Convert unnamed expressions to PropertyEQ with inherited names 1051 new_expressions: list[exp.PropertyEQ] = [] 1052 for i, expr in enumerate(struct.expressions): 1053 if not isinstance(expr, exp.PropertyEQ): 1054 # Create PropertyEQ: field_name := value, preserving the type from the inner expression 1055 property_eq = exp.PropertyEQ( 1056 this=field_names[i].copy(), 1057 expression=expr, 1058 ) 1059 property_eq.type = expr.type 1060 new_expressions.append(property_eq) 1061 else: 1062 new_expressions.append(expr) 1063 1064 struct.set("expressions", new_expressions) 1065 1066 return expression
16class SqlHandler(t.Protocol): 17 def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> str: ...
Base class for protocol classes.
Protocol classes are defined as::
class Proto(Protocol):
def meth(self) -> int:
...
Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::
class C:
def meth(self) -> int:
return 0
def func(x: Proto) -> int:
return x.meth()
func(C()) # Passes static type check
See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::
class GenProto(Protocol[T]):
def meth(self) -> T:
...
1431def _no_init_or_replace_init(self, *args, **kwargs): 1432 cls = type(self) 1433 1434 if cls._is_protocol: 1435 raise TypeError('Protocols cannot be instantiated') 1436 1437 # Already using a custom `__init__`. No need to calculate correct 1438 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1439 if cls.__init__ is not _no_init_or_replace_init: 1440 return 1441 1442 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1443 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1444 # searches for a proper new `__init__` in the MRO. The new `__init__` 1445 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1446 # instantiation of the protocol subclass will thus use the new 1447 # `__init__` and no longer call `_no_init_or_replace_init`. 1448 for base in cls.__mro__: 1449 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1450 if init is not _no_init_or_replace_init: 1451 cls.__init__ = init 1452 break 1453 else: 1454 # should not happen 1455 cls.__init__ = object.__init__ 1456 1457 cls.__init__(self, *args, **kwargs)
20def preprocess( 21 transforms: list[t.Callable[[exp.Expr], exp.Expr]], 22 generator: t.Callable[[Generator, exp.Expr], str] | None = None, 23) -> t.Callable[[Generator, exp.Expr], str]: 24 """ 25 Creates a new transform by chaining a sequence of transformations and converts the resulting 26 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 27 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 28 29 Args: 30 transforms: sequence of transform functions. These will be called in order. 31 32 Returns: 33 Function that can be used as a generator transform. 34 """ 35 36 def _to_sql(self: Generator, expression: exp.Expr) -> str: 37 expression_type = type(expression) 38 39 try: 40 expression = transforms[0](expression) 41 for transform in transforms[1:]: 42 expression = transform(expression) 43 except UnsupportedError as unsupported_error: 44 self.unsupported(str(unsupported_error)) 45 46 if generator: 47 return generator(self, expression) 48 49 _sql_handler: SqlHandler | None = getattr(self, expression.key + "_sql", None) 50 if _sql_handler: 51 return _sql_handler(expression) 52 53 transforms_handler = self.TRANSFORMS.get(type(expression)) 54 if transforms_handler: 55 if expression_type is type(expression): 56 if isinstance(expression, exp.Func): 57 return self.function_fallback_sql(expression) 58 59 # Ensures we don't enter an infinite loop. This can happen when the original expression 60 # has the same type as the final expression and there's no _sql method available for it, 61 # because then it'd re-enter _to_sql. 62 raise ValueError( 63 f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed." 64 ) 65 66 return transforms_handler(self, expression) 67 68 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 69 70 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
73def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr: 74 if isinstance(expression, exp.Select): 75 count = 0 76 recursive_ctes: list[exp.Expr] = [] 77 78 for unnest in expression.find_all(exp.Unnest): 79 if ( 80 not isinstance(unnest.parent, (exp.From, exp.Join)) 81 or len(unnest.expressions) != 1 82 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 83 ): 84 continue 85 86 generate_date_array = unnest.expressions[0] 87 start: exp.Expr | None = generate_date_array.args.get("start") 88 end: exp.Expr | None = generate_date_array.args.get("end") 89 step: exp.Expr | None = generate_date_array.args.get("step") 90 91 if not start or not end or not isinstance(step, exp.Interval): 92 continue 93 94 alias: exp.TableAlias | None = unnest.args.get("alias") 95 column_name: str = ( 96 alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 97 ) 98 99 start = exp.cast(start, "date") 100 date_add = exp.func( 101 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 102 ) 103 cast_date_add = exp.cast(date_add, "date") 104 105 cte_name = "_generated_dates" + (f"_{count}" if count else "") 106 107 base_query = exp.select(start.as_(column_name)) 108 recursive_query = ( 109 exp.select(cast_date_add) 110 .from_(cte_name) 111 .where(cast_date_add <= exp.cast(end, "date")) 112 ) 113 cte_query = base_query.union(recursive_query, distinct=False) 114 115 generate_dates_query = exp.select(column_name).from_(cte_name) 116 unnest.replace(generate_dates_query.subquery(cte_name)) 117 118 recursive_ctes.append( 119 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 120 ) 121 count += 1 122 123 if recursive_ctes: 124 with_expression: exp.With = expression.args.get("with_") or exp.With() 125 with_expression.set("recursive", True) 126 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 127 expression.set("with_", with_expression) 128 129 return expression
132def unnest_generate_series(expression: exp.Expr) -> exp.Expr: 133 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 134 this = expression.this 135 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 136 unnest = exp.Unnest(expressions=[this]) 137 if expression.alias: 138 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 139 140 return unnest 141 142 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
145def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr: 146 """ 147 Convert SELECT DISTINCT ON statements to a subquery with a window function. 148 149 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 150 151 Args: 152 expression: the expression that will be transformed. 153 154 Returns: 155 The transformed expression. 156 """ 157 if ( 158 isinstance(expression, exp.Select) 159 and expression.args.get("distinct") 160 and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple) 161 ): 162 row_number_window_alias = find_new_name(expression.named_selects, "_row_number") 163 164 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 165 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 166 167 order: exp.Order | None = expression.args.get("order") 168 if order: 169 window.set("order", order.pop()) 170 else: 171 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 172 173 expression.select(exp.alias_(window, row_number_window_alias), copy=False) 174 175 # We add aliases to the projections so that we can safely reference them in the outer query 176 new_selects: list[exp.Expr] = [] 177 taken_names = {row_number_window_alias} 178 for select in expression.selects[:-1]: 179 if select.is_star: 180 new_selects = [exp.Star()] 181 break 182 183 if not isinstance(select, exp.Alias): 184 alias = find_new_name(taken_names, select.output_name or "_col") 185 quoted: bool | None = ( 186 select.this.args.get("quoted") if isinstance(select, exp.Column) else None 187 ) 188 select = select.replace(exp.alias_(select, alias, quoted=quoted)) 189 190 taken_names.add(select.output_name) 191 new_selects.append(select.args["alias"]) 192 193 return ( 194 exp.select(*new_selects, copy=False) 195 .from_(expression.subquery("_t", copy=False), copy=False) 196 .where(exp.column(row_number_window_alias).eq(1), copy=False) 197 ) 198 199 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
202def eliminate_qualify(expression: exp.Expr) -> exp.Expr: 203 """ 204 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 205 206 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 207 https://docs.snowflake.com/en/sql-reference/constructs/qualify 208 209 Some dialects don't support window functions in the WHERE clause, so we need to include them as 210 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 211 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 212 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 213 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 214 corresponding expression to avoid creating invalid column references. 215 """ 216 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 217 taken = set(expression.named_selects) 218 for select in expression.selects: 219 if not select.alias_or_name: 220 alias = find_new_name(taken, "_c") 221 select.replace(exp.alias_(select, alias)) 222 taken.add(alias) 223 224 def _select_alias_or_name(select: exp.Expr) -> str | exp.Column: 225 alias_or_name = select.alias_or_name 226 identifier = select.args.get("alias") or select.this 227 if isinstance(identifier, exp.Identifier): 228 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 229 return alias_or_name 230 231 outer_selects = exp.select(*map(_select_alias_or_name, expression.selects)) 232 qualify_filters: exp.Expr = expression.args["qualify"].pop().this 233 expression_by_alias: dict[str, exp.Expr] = { 234 select.alias: select.this 235 for select in expression.selects 236 if isinstance(select, exp.Alias) 237 } 238 239 select_candidates = (exp.Window,) if expression.is_star else (exp.Window, exp.Column) 240 for select_candidate in list(qualify_filters.find_all(*select_candidates)): 241 if isinstance(select_candidate, exp.Window): 242 if expression_by_alias: 243 for column in select_candidate.find_all(exp.Column): 244 expr = expression_by_alias.get(column.name) 245 if expr: 246 column.replace(expr) 247 248 alias = find_new_name(expression.named_selects, "_w") 249 expression.select(exp.alias_(select_candidate, alias), copy=False) 250 column = exp.column(alias) 251 252 if isinstance(select_candidate.parent, exp.Qualify): 253 qualify_filters = column 254 else: 255 select_candidate.replace(column) 256 elif select_candidate.name not in expression.named_selects: 257 expression.select(select_candidate.copy(), copy=False) 258 259 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 260 qualify_filters, copy=False 261 ) 262 263 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
266def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr: 267 """ 268 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 269 other expressions. This transforms removes the precision from parameterized types in expressions. 270 """ 271 for node in expression.find_all(exp.DataType): 272 node.set( 273 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 274 ) 275 276 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
279def unqualify_unnest(expression: exp.Expr) -> exp.Expr: 280 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 281 from sqlglot.optimizer.scope import find_all_in_scope 282 283 if isinstance(expression, exp.Select): 284 unnest_aliases = { 285 unnest.alias 286 for unnest in find_all_in_scope(expression, exp.Unnest) 287 if isinstance(unnest.parent, (exp.From, exp.Join)) 288 } 289 if unnest_aliases: 290 for column in expression.find_all(exp.Column): 291 leftmost_part = column.parts[0] 292 if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases: 293 leftmost_part.pop() 294 295 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
298def unnest_to_explode( 299 expression: exp.Expr, 300 unnest_using_arrays_zip: bool = True, 301) -> exp.Expr: 302 """Convert cross join unnest into lateral view explode.""" 303 304 def _unnest_zip_exprs( 305 u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool 306 ) -> list[exp.Expr]: 307 if has_multi_expr: 308 if not unnest_using_arrays_zip: 309 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 310 311 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 312 zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)] 313 u.set("expressions", zip_exprs) 314 return zip_exprs 315 return unnest_exprs 316 317 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]: 318 if u.args.get("offset"): 319 return exp.Posexplode 320 return exp.Inline if has_multi_expr else exp.Explode 321 322 if isinstance(expression, exp.Select): 323 from_ = expression.args.get("from_") 324 325 if from_ and isinstance(from_.this, exp.Unnest): 326 unnest: exp.Unnest = from_.this 327 alias: exp.TableAlias | None = unnest.args.get("alias") 328 exprs: list[exp.Expr] = unnest.expressions 329 has_multi_expr = len(exprs) > 1 330 this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 331 332 columns: list[exp.Identifier] = alias.columns if alias else [] 333 offset: exp.Expr | None = unnest.args.get("offset") 334 if offset: 335 columns.insert( 336 0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos") 337 ) 338 339 unnest.replace( 340 exp.Table( 341 this=_udtf_type(unnest, has_multi_expr)(this=this), 342 alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None, 343 ) 344 ) 345 346 joins: list[exp.Join] = expression.args.get("joins") or [] 347 for join in list(joins): 348 join_expr = join.this 349 350 is_lateral = isinstance(join_expr, exp.Lateral) 351 352 unnest = join_expr.this if is_lateral else join_expr 353 354 if isinstance(unnest, exp.Unnest): 355 if is_lateral: 356 alias = join_expr.args.get("alias") 357 else: 358 alias = unnest.args.get("alias") 359 360 if alias is None: 361 raise UnsupportedError( 362 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires an alias" 363 ) 364 365 exprs = unnest.expressions 366 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 367 has_multi_expr = len(exprs) > 1 368 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 369 370 joins.remove(join) 371 372 alias_cols: list[exp.Identifier] = alias.columns 373 374 # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases 375 # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount. 376 # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html 377 378 if not has_multi_expr and len(alias_cols) not in (1, 2): 379 raise UnsupportedError( 380 "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases" 381 ) 382 383 offset = unnest.args.get("offset") 384 if offset: 385 alias_cols.insert( 386 0, 387 offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"), 388 ) 389 390 for e, column in zip(exprs, alias_cols): 391 expression.append( 392 "laterals", 393 exp.Lateral( 394 this=_udtf_type(unnest, has_multi_expr)(this=e), 395 view=True, 396 alias=exp.TableAlias(this=alias.this, columns=alias_cols), 397 ), 398 ) 399 400 return expression
Convert cross join unnest into lateral view explode.
403def explode_projection_to_unnest( 404 index_offset: int = 0, 405) -> t.Callable[[exp.Expr], exp.Expr]: 406 """Convert explode/posexplode projections into unnests.""" 407 408 def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr: 409 if isinstance(expression, exp.Select): 410 from sqlglot.optimizer.scope import Scope 411 412 taken_select_names = set(expression.named_selects) 413 taken_source_names = {name for name, _ in Scope(expression).references} 414 415 def new_name(names: set[str], name: str) -> str: 416 name = find_new_name(names, name) 417 names.add(name) 418 return name 419 420 arrays: list[exp.Condition] = [] 421 series_alias = new_name(taken_select_names, "pos") 422 series = exp.alias_( 423 exp.Unnest( 424 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 425 ), 426 new_name(taken_source_names, "_u"), 427 table=[series_alias], 428 ) 429 430 # we use list here because expression.selects is mutated inside the loop 431 for select in list(expression.selects): 432 explode = select.find(exp.Explode) 433 434 if explode: 435 pos_alias: t.Any = "" 436 explode_alias: t.Any = "" 437 438 if isinstance(select, exp.Alias): 439 explode_alias = select.args["alias"] 440 alias: exp.Expr = select 441 elif isinstance(select, exp.Aliases): 442 pos_alias = select.aliases[0] 443 explode_alias = select.aliases[1] 444 alias = select.replace(exp.alias_(select.this, "", copy=False)) 445 else: 446 alias = select.replace(exp.alias_(select, "")) 447 explode = alias.find(exp.Explode) 448 assert explode 449 450 is_posexplode = isinstance(explode, exp.Posexplode) 451 explode_arg = explode.this 452 453 if isinstance(explode, exp.ExplodeOuter): 454 bracket = explode_arg[0] 455 bracket.set("safe", True) 456 bracket.set("offset", True) 457 explode_arg = exp.func( 458 "IF", 459 exp.func( 460 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 461 ).eq(0), 462 exp.array(bracket, copy=False), 463 explode_arg, 464 ) 465 466 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 467 if isinstance(explode_arg, exp.Column): 468 taken_select_names.add(explode_arg.output_name) 469 470 unnest_source_alias = new_name(taken_source_names, "_u") 471 472 if not explode_alias: 473 explode_alias = new_name(taken_select_names, "col") 474 475 if is_posexplode: 476 pos_alias = new_name(taken_select_names, "pos") 477 478 if not pos_alias: 479 pos_alias = new_name(taken_select_names, "pos") 480 481 alias.set("alias", exp.to_identifier(explode_alias)) 482 483 series_table_alias = series.args["alias"].this 484 column = exp.If( 485 this=exp.column(series_alias, table=series_table_alias).eq( 486 exp.column(pos_alias, table=unnest_source_alias) 487 ), 488 true=exp.column(explode_alias, table=unnest_source_alias), 489 ) 490 491 explode.replace(column) 492 493 if is_posexplode: 494 expressions = expression.expressions 495 expressions.insert( 496 expressions.index(alias) + 1, 497 exp.If( 498 this=exp.column(series_alias, table=series_table_alias).eq( 499 exp.column(pos_alias, table=unnest_source_alias) 500 ), 501 true=exp.column(pos_alias, table=unnest_source_alias), 502 ).as_(pos_alias), 503 ) 504 expression.set("expressions", expressions) 505 506 if not arrays: 507 if expression.args.get("from_"): 508 expression.join(series, copy=False, join_type="CROSS") 509 else: 510 expression.from_(series, copy=False) 511 512 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 513 arrays.append(size) 514 515 # trino doesn't support left join unnest with on conditions 516 # if it did, this would be much simpler 517 expression.join( 518 exp.alias_( 519 exp.Unnest( 520 expressions=[explode_arg.copy()], 521 offset=exp.to_identifier(pos_alias), 522 ), 523 unnest_source_alias, 524 table=[explode_alias], 525 ), 526 join_type="CROSS", 527 copy=False, 528 ) 529 530 if index_offset != 1: 531 size = size - 1 532 533 expression.where( 534 exp.column(series_alias, table=series_table_alias) 535 .eq(exp.column(pos_alias, table=unnest_source_alias)) 536 .or_( 537 (exp.column(series_alias, table=series_table_alias) > size).and_( 538 exp.column(pos_alias, table=unnest_source_alias).eq(size) 539 ) 540 ), 541 copy=False, 542 ) 543 544 if arrays: 545 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 546 547 if index_offset != 1: 548 end = end - (1 - index_offset) 549 series.expressions[0].set("end", end) 550 551 return expression 552 553 return _explode_projection_to_unnest
Convert explode/posexplode projections into unnests.
556def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr: 557 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 558 if ( 559 isinstance(expression, exp.PERCENTILES) 560 and not isinstance(expression.parent, exp.WithinGroup) 561 and expression.expression 562 ): 563 column = expression.this.pop() 564 expression.set("this", expression.expression.pop()) 565 order = exp.Order(expressions=[exp.Ordered(this=column)]) 566 expression = exp.WithinGroup(this=expression, expression=order) 567 568 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
571def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr: 572 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 573 if ( 574 isinstance(expression, exp.WithinGroup) 575 and isinstance(expression.this, exp.PERCENTILES) 576 and isinstance(expression.expression, exp.Order) 577 ): 578 quantile = expression.this.this 579 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 580 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 581 582 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
585def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr: 586 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 587 if isinstance(expression, exp.With) and expression.recursive: 588 next_name = name_sequence("_c_") 589 590 for cte in expression.expressions: 591 if not cte.args["alias"].columns: 592 query = cte.this 593 if isinstance(query, exp.SetOperation): 594 query = query.this 595 596 cte.args["alias"].set( 597 "columns", 598 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 599 ) 600 601 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
604def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr: 605 """Replace 'epoch' in casts by the equivalent date literal.""" 606 if ( 607 isinstance(expression, (exp.Cast, exp.TryCast)) 608 and expression.name.lower() == "epoch" 609 and expression.to.this in exp.DataType.TEMPORAL_TYPES 610 ): 611 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 612 613 return expression
Replace 'epoch' in casts by the equivalent date literal.
616def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr: 617 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 618 if isinstance(expression, exp.Select): 619 for join in list[exp.Join](expression.args.get("joins") or []): 620 on: exp.Expr | None = join.args.get("on") 621 if on and join.kind in ("SEMI", "ANTI"): 622 subquery = exp.select("1").from_(join.this).where(on) 623 exists: exp.Exists | exp.Not = exp.Exists(this=subquery) 624 if join.kind == "ANTI": 625 exists = exists.not_(copy=False) 626 627 join.pop() 628 expression.where(exists, copy=False) 629 630 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
633def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr: 634 """ 635 Converts a query with a FULL OUTER join to a union of identical queries that 636 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 637 for queries that have a single FULL OUTER join. 638 """ 639 if isinstance(expression, exp.Select): 640 full_outer_joins: list[tuple[int, exp.Join]] = [ 641 (index, join) 642 for index, join in enumerate[exp.Join](expression.args.get("joins") or []) 643 if join.side == "FULL" 644 ] 645 646 if len(full_outer_joins) == 1: 647 expression_copy = expression.copy() 648 expression.set("limit", None) 649 index, full_outer_join = full_outer_joins[0] 650 651 tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name) 652 join_conditions = full_outer_join.args.get("on") or exp.and_( 653 *[ 654 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 655 for col in t.cast(list[exp.Identifier], full_outer_join.args.get("using")) 656 ] 657 ) 658 659 full_outer_join.set("side", "left") 660 anti_join_clause = ( 661 exp.select("1").from_(expression.args["from_"]).where(join_conditions) 662 ) 663 expression_copy.args["joins"][index].set("side", "right") 664 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 665 expression_copy.set("with_", None) # remove CTEs from RIGHT side 666 expression.set("order", None) # remove order by from LEFT side 667 668 return exp.union(expression, expression_copy, copy=False, distinct=False) 669 670 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
673def move_ctes_to_top_level(expression: E) -> E: 674 """ 675 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 676 defined at the top-level, so for example queries like: 677 678 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 679 680 are invalid in those dialects. This transformation can be used to ensure all CTEs are 681 moved to the top level so that the final SQL code is valid from a syntax standpoint. 682 683 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 684 """ 685 top_level_with: exp.With | None = expression.args.get("with_") 686 for inner_with in expression.find_all(exp.With): 687 if inner_with.parent is expression: 688 continue 689 690 if not top_level_with: 691 top_level_with = inner_with.pop() 692 expression.set("with_", top_level_with) 693 else: 694 if inner_with.recursive: 695 top_level_with.set("recursive", True) 696 697 parent_cte = inner_with.find_ancestor(exp.CTE) 698 inner_with.pop() 699 700 if parent_cte: 701 i = top_level_with.expressions.index(parent_cte) 702 top_level_with.expressions[i:i] = inner_with.expressions 703 top_level_with.set("expressions", top_level_with.expressions) 704 else: 705 top_level_with.set( 706 "expressions", top_level_with.expressions + inner_with.expressions 707 ) 708 709 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
712def ensure_bools(expression: exp.Expr) -> exp.Expr: 713 """Converts numeric values used in conditions into explicit boolean expressions.""" 714 from sqlglot.optimizer.canonicalize import ensure_bools 715 716 def _ensure_bool(node: exp.Expr) -> None: 717 if ( 718 node.is_number 719 or ( 720 not isinstance(node, exp.SubqueryPredicate) 721 and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 722 ) 723 or (isinstance(node, exp.Column) and not node.type) 724 ): 725 node.replace(node.neq(0)) 726 727 for node in expression.walk(): 728 ensure_bools(node, _ensure_bool) 729 730 return expression
Converts numeric values used in conditions into explicit boolean expressions.
751def ctas_with_tmp_tables_to_create_tmp_view( 752 expression: exp.Expr, 753 tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e, 754) -> exp.Expr: 755 assert isinstance(expression, exp.Create) 756 properties: exp.Properties | None = expression.args.get("properties") 757 temporary = any( 758 isinstance(prop, exp.TemporaryProperty) 759 for prop in (properties.expressions if properties is not None else []) 760 ) 761 762 # CTAS with temp tables map to CREATE TEMPORARY VIEW 763 if expression.kind == "TABLE" and temporary: 764 if expression.expression: 765 return exp.Create( 766 kind="TEMPORARY VIEW", 767 this=expression.this, 768 expression=expression.expression, 769 ) 770 return tmp_storage_provider(expression) 771 772 return expression
775def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr: 776 """ 777 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 778 PARTITIONED BY value is an array of column names, they are transformed into a schema. 779 The corresponding columns are removed from the create statement. 780 """ 781 assert isinstance(expression, exp.Create) 782 schema = expression.this 783 is_partitionable = expression.kind in {"TABLE", "VIEW"} 784 785 if isinstance(schema, exp.Schema) and is_partitionable: 786 prop = expression.find(exp.PartitionedByProperty) 787 if prop and prop.this and not isinstance(prop.this, exp.Schema): 788 columns: set[str] = {v.name.upper() for v in prop.this.expressions} 789 schema_exprs: list[exp.Expr] = schema.expressions 790 partitions = [col for col in schema_exprs if col.name.upper() in columns] 791 schema.set("expressions", [e for e in schema_exprs if e not in partitions]) 792 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 793 expression.set("this", schema) 794 795 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
798def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr: 799 """ 800 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 801 802 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 803 """ 804 assert isinstance(expression, exp.Create) 805 prop = expression.find(exp.PartitionedByProperty) 806 if ( 807 prop 808 and prop.this 809 and isinstance(prop.this, exp.Schema) 810 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 811 ): 812 prop_this = exp.Tuple( 813 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 814 ) 815 schema: exp.Schema = expression.this 816 for e in prop.this.expressions: 817 schema.append("expressions", e) 818 prop.set("this", prop_this) 819 820 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
823def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr: 824 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 825 if isinstance(expression, exp.Struct): 826 expression.set( 827 "expressions", 828 [ 829 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 830 for e in expression.expressions 831 ], 832 ) 833 834 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
837def eliminate_join_marks(expression: exp.Expr) -> exp.Expr: 838 """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178 839 840 1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax. 841 842 2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view. 843 844 The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query. 845 846 You cannot use the (+) operator to outer-join a table to itself, although self joins are valid. 847 848 The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator. 849 850 A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator. 851 852 A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression. 853 854 A WHERE condition cannot compare any column marked with the (+) operator with a subquery. 855 856 -- example with WHERE 857 SELECT d.department_name, sum(e.salary) as total_salary 858 FROM departments d, employees e 859 WHERE e.department_id(+) = d.department_id 860 group by department_name 861 862 -- example of left correlation in select 863 SELECT d.department_name, ( 864 SELECT SUM(e.salary) 865 FROM employees e 866 WHERE e.department_id(+) = d.department_id) AS total_salary 867 FROM departments d; 868 869 -- example of left correlation in from 870 SELECT d.department_name, t.total_salary 871 FROM departments d, ( 872 SELECT SUM(e.salary) AS total_salary 873 FROM employees e 874 WHERE e.department_id(+) = d.department_id 875 ) t 876 """ 877 878 from sqlglot.optimizer.scope import traverse_scope 879 from sqlglot.optimizer.normalize import normalize, normalized 880 from collections import defaultdict 881 882 # we go in reverse to check the main query for left correlation 883 for scope in reversed(traverse_scope(expression)): 884 query = scope.expression 885 886 where: exp.Expr | None = query.args.get("where") 887 joins: list[exp.Join] = query.args.get("joins", []) 888 889 if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)): 890 continue 891 892 # knockout: we do not support left correlation (see point 2) 893 assert not scope.is_correlated_subquery, "Correlated queries are not supported" 894 895 # make sure we have AND of ORs to have clear join terms 896 where = normalize(where.this) 897 assert normalized(where), "Cannot normalize JOIN predicates" 898 # dict of {name: list of join AND conditions} 899 joins_ons: defaultdict[str, list[exp.Expr]] = defaultdict(list) 900 for cond in [where] if not isinstance(where, exp.And) else where.flatten(): 901 join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")] 902 903 left_join_table = set(col.table for col in join_cols) 904 if not left_join_table: 905 continue 906 907 assert not (len(left_join_table) > 1), ( 908 "Cannot combine JOIN predicates from different tables" 909 ) 910 911 for col in join_cols: 912 col.set("join_mark", False) 913 914 joins_ons[left_join_table.pop()].append(cond) 915 916 old_joins = {join.alias_or_name: join for join in joins} 917 new_joins: dict[str, exp.Join] = {} 918 query_from = query.args["from_"] 919 920 for table, predicates in joins_ons.items(): 921 join_what = old_joins.get(table, query_from).this.copy() 922 new_joins[join_what.alias_or_name] = exp.Join( 923 this=join_what, on=exp.and_(*predicates), kind="LEFT" 924 ) 925 926 for p in predicates: 927 while isinstance(p.parent, exp.Paren): 928 p.parent.replace(p) 929 930 parent = p.parent 931 p.pop() 932 if isinstance(parent, exp.Binary): 933 left = parent.args.get("this") 934 parent.replace(parent.right if left is None else left) 935 elif isinstance(parent, exp.Where): 936 parent.pop() 937 938 if query_from.alias_or_name in new_joins: 939 only_old_joins: set[str] = old_joins.keys() - new_joins.keys() 940 assert len(only_old_joins) >= 1, ( 941 "Cannot determine which table to use in the new FROM clause" 942 ) 943 944 new_from_name = list[str](only_old_joins)[0] 945 query.set("from_", exp.From(this=old_joins[new_from_name].this)) 946 947 if new_joins: 948 for n, j in old_joins.items(): # preserve any other joins 949 if n not in new_joins and n != query.args["from_"].name: 950 if not j.kind: 951 j.set("kind", "CROSS") 952 new_joins[n] = j 953 query.set("joins", list(new_joins.values())) 954 955 return expression
https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name
-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;
-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t
958def any_to_exists(expression: exp.Expr) -> exp.Expr: 959 """ 960 Transform ANY operator to Spark's EXISTS 961 962 For example, 963 - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) 964 - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5) 965 966 Both ANY and EXISTS accept queries but currently only array expressions are supported for this 967 transformation 968 """ 969 if isinstance(expression, exp.Select): 970 for any_expr in expression.find_all(exp.Any): 971 this: exp.Expr = any_expr.this 972 if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)): 973 continue 974 975 binop = any_expr.parent 976 if isinstance(binop, exp.Binary): 977 lambda_arg = exp.to_identifier("x") 978 any_expr.replace(lambda_arg) 979 lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg]) 980 binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr)) 981 982 return expression
Transform ANY operator to Spark's EXISTS
For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation
985def eliminate_window_clause(expression: exp.Expr) -> exp.Expr: 986 """Eliminates the `WINDOW` query clause by inling each named window.""" 987 windows: list[exp.Expr] | None = expression.args.get("windows") 988 if isinstance(expression, exp.Select) and windows is not None: 989 from sqlglot.optimizer.scope import find_all_in_scope 990 991 expression.set("windows", None) 992 993 window_expression: dict[str, exp.Expr] = {} 994 995 def _inline_inherited_window(window: exp.Expr) -> None: 996 inherited_window = window_expression.get(window.alias.lower()) 997 if not inherited_window: 998 return 999 1000 window.set("alias", None) 1001 for key in ("partition_by", "order", "spec"): 1002 arg: exp.Expr | None = inherited_window.args.get(key) 1003 if arg is not None: 1004 window.set(key, arg.copy()) 1005 1006 for window in windows: 1007 _inline_inherited_window(window) 1008 window_expression[window.name.lower()] = window 1009 1010 for window in find_all_in_scope(expression, exp.Window): 1011 _inline_inherited_window(window) 1012 1013 return expression
Eliminates the WINDOW query clause by inling each named window.
1016def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr: 1017 """ 1018 Inherit field names from the first struct in an array. 1019 1020 BigQuery supports implicitly inheriting names from the first STRUCT in an array: 1021 1022 Example: 1023 ARRAY[ 1024 STRUCT('Alice' AS name, 85 AS score), -- defines names 1025 STRUCT('Bob', 92), -- inherits names 1026 STRUCT('Diana', 95) -- inherits names 1027 ] 1028 1029 This transformation makes the field names explicit on all structs by adding 1030 PropertyEQ nodes, in order to facilitate transpilation to other dialects. 1031 1032 Args: 1033 expression: The expression tree to transform 1034 1035 Returns: 1036 The modified expression with field names inherited in all structs 1037 """ 1038 if ( 1039 isinstance(expression, exp.Array) 1040 and expression.args.get("struct_name_inheritance") 1041 and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct) 1042 and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions) 1043 ): 1044 field_names: list[exp.Identifier] = [fld.this for fld in first_item.expressions] 1045 1046 # Apply field names to subsequent structs that don't have them 1047 for struct in expression.expressions[1:]: 1048 if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names): 1049 continue 1050 1051 # Convert unnamed expressions to PropertyEQ with inherited names 1052 new_expressions: list[exp.PropertyEQ] = [] 1053 for i, expr in enumerate(struct.expressions): 1054 if not isinstance(expr, exp.PropertyEQ): 1055 # Create PropertyEQ: field_name := value, preserving the type from the inner expression 1056 property_eq = exp.PropertyEQ( 1057 this=field_names[i].copy(), 1058 expression=expr, 1059 ) 1060 property_eq.type = expr.type 1061 new_expressions.append(property_eq) 1062 else: 1063 new_expressions.append(expr) 1064 1065 struct.set("expressions", new_expressions) 1066 1067 return expression
Inherit field names from the first struct in an array.
BigQuery supports implicitly inheriting names from the first STRUCT in an array:
Example:
ARRAY[ STRUCT('Alice' AS name, 85 AS score), -- defines names STRUCT('Bob', 92), -- inherits names STRUCT('Diana', 95) -- inherits names ]
This transformation makes the field names explicit on all structs by adding PropertyEQ nodes, in order to facilitate transpilation to other dialects.
Arguments:
- expression: The expression tree to transform
Returns:
The modified expression with field names inherited in all structs