Skip to content

[Feat] TaylorSeer Cache#12648

Merged
sayakpaul merged 39 commits intohuggingface:mainfrom
toilaluan:feat-taylorseer
Dec 6, 2025
Merged

[Feat] TaylorSeer Cache#12648
sayakpaul merged 39 commits intohuggingface:mainfrom
toilaluan:feat-taylorseer

Conversation

@toilaluan
Copy link
Copy Markdown
Contributor

What does this PR do?

Adding TaylorSeer Caching method to accelerate inference speed mentioned in #12569

Author's codebase: https://github.com/Shenyi-Z/TaylorSeer

This PR structure will heavily mimic FasterCache (https://github.com/huggingface/diffusers/pull/10163/files) behaviour
I prioritze to make it work on image model pipelines (Flux, Qwen Image) for ease of evaluation

Expected Output

4->5x speeding up by these settings while keep output images are qualified

image

State Design

Core of this algorithm is about predict features of step t by using real computed features from previous step using Taylor Expansion Approximation.
We design a State class, include predict & update method and taylor_factors: Tensor to maintain iteration information. Each feature tensor will be bounded to a state instance (in double stream attention class in Flux & QwenImage, output of this module is image_features & txt_features, we will create 2 state instances for them)

  • update method will be called from real compute timestep and update taylor_factors using math formular referenced to original implementation
  • predict method will be called to predict feature from current taylor_factors using math formular referenced to original implementation

@seed93
Copy link
Copy Markdown

seed93 commented Nov 14, 2025

Will you adapt this great PR for flux kontext controlnet or flux controlnet? It would be nice if it is implemented and I am very eager to try it out.

@toilaluan
Copy link
Copy Markdown
Contributor Author

@seed93 yes, i am prioritizing for flux series and qwen image

@toilaluan
Copy link
Copy Markdown
Contributor Author

Here is analysis about TaylorSeer for Flux
Comparing with baseline, the output image is different, although PAB method give pretty close result
This result is match with author's implementation

model_id cache_method compute_dtype compile time model_memory model_max_memory_reserved inference_memory inference_max_memory_reserved
flux none fp16 False 22.318 33.313 33.322 33.322 34.305
flux pyramid_attention_broadcast fp16 False 18.394 33.313 33.322 33.322 35.789
flux taylorseer_cache fp16 False 6.457 33.313 33.322 33.322 38.18

Flux visual results

Baseline

image

Pyramid Attention Broadcast

image

TaylorSeer Cache (this implementation)

image

TaylorSeer Original (https://github.com/Shenyi-Z/TaylorSeer/blob/main/TaylorSeers-Diffusers/taylorseer_flux/diffusers_taylorseer_flux.py)

image
Details

Benchmark code is based on #10163

import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    MochiPipeline,
)
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/root/diffusers")
branch = repo.active_branch

from diffusers import (
    apply_taylorseer_cache, 
    TaylorSeerCacheConfig, 
    apply_faster_cache, 
    FasterCacheConfig, 
    apply_pyramid_attention_broadcast, 
    PyramidAttentionBroadcastConfig,
)

def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    generation_kwargs = {
        "prompt": "A cat holding a sign that says hello world",
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def prepare_flux_config(cache_method: str, pipe: FluxPipeline):
    if cache_method == "pyramid_attention_broadcast":
        return PyramidAttentionBroadcastConfig(
            spatial_attention_block_skip_range=2,
            spatial_attention_timestep_skip_range=(100, 950),
            spatial_attention_block_identifiers=["transformer_blocks", "single_transformer_blocks"],
            current_timestep_callback=lambda: pipe.current_timestep,
        )
    elif cache_method == "taylorseer_cache":
        return TaylorSeerCacheConfig(predict_steps=5, max_order=1, warmup_steps=3, taylor_factors_dtype=torch.float16, architecture="flux")
    elif cache_method == "fastercache":
        return FasterCacheConfig(
        spatial_attention_block_skip_range=2,
        spatial_attention_timestep_skip_range=(-1, 681),
        low_frequency_weight_update_timestep_range=(99, 641),
        high_frequency_weight_update_timestep_range=(-1, 301),
        spatial_attention_block_identifiers=["transformer_blocks"],
        attention_weight_callback=lambda _: 0.3,
        tensor_format="BFCHW",
    )
    elif cache_method == "none":
        return None


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


MODEL_MAPPING = {
    "flux": {
        "prepare": prepare_flux,
        "config": prepare_flux_config,
        "decode": decode_flux,
    },
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    print(f"Generator: {generator}")
    print(f"Generation kwargs: {generation_kwargs}")
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


@torch.no_grad()
def main(model_id: str, cache_method: str, output_dir: str, dtype: str):
    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]

    try:
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        model_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 2. Apply attention approximation technique
        config = model["config"](cache_method, pipe)
        if cache_method == "pyramid_attention_broadcast":
            apply_pyramid_attention_broadcast(pipe.transformer, config)
        elif cache_method == "fastercache":
            apply_faster_cache(pipe.transformer, config)
        elif cache_method == "taylorseer_cache":
            apply_taylorseer_cache(pipe.transformer, config)
        elif cache_method == "none":
            pass
        else:
            raise ValueError(f"Invalid {cache_method=} provided.")

        # 4. Benchmark
        time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        inference_max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---dtype-{dtype}---cache_method-{cache_method}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": time,
            "model_memory": model_memory,
            "model_max_memory_reserved": model_max_memory_reserved,
            "inference_memory": inference_memory,
            "inference_max_memory_reserved": inference_max_memory_reserved,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "cache_method": cache_method,
            "compute_dtype": dtype,
            "time": None,
            "model_memory": None,
            "model_max_memory_reserved": None,
            "inference_memory": None,
            "inference_max_memory_reserved": None,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux"],
        help="Model to run benchmark for.",
    )
    parser.add_argument(
        "--cache_method",
        type=str,
        default="pyramid_attention_broadcast",
        choices=["pyramid_attention_broadcast", "fastercache", "taylorseer_cache", "none"],
        help="Cache method to use.",
    )
    parser.add_argument(
        "--output_dir", type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(args.model_id, args.cache_method, args.output_dir, args.dtype)
    

@toilaluan
Copy link
Copy Markdown
Contributor Author

More comparison between this impl, baseline, author's impl

image

@toilaluan
Copy link
Copy Markdown
Contributor Author

I think current implementation is unified for every models that have attention modules, but to achieve full optimization, we have to config regex for which layer to cache or skip compute
Example in a sequence of Linear1, Act1, Linear2, Act2: we need to add hook for Linear1,act1,linear2 to do nothing (return an empty tensor) but cache output of act2
I already fix template for flux, but for other models, user have to write their own and pass it to the config init
@sayakpaul how do you think about this mechanism? I need some advises here

@sayakpaul sayakpaul requested a review from DN6 November 14, 2025 17:41
@toilaluan
Copy link
Copy Markdown
Contributor Author

toilaluan commented Nov 15, 2025

Tuning cache config really helps!

TaylorSeer cache configuration comparison

In the original code, they use 3 warmup steps and no cooldown. The output image differs significantly from the baseline, as shown in the report above.

As suggested in Shenyi-Z/TaylorSeer#12, increasing the warmup steps to 10 helps narrow the gap, but the cached output still has noticeable artifacts. This naturally suggested adding a cooldown phase (running the last steps without caching).

All runs below use the same prompt and 50 inference steps.

Visual comparison

Baseline vs. 3 warmup / 0 cooldown

Baseline (no cache) 3 warmup steps, 0 cooldown (cache)
Baseline output 3 warmup, 0 cooldown output

With only 3 warmup steps and 0 cooldown steps, the image content is not very close to the baseline.

10 warmup / 0 cooldown vs. 10 warmup / 5 cooldown

10 warmup steps, 0 cooldown (cache) 10 warmup steps, 5 cooldown (cache)
10 warmup, 0 cooldown output 10 warmup, 5 cooldown output

With 10 warmup steps, the content is closer to the baseline, but there are still many artifacts and noise.
By running the last 5 steps without caching (cooldown), most of these issues are resolved.


Hardware usage comparison

The table below shows the hardware usage comparison:

cache_method predict_steps max_order warmup_steps stop_predicts time (s) model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none - - - - 22.781 33.313 33.321 37.943 fp16
taylorseer_cache 5.0 1.0 3.0 - 7.099 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 9.024 55.490 55.490 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 - 9.451 55.492 55.492 70.283 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 11.000 55.490 55.490 70.283 fp16
taylorseer_cache 6.0 1.0 3.0 - 6.701 55.492 55.492 70.285 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 8.651 55.490 55.490 70.285 fp16
taylorseer_cache 6.0 1.0 10.0 - 9.053 55.492 55.492 70.283 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 11.001 55.490 55.490 70.283 fp16
image

Code

Details
import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import FluxPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    model_id = "black-forest-labs/FLUX.1-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang"
    generation_kwargs = {
        "prompt": prompt,
        "height": 1024,
        "width": 1024,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.float16
    taylor_factors_dtype = torch.float16

    param_grid = {
        'predict_steps': [5, 6],
        'max_order': [1],
        'warmup_steps': [3, 10],
        'stop_predicts': [None, 45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    set_verbosity_info()
    main(args.output_dir)

@toilaluan
Copy link
Copy Markdown
Contributor Author

Similar behavior with Qwen Image

cache_method predict_steps max_order warmup_steps stop_predicts time model_memory_gb inference_memory_gb max_memory_reserved_gb compute_dtype
none 23.01 53.791 53.807 64.359 fp16
taylorseer_cache 5.0 1.0 3.0 13.457 53.807 53.813 67.303 fp16
taylorseer_cache 5.0 1.0 3.0 45.0 14.562 53.813 53.819 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 14.775 53.819 53.825 67.303 fp16
taylorseer_cache 5.0 1.0 10.0 45.0 15.628 53.825 53.832 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 13.214 53.832 53.839 67.322 fp16
taylorseer_cache 6.0 1.0 3.0 45.0 14.349 53.838 53.845 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 14.595 53.844 53.851 67.322 fp16
taylorseer_cache 6.0 1.0 10.0 45.0 15.707 53.851 53.858 67.342 fp16

@toilaluan
Copy link
Copy Markdown
Contributor Author

toilaluan commented Nov 16, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93
I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.


Original Image

original image

Image Comparison (Side-by-Side)

Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8

Processing Times (CSV → Markdown Table)

| cache_method     | predict_steps | max_order | warmup_steps | stop_predicts | time    | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype |
|------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------|
| none             |               |           |              |               | 48.391  | 31.438           | 31.446               | 36.209                   | fp16          |
| taylorseer_cache | 7.0           | 1.0       | 10.0         | 45.0          | 21.468  | 31.447           | 31.447               | 44.625                   | fp16          |
| taylorseer_cache | 8.0           | 1.0       | 10.0         | 45.0          | 20.633  | 31.447           | 31.447               | 44.625                   | fp16          |

Reproduce Code

Details
import gc
import pathlib
import pandas as pd
import torch
from itertools import product

from diffusers import DiffusionPipeline
from diffusers.utils.logging import set_verbosity_info

from diffusers import apply_taylorseer_cache, TaylorSeerCacheConfig

def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output

def prepare_flux(dtype: torch.dtype):
    from diffusers.utils import load_image
    model_id = "black-forest-labs/FLUX.1-Kontext-dev"
    print(f"Loading {model_id} with {dtype} dtype")
    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype, use_safetensors=True)
    pipe.to("cuda")
    prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang, Ultra HD, 4K, cinematic composition."
    edit_prompt = "Add a hat to the cat"
    input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
    generation_kwargs = {
        "prompt": edit_prompt,
        "num_inference_steps": 50,
        "guidance_scale": 2.5,
        "image": input_image,
    }

    return pipe, generation_kwargs

def run_inference(pipe, generation_kwargs):
    generator = torch.Generator(device="cuda").manual_seed(181201)
    output = pipe(generator=generator, output_type="pil", **generation_kwargs).images[0]
    torch.cuda.synchronize()
    return output

def main(output_dir: str):
    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    compute_dtype = torch.bfloat16
    taylor_factors_dtype = torch.bfloat16

    param_grid = {
        'predict_steps': [7, 8],
        'max_order': [1],
        'warmup_steps': [10],
        'stop_predicts': [45]
    }
    combinations = list(product(*param_grid.values()))
    param_keys = list(param_grid.keys())

    results = []

    # Reset before each run
    def reset_cuda():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.reset_accumulated_memory_stats()
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
        torch.cuda.synchronize()

    # Baseline (no cache)
    print("Running baseline...")
    reset_cuda()
    pipe, generation_kwargs = prepare_flux(compute_dtype)
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
    inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
    max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
    image_filename = output_dir / "baseline.png"
    image.save(image_filename)
    print(f"Baseline image saved to {image_filename}")

    info = {
        'cache_method': 'none',
        'predict_steps': None,
        'max_order': None,
        'warmup_steps': None,
        'stop_predicts': None,
        'time': time,
        'model_memory_gb': model_memory,
        'inference_memory_gb': inference_memory,
        'max_memory_reserved_gb': max_memory_reserved,
        'compute_dtype': 'fp16'
    }
    results.append(info)

    # TaylorSeer cache configurations
    for combo in combinations:
        ps, mo, ws, sp = combo
        sp_str = 'None' if sp is None else str(sp)
        print(f"Running TaylorSeer with predict_steps={ps}, max_order={mo}, warmup_steps={ws}, stop_predicts={sp}...")
        del pipe
        reset_cuda()
        pipe, generation_kwargs = prepare_flux(compute_dtype)
        model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        config = TaylorSeerCacheConfig(
            predict_steps=ps,
            max_order=mo,
            warmup_steps=ws,
            stop_predicts=sp,
            taylor_factors_dtype=taylor_factors_dtype,
            architecture="flux"
        )
        apply_taylorseer_cache(pipe.transformer, config)
        time, image = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)
        max_memory_reserved = round(torch.cuda.max_memory_reserved() / 1024**3, 3)
        image_filename = output_dir / f"taylorseer_p{ps}_o{mo}_w{ws}_s{sp_str}.jpg"
        image.save(image_filename)
        print(f"TaylorSeer image saved to {image_filename}")

        info = {
            'cache_method': 'taylorseer_cache',
            'predict_steps': ps,
            'max_order': mo,
            'warmup_steps': ws,
            'stop_predicts': sp,
            'time': time,
            'model_memory_gb': model_memory,
            'inference_memory_gb': inference_memory,
            'max_memory_reserved_gb': max_memory_reserved,
            'compute_dtype': 'fp16'
        }
        results.append(info)

    # Save CSV
    df = pd.DataFrame(results)
    csv_path = output_dir / 'benchmark_results.csv'
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    # Plot latency
    import matplotlib.pyplot as plt
    plt.style.use('default')
    fig, ax = plt.subplots(figsize=(20, 8))

    baseline_row = df[df['cache_method'] == 'none'].iloc[0]
    baseline_time = baseline_row['time']

    labels = ['baseline']
    times = [baseline_time]

    taylor_df = df[df['cache_method'] == 'taylorseer_cache']
    for _, row in taylor_df.iterrows():
        sp_str = 'None' if pd.isna(row['stop_predicts']) else str(int(row['stop_predicts']))
        label = f"p{row['predict_steps']}-o{row['max_order']}-w{row['warmup_steps']}-s{sp_str}"
        labels.append(label)
        times.append(row['time'])

    bars = ax.bar(labels, times)
    ax.set_xlabel('Configuration')
    ax.set_ylabel('Latency (s)')
    ax.set_title('Inference Latency: Baseline vs TaylorSeer Cache Configurations')
    ax.tick_params(axis='x', rotation=90)
    plt.tight_layout()

    plot_path = output_dir / 'latency_comparison.png'
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"Plot saved to {plot_path}")


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_dir", type=str, required=True, help="Path to save CSV, plot, and images.")
    args = parser.parse_args()

    # set_verbosity_info()
    main(args.output_dir)

@toilaluan toilaluan marked this pull request as ready for review November 17, 2025 06:21
@seed93
Copy link
Copy Markdown

seed93 commented Nov 17, 2025

Flux Kontext – Cache vs Baseline Comparison

@seed93 I tested Flux Kontext using several cache configurations, and the results look promising. Below is a comparison of the baseline output and the cached versions.

Original Image

original image ## **Image Comparison (Side-by-Side)** Baseline — “add a hat to the cat” predict_steps=7, O=1, warmup=10, cooldown=5 predict_steps=8 ## **Processing Times (CSV → Markdown Table)** ``` | cache_method | predict_steps | max_order | warmup_steps | stop_predicts | time | model_memory_gb | inference_memory_gb | max_memory_reserved_gb | compute_dtype | |------------------|---------------|-----------|--------------|---------------|---------|------------------|----------------------|--------------------------|---------------| | none | | | | | 48.391 | 31.438 | 31.446 | 36.209 | fp16 | | taylorseer_cache | 7.0 | 1.0 | 10.0 | 45.0 | 21.468 | 31.447 | 31.447 | 44.625 | fp16 | | taylorseer_cache | 8.0 | 1.0 | 10.0 | 45.0 | 20.633 | 31.447 | 31.447 | 44.625 | fp16 | ```

Reproduce Code

This is amazing!

@seed93
Copy link
Copy Markdown

seed93 commented Nov 19, 2025

I am not sure why it uses so much gpu memory? I have only 24 GB gpu memory.

@seed93
Copy link
Copy Markdown

seed93 commented Nov 19, 2025

Could you please try using the taylorseer-lite as an option? refer to Shenyi-Z/TaylorSeer#5

@toilaluan
Copy link
Copy Markdown
Contributor Author

@seed93 yeah it seems to not complicated, i will try and post some report here

@toilaluan
Copy link
Copy Markdown
Contributor Author

Taylorseer-lite

@seed93, you can use lite version with minimal extra memory by following this script but it works for Hunyuan model, not Flux.
Flux TS-lite's output is purely noise

  • Hunyuan Output
image
  • Flux Output
image
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig


model = "hunyuanimage"  # or "flux"
if model == "flux":
    pipeline = FluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
elif model == "hunyuanimage":
    pipeline = HunyuanImagePipeline.from_pretrained(
        "hunyuanvideo-community/HunyuanImage-2.1-Diffusers",
        torch_dtype=torch.bfloat16,
    ).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    skip_identifiers=[r"^(?!proj_out$)[^.]+\.[^.]+$"],
    cache_identifiers=[r"proj_out"],
    predict_steps=5,
    max_order=2,
    warmup_steps=10,
    stop_predicts=48,
    taylor_factors_dtype=torch.bfloat16,
)

pipeline.transformer.enable_cache(cache_config)

prompt = "A fluffy teddy bear sits on a bed of soft pillows surrounded by children's toys."
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]

image.save("teddy_bear.jpg")

@toilaluan
Copy link
Copy Markdown
Contributor Author

@DN6 This feature is ready for reviewing, could you take a look 🙇

@toilaluan
Copy link
Copy Markdown
Contributor Author

@DN6, @sayakpaul I added a similar compiler disable to taylorseer cache, but i have an observation that can be applied to both FBCache and TaylorSeer:

  • graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
    By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:
import torch
from diffusers import FluxPipeline, HunyuanImagePipeline
from diffusers import TaylorSeerCacheConfig, FirstBlockCacheConfig

# torch._logging.set_logs(graph_code=True)

import torch._dynamo as dynamo
dynamo.config.recompile_limit = 100

pipeline = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to("cuda")
print(pipeline)

cache_config = TaylorSeerCacheConfig(
    cache_interval=5,
    max_order=1,
    disable_cache_before_step=50, # assume we will run full compute to see compile effect
    disable_cache_after_step=48,
    taylor_factors_dtype=torch.bfloat16,
    use_lite_mode=True
)
fbconfig = FirstBlockCacheConfig(
    threshold=1e-6, # set this value to very small so cache will not be applied to see compile effect
)

pipeline.transformer.enable_cache(fbconfig) # or cache_config

pipeline.transformer.compile(fullgraph=False, dynamic=True)

prompt = "A laptop on top of a teddy bear, realistic, high quality, 4k"
# warmup
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
# monitor this call
image = pipeline(prompt=prompt, width=1024, height=1024, num_inference_steps=50, generator=torch.Generator(device="cuda").manual_seed(181201)).images[0]
image.save("teddy_bear.jpg")

@toilaluan
Copy link
Copy Markdown
Contributor Author

toilaluan commented Dec 4, 2025

Comparison: Baseline, Baseline-22steps, FBCache, TaylorSeer Cache

FLUX.1-dev

Memory & Speed Metrics (GPU: H100, 50 steps, compiled)

Prompt Index Variant Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB)
0 baseline 5.812418 31.437537 1.162642 3.940589 11.744348 33.851628
0 baseline(steps=22) 5.445316 31.469763 0.054121 2.662160 5.271367 33.851628
0 firstblock(threshold=0.05) 5.618118 31.469763 0.053683 30.769011 8.686095 33.928777
0 taylorseer(max_order=1, cache_interval=5, disable_cache_before_step=10) 5.487217 31.469763 0.051885 59.841501 4.865684 33.852117

Visual Outputs

  • Women’s Health magazine cover, April 2025 issue, ‘Spring forward’ headline, woman in green outfit sitting on orange blocks, white sneakers, ‘Covid: five years on’ feature text, ‘15 skincare habits’ callout, professional editorial photography, magazine layout with multiple text elements
Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache
  • Soaking wet tiger cub taking shelter under a banana leaf in the rainy jungle, close up photo
Baseline Baseline-22steps FBCache TaylorSeer Cache
Baseline Baseline 22 steps FBCache TaylorCache

Analyze TaylorSeer Configurations

Memory & Speed (no compile)

Prompt Index Variant Steps Max Order Cache Interval Load Time (s) Load Memory (GB) Compile Time (s) Warmup Time (s) Main Time (s) Peak Memory (GB) Speedup vs baseline_50 Speedup vs baseline_22
0 baseline_50 50 N/A N/A 5.68 31.44 0.00 2.38 15.32 33.85 1.00x 1.00x
0 baseline_22 22 N/A N/A 5.39 31.47 0.00 1.70 6.86 33.85 2.23x 1.00x
0 taylor_o0_ci5 50 0 5 5.37 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o0_ci10 50 0 10 5.38 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o0_ci15 50 0 15 5.44 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o1_ci5 50 1 5 5.67 31.47 0.00 1.71 5.70 33.85 2.69x 1.20x
0 taylor_o1_ci8 50 1 8 5.68 31.47 0.00 1.71 4.88 33.85 3.14x 1.41x
0 taylor_o1_ci10 50 1 10 5.68 31.47 0.00 1.71 4.64 33.85 3.30x 1.48x
0 taylor_o1_ci15 50 1 15 5.68 31.47 0.00 1.71 4.40 33.85 3.48x 1.56x
0 taylor_o2_ci5 50 2 5 5.66 31.47 0.00 1.71 5.69 33.85 2.69x 1.21x
0 taylor_o2_ci10 50 2 10 5.73 31.47 0.00 1.70 4.64 33.85 3.30x 1.48x
0 taylor_o2_ci15 50 2 15 5.36 31.47 0.00 1.70 4.40 33.85 3.48x 1.56x

Visual Comparison (o1,ci5 means max_order=1, cache_interval=5)

image

Reproduce Code

  1. Baselines Vs. TaylorSeer variants
Details
import torch
from diffusers import FluxPipeline, TaylorSeerCacheConfig
import time
import os
import matplotlib.pyplot as plt
import pandas as pd
import gc

# Set dynamo config
import torch._dynamo as dynamo
dynamo.config.recompile_limit = 200

prompts = [
    "Black cat hiding behind a watermelon slice, professional studio shot, bright red and turquoise background with summer mystery vibe",
]

# Create output folder
os.makedirs("outputs", exist_ok=True)

# ============================================================================
# CONFIGURATION SECTION - Easily modify these parameters
# ============================================================================

# Fixed config parameters (applied to all TaylorSeer configs)
FIXED_CONFIG = {
    'disable_cache_before_step': 10,
    'taylor_factors_dtype': torch.bfloat16,
    'use_lite_mode': True
}

# Variable parameters to test - modify these as needed
# Format: (max_order, cache_interval)
TAYLOR_CONFIGS = [
    (0, 5),   # max_order=0, cache_interval=5
    (0, 10),
    (0, 15),
    (1, 5),   # max_order=1, cache_interval=5
    (1, 8),   # max_order=1, cache_interval=6
    (1, 10),   # max_order=1, cache_interval=7
    (1, 15),  # max_order=1, cache_interval=10
    (2, 5),   # max_order=2, cache_interval=5
    (2, 10),   # max_order=2, cache_interval=6
    (2, 15),
]

# Baseline configurations
BASELINES = [
    {'name': 'baseline_50', 'steps': 50},
    {'name': 'baseline_22', 'steps': 22},
]

# Main inference steps for TaylorSeer variants
MAIN_STEPS = 50
WARMUP_STEPS = 5

# ============================================================================

# Build TaylorSeer configs
taylor_configs = {}
for max_order, cache_interval in TAYLOR_CONFIGS:
    config_name = f'taylor_o{max_order}_ci{cache_interval}'
    taylor_configs[config_name] = TaylorSeerCacheConfig(
        max_order=max_order,
        cache_interval=cache_interval,
        **FIXED_CONFIG
    )

# Collect results
results = []

for i, prompt in enumerate(prompts):
    print(f"\n{'='*80}")
    print(f"Processing Prompt {i}: {prompt[:50]}...")
    print(f"{'='*80}\n")
    
    images = {}
    baseline_times = {}
    
    # Run all baseline variants first
    for baseline_config in BASELINES:
        variant = baseline_config['name']
        num_steps = baseline_config['steps']
        
        print(f"Running {variant} (steps={num_steps})...")
        
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
        
        # Compile with timing
        start_compile = time.time()
        # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 5 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=WARMUP_STEPS,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=num_steps,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
        main_time = end_main - start_main
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Store baseline time
        baseline_times[variant] = main_time
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Steps': num_steps,
            'Max Order': 'N/A',
            'Cache Interval': 'N/A',
            'Load Time (s)': f"{load_time:.2f}",
            'Load Memory (GB)': f"{load_mem_gb:.2f}",
            'Compile Time (s)': f"{compile_time:.2f}",
            'Warmup Time (s)': f"{warmup_time:.2f}",
            'Main Time (s)': f"{main_time:.2f}",
            'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
            'Speedup vs baseline_50': '1.00x' if variant == 'baseline_50' else f"{baseline_times['baseline_50']/main_time:.2f}x",
            'Speedup vs baseline_22': '1.00x' if variant == 'baseline_22' else f"{baseline_times.get('baseline_22', main_time)/main_time:.2f}x"
        })
        
        print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
        print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB\n")
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()
        dynamo.reset()
    
    # TaylorSeer variants with different configurations
    for config_name, tsconfig in taylor_configs.items():
        variant = config_name
        max_order = tsconfig.max_order
        cache_interval = tsconfig.cache_interval
        print(f"Running {variant} (max_order={max_order}, cache_interval={cache_interval})...")
        
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)
        
        # Enable TaylorSeer cache
        pipeline.transformer.enable_cache(tsconfig)
        
        # Compile with timing
        start_compile = time.time()
        # pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 5 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=WARMUP_STEPS,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=MAIN_STEPS,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)
        main_time = end_main - start_main
        speedup_50 = baseline_times['baseline_50'] / main_time
        speedup_22 = baseline_times['baseline_22'] / main_time
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Steps': MAIN_STEPS,
            'Max Order': max_order,
            'Cache Interval': cache_interval,
            'Load Time (s)': f"{load_time:.2f}",
            'Load Memory (GB)': f"{load_mem_gb:.2f}",
            'Compile Time (s)': f"{compile_time:.2f}",
            'Warmup Time (s)': f"{warmup_time:.2f}",
            'Main Time (s)': f"{main_time:.2f}",
            'Peak Memory (GB)': f"{peak_mem_gb:.2f}",
            'Speedup vs baseline_50': f"{speedup_50:.2f}x",
            'Speedup vs baseline_22': f"{speedup_22:.2f}x"
        })
        
        print(f"  Load: {load_time:.2f}s, Compile: {compile_time:.2f}s, Warmup: {warmup_time:.2f}s")
        print(f"  Main: {main_time:.2f}s, Peak Memory: {peak_mem_gb:.2f} GB")
        print(f"  Speedup vs baseline_50: {speedup_50:.2f}x, vs baseline_22: {speedup_22:.2f}x\n")
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()
        torch.cuda.empty_cache()
        dynamo.reset()
    
    # Plot image comparison for this prompt (select key variants)
    key_variants = ['baseline_50', 'baseline_22'] + [list(taylor_configs.keys())[j] for j in range(min(4, len(taylor_configs)))]
    num_variants = len(key_variants)
    
    fig, axs = plt.subplots(1, num_variants, figsize=(10*num_variants, 10))
    if num_variants == 1:
        axs = [axs]
    
    for j, var in enumerate(key_variants):
        if var in images:
            axs[j].imshow(images[var])
            axs[j].set_title(f"{var}", fontsize=24)
            axs[j].axis('off')
    
    plt.tight_layout()
    plt.savefig(f"outputs/comparison_prompt{i}.png", dpi=100)
    plt.close()

# Print results table
print("\n" + "="*140)
print("BENCHMARK RESULTS")
print("="*140 + "\n")

df = pd.DataFrame(results)
print(df.to_string(index=False))

# Save results to CSV
df.to_csv("outputs/benchmark_results.csv", index=False)
print("\nResults saved to outputs/benchmark_results.csv")

# Calculate and display averages per variant
print("\n" + "="*140)
print("AVERAGE METRICS BY VARIANT")
print("="*140 + "\n")

# Convert numeric columns back to float for averaging
numeric_cols = ['Load Time (s)', 'Load Memory (GB)', 'Compile Time (s)', 
                'Warmup Time (s)', 'Main Time (s)', 'Peak Memory (GB)']

df_numeric = df.copy()
for col in numeric_cols:
    df_numeric[col] = df_numeric[col].astype(float)

# Group by variant and calculate means
avg_df = df_numeric.groupby('Variant')[numeric_cols + ['Steps']].mean()

# Add configuration info
avg_df['Max Order'] = df.groupby('Variant')['Max Order'].first()
avg_df['Cache Interval'] = df.groupby('Variant')['Cache Interval'].first()

# Calculate average speedups
speedup_50_df = df.groupby('Variant')['Speedup vs baseline_50'].apply(
    lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x"
)
speedup_22_df = df.groupby('Variant')['Speedup vs baseline_22'].apply(
    lambda x: f"{sum(float(v.rstrip('x')) for v in x) / len(x):.2f}x"
)
avg_df['Avg Speedup vs baseline_50'] = speedup_50_df
avg_df['Avg Speedup vs baseline_22'] = speedup_22_df

# Reorder columns
avg_df = avg_df[['Steps', 'Max Order', 'Cache Interval'] + numeric_cols + 
                ['Avg Speedup vs baseline_50', 'Avg Speedup vs baseline_22']]

# Format numeric columns
avg_df['Steps'] = avg_df['Steps'].apply(lambda x: f"{x:.0f}")
for col in numeric_cols:
    avg_df[col] = avg_df[col].apply(lambda x: f"{x:.2f}")

print(avg_df.to_string())

# Create comprehensive visualizations
fig, axes = plt.subplots(2, 2, figsize=(20, 16))

# Extract data for plotting
variants = []
main_times = []
peak_memories = []
speedups_50 = []
speedups_22 = []
labels = []

for variant in df['Variant'].unique():
    variant_data = df_numeric[df_numeric['Variant'] == variant]
    variants.append(variant)
    main_times.append(variant_data['Main Time (s)'].mean())
    peak_memories.append(variant_data['Peak Memory (GB)'].mean())
    
    # Calculate average speedups
    speedup_50_values = df[df['Variant'] == variant]['Speedup vs baseline_50'].apply(
        lambda x: float(x.rstrip('x'))
    )
    speedup_22_values = df[df['Variant'] == variant]['Speedup vs baseline_22'].apply(
        lambda x: float(x.rstrip('x'))
    )
    speedups_50.append(speedup_50_values.mean())
    speedups_22.append(speedup_22_values.mean())
    
    # Create readable labels
    if 'baseline' in variant:
        labels.append(variant)
    else:
        parts = variant.split('_')
        order = parts[1].replace('o', 'O')
        ci = parts[2].replace('ci', 'CI')
        labels.append(f"{order}_{ci}")

# Assign colors
colors = ['#1f77b4', '#ff7f0e'] + ['#2ca02c', '#d62728', '#9467bd', '#8c564b', 
                                     '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] * 3
colors = colors[:len(variants)]

# Plot 1: Main Time Comparison
ax1 = axes[0, 0]
bars1 = ax1.bar(labels, main_times, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax1.set_ylabel('Main Generation Time (seconds)', fontsize=12, fontweight='bold')
ax1.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax1.set_title('Average Generation Time Comparison', fontsize=14, fontweight='bold')
ax1.grid(axis='y', alpha=0.3, linestyle='--')
ax1.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, time in zip(bars1, main_times):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height,
             f'{time:.2f}s', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 2: Peak Memory Comparison
ax2 = axes[0, 1]
bars2 = ax2.bar(labels, peak_memories, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax2.set_ylabel('Peak Memory Usage (GB)', fontsize=12, fontweight='bold')
ax2.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax2.set_title('Average Peak Memory Comparison', fontsize=14, fontweight='bold')
ax2.grid(axis='y', alpha=0.3, linestyle='--')
ax2.tick_params(axis='x', rotation=45)

# Add value labels on bars
for bar, mem in zip(bars2, peak_memories):
    height = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2., height,
             f'{mem:.2f}', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 3: Speedup vs baseline_50
ax3 = axes[1, 0]
bars3 = ax3.bar(labels, speedups_50, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax3.axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Baseline (50 steps)')
ax3.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold')
ax3.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax3.set_title('Speedup vs Baseline (50 steps)', fontsize=14, fontweight='bold')
ax3.grid(axis='y', alpha=0.3, linestyle='--')
ax3.tick_params(axis='x', rotation=45)
ax3.legend()

# Add value labels on bars
for bar, speedup in zip(bars3, speedups_50):
    height = bar.get_height()
    ax3.text(bar.get_x() + bar.get_width()/2., height,
             f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

# Plot 4: Speedup vs baseline_22
ax4 = axes[1, 1]
bars4 = ax4.bar(labels, speedups_22, color=colors, alpha=0.8, edgecolor='black', linewidth=1.5)
ax4.axhline(y=1.0, color='orange', linestyle='--', linewidth=2, label='Baseline (22 steps)')
ax4.set_ylabel('Speedup Factor', fontsize=12, fontweight='bold')
ax4.set_xlabel('Configuration', fontsize=12, fontweight='bold')
ax4.set_title('Speedup vs Baseline (22 steps)', fontsize=14, fontweight='bold')
ax4.grid(axis='y', alpha=0.3, linestyle='--')
ax4.tick_params(axis='x', rotation=45)
ax4.legend()

# Add value labels on bars
for bar, speedup in zip(bars4, speedups_22):
    height = bar.get_height()
    ax4.text(bar.get_x() + bar.get_width()/2., height,
             f'{speedup:.2f}x', ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.tight_layout()
plt.savefig("outputs/metrics_comparison.png", dpi=150, bbox_inches='tight')
plt.close()

print("\n" + "="*140)
print("Benchmark completed! Check the outputs/ folder for results and visualizations.")
print("="*140)
  1. Baseline, TaylorSeer, FirstBlockCache
Details
import torch
from diffusers import FluxPipeline, TaylorSeerCacheConfig, FirstBlockCacheConfig, FasterCacheConfig
import time
import os
import matplotlib.pyplot as plt
import pandas as pd
import gc  # Added for explicit garbage collection

# Set dynamo config
import torch._dynamo as dynamo
dynamo.config.recompile_limit = 200

prompts = [
    "Soaking wet tiger cub taking shelter under a banana leaf in the rainy jungle, close up photo",
]

# Create output folder
os.makedirs("outputs", exist_ok=True)

# Define cache configs
fbconfig = FirstBlockCacheConfig(
    threshold=0.05
)

tsconfig = TaylorSeerCacheConfig(
    cache_interval=5,
    max_order=1,
    disable_cache_before_step=10,
    disable_cache_after_step=48,
    taylor_factors_dtype=torch.bfloat16,
    use_lite_mode=True
)

# Collect results
results = []

for i, prompt in enumerate(prompts):
    images = {}
    for variant in ['baseline', 'baseline_reduce', 'firstblock', 'taylor']:
        # Clear cache before loading
        gc.collect()
        torch.cuda.empty_cache()
        
        # Load pipeline with timing
        start_load = time.time()
        pipeline = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-dev",
            torch_dtype=torch.bfloat16,
        ).to("cuda")
        load_time = time.time() - start_load
        load_mem_gb = torch.cuda.memory_allocated() / (1024 ** 3)  # GB
        
        # Enable cache if applicable
        if variant == 'firstblock':
            pipeline.transformer.enable_cache(fbconfig)
        elif variant == 'taylor':
            pipeline.transformer.enable_cache(tsconfig)
        # No cache for baseline and baseline_reduce
        
        # Compile with timing
        start_compile = time.time()
        pipeline.transformer.compile_repeated_blocks(fullgraph=False)
        compile_time = time.time() - start_compile
        
        # Warmup with 10 steps
        gen_warmup = torch.Generator(device="cuda").manual_seed(181201)
        start_warmup = time.time()
        _ = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=5,
            guidance_scale=3.0,
            generator=gen_warmup
        ).images[0]
        warmup_time = time.time() - start_warmup
        
        # Main run
        steps = 22 if variant == 'baseline_reduce' else 50
        
        gen_main = torch.Generator(device="cuda").manual_seed(181201)
        
        torch.cuda.reset_peak_memory_stats()
        start_main = time.time()
        image = pipeline(
            prompt=prompt,
            width=1024,
            height=1024,
            num_inference_steps=steps,
            guidance_scale=3.0,
            generator=gen_main
        ).images[0]
        end_main = time.time()
        
        peak_mem_gb = torch.cuda.max_memory_allocated() / (1024 ** 3)  # GB
        main_time = end_main - start_main
        
        # Save image
        image_path = f"outputs/{variant}_prompt{i}.jpg"
        image.save(image_path)
        images[variant] = image
        
        # Record results
        results.append({
            'Prompt Index': i,
            'Variant': variant,
            'Load Time (s)': load_time,
            'Load Memory (GB)': load_mem_gb,
            'Compile Time (s)': compile_time,
            'Warmup Time (s)': warmup_time,
            'Main Time (s)': main_time,
            'Peak Memory (GB)': peak_mem_gb
        })
        
        # Clean up
        pipeline.to("cpu")
        del pipeline
        gc.collect()  # Force garbage collection
        torch.cuda.empty_cache()  # Empty CUDA cache after GC
        dynamo.reset()  # Reset Dynamo cache (harmless even if not compiling)

    # Plot image comparison for this prompt
    fig, axs = plt.subplots(1, 4, figsize=(40, 10))
    variants_order = ['baseline', 'baseline_reduce', 'firstblock', 'taylor']
    for j, var in enumerate(variants_order):
        axs[j].imshow(images[var])
        axs[j].set_title(var)
        axs[j].axis('off')
    plt.tight_layout()
    plt.savefig(f"outputs/comparison_prompt{i}.png")
    plt.close()

# Print speed and memory comparison as a table
df = pd.DataFrame(results)
print("Speed and Memory Comparison:")
print(df.to_string(index=False))

# Optionally, plot bar charts for averages
avg_df = df.groupby('Variant').mean().reset_index()
fig, ax1 = plt.subplots(figsize=(10, 6))
ax1.bar(avg_df['Variant'], avg_df['Main Time (s)'], color='b', label='Main Time (s)')
ax1.set_ylabel('Main Time (s)')
ax2 = ax1.twinx()
ax2.plot(avg_df['Variant'], avg_df['Peak Memory (GB)'], color='r', marker='o', label='Peak Memory (GB)')
ax2.set_ylabel('Peak Memory (GB)')
fig.suptitle('Average Speed and Memory Comparison')
fig.legend()
plt.savefig("outputs/metrics_comparison.png")
plt.close()

@sayakpaul
Copy link
Copy Markdown
Member

graph will break at torch.compile.disable then it requires recompiling every block since hook is applied to block level
By default recompile limit is 8, while total blocks of transformer is much higher (57 in flux), we have to set this limit to higher than number of blocks to achieve best performance and similar graphs to regional compiling. Check the code below:

Yes, increasing the compile limit is fine here.

Some questions / notes:

  • In the code snippet provided in [Feat] TaylorSeer Cache #12648 (comment), why do we need dynamic=True?
  • Could we also add the compilation timing in here to see if that helps at all (especially with the recompilations)?
  • Let's try to add this comparison (just a link your comment is fine) in the docs? I think this is golden information!

@sayakpaul
Copy link
Copy Markdown
Member

@bot /style

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Dec 4, 2025

Style bot fixed some files and pushed the changes.

Copy link
Copy Markdown
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @toilaluan. Small update to docstring and we should be good to merge once tests pass.

@toilaluan
Copy link
Copy Markdown
Contributor Author

@DN6 @sayakpaul (cc @Shenyi-Z) I added more intensive comparison for TaylorSeer in #12648 (comment)

Interesting finding is that increasing max_order doesn't give better result, while max_order=0 (means we reuse v predicted from timestep i for next N steps still work well but i think max_order=1 is the optimal

It also includes complication timing

@Shenyi-Z
Copy link
Copy Markdown

Shenyi-Z commented Dec 5, 2025

@DN6 @sayakpaul (cc @Shenyi-Z) I added more intensive comparison for TaylorSeer in #12648 (comment)

Interesting finding is that increasing max_order doesn't give better result, while max_order=0 (means we reuse v predicted from timestep i for next N steps still work well but i think max_order=1 is the optimal

It also includes complication timing

This is basically consistent with our recent experimental results. This is because TaylorSeer essentially only simulates derivatives from differences; it is actually just a simple prediction method. Excessively high orders are not stable enough, which naturally introduces more numerical errors, making the model's predictions overly arbitrary. On the other hand, the zeroth-order simple reuse results in noticeable lack of detail. In practical applications, TaylorSeer with order=1 often achieves relatively stable performance.

@sayakpaul
Copy link
Copy Markdown
Member

@toilaluan I opened a PR here: toilaluan#1. I believe merging that should help us fix the CI issues.

@toilaluan
Copy link
Copy Markdown
Contributor Author

I merged it

@sayakpaul
Copy link
Copy Markdown
Member

Tests failing seem related. @toilaluan anything we're missing?

@toilaluan
Copy link
Copy Markdown
Contributor Author

@sayakpaul Sorry I can't find where to look into, any suggestion?

@sayakpaul
Copy link
Copy Markdown
Member

@toilaluan
Copy link
Copy Markdown
Contributor Author

@sayakpaul It fails due to this issue #12648 (comment)

Removing test on flux kontext will help

@sayakpaul sayakpaul merged commit 6290fdf into huggingface:main Dec 6, 2025
26 of 28 checks passed
@sayakpaul
Copy link
Copy Markdown
Member

An immense amount of thanks for shipping this! We will get back to you for the MVP stuff!

@toilaluan
Copy link
Copy Markdown
Contributor Author

🤗

@Trgtuan10
Copy link
Copy Markdown
Contributor

I love you @toilaluan

@sayakpaul sayakpaul mentioned this pull request Dec 8, 2025
6 tasks
onurxtasar added a commit to initml/diffusers that referenced this pull request Mar 24, 2026
* Fix broken group offloading with block_level for models with standalone layers (#12692)

* fix: group offloading to support standalone computational layers in block-level offloading

* test: for models with standalone and deeply nested layers in block-level offloading

* feat: support for block-level offloading in group offloading config

* fix: group offload block modules to AutoencoderKL and AutoencoderKLWan

* fix: update group offloading tests to use AutoencoderKL and adjust input dimensions

* refactor: streamline block offloading logic

* Apply style fixes

* update tests

* update

* fix for failing tests

* clean up

* revert to use skip_keys

* clean up

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <[email protected]>

* [Docs] Add Z-Image docs (#12775)

* initial

* toctree

* fix

* apply review and fix

* Update docs/source/en/api/pipelines/z_image.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/z_image.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/api/pipelines/z_image.md

Co-authored-by: Steven Liu <[email protected]>

---------

Co-authored-by: Steven Liu <[email protected]>

* move kandisnky docs.

* [docs] minor fixes to kandinsky docs (#12797)

up

* Improve docstrings and type hints in scheduling_deis_multistep.py (#12796)

* feat: Add `flow_prediction` to `prediction_type`, introduce `use_flow_sigmas`, `flow_shift`, `use_dynamic_shifting`, and `time_shift_type` parameters, and refine type hints for various arguments.

* style: reformat argument wrapping in `_convert_to_beta` and `index_for_timestep` method signatures.

* [Feat] TaylorSeer Cache (#12648)

* init taylor_seer cache

* make compatible with any tuple size returned

* use logger for printing, add warmup feature

* still update in warmup steps

* refractor, add docs

* add configurable cache, skip compute module

* allow special cache ids only

* add stop_predicts (cooldown)

* update docs

* apply ruff

* update to handle multple calls per timestep

* refractor to use state manager

* fix format & doc

* chores: naming, remove redundancy

* add docs

* quality & style

* fix taylor precision

* Apply style fixes

* add tests

* Apply style fixes

* Remove TaylorSeerCacheTesterMixin from flux2 tests

* rename identifiers, use more expressive taylor predict loop

* torch compile compatible

* Apply style fixes

* Update src/diffusers/hooks/taylorseer_cache.py

Co-authored-by: Dhruv Nair <[email protected]>

* update docs

* make fix-copies

* fix example usage.

* remove tests on flux kontext

---------

Co-authored-by: toilaluan <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Dhruv Nair <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* Update the TensorRT-ModelOPT to Nvidia-ModelOPT (#12793)

Update the naming

Co-authored-by: Sayak Paul <[email protected]>

* add post init for safty checker (#12794)

* add post init for safty checker

Signed-off-by: jiqing-feng <[email protected]>

* check transformers version before post init

Signed-off-by: jiqing-feng <[email protected]>

* Apply style fixes

---------

Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [HunyuanVideo1.5] support step-distilled (#12802)

* support step-distilled

* style

* Add ZImageImg2ImgPipeline (#12751)

* Add ZImageImg2ImgPipeline

Updated the pipeline structure to include ZImageImg2ImgPipeline
    alongside ZImagePipeline.
Implemented the ZImageImg2ImgPipeline class for image-to-image
    transformations, including necessary methods for
    encoding prompts, preparing latents, and denoising.
Enhanced the auto_pipeline to map the new ZImageImg2ImgPipeline
    for image generation tasks.
Added unit tests for ZImageImg2ImgPipeline to ensure
    functionality and performance.
Updated dummy objects to include ZImageImg2ImgPipeline for
    testing purposes.

* Address review comments for ZImageImg2ImgPipeline

- Add `# Copied from` annotations to encode_prompt and _encode_prompt
- Add ZImagePipeline to auto_pipeline.py for AutoPipeline support

* Add ZImage pipeline documentation

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Álvaro Somoza <[email protected]>

* [PRX] Improve model compilation (#12787)

* Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT).

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* Improve docstrings and type hints in scheduling_dpmsolver_singlestep.py (#12798)

feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler

* [Modular]z-image (#12808)

* initiL

* up up

* fix: z_image -> z-image

* style

* copy

* fix more

* some docstring fix

* Fix Qwen Edit Plus modular for multi-image input (#12601)

* try to fix qwen edit plus multi images (modular)

* up

* up

* test

* up

* up

* [WIP] Add Flux2 modular (#12763)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* [docs] improve distributed inference cp docs. (#12810)

* improve distributed inference cp docs.

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

---------

Co-authored-by: Steven Liu <[email protected]>

* post release 0.36.0 (#12804)

* post release 0.36.0

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Update distributed_inference.md to correct syntax (#12827)

* [lora] Remove lora docs unneeded and add " # Copied from ..." (#12824)

* remove unneeded docs on load_lora_weights().

* remove more.

* up[

* up

* up

* support CP in native flash attention (#12829)

Signed-off-by: Wang, Yi <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* [qwen-image] edit 2511 support (#12839)

* [qwen-image] edit 2511 support

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* fix pytest tests/pipelines/pixart_sigma/test_pixart.py::PixArtSigmaPi… (#12842)

fix pytest tests/pipelines/pixart_sigma/test_pixart.py::PixArtSigmaPipelineIntegrationTests::test_pixart_512 in xpu

Signed-off-by: Wang, Yi <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* Support for control-lora (#10686)

* run control-lora on diffusers

* cannot load lora adapter

* test

* 1

* add control-lora

* 1

* 1

* 1

* fix PeftAdapterMixin

* fix module_to_save bug

* delete json print

* resolve conflits

* merged but bug

* change peft.py

* 1

* delete state_dict print

* fix alpha

* Create control_lora.py

* Add files via upload

* rename

* no need modify as peft updated

* add doc

* fix code style

* styling isn't that hard 😉

* empty

---------

Co-authored-by: Sayak Paul <[email protected]>

* Add support for LongCat-Image (#12828)

* Add  LongCat-Image

* Update src/diffusers/models/transformers/transformer_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/transformers/transformer_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/transformers/transformer_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/transformers/transformer_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* fix code

* add doc

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image_edit.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/pipelines/longcat_image/pipeline_longcat_image.py

Co-authored-by: YiYi Xu <[email protected]>

* fix code & mask style & fix-copies

* Apply style fixes

* fix single input rewrite error

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: hadoop-imagen <hadoop-imagen@psxfb7pxrbvmh3oq-worker-0.psxfb7pxrbvmh3oq.hadoop-aipnlp.svc.cluster.local>

* fix the prefix_token_len bug (#12845)

* extend TorchAoTest::test_model_memory_usage to other platform (#12768)

* extend TorchAoTest::test_model_memory_usage to other platform

Signe-off-by: Wang, Yi <[email protected]>

* add some comments

Signed-off-by: Wang, Yi <[email protected]>

---------

Signed-off-by: Wang, Yi <[email protected]>

* Qwen Image Layered Support (#12853)

* [qwen-image] qwen image layered support

* [qwen-image] update doc

* [qwen-image] fix pr comments

* Apply style fixes

* make fix-copies

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* Z-Image-Turbo ControlNet (#12792)

* init

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Cosmos Predict2.5 Base: inference pipeline, scheduler & chkpt conversion (#12852)

* cosmos predict2.5 base: convert chkpt & pipeline
- New scheduler: scheduling_flow_unipc_multistep.py
- Changes to TransformerCosmos for text embeddings via crossattn_proj

* scheduler cleanup

* simplify inference pipeline

* cleanup scheduler + tests

* Basic tests for flow unipc

* working b2b inference

* Rename everything

* Tests for pipeline present, but not working (predict2 also not working)

* docstring update

* wrapper pipelines + make style

* remove unnecessary files

* UniPCMultistep: support use_karras_sigmas=True and use_flow_sigmas=True

* use UniPCMultistepScheduler + fix tests for pipeline

* Remove FlowUniPCMultistepScheduler

* UniPCMultistepScheduler for use_flow_sigmas=True & use_karras_sigmas=True

* num_inference_steps=36 due to bug in scheduler used by predict2.5

* Address comments

* make style + make fix-copies

* fix tests + remove references to old pipelines

* address comments

* add revision in from_pretrained call

* fix tests

* more update in modular  (#12560)

* move node registry to mellon

* up

* fix

* modula rpipeline update: filter out none for input_names, fix default blocks for pipe.init() and allow user pass additional kwargs_type in a dict

* qwen modular refactor, unpack before decode

* update mellon node config, adding* to required_inputs and required_model_inputs

* modularpipeline.from_pretrained: error out if no config found

* add a component_names property to modular blocks to be consistent!

* flux image_encoder -> vae_encoder

* controlnet_bundle

* refator MellonNodeConfig MellonPipelineConfig

* refactor & simplify mellon utils

* vae_image_encoder -> vae_encoder

* mellon config save keep key order

* style + copies

* add kwargs input for zimage

* Feature: Add Mambo-G Guidance as Guider (#12862)

* Feature: Add Mambo-G Guidance to Qwen-Image Pipeline

* change to guider implementation

* fix copied code residual

* Update src/diffusers/guiders/magnitude_aware_guidance.py

* Apply style fixes

---------

Co-authored-by: Pscgylotti <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Add `OvisImagePipeline` in `AUTO_TEXT2IMAGE_PIPELINES_MAPPING` (#12876)

* Cosmos Predict2.5 14b Conversion (#12863)

14b conversion

* Use `T5Tokenizer` instead of `MT5Tokenizer` (removed in Transformers v5.0+) (#12877)

Use `T5Tokenizer` instead of `MT5Tokenizer`

Given that the `MT5Tokenizer` in `transformers` is just a "re-export" of
`T5Tokenizer` as per
https://github.com/huggingface/transformers/blob/v4.57.3/src/transformers/models/mt5/tokenization_mt5.py
)on latest available stable Transformers i.e., v4.57.3), this commit
updates the imports to point to `T5Tokenizer` instead, so that those
still work with Transformers v5.0.0rc0 onwards.

* Add z-image-omni-base implementation (#12857)

* Add z-image-omni-base implementation

* Merged into one transformer for Z-Image.

* Fix bugs for controlnet after merging the main branch new feature.

* Fix for auto_pipeline, Add Styling.

* Refactor noise handling and modulation

- Add select_per_token function for per-token value selection
- Separate adaptive modulation logic
- Cleanify t_noisy/clean variable naming
- Move image_noise_mask handler from forward to pipeline

* Styling & Formatting.

* Rewrite code with more non-forward func & clean forward.

1.Change to one forward with shorter code with omni code (None).
2.Split out non-forward funcs: _build_unified_sequence, _prepare_sequence, patchify, pad.

* Styling & Formatting.

* Manual check fix-copies in controlnet, Add select_per_token, _patchify_image, _pad_with_ids; Styling.

* Add Import in pipeline __init__.py.

---------

Co-authored-by: Jerry Qilong Wu <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

* fix torchao quantizer for new torchao versions (#12901)

* fix torchao quantizer for new torchao versions

Summary:

`torchao==0.16.0` (not yet released) has some bc-breaking changes, this
PR fixes the diffusers repo with those changes. Specifics on the
changes:
1. `UInt4Tensor` is removed: https://github.com/pytorch/ao/pull/3536
2. old float8 tensors v1 are removed: https://github.com/pytorch/ao/pull/3510

In this PR:
1. move the logger variable up (not sure why it was in the middle of the
   file before) to get better error messages
2. gate the old torchao objects by torchao version

Test Plan:

import diffusers objects with new versions of torchao works:

```bash
> python -c "import torchao; print(torchao.__version__); from diffusers import StableDiffusionPipeline"
0.16.0.dev20251229+cu129
```

Reviewers:

Subscribers:

Tasks:

Tags:

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* fix Qwen Image Transformer single file loading mapping function to be consistent with other loader APIs (#12894)

fix Qwen single file loading to be consistent with other loader API

* Z-Image-Turbo from_single_file fix (#12888)

* chore: fix dev version in setup.py (#12904)

* Community Pipeline: Add z-image differential img2img (#12882)

* Community Pipeline: Add z-image differential img2img

* add pipeline for z-image differential img2img diffusion examples : run make style , make quality, and fix white spaces in example doc string.

---------

Co-authored-by: r4inm4ker <[email protected]>

* Fix typo in src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py (#12914)

* Fix wan 2.1 i2v context parallel (#12909)

* fix wan 2.1 i2v context parallel

* fix wan 2.1 i2v context parallel

* fix wan 2.1 i2v context parallel

* format

* fix the use of device_map in CP docs (#12902)

up

* [core] remove unneeded autoencoder methods when subclassing from `AutoencoderMixin` (#12873)

up

* Detect 2.0 vs 2.1 ZImageControlNetModel (#12861)

* Detect 2.0 vs 2.1 ZImageControlNetModel

* Possibility of control_noise_refiner being removed

* Refactor environment variable assignments in workflow (#12916)

* Add codeQL workflow (#12917)

Updated CodeQL workflow to use reusable workflow from Hugging Face and simplified language matrix.

* Delete .github/workflows/codeql.yml

* CodeQL workflow for security analysis

* Check for attention mask in backends that don't support it (#12892)

* check attention mask

* Apply style fixes

* bugfix

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* [Flux.1] improve pos embed for ascend npu by computing on npu (#12897)

* [Flux.1] improve pos embed for ascend npu by setting it back to npu computation.

* [Flux.2] improve pos embed for ascend npu by setting it back to npu computation.

* [LongCat-Image] improve pos embed for ascend npu by setting it back to npu computation.

* [Ovis-Image] improve pos embed for ascend npu by setting it back to npu computation.

* Remove unused import of is_torch_npu_available

---------

Co-authored-by: zhangtao <[email protected]>

* LTX Video 0.9.8  long multi prompt (#12614)

* LTX Video 0.9.8  long multi prompt

* Further align comfyui

- Added the “LTXEulerAncestralRFScheduler” scheduler, aligned with [sample_euler_ancestral_RF](https://github.com/comfyanonymous/ComfyUI/blob/7d6103325e1c97aa54f963253e3e7f1d6da6947f/comfy/k_diffusion/sampling.py#L234)

- Updated the LTXI2VLongMultiPromptPipeline.from_pretrained() method:
  - Now uses LTXEulerAncestralRFScheduler by default, for better compatibility with the ComfyUI LTXV workflow.

- Changed the default value of cond_strength from 1.0 to 0.5, aligning with ComfyUI’s default.

- Optimized cross-window overlap blending: moved the latent-space guidance injection to before the UNet and after each step, aligned with[KSamplerX0Inpaint]([ComfyUI/comfy/samplers.py at master · comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/samplers.py#L391))

- Adjusted the default value of skip_steps_sigma_threshold to 1.

* align with diffusers contribute rule

* Add new pipelines and update imports

* Enhance LTXI2VLongMultiPromptPipeline with noise rescaling

Refactor LTXI2VLongMultiPromptPipeline to improve documentation and add noise rescaling functionality.

* Clean up comments in scheduling_ltx_euler_ancestral_rf.py

Removed design notes and limitations from the implementation.

* Enhance video generation example with scheduler

Updated LTXI2VLongMultiPromptPipeline example to include LTXEulerAncestralRFScheduler for ComfyUI parity.

* clean up

* style

* copies

* import ltx scheduler

* copies

* fix

* fix more

* up up

* up up up

* up upup

* Apply suggestions from code review

* Update docs/source/en/api/pipelines/ltx_video.md

* Update docs/source/en/api/pipelines/ltx_video.md

---------

Co-authored-by: yiyixuxu <[email protected]>

* Add FSDP option for Flux2 (#12860)

* Add FSDP option for Flux2

* Apply style fixes

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Update examples/dreambooth/README_flux2.md

* guard accelerate import.

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Add transformer cache context for SkyReels-V2 pipelines & Update docs (#12837)

* feat: Add transformer cache context for conditional and unconditional predictions for skyreels-v2 pipes.

* docs: Remove SkyReels-V2 FLF2V model link and add contributor attribution.

* [docs] fix torchao typo. (#12883)

fix torchao typo.

* Update wan.md to remove unneeded hfoptions (#12890)

* Improve docstrings and type hints in scheduling_edm_euler.py (#12871)

* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters.

* refactor: Add type hints to DPM-Solver scheduler methods.

* [Modular] Video for Mellon (#12924)

num_frames and videos

* Add LTX 2.0 Video Pipelines (#12915)

* Initial LTX 2.0 transformer implementation

* Add tests for LTX 2 transformer model

* Get LTX 2 transformer tests working

* Rename LTX 2 compile test class to have LTX2

* Remove RoPE debug print statements

* Get LTX 2 transformer compile tests passing

* Fix LTX 2 transformer shape errors

* Initial script to convert LTX 2 transformer to diffusers

* Add more LTX 2 transformer audio arguments

* Allow LTX 2 transformer to be loaded from local path for conversion

* Improve dummy inputs and add test for LTX 2 transformer consistency

* Fix LTX 2 transformer bugs so consistency test passes

* Initial implementation of LTX 2.0 video VAE

* Explicitly specify temporal and spatial VAE scale factors when converting

* Add initial LTX 2.0 video VAE tests

* Add initial LTX 2.0 video VAE tests (part 2)

* Get diffusers implementation on par with official LTX 2.0 video VAE implementation

* Initial LTX 2.0 vocoder implementation

* Use RMSNorm implementation closer to original for LTX 2.0 video VAE

* start audio decoder.

* init registration.

* up

* simplify and clean up

* up

* Initial LTX 2.0 text encoder implementation

* Rough initial LTX 2.0 pipeline implementation

* up

* up

* up

* up

* Add imports for LTX 2.0 Audio VAE

* Conversion script for LTX 2.0 Audio VAE Decoder

* Add Audio VAE logic to T2V pipeline

* Duplicate scheduler for audio latents

* Support num_videos_per_prompt for prompt embeddings

* LTX 2.0 scheduler and full pipeline conversion

* Add script to test full LTX2Pipeline T2V inference

* Fix pipeline return bugs

* Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__

* Fix more bugs in LTX2Pipeline.__call__

* Improve CPU offload support

* Fix pipeline audio VAE decoding dtype bug

* Fix video shape error in full pipeline test script

* Get LTX 2 T2V pipeline to produce reasonable outputs

* Make LTX 2.0 scheduler more consistent with original code

* Fix typo when applying scheduler fix in T2V inference script

* Refactor Audio VAE to be simpler and remove helpers (#7)

* remove resolve causality axes stuff.

* remove a bunch of helpers.

* remove adjust output shape helper.

* remove the use of audiolatentshape.

* move normalization and patchify out of pipeline.

* fix

* up

* up

* Remove unpatchify and patchify ops before audio latents denormalization (#9)

---------

Co-authored-by: dg845 <[email protected]>

* Add support for I2V (#8)

* start i2v.

* up

* up

* up

* up

* up

* remove uniform strategy code.

* remove unneeded code.

* Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11)

* test i2v.

* Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <[email protected]>

* up (#19)

* address initial feedback from lightricks team (#16)

* cross_attn_timestep_scale_multiplier to 1000

* implement split rope type.

* up

* propagate rope_type to rope embed classes as well.

* up

* When using split RoPE, make sure that the output dtype is same as input dtype

* Fix apply split RoPE shape error when reshaping x to 4D

* Add export_utils file for exporting LTX 2.0 videos with audio

* Tests for T2V and I2V (#6)

* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <[email protected]>

* up

* use export util funcs.

* Point original checkpoint to LTX 2.0 official checkpoint

* Allow the I2V pipeline to accept image URLs

* make style and make quality

* remove function map.

* remove args.

* update docs.

* update doc entries.

* disable ltx2_consistency test

* Simplify LTX 2 RoPE forward by removing coords is None logic

* make style and make quality

* Support LTX 2.0 audio VAE encoder

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Remove print statement in audio VAE

* up

* Fix bug when calculating audio RoPE coords

* Ltx 2 latent upsample pipeline (#12922)

* Initial implementation of LTX 2.0 latent upsampling pipeline

* Add new LTX 2.0 spatial latent upsampler logic

* Add test script for LTX 2.0 latent upsampling

* Add option to enable VAE tiling in upsampling test script

* Get latent upsampler working with video latents

* Fix typo in BlurDownsample

* Add latent upsample pipeline docstring and example

* Remove deprecated pipeline VAE slicing/tiling methods

* make style and make quality

* When returning latents, return unpacked and denormalized latents for T2V and I2V

* Add model_cpu_offload_seq for latent upsampling pipeline

---------

Co-authored-by: Daniel Gu <[email protected]>

* Fix latent upsampler filename in LTX 2 conversion script

* Add latent upsample pipeline to LTX 2 docs

* Add dummy objects for LTX 2 latent upsample pipeline

* Set default FPS to official LTX 2 ckpt default of 24.0

* Set default CFG scale to official LTX 2 ckpt default of 4.0

* Update LTX 2 pipeline example docstrings

* make style and make quality

* Remove LTX 2 test scripts

* Fix LTX 2 upsample pipeline example docstring

* Add logic to convert and save a LTX 2 upsampling pipeline

* Document LTX2VideoTransformer3DModel forward pass

---------

Co-authored-by: sayakpaul <[email protected]>

* Add environment variables to checkout step (#12927)

* Improve docstrings and type hints in scheduling_consistency_decoder.py (#12928)

docs: improve docstring scheduling_consistency_decoder.py

* Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning (#12814)

* Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning

* Apply style fixes

* Fix: Remove import-time autocast in Kandinsky to prevent warnings

- Removed @torch.autocast decorator from Kandinsky classes.
- Implemented manual F.linear casting to ensure numerical parity with FP32.
- Verified bit-exact output matches main branch.

Co-authored-by: hlky <[email protected]>

* Used _keep_in_fp32_modules to align with standards

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: hlky <[email protected]>

* Upgrade GitHub Actions for Node 24 compatibility (#12865)

Signed-off-by: Salman Muin Kayser Chishti <[email protected]>

* fix the warning torch_dtype is deprecated (#12841)

* fix the warning torch_dtype is deprecated

* Add transformers version check (>= 4.56.0) for dtype parameter

* Fix linting errors

* [NPU] npu attention enable ulysses (#12919)

* npu attention enable ulysses

* clean the format

* register _native_npu_attention to _supports_context_parallel

Signed-off-by: yyt <[email protected]>

* change npu_fusion_attention's input_layout to BSND to eliminate redundant transpose

Signed-off-by: yyt <[email protected]>

* Update format

---------

Signed-off-by: yyt <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* Torchao floatx version guard (#12923)

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

---------

Co-authored-by: Sayak Paul <[email protected]>

* Bugfix for dreambooth flux2 img2img2 (#12825)

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>

* [Modular] qwen refactor (#12872)

* 3 files

* add conditoinal pipeline

* refactor qwen modular

* add layered

* up up

* u p

* add to import

* more refacotr, make layer work

* clean up a bit git add src

* more

* style

* style

* [modular] Tests for custom blocks in modular diffusers (#12557)

* start custom block testing.

* simplify modular workflow ci.

* up

* style.

* up

* up

* up

* up

* up

* up

* Apply suggestions from code review

* up

* [chore] remove controlnet implementations outside controlnet module. (#12152)

* remove controlnet implementations outside controlnet module.

* fix

* fix

* fix

* [core] Handle progress bar and logging in distributed environments (#12806)

* disable progressbar in distributed.

* up

* up

* up

* up

* up

* up

* Improve docstrings and type hints in scheduling_consistency_models.py (#12931)

docs: improve docstring scheduling_consistency_models.py

* [Feature] MultiControlNet support for SD3Impainting (#11251)

* update

* update

* addressed PR comments

* update

* Apply suggestions from code review

---------

Co-authored-by: YiYi Xu <[email protected]>

* Laplace Scheduler for DDPM (#11320)

* Add Laplace scheduler that samples more around mid-range noise levels (around log SNR=0), increasing performance (lower FID) with faster convergence speed, and robust to resolution and objective. Reference:  https://arxiv.org/pdf/2407.03297.

* Fix copies.

* Apply style fixes

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Store vae.config.scaling_factor to prevent missing attr reference (sdxl advanced dreambooth training script) (#12346)

Store vae.config.scaling_factor to prevent missing attr reference

In sdxl advanced dreambooth training script

vae.config.scaling_factor becomes inaccessible after: del vae

when: --cache_latents, and no --validation_prompt

Co-authored-by: Teriks <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

* Add thread-safe wrappers for components in pipeline (examples/server-async/utils/requestscopedpipeline.py) (#12515)

* Basic implementation of request scheduling

* Basic editing in SD and Flux Pipelines

* Small Fix

* Fix

* Update for more pipelines

* Add examples/server-async

* Add examples/server-async

* Updated RequestScopedPipeline to handle a single tokenizer lock to avoid race conditions

* Fix

* Fix _TokenizerLockWrapper

* Fix _TokenizerLockWrapper

* Delete _TokenizerLockWrapper

* Fix tokenizer

* Update examples/server-async

* Fix server-async

* Optimizations in examples/server-async

* We keep the implementation simple in examples/server-async

* Update examples/server-async/README.md

* Update examples/server-async/README.md for changes to tokenizer locks and backward-compatible retrieve_timesteps

* The changes to the diffusers core have been undone and all logic is being moved to exmaples/server-async

* Update examples/server-async/utils/*

* Fix BaseAsyncScheduler

* Rollback in the core of the diffusers

* Update examples/server-async/README.md

* Complete rollback of diffusers core files

* Simple implementation of an asynchronous server compatible with SD3-3.5 and Flux Pipelines

* Update examples/server-async/README.md

* Fixed import errors in 'examples/server-async/serverasync.py'

* Flux Pipeline Discard

* Update examples/server-async/README.md

* Apply style fixes

* Add thread-safe wrappers for components in pipeline

Refactor requestscopedpipeline.py to add thread-safe wrappers for tokenizer, VAE, and image processor. Introduce locking mechanisms to ensure thread safety during concurrent access.

* Add wrappers.py

* Apply style fixes

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [Research] Latent Perceptual Loss (LPL) for Stable Diffusion XL (#11573)

* initial

* added readme

* fix formatting

* added logging

* formatting

* use config

* debug

* better

* handle SNR

* floats have no item()

* remove debug

* formatting

* add paper link

* acknowledge reference source

* rename script

---------

Co-authored-by: Sayak Paul <[email protected]>

* Change timestep device to cpu for xla (#11501)

* Change timestep device to cpu for xla

* Add all pipelines

* ruff format

* Apply style fixes

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [LoRA] add lora_alpha to sana README (#11780)

add lora alpha to readme

* Fix wrong param types, docs, and handles noise=None in scale_noise of FlowMatching schedulers (#11669)

* Bug: Fix wrong params, docs, and handles noise=None

* make noise a required arg

---------

Co-authored-by: YiYi Xu <[email protected]>

* [docs] Remote inference (#12372)

* init

* fix

* Align HunyuanVideoConditionEmbedding with CombinedTimestepGuidanceTextProjEmbeddings (#12316)

conditioning additions inline with CombinedTimestepGuidanceTextProjEmbeddings

Co-authored-by: Samu Tamminen <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

* [Fix] syntax in QwenImageEditPlusPipeline (#12371)

* Fixes syntax for consistency among pipelines

* Update test_qwenimage_edit_plus.py

* Fix ftfy name error in Wan pipeline (#12314)

Signed-off-by: Daniel Socek <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

* [modular] error early in `enable_auto_cpu_offload` (#12578)

error early in auto_cpu_offload

* [ChronoEdit] support multiple loras (#12679)

Co-authored-by: wjay <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>

* fix how `is_fsdp` is determined (#12960)

up

* [LoRA] add LoRA support to LTX-2 (#12933)

* up

* fixes

* tests

* docs.

* fix

* change loading info.

* up

* up

* Fix: typo in autoencoder_dc.py (#12687)

Fix typo in autoencoder_dc.py

Fixing typo in `get_block` function's parameter name:
"qkv_mutliscales" -> "qkv_multiscales"

Co-authored-by: YiYi Xu <[email protected]>

* [Modular] better docstring (#12932)

add output to auto blocks + core denoising block for better doc string

* [docs] polish caching docs. (#12684)

* polish caching docs.

* Update docs/source/en/optimization/cache.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/optimization/cache.md

Co-authored-by: Steven Liu <[email protected]>

* up

---------

Co-authored-by: Steven Liu <[email protected]>

* Fix typos (#12705)

* Fix link to diffedit implementation reference (#12708)

* Fix QwenImage txt_seq_lens handling (#12702)

* Fix QwenImage txt_seq_lens handling

* formatting

* formatting

* remove txt_seq_lens and use bool  mask

* use compute_text_seq_len_from_mask

* add seq_lens to dispatch_attention_fn

* use joint_seq_lens

* remove unused index_block

* WIP: Remove seq_lens parameter and use mask-based approach

- Remove seq_lens parameter from dispatch_attention_fn
- Update varlen backends to extract seqlens from masks
- Update QwenImage to pass 2D joint_attention_mask
- Fix native backend to handle 2D boolean masks
- Fix sage_varlen seqlens_q to match seqlens_k for self-attention

Note: sage_varlen still producing black images, needs further investigation

* fix formatting

* undo sage changes

* xformers support

* hub fix

* fix torch compile issues

* fix tests

* use _prepare_attn_mask_native

* proper deprecation notice

* add deprecate to txt_seq_lens

* Update src/diffusers/models/transformers/transformer_qwenimage.py

Co-authored-by: YiYi Xu <[email protected]>

* Update src/diffusers/models/transformers/transformer_qwenimage.py

Co-authored-by: YiYi Xu <[email protected]>

* Only create the mask if there's actual padding

* fix order of docstrings

* Adds performance benchmarks and optimization details for QwenImage

Enhances documentation with comprehensive performance insights for QwenImage pipeline:

* rope_text_seq_len = text_seq_len

* rename to max_txt_seq_len

* removed deprecated args

* undo unrelated change

* Updates QwenImage performance documentation

Removes detailed attention backend benchmarks and simplifies torch.compile performance description

Focuses on key performance improvement with torch.compile, highlighting the specific speedup from 4.70s to 1.93s on an A100 GPU

Streamlines the documentation to provide more concise and actionable performance insights

* Updates deprecation warnings for txt_seq_lens parameter

Extends deprecation timeline for txt_seq_lens from version 0.37.0 to 0.39.0 across multiple Qwen image-related models

Adds a new unit test to verify the deprecation warning behavior for the txt_seq_lens parameter

* fix compile

* formatting

* fix compile tests

* rename helper

* remove duplicate

* smaller values

* removed

* use torch.cond for torch compile

* Construct joint attention mask once

* test different backends

* construct joint attention mask once to avoid reconstructing in every block

* Update src/diffusers/models/attention_dispatch.py

Co-authored-by: YiYi Xu <[email protected]>

* formatting

* raising an error from the EditPlus pipeline when batch_size > 1

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: cdutr <[email protected]>

* Bugfix for flux2 img2img2 prediction (#12855)

* Bugfix for dreambooth flux2 img2img2

* Bugfix for dreambooth flux2 img2img2

* Bugfix for dreambooth flux2 img2img2

* Bugfix for dreambooth flux2 img2img2

* Bugfix for dreambooth flux2 img2img2

* Bugfix for dreambooth flux2 img2img2

Co-authored-by: tcaimm <[email protected]>

---------

Co-authored-by: tcaimm <[email protected]>

* Add Flag to `PeftLoraLoaderMixinTests` to Enable/Disable Text Encoder LoRA Tests (#12962)

* Improve incorrect LoRA format error message

* Add flag in PeftLoraLoaderMixinTests to disable text encoder LoRA tests

* Apply changes to LTX2LoraTests

* Further improve incorrect LoRA format error msg following review

---------

Co-authored-by: Sayak Paul <[email protected]>

* Add Unified Sequence Parallel attention (#12693)

* initial scheme of unified-sp

* initial all_to_all_double

* bug fixes, added cmnts

* unified attention prototype done

* remove raising value error in contextParallelConfig to enable unified attention

* bug fix

* feat: Adds Test for Unified SP Attention and Fixes a bug in Template Ring Attention

* bug fix, lse calculation, testing

bug fixes, lse calculation

-

switched to _all_to_all_single helper in _all_to_all_dim_exchange due contiguity issues

bug fix

bug fix

bug fix

* addressing comments

* sequence parallelsim bug fixes

* code format fixes

* Apply style fixes

* code formatting fix

* added unified attention docs and removed test file

* Apply style fixes

* tip for unified attention in docs at distributed_inference.md

Co-authored-by: Sayak Paul <[email protected]>

* Update distributed_inference.md, adding benchmarks

Co-authored-by: Sayak Paul <[email protected]>

* Update docs/source/en/training/distributed_inference.md

Co-authored-by: Sayak Paul <[email protected]>

* function name fix

* fixed benchmark in docs

---------

Co-authored-by: KarthikSundar2002 <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* [Modular] Changes for using WAN I2V (#12959)

* initial

* add kayers

* Z rz rz rz rz rz rz r cogview (#12973)

* init

* add

* add 1

* Update __init__.py

* rename

* 2

* update

* init with encoder

* merge2pipeline

* Update pipeline_glm_image.py

* remove sop

* remove useless func

* Update pipeline_glm_image.py

* up

(cherry picked from commit cfe19a31b9bc14c090c7259c09f3532dfafcd059)

* review for work only

* change place

* Update pipeline_glm_image.py

* update

* Update transformer_glm_image.py

* 1

* no  negative_prompt for GLM-Image

* remove CogView4LoraLoaderMixin

* refactor attention processor.

* update

* fix

* use staticmethod

* update

* up

* up

* update

* Update glm_image.md

* 1

* Update pipeline_glm_image.py

* Update transformer_glm_image.py

* using new transformers impl

* support

* resolution change

* fix-copies

* Update src/diffusers/pipelines/glm_image/pipeline_glm_image.py

Co-authored-by: YiYi Xu <[email protected]>

* Update pipeline_glm_image.py

* use cogview4

* Update pipeline_glm_image.py

* Update pipeline_glm_image.py

* revert

* update

* batch support

* update

* version guard glm image pipeline

* validate prompt_embeds and prior_token_ids

* try docs.

* 4

* up

* up

* skip properly

* fix tests

* up

* up

---------

Co-authored-by: zRzRzRzRzRzRzR <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>

* Update distributed_inference.md to reposition sections (#12971)

* [chore] make transformers version check stricter for glm image. (#12974)

* make transformers version check stricter for glm image.

* public checkpoint.

* Remove 8bit device restriction (#12972)

* allow to

* update version

* fix version again

* again

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Copilot <[email protected]>

* style

* xfail

* add pr

---------

Co-authored-by: Copilot <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* `disable_mmap` in pipeline `from_pretrained` (#12854)

* update

* `disable_mmap` in `from_pretrained`

---------

Co-authored-by: DN6 <[email protected]>

* [Modular] mellon utils (#12978)

* up

* style

---------

Co-authored-by: [email protected] <[email protected]>

* LongCat Image pipeline: Allow offloading/quantization of text_encoder component (#12963)

* Don't attempt to move the text_encoder. Just move the generated_ids.

* The inputs to the text_encoder should be on its device

* Add `ChromaInpaintPipeline` (#12848)

* Add `ChromaInpaintPipeline`

* Set `attention_mask` to `dtype=torch.bool` for `ChromaInpaintPipeline`.

* Revert `.gitignore`.

* fix Qwen-Image series context parallel (#12970)

* fix qwen-image cp

* relax attn_mask limit for cp

* CP plan compatible with zero_cond_t

* move modulate_index plan to top level

* Flux2 klein (#12982)

* flux2-klein

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* Klein tests (#2)

* tests

* up

* tests

* up

* support step-distilled

* Apply suggestions from code review

Co-authored-by: dg845 <[email protected]>

* Apply suggestions from code review

Co-authored-by: dg845 <[email protected]>

* doc string etc

* style

* more

* copies

* klein lora training scripts (#3)

* initial commit

* initial commit

* remove remote text encoder

* initial commit

* initial commit

* initial commit

* revert

* img2img fix

* text encoder + tokenizer

* text encoder + tokenizer

* update readme

* guidance

* guidance

* guidance

* test

* test

* revert changes not needed for the non klein model

* Update examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Co-authored-by: Sayak Paul <[email protected]>

* fix guidance

* fix validation

* fix validation

* fix validation

* fix path

* space

---------

Co-authored-by: Sayak Paul <[email protected]>

* style

* Update src/diffusers/pipelines/flux2/pipeline_flux2_klein.py

* Apply style fixes

* auto pipeline

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Linoy Tsaban <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [modular] fix a bug in mellon param & improve docstrings (#12980)

* update mellonparams docstring to incude the acutal param definition render in mellon

* style

---------

Co-authored-by: [email protected] <[email protected]>

* add klein docs. (#12984)

* LTX 2 Single File Support (#12983)

* LTX 2 transformer single file support

* LTX 2 video VAE single file support

* LTX 2 audio VAE single file support

* Make it easier to distinguish LTX 1 and 2 models

* [core] gracefully error out when attn-backend x cp combo isn't supported. (#12832)

* gracefully error out when attn-backend x cp combo isn't supported.

* Revert "gracefully error out when attn-backend x cp combo isn't supported."

This reverts commit c8abb5d7c01ca6a7c0bf82c27c91a326155a5e43.

* gracefully error out when attn-backend x cp combo isn't supported.

* up

* address PR feedback.

* up

* Update src/diffusers/models/modeling_utils.py

Co-authored-by: Dhruv Nair <[email protected]>

* dot.

---------

Co-authored-by: Dhruv Nair <[email protected]>

* Improve docstrings and type hints in scheduling_cosine_dpmsolver_multistep.py (#12936)

* docs: improve docstring scheduling_cosine_dpmsolver_multistep.py

* Update src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py

Co-authored-by: Steven Liu <[email protected]>

* fix

---------

Co-authored-by: Steven Liu <[email protected]>

* [Docs] Replace root CONTRIBUTING.md with symlink to source docs (#12986)

Chore: Replace CONTRIBUTING.md with a symlink to documentation

* make style && make quality

* Revert "make style && make quality"

This reverts commit 76f51a5e927b369b5f708583e8f18b49c350f52a.

* [chore] make style to push new changes. (#12998)

make style to push new changes.

* Fibo edit pipeline (#12930)

* Feature: Add BriaFiboEditPipeline to diffusers

* Introduced BriaFiboEditPipeline class with necessary backend requirements.
* Updated import structures in relevant modules to include BriaFiboEditPipeline.
* Ensured compatibility with existing pipelines and type checking.

* Feature: Introduce Bria Fibo Edit Pipeline

* Added BriaFiboEditPipeline class for structured JSON-native image editing.
* Created documentation for the new pipeline in bria_fibo_edit.md.
* Updated import structures to include the new pipeline and its components.
* Added unit tests for the BriaFiboEditPipeline to ensure functionality and correctness.

* Enhancement: Update Bria Fibo Edit Pipeline and Documentation

* Refined the Bria Fibo Edit model description for clarity and detail.
* Added usage instructions for model authentication and login.
* Implemented mask handling functions in the BriaFiboEditPipeline for improved image editing capabilities.
* Updated unit tests to cover new mask functionalities and ensure input validation.
* Adjusted example code in documentation to reflect changes in the pipeline's usage.

* Update Bria Fibo Edit documentation with corrected Hugging Face page link

* add dreambooth training script

* style and quality

* Delete temp.py

* Enhancement: Improve JSON caption validation in DreamBoothDataset

* Updated the clean_json_caption function to handle both string and dictionary inputs for captions.
* Added error handling to raise a ValueError for invalid caption types, ensuring better input validation.

* Add datasets dependency to requirements_fibo_edit.txt

* Add bria_fibo_edit to docs table of contents

* Fix dummy objects ordering

* Fix BriaFiboEditPipeline to use passed generator parameter

The pipeline was ignoring the generator parameter and only using
the seed parameter. This caused non-deterministic outputs in tests
that pass a seeded generator.

* Remove fibo_edit training script and related files

---------

Co-authored-by: kfirbria <[email protected]>

* Fix variable name in docstring for PeftAdapterMixin.set_adapters (#13003)

Co-authored-by: Sayak Paul <[email protected]>

* Improve docstrings and type hints in scheduling_ddim_cogvideox.py (#12992)

docs: improve docstring scheduling_ddim_cogvideox.py

* [scheduler] Support custom sigmas in UniPCMultistepScheduler (#12109)

* update

* fix tests

* Apply suggestions from code review

* Revert default flow sigmas change so that tests relying on UniPC multistep still pass

* Remove custom timesteps for UniPC multistep set_timesteps

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: Daniel Gu <[email protected]>
Co-authored-by: dg845 <[email protected]>

* feat: accelerate longcat-image with regional compile (#13019)

* Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* docs: improve docstring scheduling_ddim_flax.py

* Improve docstrings and type hints in scheduling_ddim_inverse.py (#13020)

docs: improve docstring scheduling_ddim_inverse.py

* fix Dockerfiles for cuda and xformers. (#13022)

* Resnet only use contiguous in training mode. (#12977)

* fix contiguous

Signed-off-by: jiqing-feng <[email protected]>

* update tol

Signed-off-by: jiqing-feng <[email protected]>

* bigger tol

Signed-off-by: jiqing-feng <[email protected]>

* fix tests

Signed-off-by: jiqing-feng <[email protected]>

* update tol

Signed-off-by: jiqing-feng <[email protected]>

---------

Signed-off-by: jiqing-feng <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* feat: add qkv projection fuse for longcat transformers (#13021)

feat: add qkv fuse for longcat transformers

Co-authored-by: Sayak Paul <[email protected]>

* Improve docstrings and type hints in scheduling_ddim_parallel.py (#13023)

* docs: improve docstring scheduling_ddim_parallel.py

* docs: improve docstring scheduling_ddim_parallel.py

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <[email protected]>

* Update src/diffusers/schedulers/scheduling_ddim_parallel.py

Co-authored-by: Steven Liu <[email protected]>

* fix style

---------

Co-authored-by: Steven Liu <[email protected]>

* Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)

docs: improve docstring scheduling_ddpm_flax.py

* Improve docstrings and type hints in scheduling_ddpm_parallel.py (#13027)

* docs: improve docstring scheduling_ddpm_parallel.py

* Update scheduling_ddpm_parallel.py

Co-authored-by: Steven Liu <[email protected]>

---------

Co-authored-by: Steven Liu <[email protected]>

* Remove `*pooled_*` mentions from Chroma inpaint (#13026)

Remove `*pooled_*` mentions from Chroma as it has just one TE.

* Flag Flax schedulers as deprecated (#13031)

flag flax schedulers as deprecated

* [modular] add auto_docstring & more doc related refactors  (#12958)

* up

* up up

* update outputs

* style

* add modular_auto_docstring!

* more auto docstring

* style

* up up up

* more more

* up

* address feedbacks

* add TODO in the description for empty docstring

* refactor based on dhruv's feedback: remove the class method

* add template method

* up

* up up up

* apply auto docstring

* make style

* rmove space in make docstring

* Apply suggestions from code review

* revert change in z

* fix

* Apply style fixes

* include auto-docstring check in the modular ci. (#13004)

* Run ruff format after auto docstring generation

* up

* upup

* upup

* style

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* Upgrade GitHub Actions to latest versions (#12866)

* Upgrade GitHub Actions to latest versions

Signed-off-by: Salman Muin Kayser Chishti <[email protected]>

* fix: Correct GitHub Actions upgrade (fix branch refs and version formats)

* fix: Correct GitHub Actions upgrade (fix branch refs and version formats)

* fix: Correct GitHub Actions upgrade (fix branch refs and version formats)

---------

Signed-off-by: Salman Muin Kayser Chishti <[email protected]>

* [From Single File] support `from_single_file` method for `WanAnimateTransformer3DModel` (#12691)

* Add `WanAnimateTransformer3DModel` to `SINGLE_FILE_LOADABLE_CLASSES`

* Fixed dtype mismatch when loading a single file

* Fixed a bug that results in white noise for generation

* Update dtype check for time embedder - caused white noise output

* Improve code readability

* Optimize dtype handling

Removed unnecessary dtype conversions for timestep and weight.

* Apply style fixes

* Refactor time embedding dtype handling

Adjust time embedding type conversion for compatibility.

* Apply style fixes

* Modify comment for WanTimeTextImageEmbedding class

---------

Co-authored-by: Sam Edwards <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* Fix: Cosmos2.5 Video2World frame extraction and add default negative prompt (#13018)

* fix: Extract last frames for conditioning in Cosmos Video2World

* Added default negative prompt

* Apply style fixes

* Added default negative prompt in cosmos2 text2image pipeline

---------

Co-authored-by: YiYi Xu <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [GLM-Image] Add batch support for GlmImagePipeline (#13007)

* init

Signed-off-by: JaredforReal <[email protected]>

* change from right padding to left padding

Signed-off-by: JaredforReal <[email protected]>

* try i2i batch

Signed-off-by: JaredforReal <[email protected]>

* fix: revert i2i prior_token_image_ids to original 1D tensor format

* refactor KVCache for per prompt batching

Signed-off-by: JaredforReal <[email protected]>

* fix KVCache

Signed-off-by: JaredforReal <[email protected]>

* fix shape error

Signed-off-by: JaredforReal <[email protected]>

* refactor pipeline

Signed-off-by: JaredforReal <[email protected]>

* fix for left padding

Signed-off-by: JaredforReal <[email protected]>

* insert seed to AR model

Signed-off-by: JaredforReal <[email protected]>

* delete generator, use torch manual_seed

Signed-off-by: JaredforReal <[email protected]>

* add batch processing unit tests for GlmImagePipeline

Signed-off-by: JaredforReal <[email protected]>

* simplify normalize images method

Signed-off-by: JaredforReal <[email protected]>

* fix grids_per_sample

Signed-off-by: JaredforReal <[email protected]>

* fix t2i

Signed-off-by: JaredforReal <[email protected]>

* delete comments, simplify condition statement

Signed-off-by: JaredforReal <[email protected]>

* chage generate_prior_tokens outputs

Signed-off-by: JaredforReal <[email protected]>

* simplify if logic

Signed-off-by: JaredforReal <[email protected]>

* support user provided prior_token_ids directly

Signed-off-by: JaredforReal <[email protected]>

* remove blank lines

Signed-off-by: JaredforReal <[email protected]>

* align with transformers

Signed-off-by: JaredforReal <[email protected]>

* Apply style fixes

---------

Signed-off-by: JaredforReal <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [Qwen] avoid creating attention masks when there is no padding (#12987)

* avoid creating attention masks when there is no padding

* make fix-copies

* torch compile tests

* set all ones mask to none

* fix positional encoding from becoming > 4096

* fix from review

* slice freqs_cis to match the input sequence length

* keep only attenton masking change

---------

Co-authored-by: Sayak Paul <[email protected]>

* [modular]support klein (#13002)

* support klein

* style

* copies

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Álvaro Somoza <[email protected]>

* Update src/diffusers/modular_pipelines/flux2/encoders.py

* a few fix: unpack latents before decoder etc

* style

* remove guidannce to its own block

* style

* flux2-dev work in modular setting

* up

* up up

* add tests

---------

Co-authored-by: [email protected] <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Álvaro Somoza <[email protected]>

* [QwenImage] fix prompt isolation tests (#13042)

* up

* up

* up

* fix

* fast tok update (#13036)

* v5 tok update

* ruff

* keep pre v5 slow code path

* Apply style fixes

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* change to CUDA 12.9. (#13045)

* change to CUDA 12.9.

* up

* change runtime base

* FROM

* remove torchao autoquant from diffusers docs (#13048)

Summary:

Context: https://github.com/pytorch/ao/issues/3739

Test Plan: CI, since this does not change any Python code

* docs: improve docstring scheduling_dpm_cogvideox.py (#13044)

* Fix Wan/WanI2V patchification (#13038)

* Fix Wan/WanI2V patchification

* Apply style fixes

* Apply suggestions from code review

I agree with you for the idea of using `patch_size` instead. Thanks!😊

Co-authored-by: dg845 <[email protected]>

* Fix logger warning

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: dg845 <[email protected]>

* LTX2 distilled checkpoint support (#12934)

* add constants for distill sigmas values and allow ltx pipeline to pass in sigmas

* add time conditioning conversion and token packing for latents

* make style & quality

* remove prenorm

* add sigma param to ltx2 i2v

* fix copies and add pack latents to i2v

* Apply suggestions from code review

Co-authored-by: dg845 <[email protected]>

* Infer latent dims if latents/audio_latents is supplied

* add note for predefined sigmas

* run make style and quality

* revert distill timesteps & set original_state_dict_repo_idd to default None

* add latent normalize

* add create noised state, delete last sigmas

* remove normalize step in latent upsample pipeline and move it to ltx2 pipeline

* add create noise latent to i2v pipeline

* fix copies

* parse none value in weight conversion script

* explicit shape handling

* Apply suggestions from code review

Co-authored-by: dg845 <[email protected]>

* make style

* add two stage inference tests

* add ltx2 documentation

* update i2v expected_audio_slice

* Apply suggestions from code review

Co-authored-by: dg845 <[email protected]>

* Apply suggestion from @dg845

Co-authored-by: dg845 <[email protected]>

* Update ltx2.md to remove one-stage example

Removed one-stage generation example code and added comments for noise scale in two-stage generation.

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: dg845 <[email protected]>
Co-authored-by: Daniel Gu <[email protected]>

* [wan] fix layerwise upcasting tests on CPU (#13039)

up

* [ci] uniform run times and wheels for pytorch cuda. (#13047)

* uniform run times and wheels for pytorch cuda.

* 12.9

* change to 24.04.

* change to 24.04.

* docs: fix grammar in fp16_safetensors CLI warning (#13040)

* docs: fix grammar in fp16_safetensors CLI warning

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>

* [wan] fix wan 2.2 when either of the transformers isn't present. (#13055)

fix wan 2.2 when either of the transformers isn't present.

* [bug fix] GLM-Image fit new `get_image_features` API (#13052)

change get_image_features API

Signed-off-by: JaredforReal <[email protected]>
Co-authored-by: YiYi Xu <[email protected]>

* Fix aiter availability check (#13059)

Update import_utils.py

* [Modular]add a real quick start guide (#13029)

* add a real quick start guide

* Update docs/source/en/modular_diffusers/quickstart.md

* update a bit more

* fix

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/modular_diffusers/quickstart.md

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/modular_diffusers/quickstart.md

Co-authored-by: Steven Liu <[email protected]>

* update more

* Apply suggestions from code review

Co-authored-by: Sayak Paul <[email protected]>

* address more feedbacks: move components amnager earlier, explain blocks vs sub-blocks etc

* more

* remove the link to mellon guide, not exist in this PR yet

---------

Co-authored-by: Steven Liu <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>

* feat: support Ulysses Anything Attention (#12996)

* feat: support Ulysses Anything Attention

* feat: support Ulysses Anything Attention

* feat: support Ulysses Anything Attention

* feat: support Ulysses Anything Attention

* fix UAA broken while using joint attn

* update

* post check

* add docs

* add docs

* remove lru cache

* move codes

* update

* Refactor Model Tests (#12822)

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <[email protected]>

* [Flux2] Fix LoRA loading for Flux2 Klein by adaptively enumerating transformer blocks (#13030)

* Resolve Flux2 Klein 4B/9B LoRA loading errors

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <[email protected]>

* [Modular] loader related (#13025)

* tag loader_id from Automodel

* style

* load_components by default only load components that are not already loaded

* by default, skip loading the componeneets does not have the repo id

* [Modular] mellon doc etc (#13051)

* add metadata field to input/output param

* refactor mellonparam: move the template outside, add metaclass, define some generic template for custom node

* add from_custom_block

* style

* up up fix

* add mellon guide

* add to toctree

* style

* add mellon_types

* style

* mellon_type -> inpnt_types + output_types

* update doc

* add quant info to components manager

* fix more

* up up

* fix components manager

* update custom block guide

* update

* style

* add a warn for mellon and add new guides to overview

* Apply suggestions from code review

Co-authored-by: Steven Liu <[email protected]>

* Update docs/source/en/modular_diffusers/mellon.md

Co-authored-by: Ste…
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants