Skip to content

Commit 0eed3f7

Browse files
[SM6.10] Codegen for attributed LinAlg Matrix types (#8132)
Adds codegen for Attributed LinAlgMatrix types. Each type is translated into its corresponding DXIL type `%dx.types.LinAlgMatrix<mangling>` where `<mangling>` encodes the matrix attributes. For example this type: ``` __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] ``` Will be traslated to `%dx.types.LinAlgMatrixC4M4N5U1S2`. Also adds `!dx.targetTypes` metadata node that lists all LinAlgMatrix types that were ever used in the module. For the example above that would be: ``` !dx.targetTypes = !{!1} !1 ! = = !{%dx.types.LinAlgMatrixC4M4N5U1S2 undef, i32 4, i32 4, i32 5, i32 1, i32 2} ``` Later after optimizations these metadata nodes should be pruned to contain only the types actually used in the module (task #8133). Fixes #8123 Fixes #8172 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 37338c2 commit 0eed3f7

File tree

11 files changed

+181
-4
lines changed

11 files changed

+181
-4
lines changed

include/dxc/DXIL/DxilMetadataHelper.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,9 @@ class DxilMDHelper {
292292
// DXR Payload Annotations metadata.
293293
static const char kDxilDxrPayloadAnnotationsMDName[];
294294

295+
// LinAlg Matrix Target Types metadata.
296+
static const char kDxilTargetTypesMDName[];
297+
295298
// Extended shader property tags.
296299
static const unsigned kDxilShaderFlagsTag = 0;
297300
static const unsigned kDxilGSStateTag = 1;

lib/DXIL/DxilMetadataHelper.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ const char DxilMDHelper::kDxilNonUniformAttributeMDName[] = "dx.nonuniform";
9191
const char DxilMDHelper::kDxilValidatorVersionMDName[] = "dx.valver";
9292
const char DxilMDHelper::kDxilDxrPayloadAnnotationsMDName[] =
9393
"dx.dxrPayloadAnnotations";
94+
const char DxilMDHelper::kDxilTargetTypesMDName[] = "dx.targetTypes";
9495

9596
// This named metadata is not valid in final module (should be moved to
9697
// DxilContainer)
@@ -117,7 +118,7 @@ const char DxilMDHelper::kDxilSourceArgsOldMDName[] = "llvm.dbg.args";
117118
// This is reflection-only metadata
118119
const char DxilMDHelper::kDxilCountersMDName[] = "dx.counters";
119120

120-
static std::array<const char *, 8> DxilMDNames = {{
121+
static std::array<const char *, 9> DxilMDNames = {{
121122
DxilMDHelper::kDxilVersionMDName,
122123
DxilMDHelper::kDxilShaderModelMDName,
123124
DxilMDHelper::kDxilEntryPointsMDName,
@@ -126,6 +127,7 @@ static std::array<const char *, 8> DxilMDNames = {{
126127
DxilMDHelper::kDxilValidatorVersionMDName,
127128
DxilMDHelper::kDxilViewIdStateMDName,
128129
DxilMDHelper::kDxilDxrPayloadAnnotationsMDName,
130+
DxilMDHelper::kDxilTargetTypesMDName,
129131
}};
130132

131133
DxilMDHelper::DxilMDHelper(Module *pModule,

tools/clang/lib/CodeGen/CGDebugInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2122,6 +2122,11 @@ static QualType UnwrapTypeForDebugInfo(QualType T, const ASTContext &C) {
21222122
case Type::Attributed:
21232123
T = cast<AttributedType>(T)->getEquivalentType();
21242124
break;
2125+
// HLSL Change Start
2126+
case Type::AttributedLinAlgMatrix:
2127+
T = cast<AttributedLinAlgMatrixType>(T)->getWrappedType();
2128+
break;
2129+
// HLSL Change End
21252130
case Type::Elaborated:
21262131
T = cast<ElaboratedType>(T)->getNamedType();
21272132
break;

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,17 @@
2929
#include "clang/Sema/SemaDiagnostic.h"
3030
#include "llvm/ADT/STLExtras.h"
3131
#include "llvm/ADT/SmallPtrSet.h"
32+
#include "llvm/ADT/SmallString.h"
33+
#include "llvm/ADT/SmallVector.h"
3234
#include "llvm/ADT/StringSet.h"
3335
#include "llvm/ADT/StringSwitch.h"
3436
#include "llvm/IR/Constants.h"
3537
#include "llvm/IR/GetElementPtrTypeIterator.h"
3638
#include "llvm/IR/IRBuilder.h"
3739
#include "llvm/IR/InstIterator.h"
40+
#include "llvm/IR/LLVMContext.h"
41+
#include "llvm/IR/Metadata.h"
42+
#include "llvm/Support/raw_ostream.h"
3843
#include "llvm/Transforms/Utils/Cloning.h"
3944
#include <memory>
4045
#include <set>
@@ -323,6 +328,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
323328
void EmitHLSLMartrixCastForStoreOp(
324329
CodeGenFunction &CGF, SmallVector<llvm::Value *, 16> &IRCallArgs,
325330
llvm::SmallVector<clang::QualType, 16> &ArgTys) override;
331+
llvm::Type *ConvertAttributedLinAlgMatrixType(
332+
const clang::AttributedLinAlgMatrixType *T) override;
333+
326334
/// Get or add constant to the program
327335
HLCBuffer &GetOrCreateCBuffer(HLSLBufferDecl *D);
328336
};
@@ -1298,6 +1306,10 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty,
12981306
} else if (IsStringType(Ty)) {
12991307
// string won't be included in cbuffer
13001308
return 0;
1309+
} else if (Ty->getUnqualifiedDesugaredType()
1310+
->isAttributedLinAlgMatrixType()) {
1311+
// LinAlg Matrix type does not count towards cbuffer size.
1312+
return 0;
13011313
} else {
13021314
unsigned arraySize = 0;
13031315
QualType arrayElementTy = Ty;
@@ -6588,6 +6600,54 @@ Scope *CGMSHLSLRuntime::MarkScopeEnd(CodeGenFunction &CGF) {
65886600
return nullptr;
65896601
}
65906602

6603+
static MDNode *
6604+
createLinAlgMatrixTypeMetadata(LLVMContext &Ctx,
6605+
const clang::AttributedLinAlgMatrixType *T,
6606+
llvm::StructType *ST) {
6607+
auto Createi32MD = [&](int32_t Val) {
6608+
return ConstantAsMetadata::get(
6609+
ConstantInt::get(llvm::Type::getInt32Ty(Ctx), Val));
6610+
};
6611+
6612+
return MDTuple::get(
6613+
Ctx, {ConstantAsMetadata::get(UndefValue::get(ST)),
6614+
Createi32MD(static_cast<uint32_t>(T->getComponentType())),
6615+
Createi32MD(T->getRows()), Createi32MD(T->getCols()),
6616+
Createi32MD(static_cast<uint32_t>(T->getUse())),
6617+
Createi32MD(static_cast<uint32_t>(T->getScope()))});
6618+
}
6619+
6620+
llvm::Type *CGMSHLSLRuntime::ConvertAttributedLinAlgMatrixType(
6621+
const clang::AttributedLinAlgMatrixType *T) {
6622+
6623+
llvm::LLVMContext &Ctx = CGM.getLLVMContext();
6624+
llvm::Type *Int8Ptr = llvm::Type::getInt8PtrTy(Ctx);
6625+
llvm::Type *StructElemTypes[] = {Int8Ptr};
6626+
6627+
llvm::SmallString<64> Buf;
6628+
llvm::raw_svector_ostream OS(Buf);
6629+
OS << "dx.types.LinAlgMatrix";
6630+
T->appendMangledAttributes(OS);
6631+
StringRef TypeName = OS.str();
6632+
6633+
llvm::StructType *ST = CGM.getModule().getTypeByName(TypeName);
6634+
if (ST) {
6635+
assert(ST->getNumElements() == 1 && ST->getElementType(0) == Int8Ptr &&
6636+
"Unexpected existing dx.types.LinAlgMatrix type");
6637+
return ST;
6638+
}
6639+
6640+
ST = StructType::create(Ctx, StructElemTypes, TypeName);
6641+
6642+
// Add metadata node for the new target type.
6643+
NamedMDNode *DxTypesMD =
6644+
CGM.getModule().getOrInsertNamedMetadata("dx.targetTypes");
6645+
MDNode *NewTyNode = createLinAlgMatrixTypeMetadata(Ctx, T, ST);
6646+
DxTypesMD->addOperand(NewTyNode);
6647+
6648+
return ST;
6649+
}
6650+
65916651
CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) {
65926652
return new CGMSHLSLRuntime(CGM);
65936653
}

tools/clang/lib/CodeGen/CGHLSLRuntime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class ReturnStmt;
4545
class Attr;
4646
class VarDecl;
4747
class HLSLRootSignatureAttr;
48+
class AttributedLinAlgMatrixType;
4849

4950
namespace CodeGen {
5051
class CodeGenModule;
@@ -171,6 +172,9 @@ class CGHLSLRuntime {
171172
virtual void EmitHLSLMartrixCastForStoreOp(
172173
CodeGenFunction &CGF, llvm::SmallVector<llvm::Value *, 16> &IRCallArgs,
173174
llvm::SmallVector<clang::QualType, 16> &ArgTys) = 0;
175+
176+
virtual llvm::Type *ConvertAttributedLinAlgMatrixType(
177+
const clang::AttributedLinAlgMatrixType *T) = 0;
174178
};
175179

176180
/// Create an instance of a HLSL runtime class.

tools/clang/lib/CodeGen/CodeGenFunction.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) {
172172
case Type::Atomic:
173173
type = cast<AtomicType>(type)->getValueType();
174174
continue;
175+
case Type::AttributedLinAlgMatrix:
176+
type = cast<AttributedLinAlgMatrixType>(type)->getWrappedType();
177+
continue;
175178
}
176179
llvm_unreachable("unknown type kind!");
177180
}
@@ -1672,6 +1675,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) {
16721675
case Type::TypeOf:
16731676
case Type::UnaryTransform:
16741677
case Type::Attributed:
1678+
case Type::AttributedLinAlgMatrix: // HLSL Change
16751679
case Type::SubstTemplateTypeParm:
16761680
case Type::PackExpansion:
16771681
// Keep walking after single level desugaring.

tools/clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,13 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
708708
}
709709
break;
710710
}
711+
// HLSL Change Starts
712+
case Type::AttributedLinAlgMatrix: {
713+
ResultType = CGM.getHLSLRuntime().ConvertAttributedLinAlgMatrixType(
714+
cast<AttributedLinAlgMatrixType>(Ty));
715+
break;
716+
}
717+
// HLSL Change Ends
711718
}
712719

713720
assert(ResultType && "Didn't convert a type?");

tools/clang/lib/CodeGen/TargetInfo.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6220,6 +6220,10 @@ ABIArgInfo MSDXILABIInfo::classifyArgumentType(QualType Ty) const {
62206220
if (isAggregateTypeForABI(Ty))
62216221
return ABIArgInfo::getIndirect(0, /* byval */ false);
62226222

6223+
// Pass LinAlg Matrix types directly
6224+
if (Ty->isAttributedLinAlgMatrixType())
6225+
return ABIArgInfo::getDirect();
6226+
62236227
return (Ty->isPromotableIntegerType() ? ABIArgInfo::getExtend()
62246228
: ABIArgInfo::getDirect());
62256229
}
@@ -6237,8 +6241,9 @@ void MSDXILABIInfo::computeInfo(CGFunctionInfo &FI) const {
62376241
}
62386242
for (auto &I : FI.arguments()) {
62396243
I.info = classifyArgumentType(I.type);
6240-
// Do not flat matrix
6241-
if (hlsl::IsHLSLMatType(I.type))
6244+
// Do not flatten matrix types.
6245+
if (hlsl::IsHLSLMatType(I.type) ||
6246+
I.type.getTypePtr()->isAttributedLinAlgMatrixType())
62426247
I.info.setCanBeFlattened(false);
62436248
}
62446249
// TODO: set calling convention
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -I %hlsl_headers -T lib_6_10 -enable-16bit-types -fcgl %s | FileCheck %s
3+
// RUN: %dxc -I %hlsl_headers -T lib_6_10 -enable-16bit-types %s | FileCheck %s --check-prefix=CHECKVAL
4+
5+
#include <dx/linalg.h>
6+
using namespace dx::linalg;
7+
8+
// CHECK: %dx.types.LinAlgMatrixC4M4N5U1S2 = type { i8* }
9+
// CHECK: %dx.types.LinAlgMatrixC17M3N3U0S1 = type { i8* }
10+
// CHECK: %dx.types.LinAlgMatrixC9M10N20U0S0 = type { i8* }
11+
// CHECK: %dx.types.LinAlgMatrixC2M3N4U2S2 = type { i8* }
12+
13+
// CHECK: define internal void @"\01?f1@@YAXXZ"()
14+
// CHECK: %{{.*}} = alloca %dx.types.LinAlgMatrixC4M4N5U1S2, align 4
15+
void f1() {
16+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup)]] mat1;
17+
}
18+
19+
// CHECK: define internal void @"\01?f2@@YAX$linalg_matrixC17M3N3U0S1@@Z"(%dx.types.LinAlgMatrixC17M3N3U0S1 %mat2.coerce)
20+
// CHECK: %{{.*}} = alloca %dx.types.LinAlgMatrixC17M3N3U0S1, align 4
21+
void f2(__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::PackedS8x32, 3, 3, MatrixUse::A, MatrixScope::Wave)]] mat2) {
22+
}
23+
24+
// CHECK: define internal %dx.types.LinAlgMatrixC9M10N20U0S0 @"\01?f3@@YA$linalg_matrixC9M10N20U0S0@XZ"()
25+
// CHECK: %{{.*}} = alloca %dx.types.LinAlgMatrixC9M10N20U0S0, align 4
26+
typedef __builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::F32, 10, 20, MatrixUse::A, MatrixScope::Thread)]] Mat10by20;
27+
28+
Mat10by20 f3() {
29+
Mat10by20 mat3;
30+
return mat3;
31+
}
32+
33+
// CHECK: define internal void @"\01?f4@@YAXXZ"()
34+
// CHECK: call void @"\01??$fTemplate@$02$03@@YAXXZ"()
35+
36+
// CHECK: define linkonce_odr void @"\01??$fTemplate@$02$03@@YAXXZ"()
37+
// CHECK: %mat4 = alloca %dx.types.LinAlgMatrixC2M3N4U2S2, align 4
38+
39+
template <uint M, uint N>
40+
void fTemplate() {
41+
__builtin_LinAlgMatrix [[__LinAlgMatrix_Attributes(ComponentType::I16, M, N, MatrixUse::Accumulator, MatrixScope::ThreadGroup)]] mat4;
42+
}
43+
44+
void f4() {
45+
fTemplate<3, 4>();
46+
}
47+
48+
// CHECK: !dx.targetTypes = !{![[T0:.*]], ![[T1:.*]], ![[T2:.*]], ![[T3:.*]]}
49+
// CHECK: ![[T0:.*]] = !{%dx.types.LinAlgMatrixC4M4N5U1S2 undef, i32 4, i32 4, i32 5, i32 1, i32 2}
50+
// CHECK: ![[T1:.*]] = !{%dx.types.LinAlgMatrixC17M3N3U0S1 undef, i32 17, i32 3, i32 3, i32 0, i32 1}
51+
// CHECK: ![[T2:.*]] = !{%dx.types.LinAlgMatrixC9M10N20U0S0 undef, i32 9, i32 10, i32 20, i32 0, i32 0}
52+
// CHECK: ![[T3:.*]] = !{%dx.types.LinAlgMatrixC2M3N4U2S2 undef, i32 2, i32 3, i32 4, i32 2, i32 2}
53+
54+
// CHECKVAL-NOT: error: validation errors
55+
// CHECKVAL-NOT: error: Named metadata 'dx.targetTypes' is unknown.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// REQUIRES: dxil-1-10
2+
// RUN: %dxc -I %hlsl_headers -T lib_6_10 -fcgl %s | FileCheck %s
3+
// RUN: %dxc -I %hlsl_headers -T lib_6_10 %s | FileCheck %s --check-prefix=CHECKVAL
4+
5+
#include <dx/linalg.h>
6+
using namespace dx::linalg;
7+
8+
// CHECK: %"class.dx::linalg::Matrix<dx::linalg::ComponentType::ComponentEnum::I32, 4, 5, dx::linalg::MatrixUse::MatrixUseEnum::B,
9+
// CHECK-SAME: dx::linalg::MatrixScope::MatrixScopeEnum::ThreadGroup>" = type { %dx.types.LinAlgMatrixC4M4N5U1S2 }
10+
// CHECK: %dx.types.LinAlgMatrixC4M4N5U1S2 = type { i8* }
11+
12+
13+
// CHECK: %"class.dx::linalg::Matrix<dx::linalg::ComponentType::ComponentEnum::PackedS8x32, 100, 100, dx::linalg::MatrixUse::MatrixUseEnum::A,
14+
// CHECK-SAME: dx::linalg::MatrixScope::MatrixScopeEnum::Wave>" = type { %dx.types.LinAlgMatrixC17M100N100U0S1 }
15+
// CHECK: %dx.types.LinAlgMatrixC17M100N100U0S1 = type { i8* }
16+
17+
// CHECK: define internal void @"\01?f@@YAXXZ"()
18+
void f() {
19+
// CHECK: %mat1 = alloca %"class.dx::linalg::Matrix<dx::linalg::ComponentType::ComponentEnum::I32, 4, 5,
20+
// CHECK-SAME: dx::linalg::MatrixUse::MatrixUseEnum::B, dx::linalg::MatrixScope::MatrixScopeEnum::ThreadGroup>", align 4
21+
Matrix<ComponentType::I32, 4, 5, MatrixUse::B, MatrixScope::ThreadGroup> mat1;
22+
// CHECK: %mat2 = alloca %"class.dx::linalg::Matrix<dx::linalg::ComponentType::ComponentEnum::PackedS8x32, 100, 100,
23+
// CHECK-SAME: dx::linalg::MatrixUse::MatrixUseEnum::A, dx::linalg::MatrixScope::MatrixScopeEnum::Wave>", align 4
24+
Matrix<ComponentType::PackedS8x32, 100, 100, MatrixUse::A, MatrixScope::Wave> mat2;
25+
}
26+
27+
// CHECK: !dx.targetTypes = !{![[T0:.*]], ![[T0:.*]]}
28+
// CHECK: ![[T0:.*]] = !{%dx.types.LinAlgMatrixC4M4N5U1S2 undef, i32 4, i32 4, i32 5, i32 1, i32 2}
29+
// CHECK: ![[T1:.*]] = !{%dx.types.LinAlgMatrixC17M100N100U0S1 undef, i32 17, i32 100, i32 100, i32 0, i32 1}
30+
31+
// CHECKVAL-NOT: error: validation errors
32+
// CHECKVAL-NOT: error: Named metadata 'dx.targetTypes' is unknown.

0 commit comments

Comments
 (0)