Skip to content

Commit 697db52

Browse files
committed
graph: interface: op: softmax supports mixed data types
1 parent 333ef80 commit 697db52

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

src/graph/interface/op_def.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -1032,12 +1032,15 @@ DNNL_GRAPH_OP_SCHEMA(SoftMax, 1,
10321032
op_schema_t()
10331033
.set_num_inputs(1)
10341034
.set_num_outputs(1)
1035-
.set_input(0, "src", "T")
1036-
.set_output(0, "dst", "T")
1035+
.set_input(0, "src", "T1")
1036+
.set_output(0, "dst", "T2")
10371037
.set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1)
10381038
.set_type_constraints(
1039-
"T", {data_type::f32, data_type::bf16, data_type::f16})
1040-
.set_shape_inference_function(infer_identity_output_shape))
1039+
"T1", {data_type::f32, data_type::bf16, data_type::f16})
1040+
.set_type_constraints(
1041+
"T2", {data_type::f32, data_type::bf16, data_type::f16})
1042+
.set_shape_inference_function(infer_identity_output_shape)
1043+
.set_op_def_constraint_function(check_softmax_dtype))
10411044

10421045
DNNL_GRAPH_OP_SCHEMA(SoftMaxBackward, 1,
10431046
op_schema_t()

src/graph/interface/op_def_constraint.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,25 @@ bool check_matmul_dtype(const op_t *mm) {
105105
return true;
106106
}
107107

108+
// For SoftMax, if the src is f32, dst can be xf16. Otherwise, src and dst
109+
// should have the same data type.
110+
bool check_softmax_dtype(const op_t *n) {
111+
const auto &inputs = n->get_input_values();
112+
const auto &outputs = n->get_output_values();
113+
114+
const logical_tensor_t &src = inputs[0]->get_logical_tensor();
115+
const logical_tensor_t &dst = outputs[0]->get_logical_tensor();
116+
if (src.data_type != dst.data_type) {
117+
if (src.data_type != data_type::f32) {
118+
VCHECK_SHAPE_INFER(false, "%s, %s src + %s dst is not supported",
119+
op_t::kind2str(n->get_kind()).c_str(),
120+
dnnl_dt2str(src.data_type), dnnl_dt2str(dst.data_type));
121+
}
122+
}
123+
124+
return true;
125+
}
126+
108127
// check function for data_type of LayerNorm and GroupNorm.
109128
// only when data is bf16, gamma/beta/mean/var can be bf16.
110129
// If data is bf16, gamma/beta/mean/var can be f32 or bf16.

src/graph/interface/op_def_constraint.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ bool check_bn_data_type(const op_t *n);
3030

3131
bool check_matmul_dtype(const op_t *n);
3232

33+
bool check_softmax_dtype(const op_t *n);
34+
3335
bool check_ln_gn_data_type(const op_t *n);
3436

3537
bool check_typecast_data_type(const op_t *n);

0 commit comments

Comments
 (0)