@@ -131,12 +131,19 @@ void BLASKernelGenerator<hw>::binaryOp(BinaryOp op, int simd, const RegData &dst
131
131
132
132
// Apply binary operation to C with a scalar operand.
133
133
template <HW hw>
134
- void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const Subregister &offset ,
135
- const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state)
134
+ void BLASKernelGenerator<hw>::gemmScalarBinaryOpC(BinaryOp op, const GRFMultirange &offsets ,
135
+ const GEMMProblem &problem, const GEMMStrategy &strategy, GEMMState &state, Type Tco )
136
136
{
137
- auto offsetTc = offset.reinterpret (0 , state.Tacc .ngen ());
138
- if (offset != offsetTc)
139
- emov (1 , offsetTc, offset, strategy, state);
137
+ auto subOff = offsets[0 ].sub (0 , Tco.ngen ());
138
+ auto Tacc = state.Tacc ;
139
+ auto offsetTc = subOff.reinterpret (0 , Tacc.ngen ());
140
+ if (subOff != offsetTc && !one_of (Tco, Type::f8_e8m0, Type::hf8)){
141
+ emov (1 , offsetTc, subOff, strategy, state);
142
+ } else {
143
+ vector<RegisterBlock> repackLayout;
144
+ makeUnbackedRegLayout (Tacc, repackLayout, 1 , 1 , false );
145
+ copyRegisters (Tco, Tacc, repackLayout, repackLayout, offsets, offsets, strategy, state);
146
+ }
140
147
if (op == BinaryOp::Div && one_of (state.Tacc , Type::f32, Type::f16)) {
141
148
inv (1 , offsetTc, offsetTc);
142
149
op = BinaryOp::Mul;
@@ -322,7 +329,7 @@ bool BLASKernelGenerator<hw>::gemmBinaryOpC(BinaryOp op, bool row, bool column,
322
329
});
323
330
324
331
if (!row && !column)
325
- gemmScalarBinaryOpC (op, CO_regs[ 0 ]. sub ( 0 , Tco. ngen ()), problem, strategy, state);
332
+ gemmScalarBinaryOpC (op, CO_regs, problem, strategy, state, Tco );
326
333
else
327
334
gemmVectorBinaryOpC (op, column, CO_regs, Subregister (), problem, strategy, state, Tco, CO_layout);
328
335
}
0 commit comments