Skip to content

Commit 16bb505

Browse files
takuma104patrickvonplatenpatil-suraj
authored
xFormers attention op arg (huggingface#2049)
* allow passing op to xFormers attention original code by @patil-suraj huggingface/diffusers@ae0cc0b * correct style by `make style` * add attention_op arg documents * add usage example to docstring Co-authored-by: Patrick von Platen <[email protected]> * add usage example to docstring Co-authored-by: Patrick von Platen <[email protected]> * code style correction by `make style` * Update docstring code to a valid python example Co-authored-by: Suraj Patil <[email protected]> * Update docstring code to a valid python example Co-authored-by: Suraj Patil <[email protected]> * style correction by `make style` * Update code exmaple to fully functional Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent 7533e3d commit 16bb505

File tree

4 files changed

+73
-16
lines changed

4 files changed

+73
-16
lines changed

src/diffusers/models/attention.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Optional
15+
from typing import Callable, Optional
1616

1717
import torch
1818
import torch.nn.functional as F
@@ -72,6 +72,7 @@ def __init__(
7272
self.proj_attn = nn.Linear(channels, channels, 1)
7373

7474
self._use_memory_efficient_attention_xformers = False
75+
self._attention_op = None
7576

7677
def reshape_heads_to_batch_dim(self, tensor):
7778
batch_size, seq_len, dim = tensor.shape
@@ -87,7 +88,9 @@ def reshape_batch_dim_to_heads(self, tensor):
8788
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
8889
return tensor
8990

90-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
91+
def set_use_memory_efficient_attention_xformers(
92+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
93+
):
9194
if use_memory_efficient_attention_xformers:
9295
if not is_xformers_available():
9396
raise ModuleNotFoundError(
@@ -113,6 +116,7 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
113116
except Exception as e:
114117
raise e
115118
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
119+
self._attention_op = attention_op
116120

117121
def forward(self, hidden_states):
118122
residual = hidden_states
@@ -136,7 +140,9 @@ def forward(self, hidden_states):
136140

137141
if self._use_memory_efficient_attention_xformers:
138142
# Memory efficient attention
139-
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
143+
hidden_states = xformers.ops.memory_efficient_attention(
144+
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
145+
)
140146
hidden_states = hidden_states.to(query_proj.dtype)
141147
else:
142148
attention_scores = torch.baddbmm(

src/diffusers/models/cross_attention.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, Union
14+
from typing import Callable, Optional, Union
1515

1616
import torch
1717
import torch.nn.functional as F
@@ -93,7 +93,9 @@ def __init__(
9393
processor = processor if processor is not None else CrossAttnProcessor()
9494
self.set_processor(processor)
9595

96-
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
96+
def set_use_memory_efficient_attention_xformers(
97+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
98+
):
9799
if use_memory_efficient_attention_xformers:
98100
if self.added_kv_proj_dim is not None:
99101
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@@ -127,7 +129,7 @@ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_atten
127129
except Exception as e:
128130
raise e
129131

130-
processor = XFormersCrossAttnProcessor()
132+
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
131133
else:
132134
processor = CrossAttnProcessor()
133135

@@ -351,6 +353,9 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
351353

352354

353355
class XFormersCrossAttnProcessor:
356+
def __init__(self, attention_op: Optional[Callable] = None):
357+
self.attention_op = attention_op
358+
354359
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
355360
batch_size, sequence_length, _ = hidden_states.shape
356361

@@ -366,7 +371,9 @@ def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=No
366371
key = attn.head_to_batch_dim(key).contiguous()
367372
value = attn.head_to_batch_dim(value).contiguous()
368373

369-
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
374+
hidden_states = xformers.ops.memory_efficient_attention(
375+
query, key, value, attn_bias=attention_mask, op=self.attention_op
376+
)
370377
hidden_states = hidden_states.to(query.dtype)
371378
hidden_states = attn.batch_to_head_dim(hidden_states)
372379

src/diffusers/models/modeling_utils.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,15 @@ def disable_gradient_checkpointing(self):
190190
if self._supports_gradient_checkpointing:
191191
self.apply(partial(self._set_gradient_checkpointing, value=False))
192192

193-
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
193+
def set_use_memory_efficient_attention_xformers(
194+
self, valid: bool, attention_op: Optional[Callable] = None
195+
) -> None:
194196
# Recursively walk through all the children.
195197
# Any children which exposes the set_use_memory_efficient_attention_xformers method
196198
# gets the message
197199
def fn_recursive_set_mem_eff(module: torch.nn.Module):
198200
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
199-
module.set_use_memory_efficient_attention_xformers(valid)
201+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
200202

201203
for child in module.children():
202204
fn_recursive_set_mem_eff(child)
@@ -205,7 +207,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
205207
if isinstance(module, torch.nn.Module):
206208
fn_recursive_set_mem_eff(module)
207209

208-
def enable_xformers_memory_efficient_attention(self):
210+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
209211
r"""
210212
Enable memory efficient attention as implemented in xformers.
211213
@@ -214,8 +216,28 @@ def enable_xformers_memory_efficient_attention(self):
214216
215217
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
216218
is used.
219+
220+
Parameters:
221+
attention_op (`Callable`, *optional*):
222+
Override the default `None` operator for use as `op` argument to the
223+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
224+
function of xFormers.
225+
226+
Examples:
227+
228+
```py
229+
>>> import torch
230+
>>> from diffusers import UNet2DConditionModel
231+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
232+
233+
>>> model = UNet2DConditionModel.from_pretrained(
234+
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
235+
... )
236+
>>> model = model.to("cuda")
237+
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
238+
```
217239
"""
218-
self.set_use_memory_efficient_attention_xformers(True)
240+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
219241

220242
def disable_xformers_memory_efficient_attention(self):
221243
r"""

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import os
2020
from dataclasses import dataclass
2121
from pathlib import Path
22-
from typing import Any, Dict, List, Optional, Union
22+
from typing import Any, Callable, Dict, List, Optional, Union
2323

2424
import numpy as np
2525
import torch
@@ -842,7 +842,7 @@ def progress_bar(self, iterable=None, total=None):
842842
def set_progress_bar_config(self, **kwargs):
843843
self._progress_bar_config = kwargs
844844

845-
def enable_xformers_memory_efficient_attention(self):
845+
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
846846
r"""
847847
Enable memory efficient attention as implemented in xformers.
848848
@@ -851,22 +851,44 @@ def enable_xformers_memory_efficient_attention(self):
851851
852852
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
853853
is used.
854+
855+
Parameters:
856+
attention_op (`Callable`, *optional*):
857+
Override the default `None` operator for use as `op` argument to the
858+
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
859+
function of xFormers.
860+
861+
Examples:
862+
863+
```py
864+
>>> import torch
865+
>>> from diffusers import DiffusionPipeline
866+
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
867+
868+
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
869+
>>> pipe = pipe.to("cuda")
870+
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
871+
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
872+
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
873+
```
854874
"""
855-
self.set_use_memory_efficient_attention_xformers(True)
875+
self.set_use_memory_efficient_attention_xformers(True, attention_op)
856876

857877
def disable_xformers_memory_efficient_attention(self):
858878
r"""
859879
Disable memory efficient attention as implemented in xformers.
860880
"""
861881
self.set_use_memory_efficient_attention_xformers(False)
862882

863-
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
883+
def set_use_memory_efficient_attention_xformers(
884+
self, valid: bool, attention_op: Optional[Callable] = None
885+
) -> None:
864886
# Recursively walk through all the children.
865887
# Any children which exposes the set_use_memory_efficient_attention_xformers method
866888
# gets the message
867889
def fn_recursive_set_mem_eff(module: torch.nn.Module):
868890
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
869-
module.set_use_memory_efficient_attention_xformers(valid)
891+
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
870892

871893
for child in module.children():
872894
fn_recursive_set_mem_eff(child)

0 commit comments

Comments
 (0)