Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parser rewrite #1296

Merged
merged 24 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bfebc64
Running into issues with identity nodes
skewballfox Feb 4, 2024
4bfc0f2
Merge branch 'tracel-ai:main' into parser_rewrite
skewballfox Feb 11, 2024
5df9778
Vec<RefCell<Node>> seems to work for this
skewballfox Feb 11, 2024
9e4a503
back to passing tests
skewballfox Feb 12, 2024
d6ed1d5
Reworked IO into separate struct
skewballfox Feb 16, 2024
c350d03
working towards exploiting topological ordering and more informative …
skewballfox Feb 16, 2024
6ce8422
the passing of an initializer to coalesce is temporary
skewballfox Feb 16, 2024
7456ce7
cleaning up dead code
skewballfox Feb 16, 2024
7faf8f3
Merge branch 'main' into parser_rewrite
skewballfox Feb 17, 2024
77827bb
handled unsqueeze
skewballfox Feb 17, 2024
b3a6ebc
reworked node initialization and dim inference
skewballfox Feb 18, 2024
4624451
mainly cleanup
skewballfox Feb 18, 2024
1f9c051
changed how io use is tracked, moved unsqueeze remapping out of dim i…
skewballfox Feb 18, 2024
82bdf19
`cargo xtask run-checks all` now passes
skewballfox Feb 19, 2024
c8ed7de
added a fixme and a few doc strings
skewballfox Feb 19, 2024
8d495fc
removing println and dead code
skewballfox Feb 20, 2024
b7fb819
spaces in doc strings
skewballfox Feb 21, 2024
6fb94fa
altered top sort to work on node proto, moved prior to node gen
skewballfox Feb 21, 2024
ab7c9c2
Update ir.rs
skewballfox Feb 21, 2024
f740b9e
Update from_onnx.rs
skewballfox Feb 21, 2024
0f33517
updated doc string
skewballfox Feb 21, 2024
be57a67
camalcased Onnx Graph Builder
skewballfox Feb 21, 2024
f3b5b07
Merge branch 'main' into parser_rewrite
skewballfox Feb 21, 2024
3e921d6
removed self import?
skewballfox Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions burn-import/src/burn/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ impl Type {

impl ScalarType {
pub fn new<S: AsRef<str>>(name: S, kind: ScalarKind) -> Self {
if name.as_ref().is_empty() {
panic!("Scalar of Type {:?} was passed with empty name", kind);
}
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
kind,
Expand All @@ -95,6 +98,12 @@ impl TensorType {
kind: TensorKind,
shape: Option<Vec<usize>>,
) -> Self {
if name.as_ref().is_empty() {
panic!(
"Tensor of Kind {:?} with dim shape {:?} was passed with empty name",
kind, shape
);
}
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
dim,
Expand Down Expand Up @@ -141,6 +150,12 @@ impl TensorType {

impl OtherType {
pub fn new<S: AsRef<str>>(name: S, tokens: TokenStream) -> Self {
if name.as_ref().is_empty() {
panic!(
"Other type with tokens {:?} was passed with empty name",
tokens
);
}
Self {
name: Ident::new(name.as_ref(), Span::call_site()),
ty: tokens,
Expand Down
80 changes: 42 additions & 38 deletions burn-import/src/onnx/coalesce.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
use std::{iter::Peekable, slice::IterMut};

use super::ir::{AttributeValue, Node, NodeType};
use std::{iter::Peekable, slice::Iter};

use super::{
from_onnx::OnnxGraphIO,
ir::{AttributeValue, Node, NodeType},
proto_conversion::convert_node_proto,
protos::NodeProto,
};
use crate::onnx::ir::{ArgType, Data, TensorType};

/// The function transforms the graph into a new one where the nodes are coalesced into a single node.
pub fn coalesce(nodes: &mut Vec<Node>) {
let mut iter_mut = nodes.iter_mut().peekable();
let mut nodes_to_remove: Vec<String> = vec![];
while let Some(node) = iter_mut.next() {
match node.node_type {
NodeType::Gemm => convert_gemm_to_linear(node),
NodeType::MatMul => {
convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove);
}
_ => {}
pub fn coalesce(
node: &mut Node,
nodes_iter: &mut Peekable<Iter<NodeProto>>,
graph_io: &OnnxGraphIO,
) {
match node.node_type {
NodeType::Gemm => convert_gemm_to_linear(node),
NodeType::MatMul => {
convert_matmul_to_linear(node, nodes_iter, graph_io);
}
}

// Remove nodes instructed by conversation functions
for node_to_remove in nodes_to_remove {
nodes.retain(|n| n.name != node_to_remove);
_ => {}
}
}

/// This function converts a Gemm node into a Linear node
///
/// PyTorch and other frameworks use Gemm node to represent Linear layer.
fn convert_gemm_to_linear(node: &mut Node) {
pub(crate) fn convert_gemm_to_linear(node: &mut Node) {
if node.outputs.len() != 1 {
panic!("Gemm node must have 1 output");
}
Expand Down Expand Up @@ -117,10 +117,10 @@ fn transpose_flattened<T: Copy>(matrix: Vec<T>, rows: usize, cols: usize) -> Vec
///
/// This function also converts the following Add node into a Linear node if possible.
/// Add node is used to represent bias in PyTorch.
fn convert_matmul_to_linear(
pub(crate) fn convert_matmul_to_linear(
node: &mut Node,
iter_mut: &mut Peekable<IterMut<Node>>,
nodes_to_remove: &mut Vec<String>,
iter_mut: &mut Peekable<Iter<NodeProto>>,
graph_io: &OnnxGraphIO,
) {
if node.inputs.len() != 2 {
panic!("MatMul node must have 2 inputs");
Expand All @@ -143,30 +143,37 @@ fn convert_matmul_to_linear(

// Check the next node for potential conversion
if let Some(peek_node) = iter_mut.peek() {
if is_add_node_with_bias(peek_node, node) {
convert_and_remove_add_node(iter_mut, nodes_to_remove, node);
let peek_node = convert_node_proto(peek_node, graph_io).clone();
println!("next node is {:?}", peek_node);
if is_add_node_with_bias(&peek_node, node) {
skewballfox marked this conversation as resolved.
Show resolved Hide resolved
convert_and_remove_add_node(&peek_node, node);

// You don't have to remove it if it's never stored in the first place
let _ = iter_mut.next();
println!("\n\nskipping add node\n\n");
}
skewballfox marked this conversation as resolved.
Show resolved Hide resolved
}
}

/// Helper function to check if the peeked node is an Add node with bias
fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool {
peek_node.node_type == NodeType::Add
&& peek_node.inputs.len() == 2
&& ((peek_node.inputs[0].name == current_node.outputs[0].name
if peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2 {
println!("\n\ntwo matches");
println!("peek_node.inputs[0].name: {:?}", peek_node.inputs[0].name);
println!(
"current_node.outputs[0].name: {:?}",
current_node.outputs[0].name
);
skewballfox marked this conversation as resolved.
Show resolved Hide resolved
return (peek_node.inputs[0].name == current_node.outputs[0].name
&& peek_node.inputs[1].value.is_some())
|| (peek_node.inputs[1].name == current_node.outputs[0].name
&& peek_node.inputs[0].value.is_some()))
&& peek_node.inputs[0].value.is_some());
}
false
}

/// Helper function to convert and remove the Add node
fn convert_and_remove_add_node(
iter_mut: &mut Peekable<IterMut<Node>>,
nodes_to_remove: &mut Vec<String>,
current_node: &mut Node,
) {
let bias_node = iter_mut.next().unwrap();

fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) {
let bias_input = if bias_node.inputs[0].value.is_some() {
bias_node.inputs[0].clone()
} else {
Expand All @@ -176,7 +183,4 @@ fn convert_and_remove_add_node(
// Push the bias input and update the output name
current_node.inputs.push(bias_input);
current_node.outputs[0].name = bias_node.outputs[0].name.clone();

// Remove the Add node
nodes_to_remove.push(bias_node.name.clone());
}
Loading
Loading