@@ -225,3 +225,50 @@ TEST_F(TransformationTestsF, ConvertGatherToCompressedMultiOutput) {
225
225
model_ref = std::make_shared<ov::Model>(ov::NodeVector{gather_compressed}, ov::ParameterVector{input1, input2});
226
226
}
227
227
}
228
+
229
+ // In compressed FP16/BF16 weight case, gather node with constant weight decompression pattern (FP16/BF16 +
230
+ // convert(FP32)) is transformed to gather node with compressed (FP16/BF16) weights, and decompression convert is moved
231
+ // after gather node, so GatherCompressed node should not be generated.
232
+ TEST_F (TransformationTestsF, MoveDecompressionAfterGatherFP16Weight) {
233
+ {
234
+ auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 , 16 });
235
+ auto axis_const = ov::op::v0::Constant::create (ov::element::i32, ov::Shape{1 }, {1 });
236
+ auto weights_const = ov::op::v0::Constant::create (ov::element::f16, ov::Shape{32 , 16 }, {1 });
237
+ auto convert = std::make_shared<ov::op::v0::Convert>(weights_const, ov::element::f32);
238
+ auto gather = std::make_shared<ov::op::v8::Gather>(convert, input1, axis_const);
239
+
240
+ model = std::make_shared<ov::Model>(ov::NodeVector{gather}, ov::ParameterVector{input1});
241
+ manager.register_pass <MoveDecompressionAfterGather>();
242
+ }
243
+ {
244
+ auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 , 16 });
245
+ auto axis_const = ov::op::v0::Constant::create (ov::element::i32, ov::Shape{1 }, {1 });
246
+ auto weights_const = ov::op::v0::Constant::create (ov::element::f16, ov::Shape{32 , 16 }, {1 });
247
+ auto gather = std::make_shared<ov::op::v8::Gather>(weights_const, input1, axis_const);
248
+ auto convert = std::make_shared<ov::op::v0::Convert>(gather, ov::element::f32);
249
+
250
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input1});
251
+ }
252
+ }
253
+
254
+ TEST_F (TransformationTestsF, MoveDecompressionAfterGatherBF16Weight) {
255
+ {
256
+ auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 , 16 });
257
+ auto axis_const = ov::op::v0::Constant::create (ov::element::i32, ov::Shape{1 }, {1 });
258
+ auto weights_const = ov::op::v0::Constant::create (ov::element::bf16, ov::Shape{32 , 16 }, {1 });
259
+ auto convert = std::make_shared<ov::op::v0::Convert>(weights_const, ov::element::f32);
260
+ auto gather = std::make_shared<ov::op::v8::Gather>(convert, input1, axis_const);
261
+
262
+ model = std::make_shared<ov::Model>(ov::NodeVector{gather}, ov::ParameterVector{input1});
263
+ manager.register_pass <MoveDecompressionAfterGather>();
264
+ }
265
+ {
266
+ auto input1 = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, ov::PartialShape{-1 , 16 });
267
+ auto axis_const = ov::op::v0::Constant::create (ov::element::i32, ov::Shape{1 }, {1 });
268
+ auto weights_const = ov::op::v0::Constant::create (ov::element::bf16, ov::Shape{32 , 16 }, {1 });
269
+ auto gather = std::make_shared<ov::op::v8::Gather>(weights_const, input1, axis_const);
270
+ auto convert = std::make_shared<ov::op::v0::Convert>(gather, ov::element::f32);
271
+
272
+ model_ref = std::make_shared<ov::Model>(ov::NodeVector{convert}, ov::ParameterVector{input1});
273
+ }
274
+ }
0 commit comments