Edit on GitHub

sqlglot.transforms

   1from __future__ import annotations
   2
   3import typing as t
   4
   5from sqlglot import expressions as exp
   6from sqlglot.errors import UnsupportedError
   7from sqlglot.helper import find_new_name, name_sequence, seq_get
   8
   9
  10if t.TYPE_CHECKING:
  11    from sqlglot._typing import E
  12    from sqlglot.generator import Generator
  13
  14
  15class SqlHandler(t.Protocol):
  16    def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> str: ...
  17
  18
  19def preprocess(
  20    transforms: list[t.Callable[[exp.Expr], exp.Expr]],
  21    generator: t.Callable[[Generator, exp.Expr], str] | None = None,
  22) -> t.Callable[[Generator, exp.Expr], str]:
  23    """
  24    Creates a new transform by chaining a sequence of transformations and converts the resulting
  25    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
  26    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
  27
  28    Args:
  29        transforms: sequence of transform functions. These will be called in order.
  30
  31    Returns:
  32        Function that can be used as a generator transform.
  33    """
  34
  35    def _to_sql(self: Generator, expression: exp.Expr) -> str:
  36        expression_type = type(expression)
  37
  38        try:
  39            expression = transforms[0](expression)
  40            for transform in transforms[1:]:
  41                expression = transform(expression)
  42        except UnsupportedError as unsupported_error:
  43            self.unsupported(str(unsupported_error))
  44
  45        if generator:
  46            return generator(self, expression)
  47
  48        _sql_handler: SqlHandler | None = getattr(self, expression.key + "_sql", None)
  49        if _sql_handler:
  50            return _sql_handler(expression)
  51
  52        transforms_handler = self.TRANSFORMS.get(type(expression))
  53        if transforms_handler:
  54            if expression_type is type(expression):
  55                if isinstance(expression, exp.Func):
  56                    return self.function_fallback_sql(expression)
  57
  58                # Ensures we don't enter an infinite loop. This can happen when the original expression
  59                # has the same type as the final expression and there's no _sql method available for it,
  60                # because then it'd re-enter _to_sql.
  61                raise ValueError(
  62                    f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed."
  63                )
  64
  65            return transforms_handler(self, expression)
  66
  67        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
  68
  69    return _to_sql
  70
  71
  72def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr:
  73    if isinstance(expression, exp.Select):
  74        count = 0
  75        recursive_ctes: list[exp.Expr] = []
  76
  77        for unnest in expression.find_all(exp.Unnest):
  78            if (
  79                not isinstance(unnest.parent, (exp.From, exp.Join))
  80                or len(unnest.expressions) != 1
  81                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
  82            ):
  83                continue
  84
  85            generate_date_array = unnest.expressions[0]
  86            start: exp.Expr | None = generate_date_array.args.get("start")
  87            end: exp.Expr | None = generate_date_array.args.get("end")
  88            step: exp.Expr | None = generate_date_array.args.get("step")
  89
  90            if not start or not end or not isinstance(step, exp.Interval):
  91                continue
  92
  93            alias: exp.TableAlias | None = unnest.args.get("alias")
  94            column_name: str = (
  95                alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
  96            )
  97
  98            start = exp.cast(start, "date")
  99            date_add = exp.func(
 100                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
 101            )
 102            cast_date_add = exp.cast(date_add, "date")
 103
 104            cte_name = "_generated_dates" + (f"_{count}" if count else "")
 105
 106            base_query = exp.select(start.as_(column_name))
 107            recursive_query = (
 108                exp.select(cast_date_add)
 109                .from_(cte_name)
 110                .where(cast_date_add <= exp.cast(end, "date"))
 111            )
 112            cte_query = base_query.union(recursive_query, distinct=False)
 113
 114            generate_dates_query = exp.select(column_name).from_(cte_name)
 115            unnest.replace(generate_dates_query.subquery(cte_name))
 116
 117            recursive_ctes.append(
 118                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
 119            )
 120            count += 1
 121
 122        if recursive_ctes:
 123            with_expression: exp.With = expression.args.get("with_") or exp.With()
 124            with_expression.set("recursive", True)
 125            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
 126            expression.set("with_", with_expression)
 127
 128    return expression
 129
 130
 131def unnest_generate_series(expression: exp.Expr) -> exp.Expr:
 132    """Unnests GENERATE_SERIES or SEQUENCE table references."""
 133    this = expression.this
 134    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
 135        unnest = exp.Unnest(expressions=[this])
 136        if expression.alias:
 137            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
 138
 139        return unnest
 140
 141    return expression
 142
 143
 144def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr:
 145    """
 146    Convert SELECT DISTINCT ON statements to a subquery with a window function.
 147
 148    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
 149
 150    Args:
 151        expression: the expression that will be transformed.
 152
 153    Returns:
 154        The transformed expression.
 155    """
 156    if (
 157        isinstance(expression, exp.Select)
 158        and expression.args.get("distinct")
 159        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
 160    ):
 161        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
 162
 163        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
 164        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
 165
 166        order: exp.Order | None = expression.args.get("order")
 167        if order:
 168            window.set("order", order.pop())
 169        else:
 170            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
 171
 172        expression.select(exp.alias_(window, row_number_window_alias), copy=False)
 173
 174        # We add aliases to the projections so that we can safely reference them in the outer query
 175        new_selects: list[exp.Expr] = []
 176        taken_names = {row_number_window_alias}
 177        for select in expression.selects[:-1]:
 178            if select.is_star:
 179                new_selects = [exp.Star()]
 180                break
 181
 182            if not isinstance(select, exp.Alias):
 183                alias = find_new_name(taken_names, select.output_name or "_col")
 184                quoted: bool | None = (
 185                    select.this.args.get("quoted") if isinstance(select, exp.Column) else None
 186                )
 187                select = select.replace(exp.alias_(select, alias, quoted=quoted))
 188
 189            taken_names.add(select.output_name)
 190            new_selects.append(select.args["alias"])
 191
 192        return (
 193            exp.select(*new_selects, copy=False)
 194            .from_(expression.subquery("_t", copy=False), copy=False)
 195            .where(exp.column(row_number_window_alias).eq(1), copy=False)
 196        )
 197
 198    return expression
 199
 200
 201def eliminate_qualify(expression: exp.Expr) -> exp.Expr:
 202    """
 203    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
 204
 205    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
 206    https://docs.snowflake.com/en/sql-reference/constructs/qualify
 207
 208    Some dialects don't support window functions in the WHERE clause, so we need to include them as
 209    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
 210    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
 211    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
 212    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
 213    corresponding expression to avoid creating invalid column references.
 214    """
 215    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
 216        taken = set(expression.named_selects)
 217        for select in expression.selects:
 218            if not select.alias_or_name:
 219                alias = find_new_name(taken, "_c")
 220                select.replace(exp.alias_(select, alias))
 221                taken.add(alias)
 222
 223        def _select_alias_or_name(select: exp.Expr) -> str | exp.Column:
 224            alias_or_name = select.alias_or_name
 225            identifier = select.args.get("alias") or select.this
 226            if isinstance(identifier, exp.Identifier):
 227                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
 228            return alias_or_name
 229
 230        outer_selects = exp.select(*map(_select_alias_or_name, expression.selects))
 231        qualify_filters: exp.Expr = expression.args["qualify"].pop().this
 232        expression_by_alias: dict[str, exp.Expr] = {
 233            select.alias: select.this
 234            for select in expression.selects
 235            if isinstance(select, exp.Alias)
 236        }
 237
 238        select_candidates = (exp.Window,) if expression.is_star else (exp.Window, exp.Column)
 239        for select_candidate in list(qualify_filters.find_all(*select_candidates)):
 240            if isinstance(select_candidate, exp.Window):
 241                if expression_by_alias:
 242                    for column in select_candidate.find_all(exp.Column):
 243                        expr = expression_by_alias.get(column.name)
 244                        if expr:
 245                            column.replace(expr)
 246
 247                alias = find_new_name(expression.named_selects, "_w")
 248                expression.select(exp.alias_(select_candidate, alias), copy=False)
 249                column = exp.column(alias)
 250
 251                if isinstance(select_candidate.parent, exp.Qualify):
 252                    qualify_filters = column
 253                else:
 254                    select_candidate.replace(column)
 255            elif select_candidate.name not in expression.named_selects:
 256                expression.select(select_candidate.copy(), copy=False)
 257
 258        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
 259            qualify_filters, copy=False
 260        )
 261
 262    return expression
 263
 264
 265def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr:
 266    """
 267    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
 268    other expressions. This transforms removes the precision from parameterized types in expressions.
 269    """
 270    for node in expression.find_all(exp.DataType):
 271        node.set(
 272            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
 273        )
 274
 275    return expression
 276
 277
 278def unqualify_unnest(expression: exp.Expr) -> exp.Expr:
 279    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
 280    from sqlglot.optimizer.scope import find_all_in_scope
 281
 282    if isinstance(expression, exp.Select):
 283        unnest_aliases = {
 284            unnest.alias
 285            for unnest in find_all_in_scope(expression, exp.Unnest)
 286            if isinstance(unnest.parent, (exp.From, exp.Join))
 287        }
 288        if unnest_aliases:
 289            for column in expression.find_all(exp.Column):
 290                leftmost_part = column.parts[0]
 291                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
 292                    leftmost_part.pop()
 293
 294    return expression
 295
 296
 297def unnest_to_explode(
 298    expression: exp.Expr,
 299    unnest_using_arrays_zip: bool = True,
 300) -> exp.Expr:
 301    """Convert cross join unnest into lateral view explode."""
 302
 303    def _unnest_zip_exprs(
 304        u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool
 305    ) -> list[exp.Expr]:
 306        if has_multi_expr:
 307            if not unnest_using_arrays_zip:
 308                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
 309
 310            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
 311            zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)]
 312            u.set("expressions", zip_exprs)
 313            return zip_exprs
 314        return unnest_exprs
 315
 316    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]:
 317        if u.args.get("offset"):
 318            return exp.Posexplode
 319        return exp.Inline if has_multi_expr else exp.Explode
 320
 321    if isinstance(expression, exp.Select):
 322        from_ = expression.args.get("from_")
 323
 324        if from_ and isinstance(from_.this, exp.Unnest):
 325            unnest: exp.Unnest = from_.this
 326            alias: exp.TableAlias | None = unnest.args.get("alias")
 327            exprs: list[exp.Expr] = unnest.expressions
 328            has_multi_expr = len(exprs) > 1
 329            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 330
 331            columns: list[exp.Identifier] = alias.columns if alias else []
 332            offset: exp.Expr | None = unnest.args.get("offset")
 333            if offset:
 334                columns.insert(
 335                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
 336                )
 337
 338            unnest.replace(
 339                exp.Table(
 340                    this=_udtf_type(unnest, has_multi_expr)(this=this),
 341                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
 342                )
 343            )
 344
 345        joins: list[exp.Join] = expression.args.get("joins") or []
 346        for join in list(joins):
 347            join_expr = join.this
 348
 349            is_lateral = isinstance(join_expr, exp.Lateral)
 350
 351            unnest = join_expr.this if is_lateral else join_expr
 352
 353            if isinstance(unnest, exp.Unnest):
 354                if is_lateral:
 355                    alias = join_expr.args.get("alias")
 356                else:
 357                    alias = unnest.args.get("alias")
 358
 359                if alias is None:
 360                    raise UnsupportedError(
 361                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires an alias"
 362                    )
 363
 364                exprs = unnest.expressions
 365                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
 366                has_multi_expr = len(exprs) > 1
 367                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
 368
 369                joins.remove(join)
 370
 371                alias_cols: list[exp.Identifier] = alias.columns
 372
 373                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
 374                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
 375                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
 376
 377                if not has_multi_expr and len(alias_cols) not in (1, 2):
 378                    raise UnsupportedError(
 379                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
 380                    )
 381
 382                offset = unnest.args.get("offset")
 383                if offset:
 384                    alias_cols.insert(
 385                        0,
 386                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
 387                    )
 388
 389                for e, column in zip(exprs, alias_cols):
 390                    expression.append(
 391                        "laterals",
 392                        exp.Lateral(
 393                            this=_udtf_type(unnest, has_multi_expr)(this=e),
 394                            view=True,
 395                            alias=exp.TableAlias(this=alias.this, columns=alias_cols),
 396                        ),
 397                    )
 398
 399    return expression
 400
 401
 402def explode_projection_to_unnest(
 403    index_offset: int = 0,
 404) -> t.Callable[[exp.Expr], exp.Expr]:
 405    """Convert explode/posexplode projections into unnests."""
 406
 407    def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr:
 408        if isinstance(expression, exp.Select):
 409            from sqlglot.optimizer.scope import Scope
 410
 411            taken_select_names = set(expression.named_selects)
 412            taken_source_names = {name for name, _ in Scope(expression).references}
 413
 414            def new_name(names: set[str], name: str) -> str:
 415                name = find_new_name(names, name)
 416                names.add(name)
 417                return name
 418
 419            arrays: list[exp.Condition] = []
 420            series_alias = new_name(taken_select_names, "pos")
 421            series = exp.alias_(
 422                exp.Unnest(
 423                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
 424                ),
 425                new_name(taken_source_names, "_u"),
 426                table=[series_alias],
 427            )
 428
 429            # we use list here because expression.selects is mutated inside the loop
 430            for select in list(expression.selects):
 431                explode = select.find(exp.Explode)
 432
 433                if explode:
 434                    pos_alias: t.Any = ""
 435                    explode_alias: t.Any = ""
 436
 437                    if isinstance(select, exp.Alias):
 438                        explode_alias = select.args["alias"]
 439                        alias: exp.Expr = select
 440                    elif isinstance(select, exp.Aliases):
 441                        pos_alias = select.aliases[0]
 442                        explode_alias = select.aliases[1]
 443                        alias = select.replace(exp.alias_(select.this, "", copy=False))
 444                    else:
 445                        alias = select.replace(exp.alias_(select, ""))
 446                        explode = alias.find(exp.Explode)
 447                        assert explode
 448
 449                    is_posexplode = isinstance(explode, exp.Posexplode)
 450                    explode_arg = explode.this
 451
 452                    if isinstance(explode, exp.ExplodeOuter):
 453                        bracket = explode_arg[0]
 454                        bracket.set("safe", True)
 455                        bracket.set("offset", True)
 456                        explode_arg = exp.func(
 457                            "IF",
 458                            exp.func(
 459                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
 460                            ).eq(0),
 461                            exp.array(bracket, copy=False),
 462                            explode_arg,
 463                        )
 464
 465                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
 466                    if isinstance(explode_arg, exp.Column):
 467                        taken_select_names.add(explode_arg.output_name)
 468
 469                    unnest_source_alias = new_name(taken_source_names, "_u")
 470
 471                    if not explode_alias:
 472                        explode_alias = new_name(taken_select_names, "col")
 473
 474                        if is_posexplode:
 475                            pos_alias = new_name(taken_select_names, "pos")
 476
 477                    if not pos_alias:
 478                        pos_alias = new_name(taken_select_names, "pos")
 479
 480                    alias.set("alias", exp.to_identifier(explode_alias))
 481
 482                    series_table_alias = series.args["alias"].this
 483                    column = exp.If(
 484                        this=exp.column(series_alias, table=series_table_alias).eq(
 485                            exp.column(pos_alias, table=unnest_source_alias)
 486                        ),
 487                        true=exp.column(explode_alias, table=unnest_source_alias),
 488                    )
 489
 490                    explode.replace(column)
 491
 492                    if is_posexplode:
 493                        expressions = expression.expressions
 494                        expressions.insert(
 495                            expressions.index(alias) + 1,
 496                            exp.If(
 497                                this=exp.column(series_alias, table=series_table_alias).eq(
 498                                    exp.column(pos_alias, table=unnest_source_alias)
 499                                ),
 500                                true=exp.column(pos_alias, table=unnest_source_alias),
 501                            ).as_(pos_alias),
 502                        )
 503                        expression.set("expressions", expressions)
 504
 505                    if not arrays:
 506                        if expression.args.get("from_"):
 507                            expression.join(series, copy=False, join_type="CROSS")
 508                        else:
 509                            expression.from_(series, copy=False)
 510
 511                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
 512                    arrays.append(size)
 513
 514                    # trino doesn't support left join unnest with on conditions
 515                    # if it did, this would be much simpler
 516                    expression.join(
 517                        exp.alias_(
 518                            exp.Unnest(
 519                                expressions=[explode_arg.copy()],
 520                                offset=exp.to_identifier(pos_alias),
 521                            ),
 522                            unnest_source_alias,
 523                            table=[explode_alias],
 524                        ),
 525                        join_type="CROSS",
 526                        copy=False,
 527                    )
 528
 529                    if index_offset != 1:
 530                        size = size - 1
 531
 532                    expression.where(
 533                        exp.column(series_alias, table=series_table_alias)
 534                        .eq(exp.column(pos_alias, table=unnest_source_alias))
 535                        .or_(
 536                            (exp.column(series_alias, table=series_table_alias) > size).and_(
 537                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
 538                            )
 539                        ),
 540                        copy=False,
 541                    )
 542
 543            if arrays:
 544                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
 545
 546                if index_offset != 1:
 547                    end = end - (1 - index_offset)
 548                series.expressions[0].set("end", end)
 549
 550        return expression
 551
 552    return _explode_projection_to_unnest
 553
 554
 555def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
 556    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
 557    if (
 558        isinstance(expression, exp.PERCENTILES)
 559        and not isinstance(expression.parent, exp.WithinGroup)
 560        and expression.expression
 561    ):
 562        column = expression.this.pop()
 563        expression.set("this", expression.expression.pop())
 564        order = exp.Order(expressions=[exp.Ordered(this=column)])
 565        expression = exp.WithinGroup(this=expression, expression=order)
 566
 567    return expression
 568
 569
 570def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
 571    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
 572    if (
 573        isinstance(expression, exp.WithinGroup)
 574        and isinstance(expression.this, exp.PERCENTILES)
 575        and isinstance(expression.expression, exp.Order)
 576    ):
 577        quantile = expression.this.this
 578        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
 579        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
 580
 581    return expression
 582
 583
 584def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr:
 585    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
 586    if isinstance(expression, exp.With) and expression.recursive:
 587        next_name = name_sequence("_c_")
 588
 589        for cte in expression.expressions:
 590            if not cte.args["alias"].columns:
 591                query = cte.this
 592                if isinstance(query, exp.SetOperation):
 593                    query = query.this
 594
 595                cte.args["alias"].set(
 596                    "columns",
 597                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
 598                )
 599
 600    return expression
 601
 602
 603def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr:
 604    """Replace 'epoch' in casts by the equivalent date literal."""
 605    if (
 606        isinstance(expression, (exp.Cast, exp.TryCast))
 607        and expression.name.lower() == "epoch"
 608        and expression.to.this in exp.DataType.TEMPORAL_TYPES
 609    ):
 610        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
 611
 612    return expression
 613
 614
 615def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr:
 616    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
 617    if isinstance(expression, exp.Select):
 618        for join in list[exp.Join](expression.args.get("joins") or []):
 619            on: exp.Expr | None = join.args.get("on")
 620            if on and join.kind in ("SEMI", "ANTI"):
 621                subquery = exp.select("1").from_(join.this).where(on)
 622                exists: exp.Exists | exp.Not = exp.Exists(this=subquery)
 623                if join.kind == "ANTI":
 624                    exists = exists.not_(copy=False)
 625
 626                join.pop()
 627                expression.where(exists, copy=False)
 628
 629    return expression
 630
 631
 632def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr:
 633    """
 634    Converts a query with a FULL OUTER join to a union of identical queries that
 635    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
 636    for queries that have a single FULL OUTER join.
 637    """
 638    if isinstance(expression, exp.Select):
 639        full_outer_joins: list[tuple[int, exp.Join]] = [
 640            (index, join)
 641            for index, join in enumerate[exp.Join](expression.args.get("joins") or [])
 642            if join.side == "FULL"
 643        ]
 644
 645        if len(full_outer_joins) == 1:
 646            expression_copy = expression.copy()
 647            expression.set("limit", None)
 648            index, full_outer_join = full_outer_joins[0]
 649
 650            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
 651            join_conditions = full_outer_join.args.get("on") or exp.and_(
 652                *[
 653                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
 654                    for col in t.cast(list[exp.Identifier], full_outer_join.args.get("using"))
 655                ]
 656            )
 657
 658            full_outer_join.set("side", "left")
 659            anti_join_clause = (
 660                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
 661            )
 662            expression_copy.args["joins"][index].set("side", "right")
 663            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
 664            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
 665            expression.set("order", None)  # remove order by from LEFT side
 666
 667            return exp.union(expression, expression_copy, copy=False, distinct=False)
 668
 669    return expression
 670
 671
 672def move_ctes_to_top_level(expression: E) -> E:
 673    """
 674    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
 675    defined at the top-level, so for example queries like:
 676
 677        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
 678
 679    are invalid in those dialects. This transformation can be used to ensure all CTEs are
 680    moved to the top level so that the final SQL code is valid from a syntax standpoint.
 681
 682    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
 683    """
 684    top_level_with: exp.With | None = expression.args.get("with_")
 685    for inner_with in expression.find_all(exp.With):
 686        if inner_with.parent is expression:
 687            continue
 688
 689        if not top_level_with:
 690            top_level_with = inner_with.pop()
 691            expression.set("with_", top_level_with)
 692        else:
 693            if inner_with.recursive:
 694                top_level_with.set("recursive", True)
 695
 696            parent_cte = inner_with.find_ancestor(exp.CTE)
 697            inner_with.pop()
 698
 699            if parent_cte:
 700                i = top_level_with.expressions.index(parent_cte)
 701                top_level_with.expressions[i:i] = inner_with.expressions
 702                top_level_with.set("expressions", top_level_with.expressions)
 703            else:
 704                top_level_with.set(
 705                    "expressions", top_level_with.expressions + inner_with.expressions
 706                )
 707
 708    return expression
 709
 710
 711def ensure_bools(expression: exp.Expr) -> exp.Expr:
 712    """Converts numeric values used in conditions into explicit boolean expressions."""
 713    from sqlglot.optimizer.canonicalize import ensure_bools
 714
 715    def _ensure_bool(node: exp.Expr) -> None:
 716        if (
 717            node.is_number
 718            or (
 719                not isinstance(node, exp.SubqueryPredicate)
 720                and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
 721            )
 722            or (isinstance(node, exp.Column) and not node.type)
 723        ):
 724            node.replace(node.neq(0))
 725
 726    for node in expression.walk():
 727        ensure_bools(node, _ensure_bool)
 728
 729    return expression
 730
 731
 732def unqualify_columns(expression: exp.Expr) -> exp.Expr:
 733    for column in expression.find_all(exp.Column):
 734        # We only wanna pop off the table, db, catalog args
 735        for part in column.parts[:-1]:
 736            part.pop()
 737
 738    return expression
 739
 740
 741def remove_unique_constraints(expression: exp.Expr) -> exp.Expr:
 742    assert isinstance(expression, exp.Create)
 743    for constraint in expression.find_all(exp.UniqueColumnConstraint):
 744        if constraint.parent:
 745            constraint.parent.pop()
 746
 747    return expression
 748
 749
 750def ctas_with_tmp_tables_to_create_tmp_view(
 751    expression: exp.Expr,
 752    tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e,
 753) -> exp.Expr:
 754    assert isinstance(expression, exp.Create)
 755    properties: exp.Properties | None = expression.args.get("properties")
 756    temporary = any(
 757        isinstance(prop, exp.TemporaryProperty)
 758        for prop in (properties.expressions if properties is not None else [])
 759    )
 760
 761    # CTAS with temp tables map to CREATE TEMPORARY VIEW
 762    if expression.kind == "TABLE" and temporary:
 763        if expression.expression:
 764            return exp.Create(
 765                kind="TEMPORARY VIEW",
 766                this=expression.this,
 767                expression=expression.expression,
 768            )
 769        return tmp_storage_provider(expression)
 770
 771    return expression
 772
 773
 774def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
 775    """
 776    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
 777    PARTITIONED BY value is an array of column names, they are transformed into a schema.
 778    The corresponding columns are removed from the create statement.
 779    """
 780    assert isinstance(expression, exp.Create)
 781    schema = expression.this
 782    is_partitionable = expression.kind in {"TABLE", "VIEW"}
 783
 784    if isinstance(schema, exp.Schema) and is_partitionable:
 785        prop = expression.find(exp.PartitionedByProperty)
 786        if prop and prop.this and not isinstance(prop.this, exp.Schema):
 787            columns: set[str] = {v.name.upper() for v in prop.this.expressions}
 788            schema_exprs: list[exp.Expr] = schema.expressions
 789            partitions = [col for col in schema_exprs if col.name.upper() in columns]
 790            schema.set("expressions", [e for e in schema_exprs if e not in partitions])
 791            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
 792            expression.set("this", schema)
 793
 794    return expression
 795
 796
 797def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr:
 798    """
 799    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
 800
 801    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
 802    """
 803    assert isinstance(expression, exp.Create)
 804    prop = expression.find(exp.PartitionedByProperty)
 805    if (
 806        prop
 807        and prop.this
 808        and isinstance(prop.this, exp.Schema)
 809        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
 810    ):
 811        prop_this = exp.Tuple(
 812            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
 813        )
 814        schema: exp.Schema = expression.this
 815        for e in prop.this.expressions:
 816            schema.append("expressions", e)
 817        prop.set("this", prop_this)
 818
 819    return expression
 820
 821
 822def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr:
 823    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
 824    if isinstance(expression, exp.Struct):
 825        expression.set(
 826            "expressions",
 827            [
 828                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
 829                for e in expression.expressions
 830            ],
 831        )
 832
 833    return expression
 834
 835
 836def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
 837    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
 838
 839    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
 840
 841    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
 842
 843    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
 844
 845    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
 846
 847    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
 848
 849    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
 850
 851    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
 852
 853    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
 854
 855    -- example with WHERE
 856    SELECT d.department_name, sum(e.salary) as total_salary
 857    FROM departments d, employees e
 858    WHERE e.department_id(+) = d.department_id
 859    group by department_name
 860
 861    -- example of left correlation in select
 862    SELECT d.department_name, (
 863        SELECT SUM(e.salary)
 864            FROM employees e
 865            WHERE e.department_id(+) = d.department_id) AS total_salary
 866    FROM departments d;
 867
 868    -- example of left correlation in from
 869    SELECT d.department_name, t.total_salary
 870    FROM departments d, (
 871            SELECT SUM(e.salary) AS total_salary
 872            FROM employees e
 873            WHERE e.department_id(+) = d.department_id
 874        ) t
 875    """
 876
 877    from sqlglot.optimizer.scope import traverse_scope
 878    from sqlglot.optimizer.normalize import normalize, normalized
 879    from collections import defaultdict
 880
 881    # we go in reverse to check the main query for left correlation
 882    for scope in reversed(traverse_scope(expression)):
 883        query = scope.expression
 884
 885        where: exp.Expr | None = query.args.get("where")
 886        joins: list[exp.Join] = query.args.get("joins", [])
 887
 888        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
 889            continue
 890
 891        # knockout: we do not support left correlation (see point 2)
 892        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
 893
 894        # make sure we have AND of ORs to have clear join terms
 895        where = normalize(where.this)
 896        assert normalized(where), "Cannot normalize JOIN predicates"
 897        # dict of {name: list of join AND conditions}
 898        joins_ons: defaultdict[str, list[exp.Expr]] = defaultdict(list)
 899        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
 900            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
 901
 902            left_join_table = set(col.table for col in join_cols)
 903            if not left_join_table:
 904                continue
 905
 906            assert not (len(left_join_table) > 1), (
 907                "Cannot combine JOIN predicates from different tables"
 908            )
 909
 910            for col in join_cols:
 911                col.set("join_mark", False)
 912
 913            joins_ons[left_join_table.pop()].append(cond)
 914
 915        old_joins = {join.alias_or_name: join for join in joins}
 916        new_joins: dict[str, exp.Join] = {}
 917        query_from = query.args["from_"]
 918
 919        for table, predicates in joins_ons.items():
 920            join_what = old_joins.get(table, query_from).this.copy()
 921            new_joins[join_what.alias_or_name] = exp.Join(
 922                this=join_what, on=exp.and_(*predicates), kind="LEFT"
 923            )
 924
 925            for p in predicates:
 926                while isinstance(p.parent, exp.Paren):
 927                    p.parent.replace(p)
 928
 929                parent = p.parent
 930                p.pop()
 931                if isinstance(parent, exp.Binary):
 932                    left = parent.args.get("this")
 933                    parent.replace(parent.right if left is None else left)
 934                elif isinstance(parent, exp.Where):
 935                    parent.pop()
 936
 937        if query_from.alias_or_name in new_joins:
 938            only_old_joins: set[str] = old_joins.keys() - new_joins.keys()
 939            assert len(only_old_joins) >= 1, (
 940                "Cannot determine which table to use in the new FROM clause"
 941            )
 942
 943            new_from_name = list[str](only_old_joins)[0]
 944            query.set("from_", exp.From(this=old_joins[new_from_name].this))
 945
 946        if new_joins:
 947            for n, j in old_joins.items():  # preserve any other joins
 948                if n not in new_joins and n != query.args["from_"].name:
 949                    if not j.kind:
 950                        j.set("kind", "CROSS")
 951                    new_joins[n] = j
 952            query.set("joins", list(new_joins.values()))
 953
 954    return expression
 955
 956
 957def any_to_exists(expression: exp.Expr) -> exp.Expr:
 958    """
 959    Transform ANY operator to Spark's EXISTS
 960
 961    For example,
 962        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
 963        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
 964
 965    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
 966    transformation
 967    """
 968    if isinstance(expression, exp.Select):
 969        for any_expr in expression.find_all(exp.Any):
 970            this: exp.Expr = any_expr.this
 971            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
 972                continue
 973
 974            binop = any_expr.parent
 975            if isinstance(binop, exp.Binary):
 976                lambda_arg = exp.to_identifier("x")
 977                any_expr.replace(lambda_arg)
 978                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
 979                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
 980
 981    return expression
 982
 983
 984def eliminate_window_clause(expression: exp.Expr) -> exp.Expr:
 985    """Eliminates the `WINDOW` query clause by inling each named window."""
 986    windows: list[exp.Expr] | None = expression.args.get("windows")
 987    if isinstance(expression, exp.Select) and windows is not None:
 988        from sqlglot.optimizer.scope import find_all_in_scope
 989
 990        expression.set("windows", None)
 991
 992        window_expression: dict[str, exp.Expr] = {}
 993
 994        def _inline_inherited_window(window: exp.Expr) -> None:
 995            inherited_window = window_expression.get(window.alias.lower())
 996            if not inherited_window:
 997                return
 998
 999            window.set("alias", None)
1000            for key in ("partition_by", "order", "spec"):
1001                arg: exp.Expr | None = inherited_window.args.get(key)
1002                if arg is not None:
1003                    window.set(key, arg.copy())
1004
1005        for window in windows:
1006            _inline_inherited_window(window)
1007            window_expression[window.name.lower()] = window
1008
1009        for window in find_all_in_scope(expression, exp.Window):
1010            _inline_inherited_window(window)
1011
1012    return expression
1013
1014
1015def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr:
1016    """
1017    Inherit field names from the first struct in an array.
1018
1019    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1020
1021    Example:
1022        ARRAY[
1023          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1024          STRUCT('Bob', 92),                     -- inherits names
1025          STRUCT('Diana', 95)                    -- inherits names
1026        ]
1027
1028    This transformation makes the field names explicit on all structs by adding
1029    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1030
1031    Args:
1032        expression: The expression tree to transform
1033
1034    Returns:
1035        The modified expression with field names inherited in all structs
1036    """
1037    if (
1038        isinstance(expression, exp.Array)
1039        and expression.args.get("struct_name_inheritance")
1040        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1041        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1042    ):
1043        field_names: list[exp.Identifier] = [fld.this for fld in first_item.expressions]
1044
1045        # Apply field names to subsequent structs that don't have them
1046        for struct in expression.expressions[1:]:
1047            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1048                continue
1049
1050            # Convert unnamed expressions to PropertyEQ with inherited names
1051            new_expressions: list[exp.PropertyEQ] = []
1052            for i, expr in enumerate(struct.expressions):
1053                if not isinstance(expr, exp.PropertyEQ):
1054                    # Create PropertyEQ: field_name := value, preserving the type from the inner expression
1055                    property_eq = exp.PropertyEQ(
1056                        this=field_names[i].copy(),
1057                        expression=expr,
1058                    )
1059                    property_eq.type = expr.type
1060                    new_expressions.append(property_eq)
1061                else:
1062                    new_expressions.append(expr)
1063
1064            struct.set("expressions", new_expressions)
1065
1066    return expression
class SqlHandler(typing.Protocol):
16class SqlHandler(t.Protocol):
17    def __call__(self, expression: exp.Expr, *args: t.Any, **kwargs: t.Any) -> str: ...

Base class for protocol classes.

Protocol classes are defined as::

class Proto(Protocol):
    def meth(self) -> int:
        ...

Such classes are primarily used with static type checkers that recognize structural subtyping (static duck-typing), for example::

class C:
    def meth(self) -> int:
        return 0

def func(x: Proto) -> int:
    return x.meth()

func(C())  # Passes static type check

See PEP 544 for details. Protocol classes decorated with @typing.runtime_checkable act as simple-minded runtime protocols that check only the presence of given attributes, ignoring their type signatures. Protocol classes can be generic, they are defined as::

class GenProto(Protocol[T]):
    def meth(self) -> T:
        ...
SqlHandler(*args, **kwargs)
1431def _no_init_or_replace_init(self, *args, **kwargs):
1432    cls = type(self)
1433
1434    if cls._is_protocol:
1435        raise TypeError('Protocols cannot be instantiated')
1436
1437    # Already using a custom `__init__`. No need to calculate correct
1438    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1439    if cls.__init__ is not _no_init_or_replace_init:
1440        return
1441
1442    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1443    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1444    # searches for a proper new `__init__` in the MRO. The new `__init__`
1445    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1446    # instantiation of the protocol subclass will thus use the new
1447    # `__init__` and no longer call `_no_init_or_replace_init`.
1448    for base in cls.__mro__:
1449        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1450        if init is not _no_init_or_replace_init:
1451            cls.__init__ = init
1452            break
1453    else:
1454        # should not happen
1455        cls.__init__ = object.__init__
1456
1457    cls.__init__(self, *args, **kwargs)
def preprocess( transforms: list[typing.Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr]], generator: Optional[Callable[[sqlglot.generator.Generator, sqlglot.expressions.core.Expr], str]] = None) -> Callable[[sqlglot.generator.Generator, sqlglot.expressions.core.Expr], str]:
20def preprocess(
21    transforms: list[t.Callable[[exp.Expr], exp.Expr]],
22    generator: t.Callable[[Generator, exp.Expr], str] | None = None,
23) -> t.Callable[[Generator, exp.Expr], str]:
24    """
25    Creates a new transform by chaining a sequence of transformations and converts the resulting
26    expression to SQL, using either the "_sql" method corresponding to the resulting expression,
27    or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below).
28
29    Args:
30        transforms: sequence of transform functions. These will be called in order.
31
32    Returns:
33        Function that can be used as a generator transform.
34    """
35
36    def _to_sql(self: Generator, expression: exp.Expr) -> str:
37        expression_type = type(expression)
38
39        try:
40            expression = transforms[0](expression)
41            for transform in transforms[1:]:
42                expression = transform(expression)
43        except UnsupportedError as unsupported_error:
44            self.unsupported(str(unsupported_error))
45
46        if generator:
47            return generator(self, expression)
48
49        _sql_handler: SqlHandler | None = getattr(self, expression.key + "_sql", None)
50        if _sql_handler:
51            return _sql_handler(expression)
52
53        transforms_handler = self.TRANSFORMS.get(type(expression))
54        if transforms_handler:
55            if expression_type is type(expression):
56                if isinstance(expression, exp.Func):
57                    return self.function_fallback_sql(expression)
58
59                # Ensures we don't enter an infinite loop. This can happen when the original expression
60                # has the same type as the final expression and there's no _sql method available for it,
61                # because then it'd re-enter _to_sql.
62                raise ValueError(
63                    f"Expr type {expression.__class__.__name__} requires a _sql method in order to be transformed."
64                )
65
66            return transforms_handler(self, expression)
67
68        raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.")
69
70    return _to_sql

Creates a new transform by chaining a sequence of transformations and converts the resulting expression to SQL, using either the "_sql" method corresponding to the resulting expression, or the appropriate Generator.TRANSFORMS function (when applicable -- see below).

Arguments:
  • transforms: sequence of transform functions. These will be called in order.
Returns:

Function that can be used as a generator transform.

def unnest_generate_date_array_using_recursive_cte( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
 73def unnest_generate_date_array_using_recursive_cte(expression: exp.Expr) -> exp.Expr:
 74    if isinstance(expression, exp.Select):
 75        count = 0
 76        recursive_ctes: list[exp.Expr] = []
 77
 78        for unnest in expression.find_all(exp.Unnest):
 79            if (
 80                not isinstance(unnest.parent, (exp.From, exp.Join))
 81                or len(unnest.expressions) != 1
 82                or not isinstance(unnest.expressions[0], exp.GenerateDateArray)
 83            ):
 84                continue
 85
 86            generate_date_array = unnest.expressions[0]
 87            start: exp.Expr | None = generate_date_array.args.get("start")
 88            end: exp.Expr | None = generate_date_array.args.get("end")
 89            step: exp.Expr | None = generate_date_array.args.get("step")
 90
 91            if not start or not end or not isinstance(step, exp.Interval):
 92                continue
 93
 94            alias: exp.TableAlias | None = unnest.args.get("alias")
 95            column_name: str = (
 96                alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value"
 97            )
 98
 99            start = exp.cast(start, "date")
100            date_add = exp.func(
101                "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit")
102            )
103            cast_date_add = exp.cast(date_add, "date")
104
105            cte_name = "_generated_dates" + (f"_{count}" if count else "")
106
107            base_query = exp.select(start.as_(column_name))
108            recursive_query = (
109                exp.select(cast_date_add)
110                .from_(cte_name)
111                .where(cast_date_add <= exp.cast(end, "date"))
112            )
113            cte_query = base_query.union(recursive_query, distinct=False)
114
115            generate_dates_query = exp.select(column_name).from_(cte_name)
116            unnest.replace(generate_dates_query.subquery(cte_name))
117
118            recursive_ctes.append(
119                exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name])
120            )
121            count += 1
122
123        if recursive_ctes:
124            with_expression: exp.With = expression.args.get("with_") or exp.With()
125            with_expression.set("recursive", True)
126            with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions])
127            expression.set("with_", with_expression)
128
129    return expression
def unnest_generate_series( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
132def unnest_generate_series(expression: exp.Expr) -> exp.Expr:
133    """Unnests GENERATE_SERIES or SEQUENCE table references."""
134    this = expression.this
135    if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries):
136        unnest = exp.Unnest(expressions=[this])
137        if expression.alias:
138            return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False)
139
140        return unnest
141
142    return expression

Unnests GENERATE_SERIES or SEQUENCE table references.

def eliminate_distinct_on( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
145def eliminate_distinct_on(expression: exp.Expr) -> exp.Expr:
146    """
147    Convert SELECT DISTINCT ON statements to a subquery with a window function.
148
149    This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
150
151    Args:
152        expression: the expression that will be transformed.
153
154    Returns:
155        The transformed expression.
156    """
157    if (
158        isinstance(expression, exp.Select)
159        and expression.args.get("distinct")
160        and isinstance(expression.args["distinct"].args.get("on"), exp.Tuple)
161    ):
162        row_number_window_alias = find_new_name(expression.named_selects, "_row_number")
163
164        distinct_cols = expression.args["distinct"].pop().args["on"].expressions
165        window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols)
166
167        order: exp.Order | None = expression.args.get("order")
168        if order:
169            window.set("order", order.pop())
170        else:
171            window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols]))
172
173        expression.select(exp.alias_(window, row_number_window_alias), copy=False)
174
175        # We add aliases to the projections so that we can safely reference them in the outer query
176        new_selects: list[exp.Expr] = []
177        taken_names = {row_number_window_alias}
178        for select in expression.selects[:-1]:
179            if select.is_star:
180                new_selects = [exp.Star()]
181                break
182
183            if not isinstance(select, exp.Alias):
184                alias = find_new_name(taken_names, select.output_name or "_col")
185                quoted: bool | None = (
186                    select.this.args.get("quoted") if isinstance(select, exp.Column) else None
187                )
188                select = select.replace(exp.alias_(select, alias, quoted=quoted))
189
190            taken_names.add(select.output_name)
191            new_selects.append(select.args["alias"])
192
193        return (
194            exp.select(*new_selects, copy=False)
195            .from_(expression.subquery("_t", copy=False), copy=False)
196            .where(exp.column(row_number_window_alias).eq(1), copy=False)
197        )
198
199    return expression

Convert SELECT DISTINCT ON statements to a subquery with a window function.

This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.

Arguments:
  • expression: the expression that will be transformed.
Returns:

The transformed expression.

def eliminate_qualify( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
202def eliminate_qualify(expression: exp.Expr) -> exp.Expr:
203    """
204    Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
205
206    The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY:
207    https://docs.snowflake.com/en/sql-reference/constructs/qualify
208
209    Some dialects don't support window functions in the WHERE clause, so we need to include them as
210    projections in the subquery, in order to refer to them in the outer filter using aliases. Also,
211    if a column is referenced in the QUALIFY clause but is not selected, we need to include it too,
212    otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a
213    newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the
214    corresponding expression to avoid creating invalid column references.
215    """
216    if isinstance(expression, exp.Select) and expression.args.get("qualify"):
217        taken = set(expression.named_selects)
218        for select in expression.selects:
219            if not select.alias_or_name:
220                alias = find_new_name(taken, "_c")
221                select.replace(exp.alias_(select, alias))
222                taken.add(alias)
223
224        def _select_alias_or_name(select: exp.Expr) -> str | exp.Column:
225            alias_or_name = select.alias_or_name
226            identifier = select.args.get("alias") or select.this
227            if isinstance(identifier, exp.Identifier):
228                return exp.column(alias_or_name, quoted=identifier.args.get("quoted"))
229            return alias_or_name
230
231        outer_selects = exp.select(*map(_select_alias_or_name, expression.selects))
232        qualify_filters: exp.Expr = expression.args["qualify"].pop().this
233        expression_by_alias: dict[str, exp.Expr] = {
234            select.alias: select.this
235            for select in expression.selects
236            if isinstance(select, exp.Alias)
237        }
238
239        select_candidates = (exp.Window,) if expression.is_star else (exp.Window, exp.Column)
240        for select_candidate in list(qualify_filters.find_all(*select_candidates)):
241            if isinstance(select_candidate, exp.Window):
242                if expression_by_alias:
243                    for column in select_candidate.find_all(exp.Column):
244                        expr = expression_by_alias.get(column.name)
245                        if expr:
246                            column.replace(expr)
247
248                alias = find_new_name(expression.named_selects, "_w")
249                expression.select(exp.alias_(select_candidate, alias), copy=False)
250                column = exp.column(alias)
251
252                if isinstance(select_candidate.parent, exp.Qualify):
253                    qualify_filters = column
254                else:
255                    select_candidate.replace(column)
256            elif select_candidate.name not in expression.named_selects:
257                expression.select(select_candidate.copy(), copy=False)
258
259        return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where(
260            qualify_filters, copy=False
261        )
262
263    return expression

Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.

The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify

Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.

def remove_precision_parameterized_types( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
266def remove_precision_parameterized_types(expression: exp.Expr) -> exp.Expr:
267    """
268    Some dialects only allow the precision for parameterized types to be defined in the DDL and not in
269    other expressions. This transforms removes the precision from parameterized types in expressions.
270    """
271    for node in expression.find_all(exp.DataType):
272        node.set(
273            "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)]
274        )
275
276    return expression

Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.

def unqualify_unnest( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
279def unqualify_unnest(expression: exp.Expr) -> exp.Expr:
280    """Remove references to unnest table aliases, added by the optimizer's qualify_columns step."""
281    from sqlglot.optimizer.scope import find_all_in_scope
282
283    if isinstance(expression, exp.Select):
284        unnest_aliases = {
285            unnest.alias
286            for unnest in find_all_in_scope(expression, exp.Unnest)
287            if isinstance(unnest.parent, (exp.From, exp.Join))
288        }
289        if unnest_aliases:
290            for column in expression.find_all(exp.Column):
291                leftmost_part = column.parts[0]
292                if leftmost_part.arg_key != "this" and leftmost_part.this in unnest_aliases:
293                    leftmost_part.pop()
294
295    return expression

Remove references to unnest table aliases, added by the optimizer's qualify_columns step.

def unnest_to_explode( expression: sqlglot.expressions.core.Expr, unnest_using_arrays_zip: bool = True) -> sqlglot.expressions.core.Expr:
298def unnest_to_explode(
299    expression: exp.Expr,
300    unnest_using_arrays_zip: bool = True,
301) -> exp.Expr:
302    """Convert cross join unnest into lateral view explode."""
303
304    def _unnest_zip_exprs(
305        u: exp.Unnest, unnest_exprs: list[exp.Expr], has_multi_expr: bool
306    ) -> list[exp.Expr]:
307        if has_multi_expr:
308            if not unnest_using_arrays_zip:
309                raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays")
310
311            # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions
312            zip_exprs: list[exp.Expr] = [exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs)]
313            u.set("expressions", zip_exprs)
314            return zip_exprs
315        return unnest_exprs
316
317    def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> type[exp.Func]:
318        if u.args.get("offset"):
319            return exp.Posexplode
320        return exp.Inline if has_multi_expr else exp.Explode
321
322    if isinstance(expression, exp.Select):
323        from_ = expression.args.get("from_")
324
325        if from_ and isinstance(from_.this, exp.Unnest):
326            unnest: exp.Unnest = from_.this
327            alias: exp.TableAlias | None = unnest.args.get("alias")
328            exprs: list[exp.Expr] = unnest.expressions
329            has_multi_expr = len(exprs) > 1
330            this, *_ = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
331
332            columns: list[exp.Identifier] = alias.columns if alias else []
333            offset: exp.Expr | None = unnest.args.get("offset")
334            if offset:
335                columns.insert(
336                    0, offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos")
337                )
338
339            unnest.replace(
340                exp.Table(
341                    this=_udtf_type(unnest, has_multi_expr)(this=this),
342                    alias=exp.TableAlias(this=alias.this, columns=columns) if alias else None,
343                )
344            )
345
346        joins: list[exp.Join] = expression.args.get("joins") or []
347        for join in list(joins):
348            join_expr = join.this
349
350            is_lateral = isinstance(join_expr, exp.Lateral)
351
352            unnest = join_expr.this if is_lateral else join_expr
353
354            if isinstance(unnest, exp.Unnest):
355                if is_lateral:
356                    alias = join_expr.args.get("alias")
357                else:
358                    alias = unnest.args.get("alias")
359
360                if alias is None:
361                    raise UnsupportedError(
362                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires an alias"
363                    )
364
365                exprs = unnest.expressions
366                # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here
367                has_multi_expr = len(exprs) > 1
368                exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr)
369
370                joins.remove(join)
371
372                alias_cols: list[exp.Identifier] = alias.columns
373
374                # # Handle UNNEST to LATERAL VIEW EXPLODE: Exception is raised when there are 0 or > 2 aliases
375                # Spark LATERAL VIEW EXPLODE requires single alias for array/struct and two for Map type column unlike unnest in trino/presto which can take an arbitrary amount.
376                # Refs: https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-lateral-view.html
377
378                if not has_multi_expr and len(alias_cols) not in (1, 2):
379                    raise UnsupportedError(
380                        "CROSS JOIN UNNEST to LATERAL VIEW EXPLODE transformation requires explicit column aliases"
381                    )
382
383                offset = unnest.args.get("offset")
384                if offset:
385                    alias_cols.insert(
386                        0,
387                        offset if isinstance(offset, exp.Identifier) else exp.to_identifier("pos"),
388                    )
389
390                for e, column in zip(exprs, alias_cols):
391                    expression.append(
392                        "laterals",
393                        exp.Lateral(
394                            this=_udtf_type(unnest, has_multi_expr)(this=e),
395                            view=True,
396                            alias=exp.TableAlias(this=alias.this, columns=alias_cols),
397                        ),
398                    )
399
400    return expression

Convert cross join unnest into lateral view explode.

def explode_projection_to_unnest( index_offset: int = 0) -> Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr]:
403def explode_projection_to_unnest(
404    index_offset: int = 0,
405) -> t.Callable[[exp.Expr], exp.Expr]:
406    """Convert explode/posexplode projections into unnests."""
407
408    def _explode_projection_to_unnest(expression: exp.Expr) -> exp.Expr:
409        if isinstance(expression, exp.Select):
410            from sqlglot.optimizer.scope import Scope
411
412            taken_select_names = set(expression.named_selects)
413            taken_source_names = {name for name, _ in Scope(expression).references}
414
415            def new_name(names: set[str], name: str) -> str:
416                name = find_new_name(names, name)
417                names.add(name)
418                return name
419
420            arrays: list[exp.Condition] = []
421            series_alias = new_name(taken_select_names, "pos")
422            series = exp.alias_(
423                exp.Unnest(
424                    expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))]
425                ),
426                new_name(taken_source_names, "_u"),
427                table=[series_alias],
428            )
429
430            # we use list here because expression.selects is mutated inside the loop
431            for select in list(expression.selects):
432                explode = select.find(exp.Explode)
433
434                if explode:
435                    pos_alias: t.Any = ""
436                    explode_alias: t.Any = ""
437
438                    if isinstance(select, exp.Alias):
439                        explode_alias = select.args["alias"]
440                        alias: exp.Expr = select
441                    elif isinstance(select, exp.Aliases):
442                        pos_alias = select.aliases[0]
443                        explode_alias = select.aliases[1]
444                        alias = select.replace(exp.alias_(select.this, "", copy=False))
445                    else:
446                        alias = select.replace(exp.alias_(select, ""))
447                        explode = alias.find(exp.Explode)
448                        assert explode
449
450                    is_posexplode = isinstance(explode, exp.Posexplode)
451                    explode_arg = explode.this
452
453                    if isinstance(explode, exp.ExplodeOuter):
454                        bracket = explode_arg[0]
455                        bracket.set("safe", True)
456                        bracket.set("offset", True)
457                        explode_arg = exp.func(
458                            "IF",
459                            exp.func(
460                                "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array())
461                            ).eq(0),
462                            exp.array(bracket, copy=False),
463                            explode_arg,
464                        )
465
466                    # This ensures that we won't use [POS]EXPLODE's argument as a new selection
467                    if isinstance(explode_arg, exp.Column):
468                        taken_select_names.add(explode_arg.output_name)
469
470                    unnest_source_alias = new_name(taken_source_names, "_u")
471
472                    if not explode_alias:
473                        explode_alias = new_name(taken_select_names, "col")
474
475                        if is_posexplode:
476                            pos_alias = new_name(taken_select_names, "pos")
477
478                    if not pos_alias:
479                        pos_alias = new_name(taken_select_names, "pos")
480
481                    alias.set("alias", exp.to_identifier(explode_alias))
482
483                    series_table_alias = series.args["alias"].this
484                    column = exp.If(
485                        this=exp.column(series_alias, table=series_table_alias).eq(
486                            exp.column(pos_alias, table=unnest_source_alias)
487                        ),
488                        true=exp.column(explode_alias, table=unnest_source_alias),
489                    )
490
491                    explode.replace(column)
492
493                    if is_posexplode:
494                        expressions = expression.expressions
495                        expressions.insert(
496                            expressions.index(alias) + 1,
497                            exp.If(
498                                this=exp.column(series_alias, table=series_table_alias).eq(
499                                    exp.column(pos_alias, table=unnest_source_alias)
500                                ),
501                                true=exp.column(pos_alias, table=unnest_source_alias),
502                            ).as_(pos_alias),
503                        )
504                        expression.set("expressions", expressions)
505
506                    if not arrays:
507                        if expression.args.get("from_"):
508                            expression.join(series, copy=False, join_type="CROSS")
509                        else:
510                            expression.from_(series, copy=False)
511
512                    size: exp.Condition = exp.ArraySize(this=explode_arg.copy())
513                    arrays.append(size)
514
515                    # trino doesn't support left join unnest with on conditions
516                    # if it did, this would be much simpler
517                    expression.join(
518                        exp.alias_(
519                            exp.Unnest(
520                                expressions=[explode_arg.copy()],
521                                offset=exp.to_identifier(pos_alias),
522                            ),
523                            unnest_source_alias,
524                            table=[explode_alias],
525                        ),
526                        join_type="CROSS",
527                        copy=False,
528                    )
529
530                    if index_offset != 1:
531                        size = size - 1
532
533                    expression.where(
534                        exp.column(series_alias, table=series_table_alias)
535                        .eq(exp.column(pos_alias, table=unnest_source_alias))
536                        .or_(
537                            (exp.column(series_alias, table=series_table_alias) > size).and_(
538                                exp.column(pos_alias, table=unnest_source_alias).eq(size)
539                            )
540                        ),
541                        copy=False,
542                    )
543
544            if arrays:
545                end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:])
546
547                if index_offset != 1:
548                    end = end - (1 - index_offset)
549                series.expressions[0].set("end", end)
550
551        return expression
552
553    return _explode_projection_to_unnest

Convert explode/posexplode projections into unnests.

def add_within_group_for_percentiles( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
556def add_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
557    """Transforms percentiles by adding a WITHIN GROUP clause to them."""
558    if (
559        isinstance(expression, exp.PERCENTILES)
560        and not isinstance(expression.parent, exp.WithinGroup)
561        and expression.expression
562    ):
563        column = expression.this.pop()
564        expression.set("this", expression.expression.pop())
565        order = exp.Order(expressions=[exp.Ordered(this=column)])
566        expression = exp.WithinGroup(this=expression, expression=order)
567
568    return expression

Transforms percentiles by adding a WITHIN GROUP clause to them.

def remove_within_group_for_percentiles( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
571def remove_within_group_for_percentiles(expression: exp.Expr) -> exp.Expr:
572    """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause."""
573    if (
574        isinstance(expression, exp.WithinGroup)
575        and isinstance(expression.this, exp.PERCENTILES)
576        and isinstance(expression.expression, exp.Order)
577    ):
578        quantile = expression.this.this
579        input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this
580        return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile))
581
582    return expression

Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.

def add_recursive_cte_column_names( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
585def add_recursive_cte_column_names(expression: exp.Expr) -> exp.Expr:
586    """Uses projection output names in recursive CTE definitions to define the CTEs' columns."""
587    if isinstance(expression, exp.With) and expression.recursive:
588        next_name = name_sequence("_c_")
589
590        for cte in expression.expressions:
591            if not cte.args["alias"].columns:
592                query = cte.this
593                if isinstance(query, exp.SetOperation):
594                    query = query.this
595
596                cte.args["alias"].set(
597                    "columns",
598                    [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects],
599                )
600
601    return expression

Uses projection output names in recursive CTE definitions to define the CTEs' columns.

def epoch_cast_to_ts( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
604def epoch_cast_to_ts(expression: exp.Expr) -> exp.Expr:
605    """Replace 'epoch' in casts by the equivalent date literal."""
606    if (
607        isinstance(expression, (exp.Cast, exp.TryCast))
608        and expression.name.lower() == "epoch"
609        and expression.to.this in exp.DataType.TEMPORAL_TYPES
610    ):
611        expression.this.replace(exp.Literal.string("1970-01-01 00:00:00"))
612
613    return expression

Replace 'epoch' in casts by the equivalent date literal.

def eliminate_semi_and_anti_joins( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
616def eliminate_semi_and_anti_joins(expression: exp.Expr) -> exp.Expr:
617    """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead."""
618    if isinstance(expression, exp.Select):
619        for join in list[exp.Join](expression.args.get("joins") or []):
620            on: exp.Expr | None = join.args.get("on")
621            if on and join.kind in ("SEMI", "ANTI"):
622                subquery = exp.select("1").from_(join.this).where(on)
623                exists: exp.Exists | exp.Not = exp.Exists(this=subquery)
624                if join.kind == "ANTI":
625                    exists = exists.not_(copy=False)
626
627                join.pop()
628                expression.where(exists, copy=False)
629
630    return expression

Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.

def eliminate_full_outer_join( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
633def eliminate_full_outer_join(expression: exp.Expr) -> exp.Expr:
634    """
635    Converts a query with a FULL OUTER join to a union of identical queries that
636    use LEFT/RIGHT OUTER joins instead. This transformation currently only works
637    for queries that have a single FULL OUTER join.
638    """
639    if isinstance(expression, exp.Select):
640        full_outer_joins: list[tuple[int, exp.Join]] = [
641            (index, join)
642            for index, join in enumerate[exp.Join](expression.args.get("joins") or [])
643            if join.side == "FULL"
644        ]
645
646        if len(full_outer_joins) == 1:
647            expression_copy = expression.copy()
648            expression.set("limit", None)
649            index, full_outer_join = full_outer_joins[0]
650
651            tables = (expression.args["from_"].alias_or_name, full_outer_join.alias_or_name)
652            join_conditions = full_outer_join.args.get("on") or exp.and_(
653                *[
654                    exp.column(col, tables[0]).eq(exp.column(col, tables[1]))
655                    for col in t.cast(list[exp.Identifier], full_outer_join.args.get("using"))
656                ]
657            )
658
659            full_outer_join.set("side", "left")
660            anti_join_clause = (
661                exp.select("1").from_(expression.args["from_"]).where(join_conditions)
662            )
663            expression_copy.args["joins"][index].set("side", "right")
664            expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_())
665            expression_copy.set("with_", None)  # remove CTEs from RIGHT side
666            expression.set("order", None)  # remove order by from LEFT side
667
668            return exp.union(expression, expression_copy, copy=False, distinct=False)
669
670    return expression

Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.

def move_ctes_to_top_level(expression: ~E) -> ~E:
673def move_ctes_to_top_level(expression: E) -> E:
674    """
675    Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be
676    defined at the top-level, so for example queries like:
677
678        SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
679
680    are invalid in those dialects. This transformation can be used to ensure all CTEs are
681    moved to the top level so that the final SQL code is valid from a syntax standpoint.
682
683    TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
684    """
685    top_level_with: exp.With | None = expression.args.get("with_")
686    for inner_with in expression.find_all(exp.With):
687        if inner_with.parent is expression:
688            continue
689
690        if not top_level_with:
691            top_level_with = inner_with.pop()
692            expression.set("with_", top_level_with)
693        else:
694            if inner_with.recursive:
695                top_level_with.set("recursive", True)
696
697            parent_cte = inner_with.find_ancestor(exp.CTE)
698            inner_with.pop()
699
700            if parent_cte:
701                i = top_level_with.expressions.index(parent_cte)
702                top_level_with.expressions[i:i] = inner_with.expressions
703                top_level_with.set("expressions", top_level_with.expressions)
704            else:
705                top_level_with.set(
706                    "expressions", top_level_with.expressions + inner_with.expressions
707                )
708
709    return expression

Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:

SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq

are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.

TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).

def ensure_bools( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
712def ensure_bools(expression: exp.Expr) -> exp.Expr:
713    """Converts numeric values used in conditions into explicit boolean expressions."""
714    from sqlglot.optimizer.canonicalize import ensure_bools
715
716    def _ensure_bool(node: exp.Expr) -> None:
717        if (
718            node.is_number
719            or (
720                not isinstance(node, exp.SubqueryPredicate)
721                and node.is_type(exp.DType.UNKNOWN, *exp.DataType.NUMERIC_TYPES)
722            )
723            or (isinstance(node, exp.Column) and not node.type)
724        ):
725            node.replace(node.neq(0))
726
727    for node in expression.walk():
728        ensure_bools(node, _ensure_bool)
729
730    return expression

Converts numeric values used in conditions into explicit boolean expressions.

def unqualify_columns( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
733def unqualify_columns(expression: exp.Expr) -> exp.Expr:
734    for column in expression.find_all(exp.Column):
735        # We only wanna pop off the table, db, catalog args
736        for part in column.parts[:-1]:
737            part.pop()
738
739    return expression
def remove_unique_constraints( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
742def remove_unique_constraints(expression: exp.Expr) -> exp.Expr:
743    assert isinstance(expression, exp.Create)
744    for constraint in expression.find_all(exp.UniqueColumnConstraint):
745        if constraint.parent:
746            constraint.parent.pop()
747
748    return expression
def ctas_with_tmp_tables_to_create_tmp_view( expression: sqlglot.expressions.core.Expr, tmp_storage_provider: Callable[[sqlglot.expressions.core.Expr], sqlglot.expressions.core.Expr] = <function <lambda>>) -> sqlglot.expressions.core.Expr:
751def ctas_with_tmp_tables_to_create_tmp_view(
752    expression: exp.Expr,
753    tmp_storage_provider: t.Callable[[exp.Expr], exp.Expr] = lambda e: e,
754) -> exp.Expr:
755    assert isinstance(expression, exp.Create)
756    properties: exp.Properties | None = expression.args.get("properties")
757    temporary = any(
758        isinstance(prop, exp.TemporaryProperty)
759        for prop in (properties.expressions if properties is not None else [])
760    )
761
762    # CTAS with temp tables map to CREATE TEMPORARY VIEW
763    if expression.kind == "TABLE" and temporary:
764        if expression.expression:
765            return exp.Create(
766                kind="TEMPORARY VIEW",
767                this=expression.this,
768                expression=expression.expression,
769            )
770        return tmp_storage_provider(expression)
771
772    return expression
def move_schema_columns_to_partitioned_by( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
775def move_schema_columns_to_partitioned_by(expression: exp.Expr) -> exp.Expr:
776    """
777    In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the
778    PARTITIONED BY value is an array of column names, they are transformed into a schema.
779    The corresponding columns are removed from the create statement.
780    """
781    assert isinstance(expression, exp.Create)
782    schema = expression.this
783    is_partitionable = expression.kind in {"TABLE", "VIEW"}
784
785    if isinstance(schema, exp.Schema) and is_partitionable:
786        prop = expression.find(exp.PartitionedByProperty)
787        if prop and prop.this and not isinstance(prop.this, exp.Schema):
788            columns: set[str] = {v.name.upper() for v in prop.this.expressions}
789            schema_exprs: list[exp.Expr] = schema.expressions
790            partitions = [col for col in schema_exprs if col.name.upper() in columns]
791            schema.set("expressions", [e for e in schema_exprs if e not in partitions])
792            prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions)))
793            expression.set("this", schema)
794
795    return expression

In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.

def move_partitioned_by_to_schema_columns( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
798def move_partitioned_by_to_schema_columns(expression: exp.Expr) -> exp.Expr:
799    """
800    Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
801
802    Currently, SQLGlot uses the DATASOURCE format for Spark 3.
803    """
804    assert isinstance(expression, exp.Create)
805    prop = expression.find(exp.PartitionedByProperty)
806    if (
807        prop
808        and prop.this
809        and isinstance(prop.this, exp.Schema)
810        and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions)
811    ):
812        prop_this = exp.Tuple(
813            expressions=[exp.to_identifier(e.this) for e in prop.this.expressions]
814        )
815        schema: exp.Schema = expression.this
816        for e in prop.this.expressions:
817            schema.append("expressions", e)
818        prop.set("this", prop_this)
819
820    return expression

Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.

Currently, SQLGlot uses the DATASOURCE format for Spark 3.

def struct_kv_to_alias( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
823def struct_kv_to_alias(expression: exp.Expr) -> exp.Expr:
824    """Converts struct arguments to aliases, e.g. STRUCT(1 AS y)."""
825    if isinstance(expression, exp.Struct):
826        expression.set(
827            "expressions",
828            [
829                exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e
830                for e in expression.expressions
831            ],
832        )
833
834    return expression

Converts struct arguments to aliases, e.g. STRUCT(1 AS y).

def eliminate_join_marks( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
837def eliminate_join_marks(expression: exp.Expr) -> exp.Expr:
838    """https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178
839
840    1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.
841
842    2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.
843
844    The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.
845
846    You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.
847
848    The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.
849
850    A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.
851
852    A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.
853
854    A WHERE condition cannot compare any column marked with the (+) operator with a subquery.
855
856    -- example with WHERE
857    SELECT d.department_name, sum(e.salary) as total_salary
858    FROM departments d, employees e
859    WHERE e.department_id(+) = d.department_id
860    group by department_name
861
862    -- example of left correlation in select
863    SELECT d.department_name, (
864        SELECT SUM(e.salary)
865            FROM employees e
866            WHERE e.department_id(+) = d.department_id) AS total_salary
867    FROM departments d;
868
869    -- example of left correlation in from
870    SELECT d.department_name, t.total_salary
871    FROM departments d, (
872            SELECT SUM(e.salary) AS total_salary
873            FROM employees e
874            WHERE e.department_id(+) = d.department_id
875        ) t
876    """
877
878    from sqlglot.optimizer.scope import traverse_scope
879    from sqlglot.optimizer.normalize import normalize, normalized
880    from collections import defaultdict
881
882    # we go in reverse to check the main query for left correlation
883    for scope in reversed(traverse_scope(expression)):
884        query = scope.expression
885
886        where: exp.Expr | None = query.args.get("where")
887        joins: list[exp.Join] = query.args.get("joins", [])
888
889        if not where or not any(c.args.get("join_mark") for c in where.find_all(exp.Column)):
890            continue
891
892        # knockout: we do not support left correlation (see point 2)
893        assert not scope.is_correlated_subquery, "Correlated queries are not supported"
894
895        # make sure we have AND of ORs to have clear join terms
896        where = normalize(where.this)
897        assert normalized(where), "Cannot normalize JOIN predicates"
898        # dict of {name: list of join AND conditions}
899        joins_ons: defaultdict[str, list[exp.Expr]] = defaultdict(list)
900        for cond in [where] if not isinstance(where, exp.And) else where.flatten():
901            join_cols = [col for col in cond.find_all(exp.Column) if col.args.get("join_mark")]
902
903            left_join_table = set(col.table for col in join_cols)
904            if not left_join_table:
905                continue
906
907            assert not (len(left_join_table) > 1), (
908                "Cannot combine JOIN predicates from different tables"
909            )
910
911            for col in join_cols:
912                col.set("join_mark", False)
913
914            joins_ons[left_join_table.pop()].append(cond)
915
916        old_joins = {join.alias_or_name: join for join in joins}
917        new_joins: dict[str, exp.Join] = {}
918        query_from = query.args["from_"]
919
920        for table, predicates in joins_ons.items():
921            join_what = old_joins.get(table, query_from).this.copy()
922            new_joins[join_what.alias_or_name] = exp.Join(
923                this=join_what, on=exp.and_(*predicates), kind="LEFT"
924            )
925
926            for p in predicates:
927                while isinstance(p.parent, exp.Paren):
928                    p.parent.replace(p)
929
930                parent = p.parent
931                p.pop()
932                if isinstance(parent, exp.Binary):
933                    left = parent.args.get("this")
934                    parent.replace(parent.right if left is None else left)
935                elif isinstance(parent, exp.Where):
936                    parent.pop()
937
938        if query_from.alias_or_name in new_joins:
939            only_old_joins: set[str] = old_joins.keys() - new_joins.keys()
940            assert len(only_old_joins) >= 1, (
941                "Cannot determine which table to use in the new FROM clause"
942            )
943
944            new_from_name = list[str](only_old_joins)[0]
945            query.set("from_", exp.From(this=old_joins[new_from_name].this))
946
947        if new_joins:
948            for n, j in old_joins.items():  # preserve any other joins
949                if n not in new_joins and n != query.args["from_"].name:
950                    if not j.kind:
951                        j.set("kind", "CROSS")
952                    new_joins[n] = j
953            query.set("joins", list(new_joins.values()))
954
955    return expression

https://docs.oracle.com/cd/B19306_01/server.102/b14200/queries006.htm#sthref3178

  1. You cannot specify the (+) operator in a query block that also contains FROM clause join syntax.

  2. The (+) operator can appear only in the WHERE clause or, in the context of left-correlation (that is, when specifying the TABLE clause) in the FROM clause, and can be applied only to a column of a table or view.

The (+) operator does not produce an outer join if you specify one table in the outer query and the other table in an inner query.

You cannot use the (+) operator to outer-join a table to itself, although self joins are valid.

The (+) operator can be applied only to a column, not to an arbitrary expression. However, an arbitrary expression can contain one or more columns marked with the (+) operator.

A WHERE condition containing the (+) operator cannot be combined with another condition using the OR logical operator.

A WHERE condition cannot use the IN comparison condition to compare a column marked with the (+) operator with an expression.

A WHERE condition cannot compare any column marked with the (+) operator with a subquery.

-- example with WHERE SELECT d.department_name, sum(e.salary) as total_salary FROM departments d, employees e WHERE e.department_id(+) = d.department_id group by department_name

-- example of left correlation in select SELECT d.department_name, ( SELECT SUM(e.salary) FROM employees e WHERE e.department_id(+) = d.department_id) AS total_salary FROM departments d;

-- example of left correlation in from SELECT d.department_name, t.total_salary FROM departments d, ( SELECT SUM(e.salary) AS total_salary FROM employees e WHERE e.department_id(+) = d.department_id ) t

def any_to_exists( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
958def any_to_exists(expression: exp.Expr) -> exp.Expr:
959    """
960    Transform ANY operator to Spark's EXISTS
961
962    For example,
963        - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col)
964        - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)
965
966    Both ANY and EXISTS accept queries but currently only array expressions are supported for this
967    transformation
968    """
969    if isinstance(expression, exp.Select):
970        for any_expr in expression.find_all(exp.Any):
971            this: exp.Expr = any_expr.this
972            if isinstance(this, exp.Query) or isinstance(any_expr.parent, (exp.Like, exp.ILike)):
973                continue
974
975            binop = any_expr.parent
976            if isinstance(binop, exp.Binary):
977                lambda_arg = exp.to_identifier("x")
978                any_expr.replace(lambda_arg)
979                lambda_expr = exp.Lambda(this=binop.copy(), expressions=[lambda_arg])
980                binop.replace(exp.Exists(this=this.unnest(), expression=lambda_expr))
981
982    return expression

Transform ANY operator to Spark's EXISTS

For example, - Postgres: SELECT * FROM tbl WHERE 5 > ANY(tbl.col) - Spark: SELECT * FROM tbl WHERE EXISTS(tbl.col, x -> x < 5)

Both ANY and EXISTS accept queries but currently only array expressions are supported for this transformation

def eliminate_window_clause( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
 985def eliminate_window_clause(expression: exp.Expr) -> exp.Expr:
 986    """Eliminates the `WINDOW` query clause by inling each named window."""
 987    windows: list[exp.Expr] | None = expression.args.get("windows")
 988    if isinstance(expression, exp.Select) and windows is not None:
 989        from sqlglot.optimizer.scope import find_all_in_scope
 990
 991        expression.set("windows", None)
 992
 993        window_expression: dict[str, exp.Expr] = {}
 994
 995        def _inline_inherited_window(window: exp.Expr) -> None:
 996            inherited_window = window_expression.get(window.alias.lower())
 997            if not inherited_window:
 998                return
 999
1000            window.set("alias", None)
1001            for key in ("partition_by", "order", "spec"):
1002                arg: exp.Expr | None = inherited_window.args.get(key)
1003                if arg is not None:
1004                    window.set(key, arg.copy())
1005
1006        for window in windows:
1007            _inline_inherited_window(window)
1008            window_expression[window.name.lower()] = window
1009
1010        for window in find_all_in_scope(expression, exp.Window):
1011            _inline_inherited_window(window)
1012
1013    return expression

Eliminates the WINDOW query clause by inling each named window.

def inherit_struct_field_names( expression: sqlglot.expressions.core.Expr) -> sqlglot.expressions.core.Expr:
1016def inherit_struct_field_names(expression: exp.Expr) -> exp.Expr:
1017    """
1018    Inherit field names from the first struct in an array.
1019
1020    BigQuery supports implicitly inheriting names from the first STRUCT in an array:
1021
1022    Example:
1023        ARRAY[
1024          STRUCT('Alice' AS name, 85 AS score),  -- defines names
1025          STRUCT('Bob', 92),                     -- inherits names
1026          STRUCT('Diana', 95)                    -- inherits names
1027        ]
1028
1029    This transformation makes the field names explicit on all structs by adding
1030    PropertyEQ nodes, in order to facilitate transpilation to other dialects.
1031
1032    Args:
1033        expression: The expression tree to transform
1034
1035    Returns:
1036        The modified expression with field names inherited in all structs
1037    """
1038    if (
1039        isinstance(expression, exp.Array)
1040        and expression.args.get("struct_name_inheritance")
1041        and isinstance(first_item := seq_get(expression.expressions, 0), exp.Struct)
1042        and all(isinstance(fld, exp.PropertyEQ) for fld in first_item.expressions)
1043    ):
1044        field_names: list[exp.Identifier] = [fld.this for fld in first_item.expressions]
1045
1046        # Apply field names to subsequent structs that don't have them
1047        for struct in expression.expressions[1:]:
1048            if not isinstance(struct, exp.Struct) or len(struct.expressions) != len(field_names):
1049                continue
1050
1051            # Convert unnamed expressions to PropertyEQ with inherited names
1052            new_expressions: list[exp.PropertyEQ] = []
1053            for i, expr in enumerate(struct.expressions):
1054                if not isinstance(expr, exp.PropertyEQ):
1055                    # Create PropertyEQ: field_name := value, preserving the type from the inner expression
1056                    property_eq = exp.PropertyEQ(
1057                        this=field_names[i].copy(),
1058                        expression=expr,
1059                    )
1060                    property_eq.type = expr.type
1061                    new_expressions.append(property_eq)
1062                else:
1063                    new_expressions.append(expr)
1064
1065            struct.set("expressions", new_expressions)
1066
1067    return expression

Inherit field names from the first struct in an array.

BigQuery supports implicitly inheriting names from the first STRUCT in an array:

Example:

ARRAY[ STRUCT('Alice' AS name, 85 AS score), -- defines names STRUCT('Bob', 92), -- inherits names STRUCT('Diana', 95) -- inherits names ]

This transformation makes the field names explicit on all structs by adding PropertyEQ nodes, in order to facilitate transpilation to other dialects.

Arguments:
  • expression: The expression tree to transform
Returns:

The modified expression with field names inherited in all structs