@@ -96,7 +96,7 @@ def temp_float32_matmul_precision(precision: str):
9696
9797def skip_on_cpu (test_func ):
9898 """Decorator to skip tests that are not supported on CPU."""
99- decorated_func = skipCPUIf (True , "Not supported on CUDA " )(test_func )
99+ decorated_func = skipCPUIf (True , "Not supported on CPU " )(test_func )
100100 return decorated_func
101101
102102
@@ -2842,6 +2842,7 @@ def test_strided_backwards(self):
28422842 (1 , 0 , 2 , 3 ), # Reverse order
28432843 (0 , 2 , 1 , 3 ), # Mixed order
28442844 (2 , 0 , 1 , 3 ), # Another mixed order
2845+ (0 , 1 , 3 , 2 ), # Non contiguous last dim
28452846 ],
28462847 )
28472848 @common_utils .parametrize ("shape" , [(2 , 1 , 128 , 16 ), (4 , 2 , 64 , 16 )])
@@ -2890,12 +2891,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
28902891 @common_utils .parametrize ("mode" , ["eager" , "inductor" ])
28912892 @common_utils .parametrize (
28922893 "permute_order" ,
2893- [
2894- (0 , 1 , 2 , 3 ),
2895- (1 , 0 , 2 , 3 ),
2896- (0 , 2 , 1 , 3 ),
2897- (2 , 0 , 1 , 3 ),
2898- ],
2894+ [(0 , 1 , 2 , 3 ), (1 , 0 , 2 , 3 ), (0 , 2 , 1 , 3 ), (2 , 0 , 1 , 3 ), (0 , 1 , 3 , 2 )],
28992895 )
29002896 @common_utils .parametrize ("shape" , [(2 , 5 , 128 , 16 ), (4 , 2 , 64 , 16 )])
29012897 def test_flex_attention_backward_stride_ordering (
@@ -2939,6 +2935,69 @@ def test_flex_attention_backward_stride_ordering(
29392935 f"Mode: { mode } , Stride order mismatch for { name } : grad { input_stride_order } , input { orig_stride_order } ." ,
29402936 )
29412937
2938+ @supported_platform
2939+ def test_non_contiguous_last_dim (self , device ):
2940+ """Test flex_attention with tensors having non contiguous last dimension."""
2941+ B , H , D = 4 , 8 , 64
2942+ dtype = torch .float16 if device == "cuda" else torch .float32
2943+ for S in [16 , 64 ]:
2944+
2945+ def column_major_tensor ():
2946+ tensor = torch .randn (
2947+ (B , H , S , D ),
2948+ dtype = dtype ,
2949+ device = device ,
2950+ )
2951+ # Column major in last 2 dims
2952+ return tensor .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
2953+
2954+ q = column_major_tensor ()
2955+ k = column_major_tensor ()
2956+ v = column_major_tensor ()
2957+
2958+ requires_grad = device in DEVICE_SUPPORTS_BACKWARDS
2959+ if requires_grad :
2960+ q .requires_grad_ (True )
2961+ k .requires_grad_ (True )
2962+ v .requires_grad_ (True )
2963+
2964+ self .assertNotEqual (q .stride ()[- 1 ], 1 )
2965+ self .assertNotEqual (k .stride ()[- 1 ], 1 )
2966+ self .assertNotEqual (v .stride ()[- 1 ], 1 )
2967+
2968+ q_ref , k_ref , v_ref = query_key_value_clones (q , k , v )
2969+ q_gold , k_gold , v_gold = query_key_value_clones (q , k , v , torch .float64 )
2970+
2971+ golden_out = flex_attention (q_gold , k_gold , v_gold )
2972+ ref_out = flex_attention (q_ref , k_ref , v_ref )
2973+
2974+ flex_compiled = torch .compile (flex_attention , fullgraph = True , dynamic = True )
2975+ compiled_out = flex_compiled (q , k , v )
2976+
2977+ self ._check_out (golden_out , ref_out , compiled_out )
2978+
2979+ if requires_grad :
2980+ backward_grad = torch .randn_like (ref_out )
2981+
2982+ golden_out .backward (backward_grad .to (torch .float64 ))
2983+ ref_out .backward (backward_grad )
2984+ compiled_out .backward (backward_grad )
2985+
2986+ self ._check_out_and_grad (
2987+ golden_out ,
2988+ ref_out ,
2989+ compiled_out ,
2990+ q_gold ,
2991+ q_ref ,
2992+ q ,
2993+ k_gold ,
2994+ k_ref ,
2995+ k ,
2996+ v_gold ,
2997+ v_ref ,
2998+ v ,
2999+ )
3000+
29423001 @supported_platform
29433002 @common_utils .parametrize ("compile" , [True , False ])
29443003 def test_fully_masked_out_rows_0_check (self , device , compile : bool ):
0 commit comments