34 #include <cudnn_backend.h> 62 ss <<
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR :" 70 case CUDNN_POINTWISE_ADD:
71 case CUDNN_POINTWISE_MUL:
72 #if (CUDNN_VERSION >= 8300) 73 case CUDNN_POINTWISE_DIV:
74 case CUDNN_POINTWISE_ADD_SQUARE:
75 case CUDNN_POINTWISE_SUB:
76 case CUDNN_POINTWISE_CMP_EQ:
77 case CUDNN_POINTWISE_CMP_NEQ:
78 case CUDNN_POINTWISE_CMP_GT:
79 case CUDNN_POINTWISE_CMP_GE:
80 case CUDNN_POINTWISE_CMP_LT:
81 case CUDNN_POINTWISE_CMP_LE:
82 case CUDNN_POINTWISE_LOGICAL_AND:
83 case CUDNN_POINTWISE_LOGICAL_OR:
85 case CUDNN_POINTWISE_MIN:
86 case CUDNN_POINTWISE_MAX:
87 case CUDNN_POINTWISE_RELU_BWD:
88 case CUDNN_POINTWISE_TANH_BWD:
89 case CUDNN_POINTWISE_SIGMOID_BWD:
90 case CUDNN_POINTWISE_ELU_BWD:
91 case CUDNN_POINTWISE_GELU_BWD:
92 case CUDNN_POINTWISE_SOFTPLUS_BWD:
93 case CUDNN_POINTWISE_SWISH_BWD:
95 case CUDNN_POINTWISE_SQRT:
96 case CUDNN_POINTWISE_RELU_FWD:
97 case CUDNN_POINTWISE_TANH_FWD:
98 case CUDNN_POINTWISE_SIGMOID_FWD:
99 case CUDNN_POINTWISE_ELU_FWD:
100 case CUDNN_POINTWISE_GELU_FWD:
101 case CUDNN_POINTWISE_SOFTPLUS_FWD:
102 case CUDNN_POINTWISE_SWISH_FWD:
103 #if (CUDNN_VERSION >= 8300) 104 case CUDNN_POINTWISE_EXP:
105 case CUDNN_POINTWISE_LOG:
106 case CUDNN_POINTWISE_NEG:
107 case CUDNN_POINTWISE_MOD:
108 case CUDNN_POINTWISE_POW:
109 case CUDNN_POINTWISE_ABS:
110 case CUDNN_POINTWISE_CEIL:
111 case CUDNN_POINTWISE_FLOOR:
112 case CUDNN_POINTWISE_COS:
113 case CUDNN_POINTWISE_TAN:
114 case CUDNN_POINTWISE_SIN:
115 case CUDNN_POINTWISE_RSQRT:
116 case CUDNN_POINTWISE_LOGICAL_NOT:
140 cudnnPointwiseMode_t
mode = CUDNN_POINTWISE_ADD;
162 m_pointWiseDesc.math_precision = data_type_;
168 m_pointWiseDesc.upper_clip = u;
169 m_pointWiseDesc.lower_clip = l;
175 m_pointWiseDesc.mode = mode_;
181 m_pointWiseDesc.nan_propagation = nan_mode_;
188 m_pointWiseDesc.lower_clip = lower_clip_;
194 m_pointWiseDesc.upper_clip = upper_clip_;
200 m_pointWiseDesc.lower_clip_slope = lower_clip_slope_;
206 m_pointWiseDesc.elu_alpha = elu_alpha_;
212 m_pointWiseDesc.softplus_beta = softplus_beta_;
218 m_pointWiseDesc.swish_beta = swish_beta_;
227 auto status = m_pointWiseDesc.initialize_managed_backend_pointer(CUDNN_BACKEND_POINTWISE_DESCRIPTOR);
228 if (
status != CUDNN_STATUS_SUCCESS) {
230 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnCreate Failed");
231 return std::move(m_pointWiseDesc);
235 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
236 CUDNN_ATTR_POINTWISE_MODE,
237 CUDNN_TYPE_POINTWISE_MODE,
239 &m_pointWiseDesc.mode);
240 if (
status != CUDNN_STATUS_SUCCESS) {
244 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: CUDNN_TYPE_POINTWISE_MODE SetAttribute Failed");
245 return std::move(m_pointWiseDesc);
248 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
249 CUDNN_ATTR_POINTWISE_MATH_PREC,
250 CUDNN_TYPE_DATA_TYPE,
252 &m_pointWiseDesc.math_precision);
253 if (
status != CUDNN_STATUS_SUCCESS) {
257 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_MATH_PREC Failed");
258 return std::move(m_pointWiseDesc);
261 if (m_pointWiseDesc.mode == CUDNN_POINTWISE_RELU_FWD || m_pointWiseDesc.mode == CUDNN_POINTWISE_RELU_BWD) {
262 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
263 CUDNN_ATTR_POINTWISE_NAN_PROPAGATION,
264 CUDNN_TYPE_NAN_PROPOGATION,
266 &m_pointWiseDesc.nan_propagation);
267 if (
status != CUDNN_STATUS_SUCCESS) {
271 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_NAN_PROPAGATION Failed");
272 return std::move(m_pointWiseDesc);
275 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
276 CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP,
279 &m_pointWiseDesc.lower_clip);
280 if (
status != CUDNN_STATUS_SUCCESS) {
284 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP, Failed");
285 return std::move(m_pointWiseDesc);
288 if (m_pointWiseDesc.math_precision == CUDNN_DATA_FLOAT) {
289 double clamped_upper_clip =
290 std::min<double>(m_pointWiseDesc.upper_clip, std::numeric_limits<float>::max());
291 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
292 CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP,
295 &clamped_upper_clip);
298 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
299 CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP,
302 &m_pointWiseDesc.upper_clip);
304 if (
status != CUDNN_STATUS_SUCCESS) {
308 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_RELU_UPPER_CLIP, Failed");
309 return std::move(m_pointWiseDesc);
312 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
313 CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE,
316 &m_pointWiseDesc.lower_clip_slope);
317 if (
status != CUDNN_STATUS_SUCCESS) {
320 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute " 321 "CUDNN_ATTR_POINTWISE_RELU_LOWER_CLIP_SLOPE, Failed");
322 return std::move(m_pointWiseDesc);
324 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_ELU_FWD || m_pointWiseDesc.mode == CUDNN_POINTWISE_ELU_BWD) {
325 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
326 CUDNN_ATTR_POINTWISE_ELU_ALPHA,
329 &m_pointWiseDesc.elu_alpha);
330 if (
status != CUDNN_STATUS_SUCCESS) {
334 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_ELU_ALPHA, Failed");
335 return std::move(m_pointWiseDesc);
337 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_SOFTPLUS_FWD ||
338 m_pointWiseDesc.mode == CUDNN_POINTWISE_SOFTPLUS_BWD) {
339 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
340 CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA,
343 &m_pointWiseDesc.softplus_beta);
344 if (
status != CUDNN_STATUS_SUCCESS) {
348 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_SOFTPLUS_BETA, Failed");
349 return std::move(m_pointWiseDesc);
351 }
else if (m_pointWiseDesc.mode == CUDNN_POINTWISE_SWISH_FWD ||
352 m_pointWiseDesc.mode == CUDNN_POINTWISE_SWISH_BWD) {
353 status = cudnnBackendSetAttribute(m_pointWiseDesc.pointer->get_backend_descriptor(),
354 CUDNN_ATTR_POINTWISE_SWISH_BETA,
357 &m_pointWiseDesc.swish_beta);
358 if (
status != CUDNN_STATUS_SUCCESS) {
362 "CUDNN_BACKEND_POINTWISE_DESCRIPTOR: SetAttribute CUDNN_ATTR_POINTWISE_SWISH_BETA, Failed");
363 return std::move(m_pointWiseDesc);
368 status = cudnnBackendFinalize(m_pointWiseDesc.pointer->get_backend_descriptor());
369 if (
status != CUDNN_STATUS_SUCCESS) {
371 &m_pointWiseDesc,
status,
"CUDNN_BACKEND_POINTWISE_DESCRIPTOR: cudnnFinalize Failed");
372 return std::move(m_pointWiseDesc);
375 getLogger() <<
"[cudnn_frontend] " << m_pointWiseDesc << std::endl;
376 return std::move(m_pointWiseDesc);
PointWiseDesc_v8()=default
ConditionalStreamer & getLogger()
static void set_error_and_throw_exception(BackendDescriptor const *desc, cudnnStatus_t status, const char *message)
auto setClipping(double l, double u) -> PointWiseDescBuilder_v8 &
Set upper and lower limits for the RELU activation.
auto setMode(cudnnNanPropagation_t nan_mode_) -> PointWiseDescBuilder_v8 &
Set NaN propagation mode.
auto setSwishBeta(double swish_beta_) -> PointWiseDescBuilder_v8 &
cudnnPointwiseMode_t getPointWiseMode() const
cudnnNanPropagation_t nan_propagation
~PointWiseDesc_v8()=default
auto setReluLowerClip(double lower_clip_) -> PointWiseDescBuilder_v8 &
auto setSoftplusBeta(double softplus_beta_) -> PointWiseDescBuilder_v8 &
friend class PointWiseDescBuilder_v8
std::string describe() const override
Return a string describing the backend Descriptor.
auto setReluLowerClipSlope(double lower_clip_slope_) -> PointWiseDescBuilder_v8 &
int64_t getPortCount() const
PointWiseDesc_v8 m_pointWiseDesc
cudnnDataType_t math_precision
auto setMathPrecision(cudnnDataType_t data_type_) -> PointWiseDescBuilder_v8 &
Set Math Precision Data Type for the Convolution Operation.
auto setMode(cudnnPointwiseMode_t mode_) -> PointWiseDescBuilder_v8 &
Set pointwise mode for the activation.
auto setEluAlpha(double elu_alpha_) -> PointWiseDescBuilder_v8 &
cudnnPointwiseMode_t mode
PointWiseDesc_v8 & operator=(PointWiseDesc_v8 &&from)=default
auto setReluUpperClip(double upper_clip_) -> PointWiseDescBuilder_v8 &
cudnnStatus_t status
Shared pointer of the OpaqueBackendPointer.
PointWiseDesc_v8 && build()