Skip to content

Commit ae89bc8

Browse files
silverguofacebook-github-bot
authored andcommitted
Override unload_method in training_module to erase the tensors pointing to the released memory (#13590)
Summary: as title Differential Revision: D80754181
1 parent 335de46 commit ae89bc8

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

extension/training/module/test/training_module_test.cpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,3 +199,58 @@ TEST_F(TrainingModuleTest, DataExternalConstantsTest) {
199199
ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2);
200200
ASSERT_EQ(attributes.find("b")->second.dim(), 2);
201201
}
202+
203+
TEST_F(TrainingModuleTest, UnloadMethodTest) {
204+
const char* ptd_path = std::getenv("ET_MODULE_TRAIN_DATA_PATH");
205+
Result<FileDataLoader> data_map_loader_res = FileDataLoader::from(ptd_path);
206+
ASSERT_EQ(data_map_loader_res.error(), Error::Ok);
207+
208+
auto data_map_loader =
209+
std::make_unique<torch::executor::util::FileDataLoader>(
210+
std::move(data_map_loader_res.get()));
211+
212+
const char* pte_path = std::getenv("ET_MODULE_TRAIN_PROGRAM_PATH");
213+
Result<FileDataLoader> pte_loader_res = FileDataLoader::from(pte_path);
214+
ASSERT_EQ(pte_loader_res.error(), Error::Ok);
215+
216+
auto pte_loader = std::make_unique<torch::executor::util::FileDataLoader>(
217+
std::move(pte_loader_res.get()));
218+
219+
auto mod = executorch::extension::training::TrainingModule(
220+
std::move(pte_loader),
221+
nullptr,
222+
nullptr,
223+
nullptr,
224+
std::move(data_map_loader));
225+
226+
auto parameters_res = mod.named_parameters("forward");
227+
ASSERT_EQ(parameters_res.error(), Error::Ok);
228+
auto& parameters = parameters_res.get();
229+
230+
ASSERT_NEAR(
231+
parameters_res.get()
232+
.find("linear.bias")
233+
->second.const_data_ptr<float>()[0],
234+
0.1528,
235+
0.0001);
236+
237+
// mock training
238+
auto linear_bias_ptr =
239+
parameters.find("linear.bias")->second.mutable_data_ptr<float>();
240+
linear_bias_ptr[0] += 0.5;
241+
ASSERT_NEAR(
242+
parameters.find("linear.bias")->second.const_data_ptr<float>()[0],
243+
0.6528,
244+
0.0001);
245+
246+
mod.unload_method("forward");
247+
248+
auto new_parameters_res = mod.named_parameters("forward");
249+
ASSERT_EQ(new_parameters_res.error(), Error::Ok);
250+
ASSERT_NEAR(
251+
new_parameters_res.get()
252+
.find("linear.bias")
253+
->second.const_data_ptr<float>()[0],
254+
0.1528,
255+
0.0001);
256+
}

extension/training/module/training_module.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ class ET_EXPERIMENTAL TrainingModule final
4949
explicit TrainingModule(Module&&) = delete;
5050
TrainingModule& operator=(Module&&) = delete;
5151

52+
// Overriding to erase the tensors pointing to the released memory.
53+
inline bool unload_method(const std::string& method_name) {
54+
method_named_gradients_.erase(method_name);
55+
method_named_parameters_.erase(method_name);
56+
method_named_attributes_.erase(method_name);
57+
58+
return methods_.erase(method_name);
59+
}
60+
61+
inline bool unload_forward() {
62+
return unload_method("forward");
63+
}
64+
5265
/**
5366
* Execute a specific method with the given input and retrieve output. Only
5467
* valid if the specified method is a joint graph. Loads the program and

0 commit comments

Comments
 (0)