-
Notifications
You must be signed in to change notification settings - Fork 81
Expand file tree
/
Copy pathcutlass.cpp
More file actions
121 lines (110 loc) · 4.34 KB
/
cutlass.cpp
File metadata and controls
121 lines (110 loc) · 4.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#ifdef NVFUSER_ENABLE_CUTLASS
#include <bindings.h>
#include <nvf_cutlass.h>
namespace nvfuser::python {
namespace {
void bindGemm(py::module_& cutlass) {
cutlass.def(
"mxfp8_scaled_mm",
[](const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& alpha,
at::ScalarType out_dtype) -> torch::Tensor {
return cutlass_kernels::mxfp8_scaled_mm(
a, b, scales_a, scales_b, alpha, out_dtype);
},
R"(Computes mxfp8 matmul and returns bf16 or fp16 output tensor.
mxfp8_scaled_mm(Tensor a,
Tensor b,
Tensor scales_a,
Tensor scales_b,
Tensor alpha,
DataType out_dtype)
-> Tensor output)");
cutlass.def(
"nvfp4_scaled_mm",
[](const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& alpha,
at::ScalarType out_dtype) -> torch::Tensor {
return cutlass_kernels::nvfp4_scaled_mm(
a, b, scales_a, scales_b, alpha, out_dtype);
},
R"(Computes nvfp4 matmul and returns bf16 or fp16 output tensor.
nvfp4_scaled_mm(Tensor a,
Tensor b,
Tensor scales_a,
Tensor scales_b,
Tensor alpha,
DataType out_dtype)
-> Tensor output)");
cutlass.def(
"nvfp4_scaled_mm_blockscale",
[](const torch::Tensor& a_nvfp4,
const torch::Tensor& b_nvfp4,
const torch::Tensor& scales_a,
const torch::Tensor& scales_b,
const torch::Tensor& alpha,
const torch::Tensor& global_normconst) -> py::tuple {
std::pair<torch::Tensor, torch::Tensor> output =
cutlass_kernels::nvfp4_scaled_mm_blockscale(
a_nvfp4, b_nvfp4, scales_a, scales_b, alpha, global_normconst);
return py::make_tuple(output.first, output.second);
},
R"(Computes nvfp4 matmul and blockscale quantization. It returns nvfp4
output tensor and its blockscale factor.
nvfp4_scaled_mm_blockscale(Tensor a_nvfp4,
Tensor b_nvfp4,
Tensor scales_a,
Tensor scales_b,
Tensor alpha,
Tensor global_normconst)
-> tuple(Tensor out_nvfp4, Tensor blockscale))");
}
void bindGroupedGemm(py::module_& cutlass) {
cutlass.def(
"grouped_mm",
&cutlass_kernels::grouped_mm,
R"(Computes grouped matmul and returns bf16 or fp16 output tensor.
grouped_mm(Tensor a,
Tensor b,
Tensor ab_strides,
Tensor c_strides,
Tensor problem_sizes,
Tensor expert_offsets) -> Tensor output)");
cutlass.def(
"nvfp4_scaled_grouped_mm",
&cutlass_kernels::nvfp4_scaled_grouped_mm,
R"(Computes nvfp4 grouped matmul and returns bf16 or fp16 output tensor.
nvfp4_scaled_grouped_mm(Tensor a,
Tensor b,
Tensor a_blockscale,
Tensor b_blockscale,
Tensor alphas,
Tensor ab_strides,
Tensor c_strides,
Tensor problem_sizes,
Tensor expert_offsets,
Tensor sf_offsets,
DataType out_dtype) -> Tensor output)");
}
} // namespace
void bindCutlass(py::module& nvfuser) {
py::module_ nvf_cutlass = nvfuser.def_submodule(
"nvf_cutlass", "This submodule contains all cutlass gemms for NvFuser.");
bindGemm(nvf_cutlass);
bindGroupedGemm(nvf_cutlass);
}
} // namespace nvfuser::python
#endif