Skip to content

Commit

Permalink
Limit grad recursion depth by not recursing through non-grad inputs (#…
Browse files Browse the repository at this point in the history
…1764)

* limit grad recursion depth

* add grad of module test
  • Loading branch information
awni authored Jan 14, 2025
1 parent 5cc5201 commit 33421c1
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 100 deletions.
40 changes: 23 additions & 17 deletions mlx/transforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ void eval(std::vector<array> outputs) {
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
const std::vector<array>& cotans,
const std::vector<int>& argnums) {
// Set the global tracing flag.
detail::InTracing in_tracing;

Expand Down Expand Up @@ -330,10 +331,14 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
// to the tape which need a gradient.
std::unordered_set<std::uintptr_t> cache;
std::unordered_set<std::uintptr_t> calc_grad;
for (auto& primal : primals_) {
for (int i = 0, j = 0; i < primals_.size(); ++i) {
auto& primal = primals_[i];
primal.set_tracer(false);
calc_grad.insert(primal.id());
cache.insert(primal.id());
if (j < argnums.size() && argnums[j] == i) {
j++;
calc_grad.insert(primal.id());
}
}

std::vector<array> tape;
Expand Down Expand Up @@ -435,7 +440,8 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
}
}
std::vector<array> vjps;
for (auto& primal : primals_) {
for (auto arg : argnums) {
auto& primal = primals_[arg];
if (auto cotan_it = cotan_map.find(primal.id());
cotan_it != cotan_map.end()) {
vjps.push_back(cotan_it->second);
Expand All @@ -448,6 +454,15 @@ std::pair<std::vector<array>, std::vector<array>> vjp(
return {outputs, vjps};
}

std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotans) {
std::vector<int> argnums(primals.size());
std::iota(argnums.begin(), argnums.end(), 0);
return vjp(fun, primals, cotans, argnums);
}

std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
Expand Down Expand Up @@ -606,15 +621,10 @@ ValueAndGradFn value_and_grad(
<< inputs.size() << " inputs.";
throw std::invalid_argument(msg.str());
}
std::vector<int> sorted_argnums(args.begin(), args.end());

auto gfun = [&fun, &inputs, &args](const std::vector<array>& ginputs) {
std::vector<array> inputs_(inputs);
auto argit = args.begin();
for (int i = 0; i < ginputs.size(); ++i) {
inputs_[*argit] = ginputs[i];
++argit;
}
auto outputs = fun(inputs_);
auto gfun = [&fun](const std::vector<array>& inputs) {
auto outputs = fun(inputs);
for (int i = 1; i < outputs.size(); i++) {
auto& out = outputs[i];
auto s = out.has_primitive() ? out.primitive().stream()
Expand All @@ -624,12 +634,8 @@ ValueAndGradFn value_and_grad(
return outputs;
};

std::vector<array> ginputs;
for (auto arg : args) {
ginputs.push_back(inputs[arg]);
}
// Set the incoming gradient to float32, vjp will cast it to the output type
auto [outputs, grads] = vjp(gfun, ginputs, {array(1.0f)});
auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
return std::make_pair(outputs, grads);
};
}
Expand Down
165 changes: 89 additions & 76 deletions python/src/transforms.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
// Copyright © 2023-2024 Apple Inc.

#include <algorithm>
#include <numeric>
#include <sstream>
#include <unordered_set>

#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/pair.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_set.h>
#include <nanobind/stl/variant.h>
#include <nanobind/stl/vector.h>

#include <algorithm>
#include <numeric>
#include <sstream>

#include "mlx/array.h"
#include "mlx/compile.h"
#include "mlx/compile_impl.h"
Expand All @@ -27,52 +29,53 @@ using namespace nb::literals;
using mx::operator<<;

using IntOrVec = std::variant<int, std::vector<int>>;
using StrOrVec = std::variant<std::string, std::vector<std::string>>;
using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;

inline std::string type_name_str(const nb::handle& o) {
return nb::cast<std::string>(nb::type_name(o.type()));
}

template <typename T>
std::vector<T> to_vector(const std::variant<T, std::vector<T>>& v) {
std::vector<T> vals;
if (auto pv = std::get_if<T>(&v); pv) {
vals.push_back(*pv);
} else {
vals = std::get<std::vector<T>>(v);
}
return vals;
}

auto validate_argnums_argnames(
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto vec_names = to_vector(argnames);
const StrOrSet& argnames) {
std::unordered_set<std::string> setnames;
if (auto pv = std::get_if<std::string>(&argnames); pv) {
setnames = {*pv};
} else {
setnames = std::get<std::unordered_set<std::string>>(argnames);
}

if (!argnums.has_value()) {
// argnums was not provided and argnames was empty
if (vec_names.empty()) {
return std::make_pair(std::vector<int>{0}, vec_names);
if (setnames.empty()) {
return std::make_pair(std::vector<int>{0}, setnames);
} else {
return std::make_pair(std::vector<int>{}, vec_names);
return std::make_pair(std::vector<int>{}, setnames);
}
}

return std::make_pair(to_vector(*argnums), vec_names);
std::vector<int> vecnums;
if (auto pv = std::get_if<int>(&(*argnums)); pv) {
vecnums = {*pv};
} else {
vecnums = std::get<std::vector<int>>(*argnums);
}

return std::make_pair(vecnums, setnames);
}

auto py_value_and_grad(
const nb::callable& fun,
std::vector<int> argnums,
std::vector<std::string> argnames,
std::unordered_set<std::string> argnames,
const std::string& error_msg_tag,
bool scalar_func_only) {
// Sanitize argnums
if (argnums.size() == 0 && argnames.size() == 0) {
throw std::invalid_argument(
error_msg_tag + " Gradient wrt no argument requested");
}
if (argnums.size() > 0) {
for (auto arg : argnums) {
std::sort(argnums.begin(), argnums.end());
if (argnums[0] < 0) {
std::ostringstream msg;
Expand All @@ -81,10 +84,18 @@ auto py_value_and_grad(
<< argnums[0];
throw std::invalid_argument(msg.str());
}
for (int i = 1; i < argnums.size(); ++i) {
if (argnums[i] == argnums[i - 1]) {
std::ostringstream msg;
msg << error_msg_tag << " Duplicate argument index " << argnums[0]
<< " is not allowed.";
throw std::invalid_argument(msg.str());
}
}
}

return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
const nb::args& args, const nb::kwargs& kwargs) {
nb::args& args, nb::kwargs& kwargs) {
// Sanitize the input
if (argnums.size() > 0 && argnums.back() >= args.size()) {
std::ostringstream msg;
Expand Down Expand Up @@ -112,59 +123,59 @@ auto py_value_and_grad(
// Collect the arrays
std::vector<mx::array> arrays;
std::vector<int> counts(1, 0);
for (auto i : argnums) {
auto argsi = tree_flatten(args[i]);
std::vector<int> gradient_indices;
for (int i = 0, j = 0; i < args.size(); ++i) {
bool needs_grad = (j < argnums.size() && argnums[j] == i);
auto argsi = tree_flatten(args[i], /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsi.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
j++;
counts.push_back(argsi.size());
}
arrays.insert(arrays.end(), argsi.begin(), argsi.end());
counts.push_back(argsi.size());
}
for (auto& key : argnames) {
auto argsk = tree_flatten(kwargs[key.c_str()]);
for (auto item : kwargs) {
bool needs_grad =
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
auto argsk = tree_flatten(item.second, /* strict = */ needs_grad);
if (needs_grad) {
auto old_size = gradient_indices.size();
gradient_indices.resize(old_size + argsk.size());
std::iota(
gradient_indices.begin() + old_size,
gradient_indices.end(),
arrays.size());
counts.push_back(argsk.size());
}
arrays.insert(arrays.end(), argsk.begin(), argsk.end());
counts.push_back(argsk.size());
}
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
std::vector<int> gradient_indices(arrays.size());
std::iota(gradient_indices.begin(), gradient_indices.end(), 0);

// value_out will hold the output of the python function in order to be
// able to reconstruct the python tree of extra return values
nb::object py_value_out;
auto value_and_grads = mx::value_and_grad(
[&fun,
&arrays,
&args,
&kwargs,
&argnums,
&argnames,
&counts,
&py_value_out,
&error_msg_tag,
scalar_func_only](const std::vector<mx::array>& a) {
// Copy the arguments
nb::list args_cpy;
nb::kwargs kwargs_cpy = nb::kwargs();
int j = 0;
for (int i = 0; i < args.size(); ++i) {
if (j < argnums.size() && i == argnums[j]) {
args_cpy.append(tree_unflatten(args[i], a, counts[j]));
j++;
} else {
args_cpy.append(args[i]);
}
}
for (auto& key : argnames) {
kwargs_cpy[key.c_str()] =
tree_unflatten(kwargs[key.c_str()], a, counts[j]);
j++;
}
for (auto item : kwargs) {
if (kwargs_cpy.contains(item.first)) {
continue;
}
kwargs_cpy[item.first] = item.second;
}
nb::list tree;
tree.append(args);
tree.append(kwargs);
tree_replace(tree, arrays, a);

// Call the python function
py_value_out = fun(*args_cpy, **kwargs_cpy);
py_value_out = fun(*tree[0], **tree[1]);

tree_replace(tree, arrays, a);

// Validate the return value of the python function
if (!nb::isinstance<mx::array>(py_value_out)) {
Expand Down Expand Up @@ -247,10 +258,13 @@ auto py_value_and_grad(
py_grads = positional_grads;
} else {
nb::dict grads_;
for (int i = 0; i < argnames.size(); i++) {
auto& k = argnames[i];
grads_[k.c_str()] = tree_unflatten(
kwargs[k.c_str()], gradients, counts[i + argnums.size()]);
int i = 0;
for (auto item : kwargs) {
auto k = nb::cast<std::string>(item.first);
if (argnames.find(k) != argnames.end()) {
grads_[k.c_str()] = tree_unflatten(
nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);
}
}
keyword_grads = grads_;

Expand Down Expand Up @@ -1207,17 +1221,17 @@ void init_transforms(nb::module_& m) {
"value_and_grad",
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
const StrOrSet& argnames) {
auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames);
return nb::cpp_function(py_value_and_grad(
fun, argnums_vec, argnames_vec, "[value_and_grad]", false));
fun, argnums_vec, argnames_set, "[value_and_grad]", false));
},
"fun"_a,
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
R"pbdoc(
Returns a function which computes the value and gradient of ``fun``.
Expand Down Expand Up @@ -1271,21 +1285,20 @@ void init_transforms(nb::module_& m) {
"grad",
[](const nb::callable& fun,
const std::optional<IntOrVec>& argnums,
const StrOrVec& argnames) {
auto [argnums_vec, argnames_vec] =
const StrOrSet& argnames) {
auto [argnums_vec, argnames_set] =
validate_argnums_argnames(argnums, argnames);
auto fn =
py_value_and_grad(fun, argnums_vec, argnames_vec, "[grad]", true);
return nb::cpp_function(
[fn](const nb::args& args, const nb::kwargs& kwargs) {
return fn(args, kwargs).second;
});
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
return nb::cpp_function([fn](nb::args& args, nb::kwargs& kwargs) {
return fn(args, kwargs).second;
});
},
"fun"_a,
"argnums"_a = nb::none(),
"argnames"_a = std::vector<std::string>{},
nb::sig(
"def grad(fun: Callable, argnums: Optional[Union[int, list[int]]] = None, argnames: Union[str, list[str]] = []) -> Callable"),
"def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"),
R"pbdoc(
Returns a function which computes the gradient of ``fun``.
Expand Down
11 changes: 6 additions & 5 deletions python/src/trees.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void tree_visit(
return recurse(trees);
}

void tree_visit(nb::object tree, std::function<void(nb::handle)> visitor) {
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) {
std::function<void(nb::handle)> recurse;
recurse = [&](nb::handle subtree) {
if (nb::isinstance<nb::list>(subtree) ||
Expand Down Expand Up @@ -178,10 +178,11 @@ void tree_visit_update(
}
return nb::cast<nb::object>(l);
} else if (nb::isinstance<nb::tuple>(subtree)) {
for (auto item : subtree) {
recurse(item);
nb::list l(subtree);
for (int i = 0; i < l.size(); ++i) {
l[i] = recurse(l[i]);
}
return nb::cast<nb::object>(subtree);
return nb::cast<nb::object>(nb::tuple(l));
} else if (nb::isinstance<nb::dict>(subtree)) {
auto d = nb::cast<nb::dict>(subtree);
for (auto item : d) {
Expand Down Expand Up @@ -224,7 +225,7 @@ void tree_replace(
});
}

std::vector<mx::array> tree_flatten(nb::object tree, bool strict /* = true */) {
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict /* = true */) {
std::vector<mx::array> flat_tree;

tree_visit(tree, [&](nb::handle obj) {
Expand Down
Loading

0 comments on commit 33421c1

Please sign in to comment.