31 auto major_version = cudnnGetVersion() / 1000;
33 auto minor_version = (cudnnGetVersion() / 100) % 10;
34 if (major_version >= 8) {
35 if (minor_version <= 2) {
40 if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
41 if (opGraphTag.find(
"bias") == std::string::npos) {
42 std::vector<int> engine_list(50);
43 std::iota(engine_list.begin(), engine_list.end(), 0);
48 }
else if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
49 std::vector<int> engine_list(61);
50 std::iota(engine_list.begin(), engine_list.end(), 0);
52 }
else if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
58 if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR) {
59 if (opGraphTag.find(
"bias") == std::string::npos) {
64 }
else if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR) {
66 }
else if (mode == CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR) {
84 ss <<
"CUDNN_BACKEND_FALLBACK ENGINES :";
111 cudnnBackendDescriptorType_t
mode;
129 m_fallback_list.opGraph = opGraph_.get_desc();
130 m_fallback_list.opGraphTag = opGraph_.getTag();
131 m_fallback_list.num_ops = opGraph_.getOpCount();
136 m_fallback_list.mode =
mode;
145 if (m_fallback_list.opGraph ==
nullptr) {
147 CUDNN_STATUS_BAD_PARAM,
148 "CUDNN_BACKEND_ENGINEHEUR_DESCRIPTOR: Check and Set the " 149 "CUDNN_ATTR_ENGINEHEUR_OPERATION_GRAPH field for heuristic");
150 return std::move(m_fallback_list);
153 for (std::uint32_t i = 0; i < fallback_engine_list.size(); i++) {
154 #ifndef NV_CUDNN_DISABLE_EXCEPTION 159 .setOperationGraph(m_fallback_list.opGraph)
162 m_fallback_list.m_engine_configs.emplace_back(engine_config.get_desc());
163 #ifndef NV_CUDNN_DISABLE_EXCEPTION 169 getLogger() <<
"[cudnn_frontend] " << m_fallback_list << std::endl;
170 return std::move(m_fallback_list);
static auto get_fallback_engine_list(cudnnBackendDescriptorType_t mode, const std::string &opGraphTag) -> std::vector< int >
ConditionalStreamer & getLogger()
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setGlobalEngineIdx(int64_t idx_) -> EngineBuilder_v8 &
Set engine index for the engine.
std::string describe() const override
Return a string describing the backend Descriptor.
EngineFallbackList_v8(EngineFallbackList_v8 &&from)
cudnnBackendDescriptorType_t mode
ManagedOpaqueDescriptor get_desc() const
Returns a copy of underlying managed descriptor.
ManagedOpaqueDescriptor opGraph
auto setEngine(Engine_v8 const &engine_) -> EngineConfigBuilder_v8 &
Set engine for the EngineConfig_v8.
auto setOperation(cudnnBackendDescriptorType_t mode) -> EngineFallbackListBuilder_v8 &
cudnnStatus_t get_status() const
Current status of the descriptor.
~EngineFallbackList_v8()=default
friend class EngineFallbackListBuilder_v8
EngineFallbackList_v8()=default
std::shared_ptr< OpaqueBackendPointer > ManagedOpaqueDescriptor
const char * get_error() const
Diagonistic error message if any.
auto setOperationGraph(OperationGraph_v8 &opGraph_) -> EngineFallbackListBuilder_v8 &
Set operationGraph for the engine (opGraph is not destroyed)
EngineFallbackList_v8 && build()
EngineFallbackList_v8 m_fallback_list
EngineFallbackList_v8 & operator=(EngineFallbackList_v8 const &)=delete
auto getFallbackList() -> std::vector< ManagedOpaqueDescriptor > &
std::vector< ManagedOpaqueDescriptor > m_engine_configs