Skip to content

Commit 155a653

Browse files
committed
[FlexAttention] Remove Old Constraint on last dim strides
ghstack-source-id: 287e517 Pull Request resolved: #151959
1 parent 3699c86 commit 155a653

File tree

2 files changed

+80
-12
lines changed

2 files changed

+80
-12
lines changed

test/inductor/test_flex_attention.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def temp_float32_matmul_precision(precision: str):
9696

9797
def skip_on_cpu(test_func):
9898
"""Decorator to skip tests that are not supported on CPU."""
99-
decorated_func = skipCPUIf(True, "Not supported on CUDA")(test_func)
99+
decorated_func = skipCPUIf(True, "Not supported on CPU")(test_func)
100100
return decorated_func
101101

102102

@@ -2842,6 +2842,7 @@ def test_strided_backwards(self):
28422842
(1, 0, 2, 3), # Reverse order
28432843
(0, 2, 1, 3), # Mixed order
28442844
(2, 0, 1, 3), # Another mixed order
2845+
(0, 1, 3, 2), # Non contiguous last dim
28452846
],
28462847
)
28472848
@common_utils.parametrize("shape", [(2, 1, 128, 16), (4, 2, 64, 16)])
@@ -2890,12 +2891,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
28902891
@common_utils.parametrize("mode", ["eager", "inductor"])
28912892
@common_utils.parametrize(
28922893
"permute_order",
2893-
[
2894-
(0, 1, 2, 3),
2895-
(1, 0, 2, 3),
2896-
(0, 2, 1, 3),
2897-
(2, 0, 1, 3),
2898-
],
2894+
[(0, 1, 2, 3), (1, 0, 2, 3), (0, 2, 1, 3), (2, 0, 1, 3), (0, 1, 3, 2)],
28992895
)
29002896
@common_utils.parametrize("shape", [(2, 5, 128, 16), (4, 2, 64, 16)])
29012897
def test_flex_attention_backward_stride_ordering(
@@ -2939,6 +2935,69 @@ def test_flex_attention_backward_stride_ordering(
29392935
f"Mode: {mode}, Stride order mismatch for {name}: grad {input_stride_order}, input {orig_stride_order}.",
29402936
)
29412937

2938+
@supported_platform
2939+
def test_non_contiguous_last_dim(self, device):
2940+
"""Test flex_attention with tensors having non contiguous last dimension."""
2941+
B, H, D = 4, 8, 64
2942+
dtype = torch.float16 if device == "cuda" else torch.float32
2943+
for S in [16, 64]:
2944+
2945+
def column_major_tensor():
2946+
tensor = torch.randn(
2947+
(B, H, S, D),
2948+
dtype=dtype,
2949+
device=device,
2950+
)
2951+
# Column major in last 2 dims
2952+
return tensor.transpose(-1, -2).contiguous().transpose(-1, -2)
2953+
2954+
q = column_major_tensor()
2955+
k = column_major_tensor()
2956+
v = column_major_tensor()
2957+
2958+
requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
2959+
if requires_grad:
2960+
q.requires_grad_(True)
2961+
k.requires_grad_(True)
2962+
v.requires_grad_(True)
2963+
2964+
self.assertNotEqual(q.stride()[-1], 1)
2965+
self.assertNotEqual(k.stride()[-1], 1)
2966+
self.assertNotEqual(v.stride()[-1], 1)
2967+
2968+
q_ref, k_ref, v_ref = query_key_value_clones(q, k, v)
2969+
q_gold, k_gold, v_gold = query_key_value_clones(q, k, v, torch.float64)
2970+
2971+
golden_out = flex_attention(q_gold, k_gold, v_gold)
2972+
ref_out = flex_attention(q_ref, k_ref, v_ref)
2973+
2974+
flex_compiled = torch.compile(flex_attention, fullgraph=True, dynamic=True)
2975+
compiled_out = flex_compiled(q, k, v)
2976+
2977+
self._check_out(golden_out, ref_out, compiled_out)
2978+
2979+
if requires_grad:
2980+
backward_grad = torch.randn_like(ref_out)
2981+
2982+
golden_out.backward(backward_grad.to(torch.float64))
2983+
ref_out.backward(backward_grad)
2984+
compiled_out.backward(backward_grad)
2985+
2986+
self._check_out_and_grad(
2987+
golden_out,
2988+
ref_out,
2989+
compiled_out,
2990+
q_gold,
2991+
q_ref,
2992+
q,
2993+
k_gold,
2994+
k_ref,
2995+
k,
2996+
v_gold,
2997+
v_ref,
2998+
v,
2999+
)
3000+
29423001
@supported_platform
29433002
@common_utils.parametrize("compile", [True, False])
29443003
def test_fully_masked_out_rows_0_check(self, device, compile: bool):

torch/_inductor/kernel/flex_attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,15 @@ def check_cpu_supported():
930930
return supported
931931

932932

933+
def contiguous_last_dim(x):
934+
"""Ensure that realized IR node has a contigous stride in the last dimension."""
935+
strides = x.maybe_get_stride()
936+
if strides and strides[-1] != 1:
937+
contiguous_stride_order = list(reversed(range(len(x.get_size()))))
938+
return ExternKernel.require_stride_order(x, contiguous_stride_order)
939+
return x
940+
941+
933942
def lower_cpu(
934943
query,
935944
key,
@@ -1092,6 +1101,9 @@ def convert_mask_graph_module(mask_graph):
10921101
if isinstance(item, TensorBox):
10931102
fake_buffers.append(item.data.data) # type: ignore[attr-defined]
10941103

1104+
# CPU kernel requires last dim to be contiguous
1105+
query, key, value = map(contiguous_last_dim, [query, key, value])
1106+
10951107
(
10961108
query,
10971109
key,
@@ -1258,7 +1270,6 @@ def set_head_dim_values(
12581270
)
12591271

12601272

1261-
# TODO: We probably also need a layout constraint?
12621273
@register_lowering(torch.ops.higher_order.flex_attention, type_promotion_kind=None)
12631274
def flex_attention(
12641275
query,
@@ -1413,11 +1424,9 @@ def flex_attention(
14131424
else:
14141425
kernel_options.setdefault("IS_DIVISIBLE", True)
14151426

1416-
# Reuse query strides for output layout despite different last dimension.
1417-
# This works because only the last dim differs and we check it is contiguous.
1427+
# NB it is okay that the v_head_dim is different
1428+
# We are using these to match fill order of the output.
14181429
q_strides = query.get_stride()
1419-
assert q_strides[-1] == 1, "Query must be contiguous in the last dimension"
1420-
14211430
# Construct output layout with strides matching the query.
14221431
out_size = [B, Hq, seq_len_q, v_head_dim]
14231432
out_strides = infer_dense_strides(out_size, q_strides)

0 commit comments

Comments
 (0)