|
29 | 29 | #include "clang/Sema/SemaDiagnostic.h" |
30 | 30 | #include "llvm/ADT/STLExtras.h" |
31 | 31 | #include "llvm/ADT/SmallPtrSet.h" |
| 32 | +#include "llvm/ADT/SmallString.h" |
| 33 | +#include "llvm/ADT/SmallVector.h" |
32 | 34 | #include "llvm/ADT/StringSet.h" |
33 | 35 | #include "llvm/ADT/StringSwitch.h" |
34 | 36 | #include "llvm/IR/Constants.h" |
35 | 37 | #include "llvm/IR/GetElementPtrTypeIterator.h" |
36 | 38 | #include "llvm/IR/IRBuilder.h" |
37 | 39 | #include "llvm/IR/InstIterator.h" |
| 40 | +#include "llvm/IR/LLVMContext.h" |
| 41 | +#include "llvm/IR/Metadata.h" |
| 42 | +#include "llvm/Support/raw_ostream.h" |
38 | 43 | #include "llvm/Transforms/Utils/Cloning.h" |
39 | 44 | #include <memory> |
40 | 45 | #include <set> |
@@ -323,6 +328,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { |
323 | 328 | void EmitHLSLMartrixCastForStoreOp( |
324 | 329 | CodeGenFunction &CGF, SmallVector<llvm::Value *, 16> &IRCallArgs, |
325 | 330 | llvm::SmallVector<clang::QualType, 16> &ArgTys) override; |
| 331 | + llvm::Type *ConvertAttributedLinAlgMatrixType( |
| 332 | + const clang::AttributedLinAlgMatrixType *T) override; |
| 333 | + |
326 | 334 | /// Get or add constant to the program |
327 | 335 | HLCBuffer &GetOrCreateCBuffer(HLSLBufferDecl *D); |
328 | 336 | }; |
@@ -1298,6 +1306,10 @@ unsigned CGMSHLSLRuntime::AddTypeAnnotation(QualType Ty, |
1298 | 1306 | } else if (IsStringType(Ty)) { |
1299 | 1307 | // string won't be included in cbuffer |
1300 | 1308 | return 0; |
| 1309 | + } else if (Ty->getUnqualifiedDesugaredType() |
| 1310 | + ->isAttributedLinAlgMatrixType()) { |
| 1311 | + // LinAlg Matrix type does not count towards cbuffer size. |
| 1312 | + return 0; |
1301 | 1313 | } else { |
1302 | 1314 | unsigned arraySize = 0; |
1303 | 1315 | QualType arrayElementTy = Ty; |
@@ -6588,6 +6600,54 @@ Scope *CGMSHLSLRuntime::MarkScopeEnd(CodeGenFunction &CGF) { |
6588 | 6600 | return nullptr; |
6589 | 6601 | } |
6590 | 6602 |
|
| 6603 | +static MDNode * |
| 6604 | +createLinAlgMatrixTypeMetadata(LLVMContext &Ctx, |
| 6605 | + const clang::AttributedLinAlgMatrixType *T, |
| 6606 | + llvm::StructType *ST) { |
| 6607 | + auto Createi32MD = [&](int32_t Val) { |
| 6608 | + return ConstantAsMetadata::get( |
| 6609 | + ConstantInt::get(llvm::Type::getInt32Ty(Ctx), Val)); |
| 6610 | + }; |
| 6611 | + |
| 6612 | + return MDTuple::get( |
| 6613 | + Ctx, {ConstantAsMetadata::get(UndefValue::get(ST)), |
| 6614 | + Createi32MD(static_cast<uint32_t>(T->getComponentType())), |
| 6615 | + Createi32MD(T->getRows()), Createi32MD(T->getCols()), |
| 6616 | + Createi32MD(static_cast<uint32_t>(T->getUse())), |
| 6617 | + Createi32MD(static_cast<uint32_t>(T->getScope()))}); |
| 6618 | +} |
| 6619 | + |
| 6620 | +llvm::Type *CGMSHLSLRuntime::ConvertAttributedLinAlgMatrixType( |
| 6621 | + const clang::AttributedLinAlgMatrixType *T) { |
| 6622 | + |
| 6623 | + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); |
| 6624 | + llvm::Type *Int8Ptr = llvm::Type::getInt8PtrTy(Ctx); |
| 6625 | + llvm::Type *StructElemTypes[] = {Int8Ptr}; |
| 6626 | + |
| 6627 | + llvm::SmallString<64> Buf; |
| 6628 | + llvm::raw_svector_ostream OS(Buf); |
| 6629 | + OS << "dx.types.LinAlgMatrix"; |
| 6630 | + T->appendMangledAttributes(OS); |
| 6631 | + StringRef TypeName = OS.str(); |
| 6632 | + |
| 6633 | + llvm::StructType *ST = CGM.getModule().getTypeByName(TypeName); |
| 6634 | + if (ST) { |
| 6635 | + assert(ST->getNumElements() == 1 && ST->getElementType(0) == Int8Ptr && |
| 6636 | + "Unexpected existing dx.types.LinAlgMatrix type"); |
| 6637 | + return ST; |
| 6638 | + } |
| 6639 | + |
| 6640 | + ST = StructType::create(Ctx, StructElemTypes, TypeName); |
| 6641 | + |
| 6642 | + // Add metadata node for the new target type. |
| 6643 | + NamedMDNode *DxTypesMD = |
| 6644 | + CGM.getModule().getOrInsertNamedMetadata("dx.targetTypes"); |
| 6645 | + MDNode *NewTyNode = createLinAlgMatrixTypeMetadata(Ctx, T, ST); |
| 6646 | + DxTypesMD->addOperand(NewTyNode); |
| 6647 | + |
| 6648 | + return ST; |
| 6649 | +} |
| 6650 | + |
6591 | 6651 | CGHLSLRuntime *CodeGen::CreateMSHLSLRuntime(CodeGenModule &CGM) { |
6592 | 6652 | return new CGMSHLSLRuntime(CGM); |
6593 | 6653 | } |
0 commit comments