aboutsummaryrefslogtreecommitdiffstats
path: root/lib/Target/NVPTX
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Target/NVPTX')
-rw-r--r--lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp10
-rw-r--r--lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h5
-rw-r--r--lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp7
-rw-r--r--lib/Target/NVPTX/NVPTX.td22
-rw-r--r--lib/Target/NVPTX/NVPTXAsmPrinter.cpp77
-rw-r--r--lib/Target/NVPTX/NVPTXFavorNonGenericAddrSpaces.cpp5
-rw-r--r--lib/Target/NVPTX/NVPTXGenericToNVVM.cpp1
-rw-r--r--lib/Target/NVPTX/NVPTXISelLowering.cpp2
-rw-r--r--lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp14
-rw-r--r--lib/Target/NVPTX/NVPTXTargetMachine.h2
-rw-r--r--lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp70
-rw-r--r--lib/Target/NVPTX/NVPTXTargetTransformInfo.h2
12 files changed, 146 insertions, 71 deletions
diff --git a/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp b/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp
index 80b2f62..ac92df9 100644
--- a/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp
+++ b/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.cpp
@@ -28,13 +28,9 @@ using namespace llvm;
#include "NVPTXGenAsmWriter.inc"
-
NVPTXInstPrinter::NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
- const MCRegisterInfo &MRI,
- const MCSubtargetInfo &STI)
- : MCInstPrinter(MAI, MII, MRI) {
- setAvailableFeatures(STI.getFeatureBits());
-}
+ const MCRegisterInfo &MRI)
+ : MCInstPrinter(MAI, MII, MRI) {}
void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const {
// Decode the virtual register
@@ -72,7 +68,7 @@ void NVPTXInstPrinter::printRegName(raw_ostream &OS, unsigned RegNo) const {
}
void NVPTXInstPrinter::printInst(const MCInst *MI, raw_ostream &OS,
- StringRef Annot) {
+ StringRef Annot, const MCSubtargetInfo &STI) {
printInstruction(MI, OS);
// Next always print the annotation.
diff --git a/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h b/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h
index 0496964..02c5a21 100644
--- a/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h
+++ b/lib/Target/NVPTX/InstPrinter/NVPTXInstPrinter.h
@@ -25,10 +25,11 @@ class MCSubtargetInfo;
class NVPTXInstPrinter : public MCInstPrinter {
public:
NVPTXInstPrinter(const MCAsmInfo &MAI, const MCInstrInfo &MII,
- const MCRegisterInfo &MRI, const MCSubtargetInfo &STI);
+ const MCRegisterInfo &MRI);
void printRegName(raw_ostream &OS, unsigned RegNo) const override;
- void printInst(const MCInst *MI, raw_ostream &OS, StringRef Annot) override;
+ void printInst(const MCInst *MI, raw_ostream &OS, StringRef Annot,
+ const MCSubtargetInfo &STI) override;
// Autogenerated by tblgen.
void printInstruction(const MCInst *MI, raw_ostream &O);
diff --git a/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp b/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp
index 2b4d864..f9e4324 100644
--- a/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp
+++ b/lib/Target/NVPTX/MCTargetDesc/NVPTXMCTargetDesc.cpp
@@ -58,14 +58,13 @@ static MCCodeGenInfo *createNVPTXMCCodeGenInfo(
return X;
}
-static MCInstPrinter *createNVPTXMCInstPrinter(const Target &T,
+static MCInstPrinter *createNVPTXMCInstPrinter(const Triple &T,
unsigned SyntaxVariant,
const MCAsmInfo &MAI,
const MCInstrInfo &MII,
- const MCRegisterInfo &MRI,
- const MCSubtargetInfo &STI) {
+ const MCRegisterInfo &MRI) {
if (SyntaxVariant == 0)
- return new NVPTXInstPrinter(MAI, MII, MRI, STI);
+ return new NVPTXInstPrinter(MAI, MII, MRI);
return nullptr;
}
diff --git a/lib/Target/NVPTX/NVPTX.td b/lib/Target/NVPTX/NVPTX.td
index 93fabf6..96abfa8 100644
--- a/lib/Target/NVPTX/NVPTX.td
+++ b/lib/Target/NVPTX/NVPTX.td
@@ -32,20 +32,28 @@ def SM21 : SubtargetFeature<"sm_21", "SmVersion", "21",
"Target SM 2.1">;
def SM30 : SubtargetFeature<"sm_30", "SmVersion", "30",
"Target SM 3.0">;
+def SM32 : SubtargetFeature<"sm_32", "SmVersion", "32",
+ "Target SM 3.2">;
def SM35 : SubtargetFeature<"sm_35", "SmVersion", "35",
"Target SM 3.5">;
+def SM37 : SubtargetFeature<"sm_37", "SmVersion", "37",
+ "Target SM 3.7">;
def SM50 : SubtargetFeature<"sm_50", "SmVersion", "50",
"Target SM 5.0">;
+def SM52 : SubtargetFeature<"sm_52", "SmVersion", "52",
+ "Target SM 5.2">;
+def SM53 : SubtargetFeature<"sm_53", "SmVersion", "53",
+ "Target SM 5.3">;
// PTX Versions
-def PTX30 : SubtargetFeature<"ptx30", "PTXVersion", "30",
- "Use PTX version 3.0">;
-def PTX31 : SubtargetFeature<"ptx31", "PTXVersion", "31",
- "Use PTX version 3.1">;
def PTX32 : SubtargetFeature<"ptx32", "PTXVersion", "32",
"Use PTX version 3.2">;
def PTX40 : SubtargetFeature<"ptx40", "PTXVersion", "40",
"Use PTX version 4.0">;
+def PTX41 : SubtargetFeature<"ptx41", "PTXVersion", "41",
+ "Use PTX version 4.1">;
+def PTX42 : SubtargetFeature<"ptx42", "PTXVersion", "42",
+ "Use PTX version 4.2">;
//===----------------------------------------------------------------------===//
// NVPTX supported processors.
@@ -57,8 +65,12 @@ class Proc<string Name, list<SubtargetFeature> Features>
def : Proc<"sm_20", [SM20]>;
def : Proc<"sm_21", [SM21]>;
def : Proc<"sm_30", [SM30]>;
+def : Proc<"sm_32", [SM32, PTX40]>;
def : Proc<"sm_35", [SM35]>;
-def : Proc<"sm_50", [SM50]>;
+def : Proc<"sm_37", [SM37, PTX41]>;
+def : Proc<"sm_50", [SM50, PTX40]>;
+def : Proc<"sm_52", [SM52, PTX41]>;
+def : Proc<"sm_53", [SM53, PTX42]>;
def NVPTXInstrInfo : InstrInfo {
diff --git a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
index cc58b07..9a71964 100644
--- a/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ b/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -118,7 +118,7 @@ void NVPTXAsmPrinter::emitLineNumberAsDotLoc(const MachineInstr &MI) {
DebugLoc curLoc = MI.getDebugLoc();
- if (prevDebugLoc.isUnknown() && curLoc.isUnknown())
+ if (!prevDebugLoc && !curLoc)
return;
if (prevDebugLoc == curLoc)
@@ -126,39 +126,32 @@ void NVPTXAsmPrinter::emitLineNumberAsDotLoc(const MachineInstr &MI) {
prevDebugLoc = curLoc;
- if (curLoc.isUnknown())
+ if (!curLoc)
return;
- const MachineFunction *MF = MI.getParent()->getParent();
- //const TargetMachine &TM = MF->getTarget();
-
- const LLVMContext &ctx = MF->getFunction()->getContext();
- DIScope Scope(curLoc.getScope(ctx));
-
- assert((!Scope || Scope.isScope()) &&
- "Scope of a DebugLoc should be null or a DIScope.");
+ auto *Scope = cast_or_null<MDScope>(curLoc.getScope());
if (!Scope)
return;
- StringRef fileName(Scope.getFilename());
- StringRef dirName(Scope.getDirectory());
+ StringRef fileName(Scope->getFilename());
+ StringRef dirName(Scope->getDirectory());
SmallString<128> FullPathName = dirName;
if (!dirName.empty() && !sys::path::is_absolute(fileName)) {
sys::path::append(FullPathName, fileName);
- fileName = FullPathName.str();
+ fileName = FullPathName;
}
- if (filenameMap.find(fileName.str()) == filenameMap.end())
+ if (filenameMap.find(fileName) == filenameMap.end())
return;
// Emit the line from the source file.
if (InterleaveSrc)
- this->emitSrcInText(fileName.str(), curLoc.getLine());
+ this->emitSrcInText(fileName, curLoc.getLine());
std::stringstream temp;
- temp << "\t.loc " << filenameMap[fileName.str()] << " " << curLoc.getLine()
+ temp << "\t.loc " << filenameMap[fileName] << " " << curLoc.getLine()
<< " " << curLoc.getCol();
- OutStreamer.EmitRawText(Twine(temp.str().c_str()));
+ OutStreamer.EmitRawText(temp.str());
}
void NVPTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
@@ -641,7 +634,7 @@ static bool usedInGlobalVarDef(const Constant *C) {
return false;
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
- if (GV->getName().str() == "llvm.used")
+ if (GV->getName() == "llvm.used")
return false;
return true;
}
@@ -656,7 +649,7 @@ static bool usedInGlobalVarDef(const Constant *C) {
static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
- if (othergv->getName().str() == "llvm.used")
+ if (othergv->getName() == "llvm.used")
return true;
}
@@ -780,32 +773,32 @@ void NVPTXAsmPrinter::recordAndEmitFilenames(Module &M) {
DbgFinder.processModule(M);
unsigned i = 1;
- for (DICompileUnit DIUnit : DbgFinder.compile_units()) {
- StringRef Filename(DIUnit.getFilename());
- StringRef Dirname(DIUnit.getDirectory());
+ for (const MDCompileUnit *DIUnit : DbgFinder.compile_units()) {
+ StringRef Filename = DIUnit->getFilename();
+ StringRef Dirname = DIUnit->getDirectory();
SmallString<128> FullPathName = Dirname;
if (!Dirname.empty() && !sys::path::is_absolute(Filename)) {
sys::path::append(FullPathName, Filename);
- Filename = FullPathName.str();
+ Filename = FullPathName;
}
- if (filenameMap.find(Filename.str()) != filenameMap.end())
+ if (filenameMap.find(Filename) != filenameMap.end())
continue;
- filenameMap[Filename.str()] = i;
- OutStreamer.EmitDwarfFileDirective(i, "", Filename.str());
+ filenameMap[Filename] = i;
+ OutStreamer.EmitDwarfFileDirective(i, "", Filename);
++i;
}
- for (DISubprogram SP : DbgFinder.subprograms()) {
- StringRef Filename(SP.getFilename());
- StringRef Dirname(SP.getDirectory());
+ for (MDSubprogram *SP : DbgFinder.subprograms()) {
+ StringRef Filename = SP->getFilename();
+ StringRef Dirname = SP->getDirectory();
SmallString<128> FullPathName = Dirname;
if (!Dirname.empty() && !sys::path::is_absolute(Filename)) {
sys::path::append(FullPathName, Filename);
- Filename = FullPathName.str();
+ Filename = FullPathName;
}
- if (filenameMap.find(Filename.str()) != filenameMap.end())
+ if (filenameMap.find(Filename) != filenameMap.end())
continue;
- filenameMap[Filename.str()] = i;
+ filenameMap[Filename] = i;
++i;
}
}
@@ -1011,7 +1004,7 @@ void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
msg.append("Error: ");
msg.append("Symbol ");
if (V->hasName())
- msg.append(V->getName().str());
+ msg.append(V->getName());
msg.append("has unsupported appending linkage type");
llvm_unreachable(msg.c_str());
} else if (!V->hasInternalLinkage() &&
@@ -1147,7 +1140,7 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
const Function *demotedFunc = nullptr;
if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
- O << "// " << GVar->getName().str() << " has been demoted\n";
+ O << "// " << GVar->getName() << " has been demoted\n";
if (localDecls.find(demotedFunc) != localDecls.end())
localDecls[demotedFunc].push_back(GVar);
else {
@@ -1195,9 +1188,10 @@ void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
// The frontend adds zero-initializer to variables that don't have an
// initial value, so skip warning for this case.
if (!GVar->getInitializer()->isNullValue()) {
- std::string warnMsg = "initial value of '" + GVar->getName().str() +
- "' is not allowed in addrspace(" +
- llvm::utostr_32(PTy->getAddressSpace()) + ")";
+ std::string warnMsg =
+ ("initial value of '" + GVar->getName() +
+ "' is not allowed in addrspace(" +
+ Twine(llvm::utostr_32(PTy->getAddressSpace())) + ")").str();
report_fatal_error(warnMsg.c_str());
}
}
@@ -1771,12 +1765,11 @@ void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
case Type::IntegerTyID: {
const Type *ETy = CPV->getType();
if (ETy == Type::getInt8Ty(CPV->getContext())) {
- unsigned char c =
- (unsigned char)(dyn_cast<ConstantInt>(CPV))->getZExtValue();
+ unsigned char c = (unsigned char)cast<ConstantInt>(CPV)->getZExtValue();
ptr = &c;
aggBuffer->addBytes(ptr, 1, Bytes);
} else if (ETy == Type::getInt16Ty(CPV->getContext())) {
- short int16 = (short)(dyn_cast<ConstantInt>(CPV))->getZExtValue();
+ short int16 = (short)cast<ConstantInt>(CPV)->getZExtValue();
ptr = (unsigned char *)&int16;
aggBuffer->addBytes(ptr, 2, Bytes);
} else if (ETy == Type::getInt32Ty(CPV->getContext())) {
@@ -2086,7 +2079,7 @@ void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
void NVPTXAsmPrinter::emitSrcInText(StringRef filename, unsigned line) {
std::stringstream temp;
- LineReader *reader = this->getReader(filename.str());
+ LineReader *reader = this->getReader(filename);
temp << "\n//";
temp << filename.str();
temp << ":";
@@ -2094,7 +2087,7 @@ void NVPTXAsmPrinter::emitSrcInText(StringRef filename, unsigned line) {
temp << " ";
temp << reader->readLine(line);
temp << "\n";
- this->OutStreamer.EmitRawText(Twine(temp.str()));
+ this->OutStreamer.EmitRawText(temp.str());
}
LineReader *NVPTXAsmPrinter::getReader(std::string filename) {
diff --git a/lib/Target/NVPTX/NVPTXFavorNonGenericAddrSpaces.cpp b/lib/Target/NVPTX/NVPTXFavorNonGenericAddrSpaces.cpp
index 6d7c99c..ae63cae 100644
--- a/lib/Target/NVPTX/NVPTXFavorNonGenericAddrSpaces.cpp
+++ b/lib/Target/NVPTX/NVPTXFavorNonGenericAddrSpaces.cpp
@@ -132,9 +132,8 @@ bool NVPTXFavorNonGenericAddrSpaces::hoistAddrSpaceCastFromGEP(
} else {
// GEP is a constant expression.
Constant *NewGEPCE = ConstantExpr::getGetElementPtr(
- cast<Constant>(Cast->getOperand(0)),
- Indices,
- GEP->isInBounds());
+ GEP->getSourceElementType(), cast<Constant>(Cast->getOperand(0)),
+ Indices, GEP->isInBounds());
GEP->replaceAllUsesWith(
ConstantExpr::getAddrSpaceCast(NewGEPCE, GEP->getType()));
}
diff --git a/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp b/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp
index 850c020..6fd09c4 100644
--- a/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp
+++ b/lib/Target/NVPTX/NVPTXGenericToNVVM.cpp
@@ -347,6 +347,7 @@ Value *GenericToNVVM::remapConstantExpr(Module *M, Function *F, ConstantExpr *C,
NewOperands[0],
makeArrayRef(&NewOperands[1], NumOperands - 1))
: Builder.CreateInBoundsGEP(
+ cast<GEPOperator>(C)->getSourceElementType(),
NewOperands[0],
makeArrayRef(&NewOperands[1], NumOperands - 1));
case Instruction::Select:
diff --git a/lib/Target/NVPTX/NVPTXISelLowering.cpp b/lib/Target/NVPTX/NVPTXISelLowering.cpp
index ff74e6e..8b06657 100644
--- a/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3893,7 +3893,7 @@ static SDValue PerformADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
const SDNode *left = N0.getOperand(0).getNode();
const SDNode *right = N0.getOperand(1).getNode();
- if (dyn_cast<ConstantSDNode>(left) || dyn_cast<ConstantSDNode>(right))
+ if (isa<ConstantSDNode>(left) || isa<ConstantSDNode>(right))
opIsLive = true;
if (!opIsLive)
diff --git a/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp b/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp
index 578401a..6ab0fad 100644
--- a/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp
+++ b/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp
@@ -70,8 +70,8 @@ static void convertTransferToLoop(
// srcAddr and dstAddr are expected to be pointer types,
// so no check is made here.
- unsigned srcAS = dyn_cast<PointerType>(srcAddr->getType())->getAddressSpace();
- unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
+ unsigned srcAS = cast<PointerType>(srcAddr->getType())->getAddressSpace();
+ unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace();
// Cast pointers to (char *)
srcAddr = builder.CreateBitCast(srcAddr, Type::getInt8PtrTy(Context, srcAS));
@@ -84,9 +84,11 @@ static void convertTransferToLoop(
ind->addIncoming(ConstantInt::get(indType, 0), origBB);
// load from srcAddr+ind
- Value *val = loop.CreateLoad(loop.CreateGEP(srcAddr, ind), srcVolatile);
+ Value *val = loop.CreateLoad(loop.CreateGEP(loop.getInt8Ty(), srcAddr, ind),
+ srcVolatile);
// store at dstAddr+ind
- loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), dstVolatile);
+ loop.CreateStore(val, loop.CreateGEP(loop.getInt8Ty(), dstAddr, ind),
+ dstVolatile);
// The value for ind coming from backedge is (ind + 1)
Value *newind = loop.CreateAdd(ind, ConstantInt::get(indType, 1));
@@ -106,7 +108,7 @@ static void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr,
origBB->getTerminator()->setSuccessor(0, loopBB);
IRBuilder<> builder(origBB, origBB->getTerminator());
- unsigned dstAS = dyn_cast<PointerType>(dstAddr->getType())->getAddressSpace();
+ unsigned dstAS = cast<PointerType>(dstAddr->getType())->getAddressSpace();
// Cast pointer to the type of value getting stored
dstAddr =
@@ -116,7 +118,7 @@ static void convertMemSetToLoop(Instruction *splitAt, Value *dstAddr,
PHINode *ind = loop.CreatePHI(len->getType(), 0);
ind->addIncoming(ConstantInt::get(len->getType(), 0), origBB);
- loop.CreateStore(val, loop.CreateGEP(dstAddr, ind), false);
+ loop.CreateStore(val, loop.CreateGEP(val->getType(), dstAddr, ind), false);
Value *newind = loop.CreateAdd(ind, ConstantInt::get(len->getType(), 1));
ind->addIncoming(newind, loopBB);
diff --git a/lib/Target/NVPTX/NVPTXTargetMachine.h b/lib/Target/NVPTX/NVPTXTargetMachine.h
index b8df5af..2cd10e8 100644
--- a/lib/Target/NVPTX/NVPTXTargetMachine.h
+++ b/lib/Target/NVPTX/NVPTXTargetMachine.h
@@ -52,7 +52,7 @@ public:
TargetPassConfig *createPassConfig(PassManagerBase &PM) override;
// Emission of machine code through MCJIT is not supported.
- bool addPassesToEmitMC(PassManagerBase &, MCContext *&, raw_ostream &,
+ bool addPassesToEmitMC(PassManagerBase &, MCContext *&, raw_pwrite_stream &,
bool = true) override {
return true;
}
diff --git a/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index b8af04d..dc81802 100644
--- a/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -8,6 +8,7 @@
//===----------------------------------------------------------------------===//
#include "NVPTXTargetTransformInfo.h"
+#include "NVPTXUtilities.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Analysis/ValueTracking.h"
@@ -19,6 +20,75 @@ using namespace llvm;
#define DEBUG_TYPE "NVPTXtti"
+// Whether the given intrinsic reads threadIdx.x/y/z.
+static bool readsThreadIndex(const IntrinsicInst *II) {
+ switch (II->getIntrinsicID()) {
+ default: return false;
+ case Intrinsic::nvvm_read_ptx_sreg_tid_x:
+ case Intrinsic::nvvm_read_ptx_sreg_tid_y:
+ case Intrinsic::nvvm_read_ptx_sreg_tid_z:
+ return true;
+ }
+}
+
+static bool readsLaneId(const IntrinsicInst *II) {
+ return II->getIntrinsicID() == Intrinsic::ptx_read_laneid;
+}
+
+// Whether the given intrinsic is an atomic instruction in PTX.
+static bool isNVVMAtomic(const IntrinsicInst *II) {
+ switch (II->getIntrinsicID()) {
+ default: return false;
+ case Intrinsic::nvvm_atomic_load_add_f32:
+ case Intrinsic::nvvm_atomic_load_inc_32:
+ case Intrinsic::nvvm_atomic_load_dec_32:
+ return true;
+ }
+}
+
+bool NVPTXTTIImpl::isSourceOfDivergence(const Value *V) {
+ // Without inter-procedural analysis, we conservatively assume that arguments
+ // to __device__ functions are divergent.
+ if (const Argument *Arg = dyn_cast<Argument>(V))
+ return !isKernelFunction(*Arg->getParent());
+
+ if (const Instruction *I = dyn_cast<Instruction>(V)) {
+ // Without pointer analysis, we conservatively assume values loaded from
+ // generic or local address space are divergent.
+ if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
+ unsigned AS = LI->getPointerAddressSpace();
+ return AS == ADDRESS_SPACE_GENERIC || AS == ADDRESS_SPACE_LOCAL;
+ }
+ // Atomic instructions may cause divergence. Atomic instructions are
+ // executed sequentially across all threads in a warp. Therefore, an earlier
+ // executed thread may see different memory inputs than a later executed
+ // thread. For example, suppose *a = 0 initially.
+ //
+ // atom.global.add.s32 d, [a], 1
+ //
+ // returns 0 for the first thread that enters the critical region, and 1 for
+ // the second thread.
+ if (I->isAtomic())
+ return true;
+ if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
+ // Instructions that read threadIdx are obviously divergent.
+ if (readsThreadIndex(II) || readsLaneId(II))
+ return true;
+ // Handle the NVPTX atomic instrinsics that cannot be represented as an
+ // atomic IR instruction.
+ if (isNVVMAtomic(II))
+ return true;
+ }
+ // Conservatively consider the return value of function calls as divergent.
+ // We could analyze callees with bodies more precisely using
+ // inter-procedural analysis.
+ if (isa<CallInst>(I))
+ return true;
+ }
+
+ return false;
+}
+
unsigned NVPTXTTIImpl::getArithmeticInstrCost(
unsigned Opcode, Type *Ty, TTI::OperandValueKind Opd1Info,
TTI::OperandValueKind Opd2Info, TTI::OperandValueProperties Opd1PropInfo,
diff --git a/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index bf21e88..4280888 100644
--- a/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -61,6 +61,8 @@ public:
bool hasBranchDivergence() { return true; }
+ bool isSourceOfDivergence(const Value *V);
+
unsigned getArithmeticInstrCost(
unsigned Opcode, Type *Ty,
TTI::OperandValueKind Opd1Info = TTI::OK_AnyValue,