@@ -199,3 +199,58 @@ TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
199
199
ASSERT_EQ (attributes.find (" b" )->second .sizes ()[0 ], 2 );
200
200
ASSERT_EQ (attributes.find (" b" )->second .dim (), 2 );
201
201
}
202
+
203
+ TEST_F (TrainingModuleTest, UnloadMethodTest) {
204
+ const char * ptd_path = std::getenv (" ET_MODULE_TRAIN_DATA_PATH" );
205
+ Result<FileDataLoader> data_map_loader_res = FileDataLoader::from (ptd_path);
206
+ ASSERT_EQ (data_map_loader_res.error (), Error::Ok);
207
+
208
+ auto data_map_loader =
209
+ std::make_unique<torch::executor::util::FileDataLoader>(
210
+ std::move (data_map_loader_res.get ()));
211
+
212
+ const char * pte_path = std::getenv (" ET_MODULE_TRAIN_PROGRAM_PATH" );
213
+ Result<FileDataLoader> pte_loader_res = FileDataLoader::from (pte_path);
214
+ ASSERT_EQ (pte_loader_res.error (), Error::Ok);
215
+
216
+ auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
217
+ std::move (pte_loader_res.get ()));
218
+
219
+ auto mod = executorch::extension::training::TrainingModule (
220
+ std::move (pte_loader),
221
+ nullptr ,
222
+ nullptr ,
223
+ nullptr ,
224
+ std::move (data_map_loader));
225
+
226
+ auto parameters_res = mod.named_parameters (" forward" );
227
+ ASSERT_EQ (parameters_res.error (), Error::Ok);
228
+ auto & parameters = parameters_res.get ();
229
+
230
+ ASSERT_NEAR (
231
+ parameters_res.get ()
232
+ .find (" linear.bias" )
233
+ ->second .const_data_ptr <float >()[0 ],
234
+ 0.1528 ,
235
+ 0.0001 );
236
+
237
+ // mock training
238
+ auto linear_bias_ptr =
239
+ parameters.find (" linear.bias" )->second .mutable_data_ptr <float >();
240
+ linear_bias_ptr[0 ] += 0.5 ;
241
+ ASSERT_NEAR (
242
+ parameters.find (" linear.bias" )->second .const_data_ptr <float >()[0 ],
243
+ 0.6528 ,
244
+ 0.0001 );
245
+
246
+ mod.unload_method (" forward" );
247
+
248
+ auto new_parameters_res = mod.named_parameters (" forward" );
249
+ ASSERT_EQ (new_parameters_res.error (), Error::Ok);
250
+ ASSERT_NEAR (
251
+ new_parameters_res.get ()
252
+ .find (" linear.bias" )
253
+ ->second .const_data_ptr <float >()[0 ],
254
+ 0.1528 ,
255
+ 0.0001 );
256
+ }
0 commit comments