From bfebc6428a6fcb227e3d2b60034c3d8ec122e78b Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 4 Feb 2024 16:06:32 -0600 Subject: [PATCH 01/21] Running into issues with identity nodes --- burn-import/src/onnx/coalesce.rs | 56 ++++- burn-import/src/onnx/from_onnx.rs | 377 +++++++++++++++++++++++++++--- 2 files changed, 399 insertions(+), 34 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 623d7584f2..0de0b58ebe 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,4 +1,4 @@ -use std::{iter::Peekable, slice::IterMut}; +use std::{collections::HashSet, iter::Peekable, slice::IterMut}; use super::ir::{AttributeValue, Node, NodeType}; use crate::onnx::ir::{ArgType, Data, TensorType}; @@ -26,7 +26,7 @@ pub fn coalesce(nodes: &mut Vec) { /// This function converts a Gemm node into a Linear node /// /// PyTorch and other frameworks use Gemm node to represent Linear layer. -fn convert_gemm_to_linear(node: &mut Node) { +pub(crate) fn convert_gemm_to_linear(node: &mut Node) { if node.outputs.len() != 1 { panic!("Gemm node must have 1 output"); } @@ -149,6 +149,41 @@ fn convert_matmul_to_linear( } } +pub(crate) fn convert_matmul_to_linear2( + node_vec: &mut Vec, + node_index: usize, + nodes_to_remove: &mut HashSet, +) { + let mut iter_mut = node_vec.iter_mut().peekable(); + let node = iter_mut.nth(node_index).unwrap(); + if node.inputs.len() != 2 { + panic!("MatMul node must have 2 inputs"); + } + + // if the second input does not have a value, it is not a weight, then proceed to the next node + if node.inputs[1].value.is_none() { + return; + } + + // Check if the second input is a 2D tensor + if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { + assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); + } else { + panic!("Tensor input is expected"); + } + + // Convert the node to Linear + node.node_type = NodeType::Linear; + + // Check the next node for potential conversion + + if let Some(next_node) = iter_mut.peek() { + if is_add_node_with_bias(&next_node, node) { + convert_node2(next_node, nodes_to_remove, node); + nodes_to_remove.insert(node_index + 1); + } + } +} /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { peek_node.node_type == NodeType::Add @@ -180,3 +215,20 @@ fn convert_and_remove_add_node( // Remove the Add node nodes_to_remove.push(bias_node.name.clone()); } + +/// Helper function to convert and remove the Add node +pub(crate) fn convert_node2( + bias_node: &Node, + nodes_to_remove: &mut HashSet, + current_node: &mut Node, +) { + let bias_input = if bias_node.inputs[0].value.is_some() { + bias_node.inputs[0].clone() + } else { + bias_node.inputs[1].clone() + }; + + // Push the bias input and update the output name + current_node.inputs.push(bias_input); + current_node.outputs[0].name = bias_node.outputs[0].name.clone(); +} diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 6b9db5b01b..30d83867b5 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -9,9 +9,15 @@ use crate::onnx::{ proto_conversion::convert_node_proto, }; -use super::dim_inference::dim_inference; -use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}; -use super::protos::{ModelProto, TensorProto}; +use super::{ + coalesce::convert_gemm_to_linear, + protos::{ModelProto, TensorProto}, +}; +use super::{ + coalesce::convert_matmul_to_linear2, + ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}, +}; +use super::{dim_inference::dim_inference, protos::ValueInfoProto}; use protobuf::Message; @@ -22,7 +28,212 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ NodeType::Conv2d, NodeType::Dropout, NodeType::Reshape, + NodeType::Unsqueeze, ]; +#[derive(Default)] +pub(crate) struct ONNXGraphBuilder { + nodes: Vec, + node_name_counter: HashMap, + old_node_names: HashMap, + //map of input names to a vec of indices of nodes that use it + input_of: HashMap>, + outputs_to_move: HashMap, + //map of output names to + output_of: HashMap, + //nodes to remove + nodes_to_remove: HashSet, + //constants to lift + constants: HashMap, + postprocess_for_constants: Vec, + constants_types: HashSet, + //identity_nodes + identity_nodes: Vec, + //matmul nodes + matmul_nodes: Vec, +} + +impl ONNXGraphBuilder { + pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { + self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); + // Convert initializers to hashmap for faster lookup + let initializers = model_proto + .graph + .initializer + .iter() + .map(|x| (x.name.clone(), x.clone())) + .collect::>(); + + let inputs = model_proto + .graph + .input + .iter() + .map(|x| (x.name.clone(), x.clone())) + .collect::>(); + + for (i, node_proto) in model_proto.graph.node.iter().enumerate() { + let mut node = convert_node_proto(node_proto); + for node_input in node.inputs.iter_mut() { + self.input_of + .entry(node_input.name.clone()) + .and_modify(|f| f.push(i)) + .or_insert(vec![i]); + if let Some(initializer) = initializers.get(&node_input.name) { + move_initializer_data(initializer, node_input); + } + } + remap_node_type(&mut node); + for node_output in node.outputs.iter() { + self.output_of.insert(node_output.name.clone(), i); + } + + let node_type = node.node_type.clone(); + if self.handle_identity(&node_type, &node, i) + || self.check_constants(&node, &node_type, i) + { + self.constants.insert(node.name.clone(), node); + } else { + //name stuff + self.handle_node_renaming(&node_type, &mut node); + + self.handle_unsqueeze(&node_type, &node, i); + //NOTE: still not done with this one + + self.handle_coalesce(node_type, &mut node, i); + + self.nodes.push(node); + } + } + self.postprocess_unsqueeze(inputs, model_proto.graph.output.clone()); + self.postprocess_identity(); + self.postprocess_constants(); + self.postprocess_coalesce(); + } + + fn handle_node_renaming(&mut self, node_type: &NodeType, node: &mut Node) { + self.node_name_counter + .entry(node_type.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + let old_name = node.name.clone(); + let new_name = + format!("{}{}", node.node_type, self.node_name_counter[&node_type]).to_lowercase(); + node.name = new_name.clone(); + println!("old name {:?} new name {:?}", old_name, new_name); + self.old_node_names.insert(old_name, new_name); + } + + fn postprocess_rename_inputs(&mut self, inputs: &mut Vec) { + for input in inputs.iter_mut() { + if let Some(new_name) = self.old_node_names.get(&input.name) { + input.name = new_name.clone(); + } + } + } + + fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) -> bool { + if node_type == &NodeType::Constant { + return true; + } else if self.constants_types.contains(node_type) { + self.postprocess_for_constants.push(i); + } + false + } + + fn postprocess_constants(&mut self) { + for i in self.postprocess_for_constants.iter() { + let mut node = &mut self.nodes[*i]; + + for input in node.inputs.iter_mut().skip(1) { + if let Some(constant) = self.constants.get(&input.name) { + if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { + input.value = constant.outputs[0].value.clone(); + input.ty = constant.outputs[0].ty.clone(); + } else { + let arg = convert_constant_value(&constant); + input.value = arg.value; + input.ty = arg.ty; + } + } + } + } + } + + //fn get_mult_ref(&self, node_name: String, node_index, ) + + fn handle_unsqueeze(&mut self, node_type: &NodeType, node: &Node, i: usize) { + if *node_type == NodeType::Unsqueeze { + self.outputs_to_move.insert(node.outputs[0].name.clone(), i); + } + } + + fn postprocess_unsqueeze( + &mut self, + node_inputs: HashMap, + graph_outputs: Vec, + ) { + for (output_name, i) in self.outputs_to_move.iter() { + let node = &mut self.nodes[*i]; + if let Some(val_proto) = node_inputs.get(output_name) { + move_output_shape(node, val_proto); + } else { + for output in graph_outputs.iter() { + if output.name == *output_name { + move_output_shape(node, output); + } + } + } + } + } + + fn handle_identity(&mut self, node_type: &NodeType, node: &Node, i: usize) -> bool { + if node_type == &NodeType::Identity && node.inputs[0].value.is_none() { + self.identity_nodes.push(node.name.clone()); + // self.nodes_to_remove.insert(i); + return true; + } + false + } + + fn postprocess_identity(&mut self) { + for identity in self.identity_nodes.iter() { + let identity_node = self + .constants + .get(identity) + .expect(&format!("Identity node {} not found", identity)); + + let input_name = &identity_node.inputs[0].name; + let identity_output = &identity_node.outputs[0].name; + + // Replace the identity node's output with its input in the connected nodes. + if let Some(node_index) = self.input_of.get(identity_output) { + for node_index in node_index { + let node = &mut self.nodes[*node_index]; + if let Some(matched_input) = + node.inputs.iter_mut().find(|x| x.name == *identity_output) + { + matched_input.name = input_name.clone(); + } + } + } + } + } + + fn handle_coalesce(&mut self, node_type: NodeType, node: &mut Node, i: usize) { + match node_type { + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + self.matmul_nodes.push(i); + } + _ => {} + } + } + + fn postprocess_coalesce(&mut self) { + for matmul_index in self.matmul_nodes.clone() { + convert_matmul_to_linear2(&mut self.nodes, matmul_index, &mut self.nodes_to_remove); + } + } +} /// Open an onnx file and convert it to a Graph (intermediate representation) /// @@ -56,35 +267,78 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { ); log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + let mut builder = ONNXGraphBuilder::default(); + builder.node_gen(&onnx_model); + let ONNXGraphBuilder { + mut nodes, + old_node_names, + nodes_to_remove, + .. + } = builder; // Convert the nodes - let mut nodes: Vec = vec![]; - for onnx_node in onnx_model.graph.node.iter() { - let mut node = convert_node_proto(onnx_node); - remap_node_type(&mut node); - nodes.push(node); - } - - // ONNX nodes must be topologically sorted per spec: - // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - - // Move inputs with initializers to states - move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); - - // Handle Identity nodes (expects inputs to be moved to states) - handle_identity(&mut nodes); - - // Lift constants to initializers (expects inputs to be moved to states) - lift_constants(&mut nodes); - - // Coalesce and transform nodes - coalesce(&mut nodes); - - // Rename nodes and inputs, save the mapping for later - let old_node_names = rename_nodes(&mut nodes); - - // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. + // let mut nodes: Vec = vec![]; + // println!("onnx_model.graph.node: {:#?}", onnx_model.graph.node); + // for onnx_node in onnx_model.graph.node.iter() { + // let mut node = convert_node_proto(onnx_node); + // if onnx_node.op_type == "Unsqueeze" { + // move_output_for_unsqueeze( + // &mut node, + // onnx_model.graph.input.clone(), + // onnx_model.graph.output.clone(), + // ); + // } + // remap_node_type(&mut node); + // nodes.push(node); + // } + + // // ONNX nodes must be topologically sorted per spec: + // // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + // assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + + // //TMP: converts initializers to hashmap, iterates over nodes + // // Move inputs with initializers to states + // move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); + // //TMP: loops over all nodes, filters for identity nodes, + // //for each identity node, loop over all nodes and if the identity is in the inputs, replace + // // Handle Identity nodes (expects inputs to be moved to states) + //handle_identity(&mut nodes); + // //TMP: loops over all nodes, filters for constants, collects them into a hashmap + // //loop over all nodes, filter for types in constant lifting, for each input + // //if it's in the hashmap, move the data to appropriate input + // // Lift constants to initializers (expects inputs to be moved to states) + //lift_constants(&mut nodes); + // //TMP: loops over all nodes, if type matches one of the current coalesce types, call the function + // //loop + // // Coalesce and transform nodes + // coalesce(&mut nodes); + + // //TMP: loop over all nodes, rename each with it's name and a counter + // //keep a map of old names to new names + // // Rename nodes and inputs, save the mapping for later + // let old_node_names = rename_nodes(&mut nodes); + + // // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. + // let mut inputs = onnx_model + // .graph + // .input + // .iter() + // .map(|x| Argument::try_from(x.clone()).unwrap()) + // .collect(); + + // // Map each output in the model's graph to an Argument and collect them into a vector. + // let mut outputs = onnx_model + // .graph + // .output + // .iter() + // .map(|x| Argument::try_from(x.clone()).unwrap()) + // .collect(); + let mut i = 0; + nodes.retain(|node| { + let res = !nodes_to_remove.contains(&i); + i += 1; + res + }); let mut inputs = onnx_model .graph .input @@ -92,7 +346,6 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { .map(|x| Argument::try_from(x.clone()).unwrap()) .collect(); - // Map each output in the model's graph to an Argument and collect them into a vector. let mut outputs = onnx_model .graph .output @@ -101,7 +354,6 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { .collect(); let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); - // Infer shapes and update the inputs and outputs dim_inference(&mut nodes, &inputs, &mut outputs); @@ -173,6 +425,67 @@ fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { } } +//this is an extremely hacky temporary solution while I figure out how to properly handle this +//situation +fn move_output_for_unsqueeze( + node: &mut Node, + outputs: Vec, + inputs: Vec, +) { + let output_name = node.outputs[0].name.clone(); + // check outputs first, as it's shorter + for output in outputs.iter() { + if output.name == output_name { + match node.outputs[0].ty { + ArgType::Tensor(ref mut tensor_type) => { + if let Some(shape) = output.type_.as_ref().unwrap().tensor_type().shape.as_ref() + { + tensor_type.shape = + Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); + return; + } + } + _ => return, + } + } + } + for input in inputs.iter() { + if input.name == output_name { + //copy the shape + match node.outputs[0].ty { + ArgType::Tensor(ref mut tensor_type) => { + if let Some(shape) = input.type_.as_ref().unwrap().tensor_type().shape.as_ref() + { + tensor_type.shape = + Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); + return; + } + } + _ => return, + } + } + } +} + +fn move_output_shape(node: &mut Node, output_tensor: &ValueInfoProto) { + match node.outputs[0].ty { + ArgType::Tensor(ref mut tensor_type) => { + if let Some(shape) = output_tensor + .type_ + .as_ref() + .unwrap() + .tensor_type() + .shape + .as_ref() + { + tensor_type.shape = + Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); + } + } + _ => {} + } +} + /// Lift constants from the graph into the states vector for known node types. /// /// The primary reason to move constants into the states vector is to reduce the number of nodes in the graph, From 5df97789a74597a8a17c06501b4979da29378b52 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 11 Feb 2024 17:44:15 -0600 Subject: [PATCH 02/21] Vec> seems to work for this --- burn-import/src/onnx/coalesce.rs | 25 ++++--- burn-import/src/onnx/from_onnx.rs | 118 ++++++++++++++++-------------- 2 files changed, 78 insertions(+), 65 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 0de0b58ebe..1ea4c236fb 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,4 +1,9 @@ -use std::{collections::HashSet, iter::Peekable, slice::IterMut}; +use std::{ + cell::{RefCell, RefMut}, + collections::HashSet, + iter::Peekable, + slice::IterMut, +}; use super::ir::{AttributeValue, Node, NodeType}; use crate::onnx::ir::{ArgType, Data, TensorType}; @@ -150,12 +155,11 @@ fn convert_matmul_to_linear( } pub(crate) fn convert_matmul_to_linear2( - node_vec: &mut Vec, + node_vec: &Vec>, node_index: usize, nodes_to_remove: &mut HashSet, ) { - let mut iter_mut = node_vec.iter_mut().peekable(); - let node = iter_mut.nth(node_index).unwrap(); + let mut node = node_vec[node_index].borrow_mut(); if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); } @@ -177,9 +181,10 @@ pub(crate) fn convert_matmul_to_linear2( // Check the next node for potential conversion - if let Some(next_node) = iter_mut.peek() { - if is_add_node_with_bias(&next_node, node) { - convert_node2(next_node, nodes_to_remove, node); + if node_index + 1 < node_vec.len() { + let next_node = node_vec[node_index + 1].borrow(); + if is_add_node_with_bias(&next_node, &node) { + convert_node2(&next_node, node); nodes_to_remove.insert(node_index + 1); } } @@ -217,11 +222,7 @@ fn convert_and_remove_add_node( } /// Helper function to convert and remove the Add node -pub(crate) fn convert_node2( - bias_node: &Node, - nodes_to_remove: &mut HashSet, - current_node: &mut Node, -) { +pub(crate) fn convert_node2<'parser>(bias_node: &Node, mut current_node: RefMut<'parser, Node>) { let bias_input = if bias_node.inputs[0].value.is_some() { bias_node.inputs[0].clone() } else { diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 30d83867b5..fc221e919d 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -1,4 +1,6 @@ use std::{ + borrow::BorrowMut, + cell::{RefCell, RefMut}, collections::{HashMap, HashSet}, fs::File, path::Path, @@ -13,6 +15,7 @@ use super::{ coalesce::convert_gemm_to_linear, protos::{ModelProto, TensorProto}, }; + use super::{ coalesce::convert_matmul_to_linear2, ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}, @@ -21,7 +24,7 @@ use super::{dim_inference::dim_inference, protos::ValueInfoProto}; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 6] = [ +const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::BatchNormalization, NodeType::Clip, NodeType::Conv1d, @@ -42,12 +45,12 @@ pub(crate) struct ONNXGraphBuilder { output_of: HashMap, //nodes to remove nodes_to_remove: HashSet, + constants_map: HashMap, //constants to lift - constants: HashMap, postprocess_for_constants: Vec, constants_types: HashSet, //identity_nodes - identity_nodes: Vec, + identity_idx: Vec, //matmul nodes matmul_nodes: Vec, } @@ -69,7 +72,7 @@ impl ONNXGraphBuilder { .iter() .map(|x| (x.name.clone(), x.clone())) .collect::>(); - + let mut nodes = Vec::with_capacity(model_proto.graph.node.len()); for (i, node_proto) in model_proto.graph.node.iter().enumerate() { let mut node = convert_node_proto(node_proto); for node_input in node.inputs.iter_mut() { @@ -87,26 +90,35 @@ impl ONNXGraphBuilder { } let node_type = node.node_type.clone(); - if self.handle_identity(&node_type, &node, i) - || self.check_constants(&node, &node_type, i) - { - self.constants.insert(node.name.clone(), node); - } else { - //name stuff - self.handle_node_renaming(&node_type, &mut node); - - self.handle_unsqueeze(&node_type, &node, i); - //NOTE: still not done with this one + _ = self.handle_identity(&node_type, &node, i); + self.check_constants(&node, &node_type, i); - self.handle_coalesce(node_type, &mut node, i); + self.handle_unsqueeze(&node_type, &node, i); + //NOTE: still not done with this one - self.nodes.push(node); + self.handle_coalesce(&node_type, &mut node, i); + if !self.nodes_to_remove.contains(&i) { + //name stuff + self.handle_node_renaming(&node_type, &mut node); } + + nodes.push(RefCell::new(node)); } - self.postprocess_unsqueeze(inputs, model_proto.graph.output.clone()); - self.postprocess_identity(); - self.postprocess_constants(); - self.postprocess_coalesce(); + self.postprocess_unsqueeze(&nodes, inputs, model_proto.graph.output.clone()); + self.postprocess_identity(&nodes); + self.postprocess_constants(&nodes); + self.postprocess_coalesce(&mut nodes); + self.nodes = nodes + .into_iter() + .enumerate() + .filter_map(|(i, x)| { + if !self.nodes_to_remove.contains(&i) { + Some(x.into_inner()) + } else { + None + } + }) + .collect(); } fn handle_node_renaming(&mut self, node_type: &NodeType, node: &mut Node) { @@ -122,7 +134,11 @@ impl ONNXGraphBuilder { self.old_node_names.insert(old_name, new_name); } - fn postprocess_rename_inputs(&mut self, inputs: &mut Vec) { + fn postprocess_rename_inputs( + &mut self, + //nodes: &mut Vec>, + inputs: &mut Vec, + ) { for input in inputs.iter_mut() { if let Some(new_name) = self.old_node_names.get(&input.name) { input.name = new_name.clone(); @@ -130,21 +146,20 @@ impl ONNXGraphBuilder { } } - fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) -> bool { - if node_type == &NodeType::Constant { - return true; - } else if self.constants_types.contains(node_type) { + fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) { + if node_type == &NodeType::Constant || self.constants_types.contains(node_type) { self.postprocess_for_constants.push(i); + self.nodes_to_remove.insert(i); } - false } - fn postprocess_constants(&mut self) { + fn postprocess_constants(&mut self, nodes: &Vec>) { for i in self.postprocess_for_constants.iter() { - let mut node = &mut self.nodes[*i]; + let mut node = nodes[*i].borrow_mut(); for input in node.inputs.iter_mut().skip(1) { - if let Some(constant) = self.constants.get(&input.name) { + if let Some(idx) = self.constants_map.get(&input.name) { + let constant = nodes[*idx].borrow(); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { input.value = constant.outputs[0].value.clone(); input.ty = constant.outputs[0].ty.clone(); @@ -168,16 +183,18 @@ impl ONNXGraphBuilder { fn postprocess_unsqueeze( &mut self, + nodes: &Vec>, node_inputs: HashMap, graph_outputs: Vec, ) { for (output_name, i) in self.outputs_to_move.iter() { - let node = &mut self.nodes[*i]; if let Some(val_proto) = node_inputs.get(output_name) { + let node = nodes[*i].borrow_mut(); move_output_shape(node, val_proto); } else { for output in graph_outputs.iter() { if output.name == *output_name { + let node = nodes[*i].borrow_mut(); move_output_shape(node, output); } } @@ -187,27 +204,24 @@ impl ONNXGraphBuilder { fn handle_identity(&mut self, node_type: &NodeType, node: &Node, i: usize) -> bool { if node_type == &NodeType::Identity && node.inputs[0].value.is_none() { - self.identity_nodes.push(node.name.clone()); - // self.nodes_to_remove.insert(i); + self.identity_idx.push(i); + self.nodes_to_remove.insert(i); return true; } false } - fn postprocess_identity(&mut self) { - for identity in self.identity_nodes.iter() { - let identity_node = self - .constants - .get(identity) - .expect(&format!("Identity node {} not found", identity)); + fn postprocess_identity(&mut self, nodes: &Vec>) { + for identity_idx in self.identity_idx.iter() { + let identity_node = nodes[*identity_idx].borrow(); let input_name = &identity_node.inputs[0].name; let identity_output = &identity_node.outputs[0].name; // Replace the identity node's output with its input in the connected nodes. - if let Some(node_index) = self.input_of.get(identity_output) { - for node_index in node_index { - let node = &mut self.nodes[*node_index]; + if let Some(indices) = self.input_of.get(identity_output) { + for node_index in indices { + let mut node = nodes[*node_index].borrow_mut(); if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *identity_output) { @@ -218,7 +232,7 @@ impl ONNXGraphBuilder { } } - fn handle_coalesce(&mut self, node_type: NodeType, node: &mut Node, i: usize) { + fn handle_coalesce(&mut self, node_type: &NodeType, node: &mut Node, i: usize) { match node_type { NodeType::Gemm => convert_gemm_to_linear(node), NodeType::MatMul => { @@ -228,9 +242,9 @@ impl ONNXGraphBuilder { } } - fn postprocess_coalesce(&mut self) { + fn postprocess_coalesce(&mut self, nodes: &mut Vec>) { for matmul_index in self.matmul_nodes.clone() { - convert_matmul_to_linear2(&mut self.nodes, matmul_index, &mut self.nodes_to_remove); + convert_matmul_to_linear2(nodes, matmul_index, &mut self.nodes_to_remove); } } } @@ -273,9 +287,9 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { let ONNXGraphBuilder { mut nodes, old_node_names, - nodes_to_remove, .. } = builder; + println!("nodes: {:#?}", nodes); // Convert the nodes // let mut nodes: Vec = vec![]; // println!("onnx_model.graph.node: {:#?}", onnx_model.graph.node); @@ -333,12 +347,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { // .iter() // .map(|x| Argument::try_from(x.clone()).unwrap()) // .collect(); - let mut i = 0; - nodes.retain(|node| { - let res = !nodes_to_remove.contains(&i); - i += 1; - res - }); + let mut inputs = onnx_model .graph .input @@ -467,7 +476,7 @@ fn move_output_for_unsqueeze( } } -fn move_output_shape(node: &mut Node, output_tensor: &ValueInfoProto) { +fn move_output_shape<'parser>(mut node: RefMut<'parser, Node>, output_tensor: &ValueInfoProto) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let Some(shape) = output_tensor @@ -634,7 +643,9 @@ fn rename_inputs( outputs: &mut Vec, ) -> HashMap { let mut old_names = HashMap::new(); - + println!("inputs: {:#?}", inputs); + println!("outputs: {:#?}", outputs); + println!("nodes: {:#?}", nodes); // rename all graph input names to follow input1, input2, input3, etc. // (assumes the input names are already unique) let mut counter = 1; @@ -648,6 +659,7 @@ fn rename_inputs( for node in nodes.iter_mut() { let mut counter = 1; + println!("node: {:#?}", node); // loop through node outputs and rename them and store the new name <-> old name mapping for output in node.outputs.iter_mut() { From 9e4a50398f96ee3c24513f8b3497902b865b99b0 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Mon, 12 Feb 2024 10:51:31 -0600 Subject: [PATCH 03/21] back to passing tests --- burn-import/src/onnx/from_onnx.rs | 107 +++++++++--------------------- 1 file changed, 31 insertions(+), 76 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index fc221e919d..c39afcaa58 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -72,6 +72,7 @@ impl ONNXGraphBuilder { .iter() .map(|x| (x.name.clone(), x.clone())) .collect::>(); + let mut nodes = Vec::with_capacity(model_proto.graph.node.len()); for (i, node_proto) in model_proto.graph.node.iter().enumerate() { let mut node = convert_node_proto(node_proto); @@ -90,6 +91,7 @@ impl ONNXGraphBuilder { } let node_type = node.node_type.clone(); + self.handle_node_renaming(&node_type, &mut node); _ = self.handle_identity(&node_type, &node, i); self.check_constants(&node, &node_type, i); @@ -97,10 +99,10 @@ impl ONNXGraphBuilder { //NOTE: still not done with this one self.handle_coalesce(&node_type, &mut node, i); - if !self.nodes_to_remove.contains(&i) { - //name stuff - self.handle_node_renaming(&node_type, &mut node); - } + // if !self.nodes_to_remove.contains(&i) && !self.constants_map.contains_key(&node.name) { + // //name stuff + // self.handle_node_renaming(&node_type, &mut node); + // } nodes.push(RefCell::new(node)); } @@ -108,6 +110,7 @@ impl ONNXGraphBuilder { self.postprocess_identity(&nodes); self.postprocess_constants(&nodes); self.postprocess_coalesce(&mut nodes); + self.nodes = nodes .into_iter() .enumerate() @@ -130,8 +133,8 @@ impl ONNXGraphBuilder { let new_name = format!("{}{}", node.node_type, self.node_name_counter[&node_type]).to_lowercase(); node.name = new_name.clone(); - println!("old name {:?} new name {:?}", old_name, new_name); - self.old_node_names.insert(old_name, new_name); + self.old_node_names + .insert(old_name.clone(), new_name.clone()); } fn postprocess_rename_inputs( @@ -147,27 +150,32 @@ impl ONNXGraphBuilder { } fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) { - if node_type == &NodeType::Constant || self.constants_types.contains(node_type) { + if node_type == &NodeType::Constant + || (node_type == &NodeType::Identity && node.inputs[0].value.is_some()) + { + self.constants_map.insert(node.outputs[0].name.clone(), i); + } else if self.constants_types.contains(node_type) { self.postprocess_for_constants.push(i); - self.nodes_to_remove.insert(i); } } fn postprocess_constants(&mut self, nodes: &Vec>) { - for i in self.postprocess_for_constants.iter() { - let mut node = nodes[*i].borrow_mut(); + for check_idx in self.postprocess_for_constants.iter() { + let mut node = nodes[*check_idx].borrow_mut(); for input in node.inputs.iter_mut().skip(1) { - if let Some(idx) = self.constants_map.get(&input.name) { - let constant = nodes[*idx].borrow(); + if let Some(const_idx) = self.constants_map.get(&input.name) { + let constant = nodes[*const_idx].borrow(); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { - input.value = constant.outputs[0].value.clone(); - input.ty = constant.outputs[0].ty.clone(); + // The value comes from Identity inputs + input.value = constant.inputs[0].value.clone(); + input.ty = constant.inputs[0].ty.clone(); } else { let arg = convert_constant_value(&constant); input.value = arg.value; input.ty = arg.ty; } + self.nodes_to_remove.insert(*const_idx); } } } @@ -289,64 +297,11 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { old_node_names, .. } = builder; - println!("nodes: {:#?}", nodes); - // Convert the nodes - // let mut nodes: Vec = vec![]; - // println!("onnx_model.graph.node: {:#?}", onnx_model.graph.node); - // for onnx_node in onnx_model.graph.node.iter() { - // let mut node = convert_node_proto(onnx_node); - // if onnx_node.op_type == "Unsqueeze" { - // move_output_for_unsqueeze( - // &mut node, - // onnx_model.graph.input.clone(), - // onnx_model.graph.output.clone(), - // ); - // } - // remap_node_type(&mut node); - // nodes.push(node); - // } - - // // ONNX nodes must be topologically sorted per spec: - // // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - // assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - - // //TMP: converts initializers to hashmap, iterates over nodes - // // Move inputs with initializers to states - // move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); - // //TMP: loops over all nodes, filters for identity nodes, - // //for each identity node, loop over all nodes and if the identity is in the inputs, replace - // // Handle Identity nodes (expects inputs to be moved to states) - //handle_identity(&mut nodes); - // //TMP: loops over all nodes, filters for constants, collects them into a hashmap - // //loop over all nodes, filter for types in constant lifting, for each input - // //if it's in the hashmap, move the data to appropriate input - // // Lift constants to initializers (expects inputs to be moved to states) - //lift_constants(&mut nodes); - // //TMP: loops over all nodes, if type matches one of the current coalesce types, call the function - // //loop - // // Coalesce and transform nodes - // coalesce(&mut nodes); - - // //TMP: loop over all nodes, rename each with it's name and a counter - // //keep a map of old names to new names - // // Rename nodes and inputs, save the mapping for later - // let old_node_names = rename_nodes(&mut nodes); - - // // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. - // let mut inputs = onnx_model - // .graph - // .input - // .iter() - // .map(|x| Argument::try_from(x.clone()).unwrap()) - // .collect(); - - // // Map each output in the model's graph to an Argument and collect them into a vector. - // let mut outputs = onnx_model - // .graph - // .output - // .iter() - // .map(|x| Argument::try_from(x.clone()).unwrap()) - // .collect(); + //println!("nodes: {:#?}", nodes); + + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); let mut inputs = onnx_model .graph @@ -643,9 +598,9 @@ fn rename_inputs( outputs: &mut Vec, ) -> HashMap { let mut old_names = HashMap::new(); - println!("inputs: {:#?}", inputs); - println!("outputs: {:#?}", outputs); - println!("nodes: {:#?}", nodes); + //println!("inputs: {:#?}", inputs); + //println!("outputs: {:#?}", outputs); + //println!("nodes: {:#?}", nodes); // rename all graph input names to follow input1, input2, input3, etc. // (assumes the input names are already unique) let mut counter = 1; @@ -659,7 +614,7 @@ fn rename_inputs( for node in nodes.iter_mut() { let mut counter = 1; - println!("node: {:#?}", node); + //println!("node: {:#?}", node); // loop through node outputs and rename them and store the new name <-> old name mapping for output in node.outputs.iter_mut() { From d6ed1d53357f3c8f6ad49942340029fc48839af4 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Fri, 16 Feb 2024 11:15:42 -0600 Subject: [PATCH 04/21] Reworked IO into separate struct --- burn-import/src/onnx/coalesce.rs | 6 + burn-import/src/onnx/from_onnx.rs | 528 +++++++++++++----------------- burn-import/src/onnx/ir.rs | 6 - 3 files changed, 232 insertions(+), 308 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 1ea4c236fb..797fb6f24c 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -154,6 +154,12 @@ fn convert_matmul_to_linear( } } +/// This function converts a MatMul node into a Linear node if possible. +/// +/// PyTorch and other frameworks use MatMul node to represent Linear layer. +/// +/// This function also converts the following Add node into a Linear node if possible. +/// Add node is used to represent bias in PyTorch. pub(crate) fn convert_matmul_to_linear2( node_vec: &Vec>, node_index: usize, diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index c39afcaa58..9715e28d92 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -7,20 +7,19 @@ use std::{ }; use crate::onnx::{ - coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, - proto_conversion::convert_node_proto, + ir::TensorType, node_remap::remap_node_type, proto_conversion::convert_node_proto, }; use super::{ coalesce::convert_gemm_to_linear, - protos::{ModelProto, TensorProto}, + protos::{ModelProto, TensorProto, ValueInfoProto}, }; +use super::dim_inference::dim_inference; use super::{ coalesce::convert_matmul_to_linear2, ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}, }; -use super::{dim_inference::dim_inference, protos::ValueInfoProto}; use protobuf::Message; @@ -33,13 +32,134 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::Reshape, NodeType::Unsqueeze, ]; + +#[derive(Debug)] +pub(crate) enum IOEntry { + In(usize), + Out(usize), + Node(usize), +} + +pub(crate) struct OnnxGraphIO { + pub(crate) inputs: Vec, + pub(crate) outputs: Vec, + ///updated names of outputs of node not stored in the graph + node_out: Vec>, + ///map of old input names to a vec of indices of nodes that use it + input_of: HashMap>, + pub(crate) old_io_names: HashMap, +} + +impl OnnxGraphIO { + pub(crate) fn new(inputs: Vec, outputs: Vec) -> Self { + let mut old_io_names = HashMap::new(); + let mut in_count = 1; + let inputs = inputs + .iter() + .enumerate() + .map(|(i, x)| { + let in_name = format!("input{}", in_count); + old_io_names.insert(x.name.clone(), IOEntry::In(i)); + let mut arg = Argument::try_from(x.clone()).unwrap(); + in_count += 1; + arg.name = in_name; + arg + }) + .collect::>(); + + let outputs = outputs + .iter() + .enumerate() + .map(|(i, x)| { + old_io_names.insert(x.name.clone(), IOEntry::Out(i)); + Argument::try_from(x.clone()).unwrap() + }) + .collect::>(); + let in_len = inputs.len(); + Self { + inputs, + outputs, + node_out: Vec::new(), + old_io_names, + input_of: HashMap::with_capacity(in_len), + } + } + + fn get_mut(&mut self, old_name: &str) -> Option<&mut Argument> { + match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => self.inputs.get_mut(*i), + Some(IOEntry::Out(i)) => self.outputs.get_mut(*i), + Some(IOEntry::Node(i)) => panic!("This is a node output"), + None => None, + } + } + + fn update(&mut self, old_name: &str, new_name: &str) { + match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => { + let arg = self.inputs.get_mut(*i).unwrap(); + arg.name = new_name.to_string(); + } + Some(IOEntry::Out(i)) => { + let arg = self.outputs.get_mut(*i).unwrap(); + arg.name = new_name.to_string(); + } + Some(IOEntry::Node(i)) => { + panic!("This output is from another node"); + } + None => { + let idx = self.node_out.len(); + self.node_out.push(Box::new(new_name.to_string())); + self.old_io_names + .insert(old_name.to_string(), IOEntry::Node(idx)); + } + } + } + fn add_input(&mut self, old_name: &str, node_idx: usize) { + self.input_of + .entry(old_name.to_string()) + .and_modify(|f| f.push(node_idx)) + .or_insert(vec![node_idx]); + } + + fn get(&self, old_name: &str) -> Option<&Argument> { + match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => self.inputs.get(*i), + Some(IOEntry::Out(i)) => self.outputs.get(*i), + Some(IOEntry::Node(i)) => panic!("This is a node output"), + None => None, + } + } + + fn get_new_name(&self, old_name: &str) -> Option { + let new_name = match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => Some(self.inputs[*i].name.clone()), + Some(IOEntry::Out(i)) => Some(self.outputs[*i].name.clone()), + Some(IOEntry::Node(i)) => Some(*self.node_out[*i].clone()), + None => None, + }; + + println!("new name value {:?}", &new_name); + if Some(old_name.to_string()) == new_name { + println!("old name hasn't changed: {}", old_name); + None + } else { + new_name + } + } + + fn get_node_indices(&self, old_input_name: &str) -> Option<&Vec> { + self.input_of.get(old_input_name) + } +} + #[derive(Default)] pub(crate) struct ONNXGraphBuilder { nodes: Vec, + inputs: Vec, + outputs: Vec, + // old_io_names: HashMap, node_name_counter: HashMap, - old_node_names: HashMap, - //map of input names to a vec of indices of nodes that use it - input_of: HashMap>, outputs_to_move: HashMap, //map of output names to output_of: HashMap, @@ -66,21 +186,19 @@ impl ONNXGraphBuilder { .map(|x| (x.name.clone(), x.clone())) .collect::>(); - let inputs = model_proto - .graph - .input - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); + let mut graph_io = OnnxGraphIO::new( + model_proto.graph.input.clone(), + model_proto.graph.output.clone(), + ); let mut nodes = Vec::with_capacity(model_proto.graph.node.len()); for (i, node_proto) in model_proto.graph.node.iter().enumerate() { let mut node = convert_node_proto(node_proto); for node_input in node.inputs.iter_mut() { - self.input_of - .entry(node_input.name.clone()) - .and_modify(|f| f.push(i)) - .or_insert(vec![i]); + // self.input_of + // .entry(node_input.name.clone()) + // .and_modify(|f| f.push(i)) + // .or_insert(vec![i]); if let Some(initializer) = initializers.get(&node_input.name) { move_initializer_data(initializer, node_input); } @@ -91,14 +209,17 @@ impl ONNXGraphBuilder { } let node_type = node.node_type.clone(); + self.handle_node_renaming(&node_type, &mut node); + self.handle_coalesce(&node_type, &mut node, i); + self.handle_unsqueeze(&node_type, &node, i); + _ = self.handle_identity(&node_type, &node, i); - self.check_constants(&node, &node_type, i); - self.handle_unsqueeze(&node_type, &node, i); + self.handle_rename_io(&mut node, i, &mut graph_io); + self.check_constants(&node, &node_type, i); //NOTE: still not done with this one - self.handle_coalesce(&node_type, &mut node, i); // if !self.nodes_to_remove.contains(&i) && !self.constants_map.contains_key(&node.name) { // //name stuff // self.handle_node_renaming(&node_type, &mut node); @@ -106,8 +227,8 @@ impl ONNXGraphBuilder { nodes.push(RefCell::new(node)); } - self.postprocess_unsqueeze(&nodes, inputs, model_proto.graph.output.clone()); - self.postprocess_identity(&nodes); + self.postprocess_unsqueeze(&nodes, &graph_io); + self.postprocess_identity(&nodes, &graph_io); self.postprocess_constants(&nodes); self.postprocess_coalesce(&mut nodes); @@ -122,6 +243,11 @@ impl ONNXGraphBuilder { } }) .collect(); + let OnnxGraphIO { + inputs, outputs, .. + } = graph_io; + self.inputs = inputs; + self.outputs = outputs; } fn handle_node_renaming(&mut self, node_type: &NodeType, node: &mut Node) { @@ -129,24 +255,40 @@ impl ONNXGraphBuilder { .entry(node_type.clone()) .and_modify(|e| *e += 1) .or_insert(1); - let old_name = node.name.clone(); let new_name = format!("{}{}", node.node_type, self.node_name_counter[&node_type]).to_lowercase(); node.name = new_name.clone(); - self.old_node_names - .insert(old_name.clone(), new_name.clone()); } - fn postprocess_rename_inputs( - &mut self, - //nodes: &mut Vec>, - inputs: &mut Vec, - ) { - for input in inputs.iter_mut() { - if let Some(new_name) = self.old_node_names.get(&input.name) { - input.name = new_name.clone(); + fn handle_rename_io(&mut self, node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { + for node_input in node.inputs.iter_mut() { + println!("old output names {:?}", &graph_io.old_io_names); + //println!("out_args{:?}", outputs); + graph_io.add_input(&node_input.name, i); + if let Some(input_name) = graph_io.get_new_name(&node_input.name) { + println!("yeet"); + node_input.passed = true; + node_input.name = input_name.clone(); + } else { + node_input.name = "".to_string(); + node_input.passed = false; } } + println!("\n\nchecking outputs"); + let mut out_count = 1; + for output in node.outputs.iter_mut() { + println!("output name: {}", &output.name); + + let new_name = format!("{}_out{}", node.name, out_count); + + graph_io.update(&output.name, &new_name); + + // self.node_output_names + // .insert(output.name.clone(), new_name.clone()); + + output.name = new_name.clone(); + out_count += 1; + } } fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) { @@ -164,6 +306,8 @@ impl ONNXGraphBuilder { let mut node = nodes[*check_idx].borrow_mut(); for input in node.inputs.iter_mut().skip(1) { + println!("checking input {:?} for const", input); + if let Some(const_idx) = self.constants_map.get(&input.name) { let constant = nodes[*const_idx].borrow(); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { @@ -189,23 +333,11 @@ impl ONNXGraphBuilder { } } - fn postprocess_unsqueeze( - &mut self, - nodes: &Vec>, - node_inputs: HashMap, - graph_outputs: Vec, - ) { - for (output_name, i) in self.outputs_to_move.iter() { - if let Some(val_proto) = node_inputs.get(output_name) { + fn postprocess_unsqueeze(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { + for (old_output_name, i) in self.outputs_to_move.iter() { + if let Some(in_arg) = graph_io.get(old_output_name) { let node = nodes[*i].borrow_mut(); - move_output_shape(node, val_proto); - } else { - for output in graph_outputs.iter() { - if output.name == *output_name { - let node = nodes[*i].borrow_mut(); - move_output_shape(node, output); - } - } + move_output_shape(node, in_arg); } } } @@ -219,7 +351,7 @@ impl ONNXGraphBuilder { false } - fn postprocess_identity(&mut self, nodes: &Vec>) { + fn postprocess_identity(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { for identity_idx in self.identity_idx.iter() { let identity_node = nodes[*identity_idx].borrow(); @@ -227,7 +359,7 @@ impl ONNXGraphBuilder { let identity_output = &identity_node.outputs[0].name; // Replace the identity node's output with its input in the connected nodes. - if let Some(indices) = self.input_of.get(identity_output) { + if let Some(indices) = graph_io.get_node_indices(identity_output) { for node_index in indices { let mut node = nodes[*node_index].borrow_mut(); if let Some(matched_input) = @@ -240,9 +372,15 @@ impl ONNXGraphBuilder { } } + /// The function transforms the graph into a new one where the nodes are coalesced into a single node. fn handle_coalesce(&mut self, node_type: &NodeType, node: &mut Node, i: usize) { match node_type { - NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::Gemm => { + println!("Gemm before {:?}\n", node); + convert_gemm_to_linear(node); + self.handle_node_renaming(&node.node_type.clone(), node); + println!("Gemm after {:?}\n", node); + } NodeType::MatMul => { self.matmul_nodes.push(i); } @@ -251,8 +389,11 @@ impl ONNXGraphBuilder { } fn postprocess_coalesce(&mut self, nodes: &mut Vec>) { + println!("{:?}", self.node_name_counter); for matmul_index in self.matmul_nodes.clone() { convert_matmul_to_linear2(nodes, matmul_index, &mut self.nodes_to_remove); + let mut node = nodes[matmul_index].borrow_mut(); + self.handle_node_renaming(&node.node_type.clone(), &mut node) } } } @@ -294,73 +435,41 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { let ONNXGraphBuilder { mut nodes, - old_node_names, + inputs: mut inner_inputs, + outputs: mut inner_outputs, .. } = builder; - //println!("nodes: {:#?}", nodes); // ONNX nodes must be topologically sorted per spec: // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - let mut inputs = onnx_model - .graph - .input - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); + let my_nodes = nodes.clone(); - let mut outputs = onnx_model - .graph - .output - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); + for i in 0..nodes.len() { + if nodes[i] != my_nodes[i] { + println!("{} != {}", nodes[i].name, my_nodes[i].name); + } + } + println!("nodes: {:#?}", nodes); + println!("inner inputs: {:#?}", inner_inputs); + println!("inner outputs: {:#?}", inner_outputs); - let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); // Infer shapes and update the inputs and outputs - dim_inference(&mut nodes, &inputs, &mut outputs); - + dim_inference(&mut nodes, &inner_inputs, &mut inner_outputs); + println!("inner outputs after dim inference: {:?}", inner_outputs); // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); + remove_unused_graph_inputs(&mut inner_inputs, &mut inner_outputs, &nodes); log::info!("Finished parsing ONNX file: {}", onnx_path.display()); ONNXGraph { nodes, - inputs, - outputs, - old_node_names, - old_input_names, + inputs: inner_inputs, + outputs: inner_outputs, } } -/// This function moves inputs that are also present -/// in the initializer to the node's states vector. -/// It also removes inputs that are already present in the states vector. -/// -/// # Arguments -/// -/// * `nodes` - A mutable reference to a vector of nodes -/// * `initializers` - A vector of TensorProto -fn move_inputs_to_state(nodes: &mut Vec, initializers: &[TensorProto]) { - // Convert initializers to hashmap for faster lookup - let initializers = initializers - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); - - // Iterate over each node in the graph - nodes.iter_mut().for_each(|node| { - for input in node.inputs.iter_mut() { - // If there is a corresponding initializer for the input, then move the data to the input value - if let Some(initializer) = initializers.get(&input.name) { - move_initializer_data(initializer, input); - } - } - }); -} - fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { // If the input name matches the tensor name in the initializer // Convert the initializer to a tensor @@ -389,203 +498,17 @@ fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { } } -//this is an extremely hacky temporary solution while I figure out how to properly handle this -//situation -fn move_output_for_unsqueeze( - node: &mut Node, - outputs: Vec, - inputs: Vec, -) { - let output_name = node.outputs[0].name.clone(); - // check outputs first, as it's shorter - for output in outputs.iter() { - if output.name == output_name { - match node.outputs[0].ty { - ArgType::Tensor(ref mut tensor_type) => { - if let Some(shape) = output.type_.as_ref().unwrap().tensor_type().shape.as_ref() - { - tensor_type.shape = - Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); - return; - } - } - _ => return, - } - } - } - for input in inputs.iter() { - if input.name == output_name { - //copy the shape - match node.outputs[0].ty { - ArgType::Tensor(ref mut tensor_type) => { - if let Some(shape) = input.type_.as_ref().unwrap().tensor_type().shape.as_ref() - { - tensor_type.shape = - Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); - return; - } - } - _ => return, - } - } - } -} - -fn move_output_shape<'parser>(mut node: RefMut<'parser, Node>, output_tensor: &ValueInfoProto) { +fn move_output_shape<'parser>(mut node: RefMut<'parser, Node>, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { - if let Some(shape) = output_tensor - .type_ - .as_ref() - .unwrap() - .tensor_type() - .shape - .as_ref() - { - tensor_type.shape = - Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); + if let ArgType::Tensor(arg_tensor) = &out_arg.ty { + tensor_type.shape = arg_tensor.shape.clone(); } } _ => {} } } -/// Lift constants from the graph into the states vector for known node types. -/// -/// The primary reason to move constants into the states vector is to reduce the number of nodes in the graph, -/// and consistently utilize the same interface for all nodes (constant inputs and inputs with initializers are -/// treated the same way). This simplification aids code generation. -/// -/// For example, if we have a graph ([Const1, Const2, Conv2d1]) where the Conv2d node has 3 inputs -/// (graph_input, const2_out1, const_out2), we can lift the constants into the states of the Conv2d node. -/// const2_out1 and const_out2 are used for the weights and bias of the Conv2d node. -/// After lifting, we will have a graph ([Conv2d1]) where the Conv2d node has 1 input (graph_input) and 2 states. -/// -/// Also note that often times, Conv2d node's inputs are not constants, but they are initializers. Initializers -/// move to the states vector as well, using the `move_inputs_to_state` function. -/// -/// -/// # Arguments -/// -/// * `nodes` - A mutable reference to a vector of nodes -/// -/// # Panics -/// -/// Panics if the node's output is not a constant. -fn lift_constants(nodes: &mut Vec) { - log::info!("Lifting constants into the states"); - - // create a set to hold the node types to process - let node_types_to_process: HashSet = - LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - - // create a new vector to hold the graph's constants (index by the node's name) - let constants = nodes - .iter() - .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) - .map(|node| (node.outputs[0].name.clone(), node.clone())) - .collect::>(); - - // create a set to hold the IDs of constants to be removed - let mut constant_to_removed = HashSet::::new(); - - for node in nodes.iter_mut() { - // Skip the node if it is not in the set of node types to process - if !node_types_to_process.contains(&node.node_type) { - continue; - } - - // Skip the first input because it is the node's true input and not a constant/state - node.inputs - .iter_mut() - .skip(1) // TODO make configurable - .for_each(|input| { - if let Some(constant) = constants.get(&input.name) { - if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { - // The value comes from Identity inputs - if let Some(constant_input) = constant.inputs.first() { - input.ty = constant_input.ty.clone(); - input.value = constant_input.value.clone(); - } - } else { - // The value comes from an attribute - let arg = convert_constant_value(constant); // get the value of the constant - - input.value = arg.value; // set the input's value to the constant's value - input.ty = arg.ty; // set the input's type to the constant's type - // remove the constant from the graph - } - constant_to_removed.insert(constant.name.clone()); - } - }); - } - - // remove the constants that were moved to the states vector - nodes.retain(|node| !constant_to_removed.contains(&node.name)); - - log::debug!( - "The number of constants lifted: {}", - constant_to_removed.len() - ); -} - -fn handle_identity(nodes: &mut Vec) { - log::info!("Handling identity nodes"); - - let mut nodes_to_remove = HashSet::new(); - - let identity_nodes = nodes - .iter() - .filter(|node| node.node_type == NodeType::Identity) - .cloned() - .collect::>(); - - // Handle pass-through nodes. - for identity_node in identity_nodes { - if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() - { - let input_name = &identity_node.inputs[0].name; - let output_name = &identity_node.outputs[0].name; - - // Replace the identity node's output with its input in the connected nodes. - for node in nodes.iter_mut() { - if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) - { - matched_input.name = input_name.clone(); - } - } - - nodes_to_remove.insert(identity_node); - } - } - - // Remove the identity nodes. - nodes.retain(|node| !nodes_to_remove.contains(node)); -} - -/// Rename the nodes in the graph to be unique and return a map of the old names to the new names. -fn rename_nodes(nodes: &mut Vec) -> HashMap { - let mut old_names = HashMap::new(); - let mut counter: HashMap = HashMap::new(); - - for node in nodes.iter_mut() { - // keep track of the number of nodes of each type - counter - .entry(node.node_type.clone()) - .and_modify(|e| *e += 1) - .or_insert(1); - - let old_name = node.name.clone(); - let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); - - node.name = new_name.clone(); - - old_names.insert(old_name, new_name); - } - - old_names -} - /// Rename the inputs and output in the graph and return a map of /// the old names to the new names. /// @@ -622,23 +545,24 @@ fn rename_inputs( let new_name = format!("{}_out{}", node.name, counter); output.name = new_name.clone(); old_names.insert(old_name, new_name); + //old_names.insert(old_name, new_name); counter += 1; } } - for node in nodes.iter_mut() { - // loop through node inputs and rename them with previously replaced names - // and mark them as passed if they are in the old_names map (i.e. they are node outputs) - for input in node.inputs.iter_mut() { - if let Some(new_name) = old_names.get(&input.name) { - input.name = new_name.clone(); - input.passed = true; - } else { - input.name = "".to_string(); // Rename to a placeholder - input.passed = false; - } - } - } + // for node in nodes.iter_mut() { + // // loop through node inputs and rename them with previously replaced names + // // and mark them as passed if they are in the old_names map (i.e. they are node outputs) + // for input in node.inputs.iter_mut() { + // if let Some(new_name) = old_names.get(&input.name) { + // input.name = new_name.clone(); + // input.passed = true; + // } else { + // input.name = "".to_string(); // Rename to a placeholder + // input.passed = false; + // } + // } + // } // Rename the graph outputs for output in outputs.iter_mut() { diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 65f03bc102..8489cfdba1 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -138,12 +138,6 @@ pub struct ONNXGraph { /// The outputs of the graph. pub outputs: Vec, - - /// The original node names. - pub old_node_names: HashMap, - - /// The original input names. - pub old_input_names: HashMap, } /// Nodes produced by the ONNX parser From c350d0372b125cae20f81f0c628b431ee4fb06b2 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Fri, 16 Feb 2024 13:52:01 -0600 Subject: [PATCH 05/21] working towards exploiting topological ordering and more informative ident errors --- burn-import/src/burn/ty.rs | 15 +++++++ burn-import/src/onnx/coalesce.rs | 68 ++++++++++++++----------------- burn-import/src/onnx/from_onnx.rs | 59 +++++++++++++++------------ 3 files changed, 78 insertions(+), 64 deletions(-) diff --git a/burn-import/src/burn/ty.rs b/burn-import/src/burn/ty.rs index f5fa41e774..963ee81abf 100644 --- a/burn-import/src/burn/ty.rs +++ b/burn-import/src/burn/ty.rs @@ -72,6 +72,9 @@ impl Type { impl ScalarType { pub fn new>(name: S, kind: ScalarKind) -> Self { + if name.as_ref().is_empty() { + panic!("Scalar of Type {:?} was passed with empty name", kind); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), kind, @@ -95,6 +98,12 @@ impl TensorType { kind: TensorKind, shape: Option>, ) -> Self { + if name.as_ref().is_empty() { + panic!( + "Tensor of Kind {:?} with dim shape {:?} was passed with empty name", + kind, shape + ); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), dim, @@ -141,6 +150,12 @@ impl TensorType { impl OtherType { pub fn new>(name: S, tokens: TokenStream) -> Self { + if name.as_ref().is_empty() { + panic!( + "Other type with tokens {:?} was passed with empty name", + tokens + ); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), ty: tokens, diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 797fb6f24c..94f4df35b2 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -2,29 +2,24 @@ use std::{ cell::{RefCell, RefMut}, collections::HashSet, iter::Peekable, - slice::IterMut, + slice::{Iter, IterMut}, }; -use super::ir::{AttributeValue, Node, NodeType}; +use super::{ + ir::{AttributeValue, Node, NodeType}, + proto_conversion::convert_node_proto, + protos::NodeProto, +}; use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. -pub fn coalesce(nodes: &mut Vec) { - let mut iter_mut = nodes.iter_mut().peekable(); - let mut nodes_to_remove: Vec = vec![]; - while let Some(node) = iter_mut.next() { - match node.node_type { - NodeType::Gemm => convert_gemm_to_linear(node), - NodeType::MatMul => { - convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); - } - _ => {} +pub fn coalesce(node: &mut Node, nodes_iter: &mut Peekable>) { + match node.node_type { + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + convert_matmul_to_linear(node, nodes_iter); } - } - - // Remove nodes instructed by conversation functions - for node_to_remove in nodes_to_remove { - nodes.retain(|n| n.name != node_to_remove); + _ => {} } } @@ -122,11 +117,7 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec /// /// This function also converts the following Add node into a Linear node if possible. /// Add node is used to represent bias in PyTorch. -fn convert_matmul_to_linear( - node: &mut Node, - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, -) { +pub(crate) fn convert_matmul_to_linear(node: &mut Node, iter_mut: &mut Peekable>) { if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); } @@ -148,8 +139,13 @@ fn convert_matmul_to_linear( // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { + let peek_node = &convert_node_proto(peek_node); + println!("next node is {:?}", peek_node); if is_add_node_with_bias(peek_node, node) { - convert_and_remove_add_node(iter_mut, nodes_to_remove, node); + convert_and_remove_add_node(peek_node, node); + // You don't have to remove it if it's never stored in the first place + let _ = iter_mut.next(); + println!("\n\nskipping add node\n\n"); } } } @@ -197,22 +193,23 @@ pub(crate) fn convert_matmul_to_linear2( } /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { - peek_node.node_type == NodeType::Add - && peek_node.inputs.len() == 2 - && ((peek_node.inputs[0].name == current_node.outputs[0].name + if (peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2) { + println!("\n\ntwo matches"); + println!("peek_node.inputs[0].name: {:?}", peek_node.inputs[0].name); + println!( + "current_node.outputs[0].name: {:?}", + current_node.outputs[0].name + ); + return ((peek_node.inputs[0].name == current_node.outputs[0].name && peek_node.inputs[1].value.is_some()) || (peek_node.inputs[1].name == current_node.outputs[0].name - && peek_node.inputs[0].value.is_some())) + && peek_node.inputs[0].value.is_some())); + } + false } /// Helper function to convert and remove the Add node -fn convert_and_remove_add_node( - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, - current_node: &mut Node, -) { - let bias_node = iter_mut.next().unwrap(); - +fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { let bias_input = if bias_node.inputs[0].value.is_some() { bias_node.inputs[0].clone() } else { @@ -222,9 +219,6 @@ fn convert_and_remove_add_node( // Push the bias input and update the output name current_node.inputs.push(bias_input); current_node.outputs[0].name = bias_node.outputs[0].name.clone(); - - // Remove the Add node - nodes_to_remove.push(bias_node.name.clone()); } /// Helper function to convert and remove the Add node diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 9715e28d92..7f6f888407 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -3,7 +3,9 @@ use std::{ cell::{RefCell, RefMut}, collections::{HashMap, HashSet}, fs::File, + iter::Peekable, path::Path, + slice::Iter, }; use crate::onnx::{ @@ -11,8 +13,8 @@ use crate::onnx::{ }; use super::{ - coalesce::convert_gemm_to_linear, - protos::{ModelProto, TensorProto, ValueInfoProto}, + coalesce::{coalesce, convert_gemm_to_linear, convert_matmul_to_linear}, + protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto}, }; use super::dim_inference::dim_inference; @@ -85,15 +87,6 @@ impl OnnxGraphIO { } } - fn get_mut(&mut self, old_name: &str) -> Option<&mut Argument> { - match self.old_io_names.get(old_name) { - Some(IOEntry::In(i)) => self.inputs.get_mut(*i), - Some(IOEntry::Out(i)) => self.outputs.get_mut(*i), - Some(IOEntry::Node(i)) => panic!("This is a node output"), - None => None, - } - } - fn update(&mut self, old_name: &str, new_name: &str) { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => { @@ -126,7 +119,7 @@ impl OnnxGraphIO { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => self.inputs.get(*i), Some(IOEntry::Out(i)) => self.outputs.get(*i), - Some(IOEntry::Node(i)) => panic!("This is a node output"), + Some(IOEntry::Node(_)) => panic!("This is a node output"), None => None, } } @@ -192,8 +185,12 @@ impl ONNXGraphBuilder { ); let mut nodes = Vec::with_capacity(model_proto.graph.node.len()); - for (i, node_proto) in model_proto.graph.node.iter().enumerate() { + let mut nd_idx = 0; + let mut node_iter = model_proto.graph.node.iter().peekable(); + + while let Some(node_proto) = node_iter.next() { let mut node = convert_node_proto(node_proto); + println!("current_node {:?}", node); for node_input in node.inputs.iter_mut() { // self.input_of // .entry(node_input.name.clone()) @@ -204,20 +201,22 @@ impl ONNXGraphBuilder { } } remap_node_type(&mut node); - for node_output in node.outputs.iter() { - self.output_of.insert(node_output.name.clone(), i); - } + // for node_output in node.outputs.iter() { + // self.output_of.insert(node_output.name.clone(), nd_idx); + // } let node_type = node.node_type.clone(); - + //coalesce(&mut node, &mut node_iter); self.handle_node_renaming(&node_type, &mut node); - self.handle_coalesce(&node_type, &mut node, i); - self.handle_unsqueeze(&node_type, &node, i); - _ = self.handle_identity(&node_type, &node, i); + //coalesce(&mut node, &mut node_iter); + + self.handle_unsqueeze(&node_type, &node, nd_idx); - self.handle_rename_io(&mut node, i, &mut graph_io); - self.check_constants(&node, &node_type, i); + _ = self.handle_identity(&node_type, &node, nd_idx); + self.handle_coalesce(&mut node, &mut node_iter, nd_idx); + self.handle_rename_io(&mut node, nd_idx, &mut graph_io); + self.check_constants(&node, &node_type, nd_idx); //NOTE: still not done with this one // if !self.nodes_to_remove.contains(&i) && !self.constants_map.contains_key(&node.name) { @@ -226,6 +225,7 @@ impl ONNXGraphBuilder { // } nodes.push(RefCell::new(node)); + nd_idx += 1; } self.postprocess_unsqueeze(&nodes, &graph_io); self.postprocess_identity(&nodes, &graph_io); @@ -373,8 +373,13 @@ impl ONNXGraphBuilder { } /// The function transforms the graph into a new one where the nodes are coalesced into a single node. - fn handle_coalesce(&mut self, node_type: &NodeType, node: &mut Node, i: usize) { - match node_type { + fn handle_coalesce( + &mut self, + node: &mut Node, + _nodes_iter: &mut Peekable>, + i: usize, + ) { + match node.node_type { NodeType::Gemm => { println!("Gemm before {:?}\n", node); convert_gemm_to_linear(node); @@ -451,9 +456,9 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { println!("{} != {}", nodes[i].name, my_nodes[i].name); } } - println!("nodes: {:#?}", nodes); - println!("inner inputs: {:#?}", inner_inputs); - println!("inner outputs: {:#?}", inner_outputs); + // println!("nodes: {:#?}", nodes); + // println!("inner inputs: {:#?}", inner_inputs); + // println!("inner outputs: {:#?}", inner_outputs); // Infer shapes and update the inputs and outputs dim_inference(&mut nodes, &inner_inputs, &mut inner_outputs); From 6ce84229968e9fd7ce7b7dff497882e060ceac97 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Fri, 16 Feb 2024 16:45:14 -0600 Subject: [PATCH 06/21] the passing of an initializer to coalesce is temporary --- burn-import/src/onnx/coalesce.rs | 39 ++++-- burn-import/src/onnx/from_onnx.rs | 216 +++++++++++++----------------- 2 files changed, 126 insertions(+), 129 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 94f4df35b2..5f66bd196a 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,6 +1,6 @@ use std::{ cell::{RefCell, RefMut}, - collections::HashSet, + collections::{HashMap, HashSet}, iter::Peekable, slice::{Iter, IterMut}, }; @@ -8,16 +8,23 @@ use std::{ use super::{ ir::{AttributeValue, Node, NodeType}, proto_conversion::convert_node_proto, - protos::NodeProto, + protos::{NodeProto, TensorProto}, +}; +use crate::onnx::{ + from_onnx::move_initializer_data, + ir::{ArgType, Data, TensorType}, }; -use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. -pub fn coalesce(node: &mut Node, nodes_iter: &mut Peekable>) { +pub fn coalesce( + node: &mut Node, + nodes_iter: &mut Peekable>, + initializers: &HashMap, +) { match node.node_type { NodeType::Gemm => convert_gemm_to_linear(node), NodeType::MatMul => { - convert_matmul_to_linear(node, nodes_iter); + convert_matmul_to_linear(node, nodes_iter, initializers); } _ => {} } @@ -117,7 +124,11 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec /// /// This function also converts the following Add node into a Linear node if possible. /// Add node is used to represent bias in PyTorch. -pub(crate) fn convert_matmul_to_linear(node: &mut Node, iter_mut: &mut Peekable>) { +pub(crate) fn convert_matmul_to_linear( + node: &mut Node, + iter_mut: &mut Peekable>, + initializers: &HashMap, +) { if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); } @@ -139,10 +150,20 @@ pub(crate) fn convert_matmul_to_linear(node: &mut Node, iter_mut: &mut Peekable< // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { - let peek_node = &convert_node_proto(peek_node); + let mut peek_node = convert_node_proto(peek_node).clone(); + for node_input in peek_node.inputs.iter_mut() { + // self.input_of + // .entry(node_input.name.clone()) + // .and_modify(|f| f.push(i)) + // .or_insert(vec![i]); + if let Some(initializer) = initializers.get(&node_input.name) { + move_initializer_data(initializer, node_input); + } + } println!("next node is {:?}", peek_node); - if is_add_node_with_bias(peek_node, node) { - convert_and_remove_add_node(peek_node, node); + if is_add_node_with_bias(&peek_node, node) { + convert_and_remove_add_node(&peek_node, node); + // You don't have to remove it if it's never stored in the first place let _ = iter_mut.next(); println!("\n\nskipping add node\n\n"); diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 7f6f888407..60c5935f93 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -151,21 +151,16 @@ pub(crate) struct ONNXGraphBuilder { nodes: Vec, inputs: Vec, outputs: Vec, - // old_io_names: HashMap, + node_name_counter: HashMap, outputs_to_move: HashMap, - //map of output names to - output_of: HashMap, //nodes to remove nodes_to_remove: HashSet, constants_map: HashMap, - //constants to lift - postprocess_for_constants: Vec, + constants_types: HashSet, - //identity_nodes - identity_idx: Vec, - //matmul nodes - matmul_nodes: Vec, + ///map from old node name to indices of identity nodes + identity_idx: HashMap, } impl ONNXGraphBuilder { @@ -184,7 +179,7 @@ impl ONNXGraphBuilder { model_proto.graph.output.clone(), ); - let mut nodes = Vec::with_capacity(model_proto.graph.node.len()); + self.nodes = Vec::with_capacity(model_proto.graph.node.len()); let mut nd_idx = 0; let mut node_iter = model_proto.graph.node.iter().peekable(); @@ -192,57 +187,36 @@ impl ONNXGraphBuilder { let mut node = convert_node_proto(node_proto); println!("current_node {:?}", node); for node_input in node.inputs.iter_mut() { - // self.input_of - // .entry(node_input.name.clone()) - // .and_modify(|f| f.push(i)) - // .or_insert(vec![i]); if let Some(initializer) = initializers.get(&node_input.name) { move_initializer_data(initializer, node_input); } } remap_node_type(&mut node); - // for node_output in node.outputs.iter() { - // self.output_of.insert(node_output.name.clone(), nd_idx); - // } - let node_type = node.node_type.clone(); //coalesce(&mut node, &mut node_iter); - self.handle_node_renaming(&node_type, &mut node); + coalesce(&mut node, &mut node_iter, &initializers); + self.handle_node_renaming(&mut node); - //coalesce(&mut node, &mut node_iter); + //self.handle_unsqueeze(&node, nd_idx); - self.handle_unsqueeze(&node_type, &node, nd_idx); - - _ = self.handle_identity(&node_type, &node, nd_idx); - self.handle_coalesce(&mut node, &mut node_iter, nd_idx); + _ = self.handle_identity(&mut node, nd_idx); + self.check_constants(&mut node, nd_idx); + //self.handle_coalesce(&mut node, &mut node_iter, nd_idx); self.handle_rename_io(&mut node, nd_idx, &mut graph_io); - self.check_constants(&node, &node_type, nd_idx); - //NOTE: still not done with this one - - // if !self.nodes_to_remove.contains(&i) && !self.constants_map.contains_key(&node.name) { - // //name stuff - // self.handle_node_renaming(&node_type, &mut node); - // } - nodes.push(RefCell::new(node)); + self.nodes.push(node); nd_idx += 1; } - self.postprocess_unsqueeze(&nodes, &graph_io); - self.postprocess_identity(&nodes, &graph_io); - self.postprocess_constants(&nodes); - self.postprocess_coalesce(&mut nodes); - - self.nodes = nodes - .into_iter() - .enumerate() - .filter_map(|(i, x)| { - if !self.nodes_to_remove.contains(&i) { - Some(x.into_inner()) - } else { - None - } - }) - .collect(); + //self.postprocess_unsqueeze(&nodes, &graph_io); + //self.postprocess_identity(&nodes, &graph_io); + //self.postprocess_constants(&nodes); + //self.postprocess_coalesce(&mut nodes); + let mut i = 0; + self.nodes.retain(|x| { + let res = !self.nodes_to_remove.contains(&i); + i += 1; + res + }); let OnnxGraphIO { inputs, outputs, .. } = graph_io; @@ -250,13 +224,19 @@ impl ONNXGraphBuilder { self.outputs = outputs; } - fn handle_node_renaming(&mut self, node_type: &NodeType, node: &mut Node) { + fn handle_node_renaming(&mut self, node: &mut Node) { + if &node.node_type == &NodeType::Linear { + println!("rename linear node {:?}", node); + } self.node_name_counter - .entry(node_type.clone()) + .entry(node.node_type.clone()) .and_modify(|e| *e += 1) .or_insert(1); - let new_name = - format!("{}{}", node.node_type, self.node_name_counter[&node_type]).to_lowercase(); + let new_name = format!( + "{}{}", + node.node_type, self.node_name_counter[&node.node_type] + ) + .to_lowercase(); node.name = new_name.clone(); } @@ -291,25 +271,17 @@ impl ONNXGraphBuilder { } } - fn check_constants(&mut self, node: &Node, node_type: &NodeType, i: usize) { - if node_type == &NodeType::Constant - || (node_type == &NodeType::Identity && node.inputs[0].value.is_some()) + fn check_constants(&mut self, node: &mut Node, i: usize) { + if &node.node_type == &NodeType::Constant + || (&node.node_type == &NodeType::Identity && node.inputs[0].value.is_some()) { self.constants_map.insert(node.outputs[0].name.clone(), i); - } else if self.constants_types.contains(node_type) { - self.postprocess_for_constants.push(i); - } - } - - fn postprocess_constants(&mut self, nodes: &Vec>) { - for check_idx in self.postprocess_for_constants.iter() { - let mut node = nodes[*check_idx].borrow_mut(); - + } else if self.constants_types.contains(&node.node_type) { for input in node.inputs.iter_mut().skip(1) { println!("checking input {:?} for const", input); if let Some(const_idx) = self.constants_map.get(&input.name) { - let constant = nodes[*const_idx].borrow(); + let constant = &self.nodes[*const_idx]; if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs input.value = constant.inputs[0].value.clone(); @@ -325,8 +297,6 @@ impl ONNXGraphBuilder { } } - //fn get_mult_ref(&self, node_name: String, node_index, ) - fn handle_unsqueeze(&mut self, node_type: &NodeType, node: &Node, i: usize) { if *node_type == NodeType::Unsqueeze { self.outputs_to_move.insert(node.outputs[0].name.clone(), i); @@ -342,65 +312,71 @@ impl ONNXGraphBuilder { } } - fn handle_identity(&mut self, node_type: &NodeType, node: &Node, i: usize) -> bool { - if node_type == &NodeType::Identity && node.inputs[0].value.is_none() { - self.identity_idx.push(i); + fn handle_identity(&mut self, node: &mut Node, i: usize) { + if &node.node_type == &NodeType::Identity && node.inputs[0].value.is_none() { + self.identity_idx.insert(node.outputs[0].name.clone(), i); self.nodes_to_remove.insert(i); - return true; - } - false - } + } else { + node.inputs.iter_mut().for_each(|x| { + if let Some(identity_idx) = self.identity_idx.get(&x.name) { + let input_name = &self.nodes[*identity_idx].inputs[0].name; - fn postprocess_identity(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { - for identity_idx in self.identity_idx.iter() { - let identity_node = nodes[*identity_idx].borrow(); - - let input_name = &identity_node.inputs[0].name; - let identity_output = &identity_node.outputs[0].name; - - // Replace the identity node's output with its input in the connected nodes. - if let Some(indices) = graph_io.get_node_indices(identity_output) { - for node_index in indices { - let mut node = nodes[*node_index].borrow_mut(); - if let Some(matched_input) = - node.inputs.iter_mut().find(|x| x.name == *identity_output) - { - matched_input.name = input_name.clone(); - } + x.name = input_name.clone(); } - } + }); } } - /// The function transforms the graph into a new one where the nodes are coalesced into a single node. - fn handle_coalesce( - &mut self, - node: &mut Node, - _nodes_iter: &mut Peekable>, - i: usize, - ) { - match node.node_type { - NodeType::Gemm => { - println!("Gemm before {:?}\n", node); - convert_gemm_to_linear(node); - self.handle_node_renaming(&node.node_type.clone(), node); - println!("Gemm after {:?}\n", node); - } - NodeType::MatMul => { - self.matmul_nodes.push(i); - } - _ => {} - } - } + // fn postprocess_identity(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { + // for identity_idx in self.identity_idx.iter() { + // let identity_node = nodes[*identity_idx].borrow(); + + // let input_name = &identity_node.inputs[0].name; + // let identity_output = &identity_node.outputs[0].name; + + // // Replace the identity node's output with its input in the connected nodes. + // if let Some(indices) = graph_io.get_node_indices(identity_output) { + // for node_index in indices { + // let mut node = nodes[*node_index].borrow_mut(); + // if let Some(matched_input) = + // node.inputs.iter_mut().find(|x| x.name == *identity_output) + // { + // matched_input.name = input_name.clone(); + // } + // } + // } + // } + // } - fn postprocess_coalesce(&mut self, nodes: &mut Vec>) { - println!("{:?}", self.node_name_counter); - for matmul_index in self.matmul_nodes.clone() { - convert_matmul_to_linear2(nodes, matmul_index, &mut self.nodes_to_remove); - let mut node = nodes[matmul_index].borrow_mut(); - self.handle_node_renaming(&node.node_type.clone(), &mut node) - } - } + //// The function transforms the graph into a new one where the nodes are coalesced into a single node. + // fn handle_coalesce( + // &mut self, + // node: &mut Node, + // _nodes_iter: &mut Peekable>, + // i: usize, + // ) { + // match node.node_type { + // NodeType::Gemm => { + // println!("Gemm before {:?}\n", node); + // convert_gemm_to_linear(node); + // self.handle_node_renaming(&node.node_type.clone(), node); + // println!("Gemm after {:?}\n", node); + // } + // NodeType::MatMul => { + // self.matmul_nodes.push(i); + // } + // _ => {} + // } + // } + + // fn postprocess_coalesce(&mut self, nodes: &mut Vec>) { + // println!("{:?}", self.node_name_counter); + // for matmul_index in self.matmul_nodes.clone() { + // convert_matmul_to_linear2(nodes, matmul_index, &mut self.nodes_to_remove); + // let mut node = nodes[matmul_index].borrow_mut(); + // self.handle_node_renaming(&node.node_type.clone(), &mut node) + // } + // } } /// Open an onnx file and convert it to a Graph (intermediate representation) @@ -475,7 +451,7 @@ pub fn parse_onnx(onnx_path: &Path) -> ONNXGraph { } } -fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { +pub(crate) fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { // If the input name matches the tensor name in the initializer // Convert the initializer to a tensor let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); From 7456ce79b66354f2bf2b49f0c12157829dd51987 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Fri, 16 Feb 2024 17:30:27 -0600 Subject: [PATCH 07/21] cleaning up dead code --- burn-import/src/onnx/coalesce.rs | 14 +-- burn-import/src/onnx/from_onnx.rs | 189 ++++++------------------------ 2 files changed, 43 insertions(+), 160 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 5f66bd196a..331867bd44 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -2,7 +2,7 @@ use std::{ cell::{RefCell, RefMut}, collections::{HashMap, HashSet}, iter::Peekable, - slice::{Iter, IterMut}, + slice::Iter, }; use super::{ @@ -152,10 +152,6 @@ pub(crate) fn convert_matmul_to_linear( if let Some(peek_node) = iter_mut.peek() { let mut peek_node = convert_node_proto(peek_node).clone(); for node_input in peek_node.inputs.iter_mut() { - // self.input_of - // .entry(node_input.name.clone()) - // .and_modify(|f| f.push(i)) - // .or_insert(vec![i]); if let Some(initializer) = initializers.get(&node_input.name) { move_initializer_data(initializer, node_input); } @@ -214,17 +210,17 @@ pub(crate) fn convert_matmul_to_linear2( } /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { - if (peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2) { + if peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2 { println!("\n\ntwo matches"); println!("peek_node.inputs[0].name: {:?}", peek_node.inputs[0].name); println!( "current_node.outputs[0].name: {:?}", current_node.outputs[0].name ); - return ((peek_node.inputs[0].name == current_node.outputs[0].name + return (peek_node.inputs[0].name == current_node.outputs[0].name && peek_node.inputs[1].value.is_some()) || (peek_node.inputs[1].name == current_node.outputs[0].name - && peek_node.inputs[0].value.is_some())); + && peek_node.inputs[0].value.is_some()); } false } @@ -243,7 +239,7 @@ fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { } /// Helper function to convert and remove the Add node -pub(crate) fn convert_node2<'parser>(bias_node: &Node, mut current_node: RefMut<'parser, Node>) { +pub(crate) fn convert_node2(bias_node: &Node, mut current_node: RefMut<'_, Node>) { let bias_input = if bias_node.inputs[0].value.is_some() { bias_node.inputs[0].clone() } else { diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 60c5935f93..027901d9da 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -3,9 +3,7 @@ use std::{ cell::{RefCell, RefMut}, collections::{HashMap, HashSet}, fs::File, - iter::Peekable, path::Path, - slice::Iter, }; use crate::onnx::{ @@ -13,15 +11,12 @@ use crate::onnx::{ }; use super::{ - coalesce::{coalesce, convert_gemm_to_linear, convert_matmul_to_linear}, - protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto}, + coalesce::coalesce, + protos::{ModelProto, TensorProto, ValueInfoProto}, }; use super::dim_inference::dim_inference; -use super::{ - coalesce::convert_matmul_to_linear2, - ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}, -}; +use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}; use protobuf::Message; @@ -97,7 +92,7 @@ impl OnnxGraphIO { let arg = self.outputs.get_mut(*i).unwrap(); arg.name = new_name.to_string(); } - Some(IOEntry::Node(i)) => { + Some(IOEntry::Node(_i)) => { panic!("This output is from another node"); } None => { @@ -199,10 +194,10 @@ impl ONNXGraphBuilder { //self.handle_unsqueeze(&node, nd_idx); - _ = self.handle_identity(&mut node, nd_idx); + self.handle_identity(&mut node, nd_idx); self.check_constants(&mut node, nd_idx); //self.handle_coalesce(&mut node, &mut node_iter, nd_idx); - self.handle_rename_io(&mut node, nd_idx, &mut graph_io); + rename_io(&mut node, nd_idx, &mut graph_io); self.nodes.push(node); nd_idx += 1; @@ -212,7 +207,7 @@ impl ONNXGraphBuilder { //self.postprocess_constants(&nodes); //self.postprocess_coalesce(&mut nodes); let mut i = 0; - self.nodes.retain(|x| { + self.nodes.retain(|_x| { let res = !self.nodes_to_remove.contains(&i); i += 1; res @@ -240,37 +235,6 @@ impl ONNXGraphBuilder { node.name = new_name.clone(); } - fn handle_rename_io(&mut self, node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { - for node_input in node.inputs.iter_mut() { - println!("old output names {:?}", &graph_io.old_io_names); - //println!("out_args{:?}", outputs); - graph_io.add_input(&node_input.name, i); - if let Some(input_name) = graph_io.get_new_name(&node_input.name) { - println!("yeet"); - node_input.passed = true; - node_input.name = input_name.clone(); - } else { - node_input.name = "".to_string(); - node_input.passed = false; - } - } - println!("\n\nchecking outputs"); - let mut out_count = 1; - for output in node.outputs.iter_mut() { - println!("output name: {}", &output.name); - - let new_name = format!("{}_out{}", node.name, out_count); - - graph_io.update(&output.name, &new_name); - - // self.node_output_names - // .insert(output.name.clone(), new_name.clone()); - - output.name = new_name.clone(); - out_count += 1; - } - } - fn check_constants(&mut self, node: &mut Node, i: usize) { if &node.node_type == &NodeType::Constant || (&node.node_type == &NodeType::Identity && node.inputs[0].value.is_some()) @@ -287,7 +251,7 @@ impl ONNXGraphBuilder { input.value = constant.inputs[0].value.clone(); input.ty = constant.inputs[0].ty.clone(); } else { - let arg = convert_constant_value(&constant); + let arg = convert_constant_value(constant); input.value = arg.value; input.ty = arg.ty; } @@ -326,57 +290,6 @@ impl ONNXGraphBuilder { }); } } - - // fn postprocess_identity(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { - // for identity_idx in self.identity_idx.iter() { - // let identity_node = nodes[*identity_idx].borrow(); - - // let input_name = &identity_node.inputs[0].name; - // let identity_output = &identity_node.outputs[0].name; - - // // Replace the identity node's output with its input in the connected nodes. - // if let Some(indices) = graph_io.get_node_indices(identity_output) { - // for node_index in indices { - // let mut node = nodes[*node_index].borrow_mut(); - // if let Some(matched_input) = - // node.inputs.iter_mut().find(|x| x.name == *identity_output) - // { - // matched_input.name = input_name.clone(); - // } - // } - // } - // } - // } - - //// The function transforms the graph into a new one where the nodes are coalesced into a single node. - // fn handle_coalesce( - // &mut self, - // node: &mut Node, - // _nodes_iter: &mut Peekable>, - // i: usize, - // ) { - // match node.node_type { - // NodeType::Gemm => { - // println!("Gemm before {:?}\n", node); - // convert_gemm_to_linear(node); - // self.handle_node_renaming(&node.node_type.clone(), node); - // println!("Gemm after {:?}\n", node); - // } - // NodeType::MatMul => { - // self.matmul_nodes.push(i); - // } - // _ => {} - // } - // } - - // fn postprocess_coalesce(&mut self, nodes: &mut Vec>) { - // println!("{:?}", self.node_name_counter); - // for matmul_index in self.matmul_nodes.clone() { - // convert_matmul_to_linear2(nodes, matmul_index, &mut self.nodes_to_remove); - // let mut node = nodes[matmul_index].borrow_mut(); - // self.handle_node_renaming(&node.node_type.clone(), &mut node) - // } - // } } /// Open an onnx file and convert it to a Graph (intermediate representation) @@ -479,7 +392,7 @@ pub(crate) fn move_initializer_data(initializer: &TensorProto, input: &mut Argum } } -fn move_output_shape<'parser>(mut node: RefMut<'parser, Node>, out_arg: &Argument) { +fn move_output_shape(mut node: RefMut<'_, Node>, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { @@ -496,65 +409,39 @@ fn move_output_shape<'parser>(mut node: RefMut<'parser, Node>, out_arg: &Argumen /// The inputs are renamed to be unique and to be in the format of /// conv2_in1, conv2_in2, etc. This is done to be consistent with /// the naming convention of the nodes and allow to be used as rust identifiers. -fn rename_inputs( - nodes: &mut Vec, - inputs: &mut Vec, - outputs: &mut Vec, -) -> HashMap { - let mut old_names = HashMap::new(); - //println!("inputs: {:#?}", inputs); - //println!("outputs: {:#?}", outputs); - //println!("nodes: {:#?}", nodes); - // rename all graph input names to follow input1, input2, input3, etc. - // (assumes the input names are already unique) - let mut counter = 1; - for input in inputs.iter_mut() { - let old_name = input.name.clone(); - let new_name = format!("input{}", counter); - input.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; - } - - for node in nodes.iter_mut() { - let mut counter = 1; - //println!("node: {:#?}", node); - - // loop through node outputs and rename them and store the new name <-> old name mapping - for output in node.outputs.iter_mut() { - let old_name = output.name.clone(); - let new_name = format!("{}_out{}", node.name, counter); - output.name = new_name.clone(); - old_names.insert(old_name, new_name); - //old_names.insert(old_name, new_name); - counter += 1; - } - } - - // for node in nodes.iter_mut() { - // // loop through node inputs and rename them with previously replaced names - // // and mark them as passed if they are in the old_names map (i.e. they are node outputs) - // for input in node.inputs.iter_mut() { - // if let Some(new_name) = old_names.get(&input.name) { - // input.name = new_name.clone(); - // input.passed = true; - // } else { - // input.name = "".to_string(); // Rename to a placeholder - // input.passed = false; - // } - // } - // } - - // Rename the graph outputs - for output in outputs.iter_mut() { - if let Some(new_name) = old_names.get(&output.name) { - output.name = new_name.clone(); +/// Rename the inputs and output in the graph and return a map of +/// the old names to the new names. +/// +/// The inputs are renamed to be unique and to be in the format of +/// conv2_in1, conv2_in2, etc. This is done to be consistent with +/// the naming convention of the nodes and allow to be used as rust identifiers. +fn rename_io(node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { + for node_input in node.inputs.iter_mut() { + println!("old output names {:?}", &graph_io.old_io_names); + graph_io.add_input(&node_input.name, i); + if let Some(input_name) = graph_io.get_new_name(&node_input.name) { + node_input.passed = true; + node_input.name = input_name.clone(); } else { - log::warn!("Output {:?} not found in old_names", output.name); + node_input.name = "".to_string(); + node_input.passed = false; } } + println!("\n\nchecking outputs"); + let mut out_count = 1; + for output in node.outputs.iter_mut() { + println!("output name: {}", &output.name); + + let new_name = format!("{}_out{}", node.name, out_count); - old_names + graph_io.update(&output.name, &new_name); + + // self.node_output_names + // .insert(output.name.clone(), new_name.clone()); + + output.name = new_name.clone(); + out_count += 1; + } } /// Removes the graph inputs/output that are not used by any node. From 77827bb7fb741192eef66bcc6b41a0bf5d878e7c Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sat, 17 Feb 2024 11:29:02 -0600 Subject: [PATCH 08/21] handled unsqueeze --- burn-import/src/onnx/from_onnx.rs | 58 +++++++++---------------------- 1 file changed, 17 insertions(+), 41 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 6c1c1cc37c..12b2a05d37 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -12,15 +12,15 @@ use crate::onnx::{ use super::{ coalesce::coalesce, + ir::OnnxGraph, protos::{ModelProto, TensorProto, ValueInfoProto}, }; use super::dim_inference::dim_inference; -use super::ir::{ArgType, Argument, Node, NodeType, ONNXGraph, Tensor}; +use super::ir::{ArgType, Argument, Node, NodeType, Tensor}; use protobuf::Message; -const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::BatchNormalization, NodeType::Clip, @@ -29,7 +29,6 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::Dropout, NodeType::Reshape, NodeType::Unsqueeze, - NodeType::Unsqueeze, ]; #[derive(Debug)] @@ -40,6 +39,8 @@ pub(crate) enum IOEntry { } pub(crate) struct OnnxGraphIO { + ///Per Onnx spec "Inputs represent graph inputs or values computed elsewhere in the graph..." + /// Thus all computed inputs are in the list of inputs in a valid Onnx file pub(crate) inputs: Vec, pub(crate) outputs: Vec, ///updated names of outputs of node not stored in the graph @@ -150,7 +151,6 @@ pub(crate) struct ONNXGraphBuilder { outputs: Vec, node_name_counter: HashMap, - outputs_to_move: HashMap, //nodes to remove nodes_to_remove: HashSet, constants_map: HashMap, @@ -177,12 +177,12 @@ impl ONNXGraphBuilder { ); self.nodes = Vec::with_capacity(model_proto.graph.node.len()); - let mut nd_idx = 0; + let mut and_idx = 0; let mut node_iter = model_proto.graph.node.iter().peekable(); while let Some(node_proto) = node_iter.next() { let mut node = convert_node_proto(node_proto); - println!("current_node {:?}", node); + for node_input in node.inputs.iter_mut() { if let Some(initializer) = initializers.get(&node_input.name) { move_initializer_data(initializer, node_input); @@ -194,20 +194,17 @@ impl ONNXGraphBuilder { coalesce(&mut node, &mut node_iter, &initializers); self.handle_node_renaming(&mut node); - //self.handle_unsqueeze(&node, nd_idx); + self.handle_unsqueeze(&mut node, &graph_io); - self.handle_identity(&mut node, nd_idx); - self.check_constants(&mut node, nd_idx); - //self.handle_coalesce(&mut node, &mut node_iter, nd_idx); - rename_io(&mut node, nd_idx, &mut graph_io); + self.handle_identity(&mut node, and_idx); + self.check_constants(&mut node, and_idx); + //self.handle_coalesce(&mut node, &mut node_iter, and_idx); + rename_io(&mut node, and_idx, &mut graph_io); self.nodes.push(node); - nd_idx += 1; + and_idx += 1; } - //self.postprocess_unsqueeze(&nodes, &graph_io); - //self.postprocess_identity(&nodes, &graph_io); - //self.postprocess_constants(&nodes); - //self.postprocess_coalesce(&mut nodes); + let mut i = 0; self.nodes.retain(|_x| { let res = !self.nodes_to_remove.contains(&i); @@ -263,16 +260,9 @@ impl ONNXGraphBuilder { } } - fn handle_unsqueeze(&mut self, node_type: &NodeType, node: &Node, i: usize) { - if *node_type == NodeType::Unsqueeze { - self.outputs_to_move.insert(node.outputs[0].name.clone(), i); - } - } - - fn postprocess_unsqueeze(&mut self, nodes: &Vec>, graph_io: &OnnxGraphIO) { - for (old_output_name, i) in self.outputs_to_move.iter() { - if let Some(in_arg) = graph_io.get(old_output_name) { - let node = nodes[*i].borrow_mut(); + fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { + if node.node_type == NodeType::Unsqueeze { + if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { move_output_shape(node, in_arg); } } @@ -340,20 +330,8 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - let my_nodes = nodes.clone(); - - for i in 0..nodes.len() { - if nodes[i] != my_nodes[i] { - println!("{} != {}", nodes[i].name, my_nodes[i].name); - } - } - // println!("nodes: {:#?}", nodes); - // println!("inner inputs: {:#?}", inner_inputs); - // println!("inner outputs: {:#?}", inner_outputs); - // Infer shapes and update the inputs and outputs dim_inference(&mut nodes, &inner_inputs, &mut inner_outputs); - println!("inner outputs after dim inference: {:?}", inner_outputs); // Remove the graph inputs/output that are not used by any node remove_unused_graph_inputs(&mut inner_inputs, &mut inner_outputs, &nodes); @@ -394,9 +372,7 @@ pub(crate) fn move_initializer_data(initializer: &TensorProto, input: &mut Argum } } - - -fn move_output_shape(mut node: RefMut<'_, Node>, out_arg: &Argument) { +fn move_output_shape(mut node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { From b3a6ebc615aefe8a84e58ace06fcbe66f44a14b9 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 18 Feb 2024 13:35:51 -0600 Subject: [PATCH 09/21] reworked node initialization and dim inference --- burn-import/src/onnx/coalesce.rs | 61 +----- burn-import/src/onnx/dim_inference.rs | 146 ++++--------- burn-import/src/onnx/from_onnx.rs | 265 +++++++++++++++++------ burn-import/src/onnx/ir.rs | 45 ++++ burn-import/src/onnx/proto_conversion.rs | 10 +- 5 files changed, 306 insertions(+), 221 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 331867bd44..a2acd87d87 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,17 +1,16 @@ use std::{ - cell::{RefCell, RefMut}, - collections::{HashMap, HashSet}, + cell::{RefMut}, iter::Peekable, slice::Iter, }; use super::{ + from_onnx::OnnxGraphIO, ir::{AttributeValue, Node, NodeType}, proto_conversion::convert_node_proto, - protos::{NodeProto, TensorProto}, + protos::{NodeProto}, }; use crate::onnx::{ - from_onnx::move_initializer_data, ir::{ArgType, Data, TensorType}, }; @@ -19,12 +18,12 @@ use crate::onnx::{ pub fn coalesce( node: &mut Node, nodes_iter: &mut Peekable>, - initializers: &HashMap, + graph_io: &OnnxGraphIO, ) { match node.node_type { NodeType::Gemm => convert_gemm_to_linear(node), NodeType::MatMul => { - convert_matmul_to_linear(node, nodes_iter, initializers); + convert_matmul_to_linear(node, nodes_iter, graph_io); } _ => {} } @@ -127,7 +126,7 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec pub(crate) fn convert_matmul_to_linear( node: &mut Node, iter_mut: &mut Peekable>, - initializers: &HashMap, + graph_io: &OnnxGraphIO, ) { if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); @@ -150,12 +149,7 @@ pub(crate) fn convert_matmul_to_linear( // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { - let mut peek_node = convert_node_proto(peek_node).clone(); - for node_input in peek_node.inputs.iter_mut() { - if let Some(initializer) = initializers.get(&node_input.name) { - move_initializer_data(initializer, node_input); - } - } + let peek_node = convert_node_proto(peek_node, graph_io).clone(); println!("next node is {:?}", peek_node); if is_add_node_with_bias(&peek_node, node) { convert_and_remove_add_node(&peek_node, node); @@ -167,47 +161,6 @@ pub(crate) fn convert_matmul_to_linear( } } -/// This function converts a MatMul node into a Linear node if possible. -/// -/// PyTorch and other frameworks use MatMul node to represent Linear layer. -/// -/// This function also converts the following Add node into a Linear node if possible. -/// Add node is used to represent bias in PyTorch. -pub(crate) fn convert_matmul_to_linear2( - node_vec: &Vec>, - node_index: usize, - nodes_to_remove: &mut HashSet, -) { - let mut node = node_vec[node_index].borrow_mut(); - if node.inputs.len() != 2 { - panic!("MatMul node must have 2 inputs"); - } - - // if the second input does not have a value, it is not a weight, then proceed to the next node - if node.inputs[1].value.is_none() { - return; - } - - // Check if the second input is a 2D tensor - if let ArgType::Tensor(ref tensor_type) = node.inputs[1].ty { - assert_eq!(tensor_type.dim, 2, "Weight must be a 2D tensor"); - } else { - panic!("Tensor input is expected"); - } - - // Convert the node to Linear - node.node_type = NodeType::Linear; - - // Check the next node for potential conversion - - if node_index + 1 < node_vec.len() { - let next_node = node_vec[node_index + 1].borrow(); - if is_add_node_with_bias(&next_node, &node) { - convert_node2(&next_node, node); - nodes_to_remove.insert(node_index + 1); - } - } -} /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { if peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2 { diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index f7d7b72663..4893ad5112 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -1,115 +1,63 @@ use core::panic; -use std::collections::HashMap; use protobuf::Enum; use super::{ + from_onnx::OnnxGraphIO, ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, op_configuration::flatten_config, protos::tensor_proto::DataType, }; -struct TensorDimUpdater { - arguments: HashMap, -} - -impl TensorDimUpdater { - fn new(inputs: &[Argument]) -> Self { - let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); - - inputs.iter().for_each(|input| { - arguments.insert(input.name.clone(), input.clone()); - }); - - Self { arguments } - } - /// Update tensor inputs from the registered arguments and returns the number of input - /// updated. - fn update_tensor_inputs(&self, node: &mut Node) -> usize { - self.update_arguments(&mut node.inputs) - } - - /// Update the arguments struct from the node output tensors and return the number of output - /// updated. - fn update_tensor_outputs(&mut self, node: &Node) -> usize { - node.outputs - .iter() - .map(|arg| { - self.arguments.insert(arg.name.clone(), arg.clone()); - }) - .count() - } - - fn update_arguments(&self, arguments: &mut [Argument]) -> usize { - arguments - .iter_mut() - .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) - .map(|(arg, input)| { - input.ty = arg.ty.clone(); - }) - .count() - } -} - /// Infer the dimension of each output tensor and update them. -pub fn dim_inference( - nodes: &mut Vec, - graph_inputs: &Vec, - graph_outputs: &mut Vec, -) { - let mut updater = TensorDimUpdater::new(graph_inputs); - - for node in nodes.iter_mut() { - updater.update_tensor_inputs(node); - - match node.node_type { - NodeType::Add => same_as_input(node), - NodeType::AveragePool2d => same_as_input(node), - NodeType::BatchNormalization => same_as_input(node), - NodeType::Cast => cast_update_outputs(node), - NodeType::Clip => same_as_input(node), - NodeType::Concat => concat_update_outputs(node), - NodeType::Constant => constant_update_outputs(node), - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), - NodeType::Cos => same_as_input(node), - NodeType::Div => same_as_input(node), - NodeType::Dropout => same_as_input(node), - NodeType::Equal => equal_update_outputs(node), - NodeType::Erf => same_as_input(node), - NodeType::Exp => same_as_input(node), - NodeType::Flatten => flatten_update_outputs(node), - NodeType::Gelu => same_as_input(node), - NodeType::GatherElements => same_as_input(node), - NodeType::GlobalAveragePool => same_as_input(node), - NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), - NodeType::Linear => linear_update_outputs(node), - NodeType::Log => same_as_input(node), - NodeType::LogSoftmax => same_as_input(node), - NodeType::MaxPool2d => same_as_input(node), - NodeType::Mul => same_as_input(node), - NodeType::Neg => same_as_input(node), - NodeType::Reciprocal => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Relu => same_as_input(node), - NodeType::Reshape => reshape_update_outputs(node), - NodeType::Shape => shape_update_outputs(node), - NodeType::Sigmoid => same_as_input(node), - NodeType::Softmax => same_as_input(node), - NodeType::Sqrt => same_as_input(node), - NodeType::Sub => same_as_input(node), - NodeType::Tanh => same_as_input(node), - NodeType::Transpose => same_as_input(node), - NodeType::Unsqueeze => unsqueeze_update_output_or_node(node), - NodeType::Pow => same_as_input(node), - // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. - _ => temporary_pass_through_stub(node), - } - - updater.update_tensor_outputs(node); +pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { + //graph_io.copy_to_node_inputs(node); + + match node.node_type { + NodeType::Add => same_as_input(node), + NodeType::AveragePool2d => same_as_input(node), + NodeType::BatchNormalization => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), + NodeType::Clip => same_as_input(node), + NodeType::Concat => concat_update_outputs(node), + NodeType::Constant => constant_update_outputs(node), + NodeType::Conv1d => conv1d_update_outputs(node), + NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::Cos => same_as_input(node), + NodeType::Div => same_as_input(node), + NodeType::Dropout => same_as_input(node), + NodeType::Equal => equal_update_outputs(node), + NodeType::Erf => same_as_input(node), + NodeType::Exp => same_as_input(node), + NodeType::Flatten => flatten_update_outputs(node), + NodeType::Gelu => same_as_input(node), + NodeType::GatherElements => same_as_input(node), + NodeType::GlobalAveragePool => same_as_input(node), + NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), + NodeType::Linear => linear_update_outputs(node), + NodeType::Log => same_as_input(node), + NodeType::LogSoftmax => same_as_input(node), + NodeType::MaxPool2d => same_as_input(node), + NodeType::Mul => same_as_input(node), + NodeType::Neg => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), + NodeType::ReduceMean => mean_update_outputs(node), + NodeType::Relu => same_as_input(node), + NodeType::Reshape => reshape_update_outputs(node), + NodeType::Shape => shape_update_outputs(node), + NodeType::Sigmoid => same_as_input(node), + NodeType::Softmax => same_as_input(node), + NodeType::Sqrt => same_as_input(node), + NodeType::Sub => same_as_input(node), + NodeType::Tanh => same_as_input(node), + NodeType::Transpose => same_as_input(node), + NodeType::Unsqueeze => unsqueeze_update_output_or_node(node), + NodeType::Pow => same_as_input(node), + // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. + _ => temporary_pass_through_stub(node), } - updater.update_arguments(graph_outputs); + graph_io.update_tensor_output(node); } fn constant_update_outputs(node: &mut Node) { diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 12b2a05d37..173e496018 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -1,6 +1,4 @@ use std::{ - borrow::BorrowMut, - cell::{RefCell, RefMut}, collections::{HashMap, HashSet}, fs::File, path::Path, @@ -43,15 +41,21 @@ pub(crate) struct OnnxGraphIO { /// Thus all computed inputs are in the list of inputs in a valid Onnx file pub(crate) inputs: Vec, pub(crate) outputs: Vec, + /// Initializers or constants, for more information on why these are treated the + /// same, please see: https://github.com/onnx/onnx/issues/4677 + pub(crate) constants: HashMap, + //pub(crate) initializers: Vec, ///updated names of outputs of node not stored in the graph - node_out: Vec>, - ///map of old input names to a vec of indices of nodes that use it - input_of: HashMap>, + node_out: Vec, pub(crate) old_io_names: HashMap, } impl OnnxGraphIO { - pub(crate) fn new(inputs: Vec, outputs: Vec) -> Self { + pub(crate) fn new( + inputs: &Vec, + outputs: &Vec, + initializers: &Vec, + ) -> Self { let mut old_io_names = HashMap::new(); let mut in_count = 1; let inputs = inputs @@ -75,45 +79,164 @@ impl OnnxGraphIO { Argument::try_from(x.clone()).unwrap() }) .collect::>(); - let in_len = inputs.len(); + + let constants = initializers + .iter() + .map(|x| (x.name.clone(), Argument::from_initializer(x))) + .collect::>(); + Self { inputs, outputs, + constants, node_out: Vec::new(), old_io_names, - input_of: HashMap::with_capacity(in_len), } } - fn update(&mut self, old_name: &str, new_name: &str) { - match self.old_io_names.get(old_name) { - Some(IOEntry::In(i)) => { - let arg = self.inputs.get_mut(*i).unwrap(); - arg.name = new_name.to_string(); + fn update_name(&mut self, arg: &Argument, new_name: &str) { + match self.old_io_names.get(&arg.name) { + Some(IOEntry::In(_)) => { + // let arg = self.inputs.get_mut(*i).unwrap(); + // arg.name = new_name.to_string(); + panic!("input names are set from the beginning"); } Some(IOEntry::Out(i)) => { let arg = self.outputs.get_mut(*i).unwrap(); arg.name = new_name.to_string(); } - Some(IOEntry::Node(_i)) => { - panic!("This output is from another node"); + Some(IOEntry::Node(i)) => { + let arg = self.node_out.get_mut(*i).unwrap(); + arg.name = new_name.to_string(); } + None => { + //Constants, Casts let idx = self.node_out.len(); - self.node_out.push(Box::new(new_name.to_string())); self.old_io_names - .insert(old_name.to_string(), IOEntry::Node(idx)); + .insert(arg.name.clone(), IOEntry::Node(idx)); + self.node_out.push(arg.clone()); + self.node_out[idx].name = new_name.to_string(); } } } - fn add_input(&mut self, old_name: &str, node_idx: usize) { - self.input_of - .entry(old_name.to_string()) - .and_modify(|f| f.push(node_idx)) - .or_insert(vec![node_idx]); + fn update_value(&mut self, updated_arg: &Argument) { + match self.old_io_names.get(&updated_arg.name) { + Some(IOEntry::Node(i)) => { + let arg = self.node_out.get_mut(*i).unwrap(); + arg.copy_all_but_name(updated_arg); + } + _ => panic!( + "Tried to update the value of {:?} which was the output from another node", + &updated_arg.name + ), + } } - fn get(&self, old_name: &str) -> Option<&Argument> { + pub fn init_in(&self, proto_str: &str) -> Argument { + match self.old_io_names.get(proto_str) { + None => { + if let Some(init_arg) = self.constants.get(proto_str) { + init_arg.clone() + } else { + Argument::new(proto_str.to_string()) + } + } + + Some(IOEntry::In(i)) => { + let mut arg = self.inputs[*i].clone(); + arg.name = proto_str.to_string(); + arg + } + Some(IOEntry::Node(i)) => { + let mut arg = self.node_out[*i].clone(); + arg.name = proto_str.to_string(); + arg + } + Some(IOEntry::Out(_)) => { + panic!("graph out {} can't be an input", &proto_str) + } + } + } + + fn insert(&mut self, arg: &Argument, new_name: &str) { + if let Some(idx) = self.old_io_names.get(&arg.name) { + if let IOEntry::Node(idx) = idx { + if self.node_out[*idx].name == arg.name { + self.node_out[*idx].name = new_name.to_string(); + return; + } + } else { + panic!("arg entry with old name {} is a graph IO", &arg.name); + } + } + + let idx = self.node_out.len(); + self.old_io_names + .insert(arg.name.clone(), IOEntry::Node(idx)); + self.node_out.push(arg.clone()); + self.node_out[idx].name = new_name.to_string(); + } + ///Copy data from the graph inputs to the nodes inputs + pub(crate) fn copy_to_node_inputs(&self, node: &mut Node) { + for input in node.inputs.iter_mut() { + if input.name.is_empty() { + continue; + } + match self.old_io_names.get(&input.name) { + Some(IOEntry::In(i)) => { + let arg = self.inputs.get(*i).unwrap(); + input.copy_all_but_name(arg); + } + Some(IOEntry::Out(_i)) => { + panic!("Output should only contain final outputs"); + } + Some(IOEntry::Node(i)) => { + let arg = self.node_out.get(*i).unwrap(); + input.copy_all_but_name(arg); + } + None => { + //happens with initializers + // println!("io names: {:?}", &self.old_io_names); + + // panic!("Failure when copying nonexistent io to input {} for node {}\nShouldn't happen", &input.name, &node.name); + } + } + } + } + ///iterate over the nodes output and copy them to the graph IO + pub(crate) fn update_tensor_output(&mut self, node: &Node) { + for node_output in node.outputs.iter() { + match self.old_io_names.get(&node_output.name) { + Some(IOEntry::In(i)) => { + let arg = self.inputs.get_mut(*i).unwrap(); + arg.copy_all_but_name(node_output); + } + Some(IOEntry::Out(i)) => { + let arg = self.outputs.get_mut(*i).unwrap(); + arg.copy_all_but_name(node_output); + } + Some(IOEntry::Node(_)) => { + panic!("This output is from another node"); + } + None => { + println!("inserting with name {:?}", &node_output.name); + let idx = self.node_out.len(); + self.old_io_names + .insert(node_output.name.clone(), IOEntry::Node(idx)); + self.node_out.push(node_output.clone()); + } + } + } + } + // fn add_input(&mut self, old_name: &str, node_idx: usize) { + // self.input_of + // .entry(old_name.to_string()) + // .and_modify(|f| f.push(node_idx)) + // .or_insert(vec![node_idx]); + // } + + pub(crate) fn get(&self, old_name: &str) -> Option<&Argument> { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => self.inputs.get(*i), Some(IOEntry::Out(i)) => self.outputs.get(*i), @@ -126,10 +249,10 @@ impl OnnxGraphIO { let new_name = match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => Some(self.inputs[*i].name.clone()), Some(IOEntry::Out(i)) => Some(self.outputs[*i].name.clone()), - Some(IOEntry::Node(i)) => Some(*self.node_out[*i].clone()), + Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()), None => None, }; - + println!("old name {:?}", &old_name); println!("new name value {:?}", &new_name); if Some(old_name.to_string()) == new_name { println!("old name hasn't changed: {}", old_name); @@ -139,9 +262,9 @@ impl OnnxGraphIO { } } - fn get_node_indices(&self, old_input_name: &str) -> Option<&Vec> { - self.input_of.get(old_input_name) - } + // fn get_node_indices(&self, old_input_name: &str) -> Option<&Vec> { + // self.input_of.get(old_input_name) + // } } #[derive(Default)] @@ -164,16 +287,17 @@ impl ONNXGraphBuilder { pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); // Convert initializers to hashmap for faster lookup - let initializers = model_proto - .graph - .initializer - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); + // let initializers = model_proto + // .graph + // .initializer + // .iter() + // .map(|x| (x.name.clone(), x.clone())) + // .collect::>(); let mut graph_io = OnnxGraphIO::new( - model_proto.graph.input.clone(), - model_proto.graph.output.clone(), + &model_proto.graph.input, + &model_proto.graph.output, + &model_proto.graph.initializer, ); self.nodes = Vec::with_capacity(model_proto.graph.node.len()); @@ -181,25 +305,24 @@ impl ONNXGraphBuilder { let mut node_iter = model_proto.graph.node.iter().peekable(); while let Some(node_proto) = node_iter.next() { - let mut node = convert_node_proto(node_proto); + let mut node = convert_node_proto(node_proto, &graph_io); - for node_input in node.inputs.iter_mut() { - if let Some(initializer) = initializers.get(&node_input.name) { - move_initializer_data(initializer, node_input); - } - } remap_node_type(&mut node); - //coalesce(&mut node, &mut node_iter); - coalesce(&mut node, &mut node_iter, &initializers); + coalesce(&mut node, &mut node_iter, &graph_io); self.handle_node_renaming(&mut node); self.handle_unsqueeze(&mut node, &graph_io); self.handle_identity(&mut node, and_idx); - self.check_constants(&mut node, and_idx); + self.check_constants(&mut node, and_idx, &mut graph_io); + + if node.node_type != NodeType::Identity { + dim_inference(&mut node, &mut graph_io); + } + //self.handle_coalesce(&mut node, &mut node_iter, and_idx); - rename_io(&mut node, and_idx, &mut graph_io); + rename_io(&mut node, &mut graph_io); self.nodes.push(node); and_idx += 1; @@ -234,28 +357,35 @@ impl ONNXGraphBuilder { node.name = new_name.clone(); } - fn check_constants(&mut self, node: &mut Node, i: usize) { + fn check_constants(&mut self, node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { if &node.node_type == &NodeType::Constant || (&node.node_type == &NodeType::Identity && node.inputs[0].value.is_some()) { self.constants_map.insert(node.outputs[0].name.clone(), i); } else if self.constants_types.contains(&node.node_type) { + println!("lift const type match for node {:?}\n\n", &node); for input in node.inputs.iter_mut().skip(1) { println!("checking input {:?} for const", input); - + println!("constants map {:?}", &self.constants_map); if let Some(const_idx) = self.constants_map.get(&input.name) { + println!("\nMATCH\n"); let constant = &self.nodes[*const_idx]; + println!("constant node {:?}", constant); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs input.value = constant.inputs[0].value.clone(); input.ty = constant.inputs[0].ty.clone(); + graph_io.update_value(input); } else { let arg = convert_constant_value(constant); input.value = arg.value; input.ty = arg.ty; + graph_io.update_value(input); } self.nodes_to_remove.insert(*const_idx); + println! {"\nupdated input {:?}", input}; } + //TODO: Future me, right now if the constant is written to the } } } @@ -270,6 +400,7 @@ impl ONNXGraphBuilder { fn handle_identity(&mut self, node: &mut Node, i: usize) { if &node.node_type == &NodeType::Identity && node.inputs[0].value.is_none() { + println!("\nfound identity node:\n{:?}\n", &node); self.identity_idx.insert(node.outputs[0].name.clone(), i); self.nodes_to_remove.insert(i); } else { @@ -320,7 +451,7 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { builder.node_gen(&onnx_model); let ONNXGraphBuilder { - mut nodes, + nodes, inputs: mut inner_inputs, outputs: mut inner_outputs, .. @@ -330,8 +461,6 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - // Infer shapes and update the inputs and outputs - dim_inference(&mut nodes, &inner_inputs, &mut inner_outputs); // Remove the graph inputs/output that are not used by any node remove_unused_graph_inputs(&mut inner_inputs, &mut inner_outputs, &nodes); @@ -372,7 +501,7 @@ pub(crate) fn move_initializer_data(initializer: &TensorProto, input: &mut Argum } } -fn move_output_shape(mut node: &mut Node, out_arg: &Argument) { +fn move_output_shape(node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { @@ -391,14 +520,11 @@ fn move_output_shape(mut node: &mut Node, out_arg: &Argument) { /// the naming convention of the nodes and allow to be used as rust identifiers. /// Rename the inputs and output in the graph and return a map of /// the old names to the new names. -/// -/// The inputs are renamed to be unique and to be in the format of -/// conv2_in1, conv2_in2, etc. This is done to be consistent with -/// the naming convention of the nodes and allow to be used as rust identifiers. -fn rename_io(node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { +fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { + println!("checking inputs for node {:?}", &node.name); for node_input in node.inputs.iter_mut() { println!("old output names {:?}", &graph_io.old_io_names); - graph_io.add_input(&node_input.name, i); + //graph_io.add_input(&node_input.name, i); if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; node_input.name = input_name.clone(); @@ -409,18 +535,25 @@ fn rename_io(node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { } println!("\n\nchecking outputs"); let mut out_count = 1; - for output in node.outputs.iter_mut() { - println!("output name: {}", &output.name); - + if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { + println!("it's a constant"); let new_name = format!("{}_out{}", node.name, out_count); + graph_io.insert(&node.outputs[0], &new_name); + node.outputs[0].name = new_name; + } else { + for output in node.outputs.iter_mut() { + println!("output name: {}", &output.name); + + let new_name = format!("{}_out{}", node.name, out_count); - graph_io.update(&output.name, &new_name); + graph_io.update_name(output, &new_name); - // self.node_output_names - // .insert(output.name.clone(), new_name.clone()); + // self.node_output_names + // .insert(output.name.clone(), new_name.clone()); - output.name = new_name.clone(); - out_count += 1; + output.name = new_name.clone(); + out_count += 1; + } } } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index d6d97f3549..1b5ff22755 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -3,6 +3,8 @@ use half::f16; use std::{collections::HashMap, fmt::Formatter}; use strum_macros::{Display, EnumString}; +use super::protos::TensorProto; + pub type Dim = usize; pub type Shape = Vec; @@ -23,6 +25,49 @@ pub struct Argument { pub passed: bool, } +impl Argument { + ///Copy everything except the name from the other argument + pub fn copy_all_but_name(&mut self, other_arg: &Argument) { + self.ty = other_arg.ty.clone(); + self.value = other_arg.value.clone(); + self.passed = other_arg.passed; + } + + pub fn from_initializer(initializer: &TensorProto) -> Argument { + let name = initializer.name.clone(); + let tensor = Tensor::try_from(initializer.clone()) + .unwrap_or_else(|_| panic!("invalid tensor {}", &initializer.name)); + + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + let value = if tensor.data.is_some() { + Some(tensor.data.clone().unwrap().into_scalar()) + } else { + None + }; + let ty = ArgType::Scalar(tensor.elem_type); + + Self { + name, + ty, + value, + passed: false, + } + } else { + Self { + name, + ty: ArgType::Tensor(TensorType { + elem_type: tensor.elem_type, + dim: tensor.dim, + shape: tensor.shape, + }), + value: tensor.data.clone(), + passed: false, + } + } + } +} + /// The type of an argument. #[derive(Debug, Clone)] pub enum ArgType { diff --git a/burn-import/src/onnx/proto_conversion.rs b/burn-import/src/onnx/proto_conversion.rs index 1fe5aac54a..e4c324e245 100644 --- a/burn-import/src/onnx/proto_conversion.rs +++ b/burn-import/src/onnx/proto_conversion.rs @@ -2,6 +2,7 @@ use std::str::{from_utf8, FromStr}; use crate::onnx::ir::TensorType; +use super::from_onnx::OnnxGraphIO; use super::ir::Dim; use super::ir::{ ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, @@ -179,12 +180,17 @@ pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { result } -pub fn convert_node_proto(node: &NodeProto) -> Node { +pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node { let name = node.name.clone(); log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); - let inputs = node.input.clone().into_iter().map(Argument::new).collect(); + let inputs = node + .input + .clone() + .into_iter() + .map(|x| graph_io.init_in(&x)) + .collect(); let outputs = node.output.clone().into_iter().map(Argument::new).collect(); From 462445111f0e81f26ae6fa4bc1b23ec30dc83398 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 18 Feb 2024 15:11:04 -0600 Subject: [PATCH 10/21] mainly cleanup --- burn-import/src/onnx/coalesce.rs | 12 +-- burn-import/src/onnx/from_onnx.rs | 170 +++++++----------------------- burn-import/src/onnx/ir.rs | 3 +- 3 files changed, 43 insertions(+), 142 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index a2acd87d87..7255ba5ed7 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,18 +1,12 @@ -use std::{ - cell::{RefMut}, - iter::Peekable, - slice::Iter, -}; +use std::{cell::RefMut, iter::Peekable, slice::Iter}; use super::{ from_onnx::OnnxGraphIO, ir::{AttributeValue, Node, NodeType}, proto_conversion::convert_node_proto, - protos::{NodeProto}, -}; -use crate::onnx::{ - ir::{ArgType, Data, TensorType}, + protos::NodeProto, }; +use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. pub fn coalesce( diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 173e496018..01e8373406 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -37,13 +37,12 @@ pub(crate) enum IOEntry { } pub(crate) struct OnnxGraphIO { - ///Per Onnx spec "Inputs represent graph inputs or values computed elsewhere in the graph..." - /// Thus all computed inputs are in the list of inputs in a valid Onnx file + /// The inputs for the Graph pub(crate) inputs: Vec, + /// The outputs for the Graph pub(crate) outputs: Vec, - /// Initializers or constants, for more information on why these are treated the - /// same, please see: https://github.com/onnx/onnx/issues/4677 - pub(crate) constants: HashMap, + /// Initializers + pub(crate) initializers: HashMap, //pub(crate) initializers: Vec, ///updated names of outputs of node not stored in the graph node_out: Vec, @@ -88,7 +87,7 @@ impl OnnxGraphIO { Self { inputs, outputs, - constants, + initializers: constants, node_out: Vec::new(), old_io_names, } @@ -97,8 +96,6 @@ impl OnnxGraphIO { fn update_name(&mut self, arg: &Argument, new_name: &str) { match self.old_io_names.get(&arg.name) { Some(IOEntry::In(_)) => { - // let arg = self.inputs.get_mut(*i).unwrap(); - // arg.name = new_name.to_string(); panic!("input names are set from the beginning"); } Some(IOEntry::Out(i)) => { @@ -109,34 +106,22 @@ impl OnnxGraphIO { let arg = self.node_out.get_mut(*i).unwrap(); arg.name = new_name.to_string(); } - None => { - //Constants, Casts - let idx = self.node_out.len(); - self.old_io_names - .insert(arg.name.clone(), IOEntry::Node(idx)); - self.node_out.push(arg.clone()); - self.node_out[idx].name = new_name.to_string(); - } - } - } - fn update_value(&mut self, updated_arg: &Argument) { - match self.old_io_names.get(&updated_arg.name) { - Some(IOEntry::Node(i)) => { - let arg = self.node_out.get_mut(*i).unwrap(); - arg.copy_all_but_name(updated_arg); + //Constants, Casts wound up here before API changes + panic!( + "Tried to update the name of {} to {} but entry doesn't exist in the map", + arg.name, new_name + ) } - _ => panic!( - "Tried to update the value of {:?} which was the output from another node", - &updated_arg.name - ), } } + ///Used to initialize the input arguments for nodes. Names need to remain the same because + /// currently the old names are the key for accessing the Argument pub fn init_in(&self, proto_str: &str) -> Argument { match self.old_io_names.get(proto_str) { None => { - if let Some(init_arg) = self.constants.get(proto_str) { + if let Some(init_arg) = self.initializers.get(proto_str) { init_arg.clone() } else { Argument::new(proto_str.to_string()) @@ -146,6 +131,7 @@ impl OnnxGraphIO { Some(IOEntry::In(i)) => { let mut arg = self.inputs[*i].clone(); arg.name = proto_str.to_string(); + arg.passed = true; arg } Some(IOEntry::Node(i)) => { @@ -177,50 +163,24 @@ impl OnnxGraphIO { self.node_out.push(arg.clone()); self.node_out[idx].name = new_name.to_string(); } - ///Copy data from the graph inputs to the nodes inputs - pub(crate) fn copy_to_node_inputs(&self, node: &mut Node) { - for input in node.inputs.iter_mut() { - if input.name.is_empty() { - continue; - } - match self.old_io_names.get(&input.name) { - Some(IOEntry::In(i)) => { - let arg = self.inputs.get(*i).unwrap(); - input.copy_all_but_name(arg); - } - Some(IOEntry::Out(_i)) => { - panic!("Output should only contain final outputs"); - } - Some(IOEntry::Node(i)) => { - let arg = self.node_out.get(*i).unwrap(); - input.copy_all_but_name(arg); - } - None => { - //happens with initializers - // println!("io names: {:?}", &self.old_io_names); - // panic!("Failure when copying nonexistent io to input {} for node {}\nShouldn't happen", &input.name, &node.name); - } - } - } - } ///iterate over the nodes output and copy them to the graph IO pub(crate) fn update_tensor_output(&mut self, node: &Node) { for node_output in node.outputs.iter() { match self.old_io_names.get(&node_output.name) { Some(IOEntry::In(i)) => { let arg = self.inputs.get_mut(*i).unwrap(); - arg.copy_all_but_name(node_output); + arg.copy_value(node_output); } Some(IOEntry::Out(i)) => { let arg = self.outputs.get_mut(*i).unwrap(); - arg.copy_all_but_name(node_output); + arg.copy_value(node_output); } Some(IOEntry::Node(_)) => { panic!("This output is from another node"); } None => { - println!("inserting with name {:?}", &node_output.name); + log::debug!("inserting with name {:?}", &node_output.name); let idx = self.node_out.len(); self.old_io_names .insert(node_output.name.clone(), IOEntry::Node(idx)); @@ -229,12 +189,6 @@ impl OnnxGraphIO { } } } - // fn add_input(&mut self, old_name: &str, node_idx: usize) { - // self.input_of - // .entry(old_name.to_string()) - // .and_modify(|f| f.push(node_idx)) - // .or_insert(vec![node_idx]); - // } pub(crate) fn get(&self, old_name: &str) -> Option<&Argument> { match self.old_io_names.get(old_name) { @@ -246,25 +200,13 @@ impl OnnxGraphIO { } fn get_new_name(&self, old_name: &str) -> Option { - let new_name = match self.old_io_names.get(old_name) { + match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => Some(self.inputs[*i].name.clone()), Some(IOEntry::Out(i)) => Some(self.outputs[*i].name.clone()), Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()), None => None, - }; - println!("old name {:?}", &old_name); - println!("new name value {:?}", &new_name); - if Some(old_name.to_string()) == new_name { - println!("old name hasn't changed: {}", old_name); - None - } else { - new_name } } - - // fn get_node_indices(&self, old_input_name: &str) -> Option<&Vec> { - // self.input_of.get(old_input_name) - // } } #[derive(Default)] @@ -286,13 +228,6 @@ pub(crate) struct ONNXGraphBuilder { impl ONNXGraphBuilder { pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - // Convert initializers to hashmap for faster lookup - // let initializers = model_proto - // .graph - // .initializer - // .iter() - // .map(|x| (x.name.clone(), x.clone())) - // .collect::>(); let mut graph_io = OnnxGraphIO::new( &model_proto.graph.input, @@ -335,15 +270,20 @@ impl ONNXGraphBuilder { res }); let OnnxGraphIO { - inputs, outputs, .. + mut inputs, + mut outputs, + old_io_names, + .. } = graph_io; + + //remove_unused_graph_inputs(&mut inputs, &mut outputs, &old_io_names); self.inputs = inputs; self.outputs = outputs; } fn handle_node_renaming(&mut self, node: &mut Node) { if &node.node_type == &NodeType::Linear { - println!("rename linear node {:?}", node); + log::debug!("rename linear node {:?}", node); } self.node_name_counter .entry(node.node_type.clone()) @@ -363,29 +303,28 @@ impl ONNXGraphBuilder { { self.constants_map.insert(node.outputs[0].name.clone(), i); } else if self.constants_types.contains(&node.node_type) { - println!("lift const type match for node {:?}\n\n", &node); + log::debug!("lift const type match for node {:?}\n\n", &node); for input in node.inputs.iter_mut().skip(1) { - println!("checking input {:?} for const", input); - println!("constants map {:?}", &self.constants_map); + log::debug!("checking input {:?} for const", input); + log::debug!("constants map {:?}", &self.constants_map); if let Some(const_idx) = self.constants_map.get(&input.name) { - println!("\nMATCH\n"); + log::debug!("\nMATCH\n"); let constant = &self.nodes[*const_idx]; - println!("constant node {:?}", constant); + log::debug!("constant node {:?}", constant); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs input.value = constant.inputs[0].value.clone(); input.ty = constant.inputs[0].ty.clone(); - graph_io.update_value(input); + //graph_io.update_value(input); } else { let arg = convert_constant_value(constant); input.value = arg.value; input.ty = arg.ty; - graph_io.update_value(input); + //graph_io.update_value(input); } self.nodes_to_remove.insert(*const_idx); - println! {"\nupdated input {:?}", input}; + log::debug! {"\nupdated input {:?}", input}; } - //TODO: Future me, right now if the constant is written to the } } } @@ -400,7 +339,7 @@ impl ONNXGraphBuilder { fn handle_identity(&mut self, node: &mut Node, i: usize) { if &node.node_type == &NodeType::Identity && node.inputs[0].value.is_none() { - println!("\nfound identity node:\n{:?}\n", &node); + log::debug!("\nfound identity node:\n{:?}\n", &node); self.identity_idx.insert(node.outputs[0].name.clone(), i); self.nodes_to_remove.insert(i); } else { @@ -473,34 +412,6 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { } } -pub(crate) fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { - // If the input name matches the tensor name in the initializer - // Convert the initializer to a tensor - let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); - - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - input.value = Some(data.into_scalar()); - } else { - input.value = None; - } - - // Update the input type - input.ty = ArgType::Scalar(tensor.elem_type); - } else { - // Move the tensor data to the input value - input.value = tensor.data.clone(); - - // Update the input type - input.ty = ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, - }); - } -} - fn move_output_shape(node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { @@ -521,9 +432,9 @@ fn move_output_shape(node: &mut Node, out_arg: &Argument) { /// Rename the inputs and output in the graph and return a map of /// the old names to the new names. fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { - println!("checking inputs for node {:?}", &node.name); + log::debug!("checking inputs for node {:?}", &node.name); for node_input in node.inputs.iter_mut() { - println!("old output names {:?}", &graph_io.old_io_names); + log::debug!("old output names {:?}", &graph_io.old_io_names); //graph_io.add_input(&node_input.name, i); if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; @@ -533,24 +444,21 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { node_input.passed = false; } } - println!("\n\nchecking outputs"); + log::debug!("\n\nchecking outputs"); let mut out_count = 1; if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { - println!("it's a constant"); + log::debug!("it's a constant"); let new_name = format!("{}_out{}", node.name, out_count); graph_io.insert(&node.outputs[0], &new_name); node.outputs[0].name = new_name; } else { for output in node.outputs.iter_mut() { - println!("output name: {}", &output.name); + log::debug!("output name: {}", &output.name); let new_name = format!("{}_out{}", node.name, out_count); graph_io.update_name(output, &new_name); - // self.node_output_names - // .insert(output.name.clone(), new_name.clone()); - output.name = new_name.clone(); out_count += 1; } diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 1b5ff22755..93e1bea438 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -27,10 +27,9 @@ pub struct Argument { impl Argument { ///Copy everything except the name from the other argument - pub fn copy_all_but_name(&mut self, other_arg: &Argument) { + pub fn copy_value(&mut self, other_arg: &Argument) { self.ty = other_arg.ty.clone(); self.value = other_arg.value.clone(); - self.passed = other_arg.passed; } pub fn from_initializer(initializer: &TensorProto) -> Argument { From 1f9c05120003a5442e3367552e0cb2b0aa0ad89f Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Sun, 18 Feb 2024 16:36:50 -0600 Subject: [PATCH 11/21] changed how io use is tracked, moved unsqueeze remapping out of dim inference --- burn-import/src/onnx/from_onnx.rs | 130 +++++++++++++++++------------- 1 file changed, 72 insertions(+), 58 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 01e8373406..aec3940035 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -4,18 +4,16 @@ use std::{ path::Path, }; -use crate::onnx::{ - ir::TensorType, node_remap::remap_node_type, proto_conversion::convert_node_proto, -}; +use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto}; use super::{ coalesce::coalesce, - ir::OnnxGraph, + ir::{Data, OnnxGraph, TensorType}, protos::{ModelProto, TensorProto, ValueInfoProto}, }; use super::dim_inference::dim_inference; -use super::ir::{ArgType, Argument, Node, NodeType, Tensor}; +use super::ir::{ArgType, Argument, Node, NodeType}; use protobuf::Message; @@ -175,6 +173,8 @@ impl OnnxGraphIO { Some(IOEntry::Out(i)) => { let arg = self.outputs.get_mut(*i).unwrap(); arg.copy_value(node_output); + //Set the output to passed since it's been altered by a Node + arg.passed = true; } Some(IOEntry::Node(_)) => { panic!("This output is from another node"); @@ -199,10 +199,23 @@ impl OnnxGraphIO { } } - fn get_new_name(&self, old_name: &str) -> Option { + /// get the updated name of a Node Input, which obviously should be + /// either a graph input or a node output. + /// will return None if the it isn't a graph input or node output(like an initializer) + /// Will panic if it's a graph output + fn get_new_name(&mut self, old_name: &str) -> Option { match self.old_io_names.get(old_name) { - Some(IOEntry::In(i)) => Some(self.inputs[*i].name.clone()), - Some(IOEntry::Out(i)) => Some(self.outputs[*i].name.clone()), + Some(IOEntry::In(i)) => { + //set the input as passed since a node is referencing it + self.inputs[*i].passed = true; + Some(self.inputs[*i].name.clone()) + } + Some(IOEntry::Out(_)) => { + panic!( + "you just tried to get an updated name on a graph output: {}", + old_name + ) + } Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()), None => None, } @@ -221,8 +234,8 @@ pub(crate) struct ONNXGraphBuilder { constants_map: HashMap, constants_types: HashSet, - ///map from old node name to indices of identity nodes - identity_idx: HashMap, + //map from old node name to indices of identity nodes + //identity_idx: HashMap, } impl ONNXGraphBuilder { @@ -246,9 +259,7 @@ impl ONNXGraphBuilder { coalesce(&mut node, &mut node_iter, &graph_io); self.handle_node_renaming(&mut node); - self.handle_unsqueeze(&mut node, &graph_io); - self.handle_identity(&mut node, and_idx); self.check_constants(&mut node, and_idx, &mut graph_io); @@ -256,7 +267,6 @@ impl ONNXGraphBuilder { dim_inference(&mut node, &mut graph_io); } - //self.handle_coalesce(&mut node, &mut node_iter, and_idx); rename_io(&mut node, &mut graph_io); self.nodes.push(node); @@ -272,10 +282,10 @@ impl ONNXGraphBuilder { let OnnxGraphIO { mut inputs, mut outputs, - old_io_names, .. } = graph_io; - + // Remove the graph inputs/output that are not used by any node + remove_unused_graph_inputs(&mut inputs, &mut outputs); //remove_unused_graph_inputs(&mut inputs, &mut outputs, &old_io_names); self.inputs = inputs; self.outputs = outputs; @@ -297,7 +307,7 @@ impl ONNXGraphBuilder { node.name = new_name.clone(); } - fn check_constants(&mut self, node: &mut Node, i: usize, graph_io: &mut OnnxGraphIO) { + fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { if &node.node_type == &NodeType::Constant || (&node.node_type == &NodeType::Identity && node.inputs[0].value.is_some()) { @@ -331,26 +341,29 @@ impl ONNXGraphBuilder { fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { if node.node_type == NodeType::Unsqueeze { - if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { - move_output_shape(node, in_arg); + if node.inputs[1].value.is_none() { + if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { + remap_unsqueeze_to_reshape(node, in_arg); + } } } } fn handle_identity(&mut self, node: &mut Node, i: usize) { - if &node.node_type == &NodeType::Identity && node.inputs[0].value.is_none() { + if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() { log::debug!("\nfound identity node:\n{:?}\n", &node); - self.identity_idx.insert(node.outputs[0].name.clone(), i); + //self.identity_idx.insert(node.outputs[0].name.clone(), i); self.nodes_to_remove.insert(i); - } else { - node.inputs.iter_mut().for_each(|x| { - if let Some(identity_idx) = self.identity_idx.get(&x.name) { - let input_name = &self.nodes[*identity_idx].inputs[0].name; - - x.name = input_name.clone(); - } - }); - } + //apparently the below is no longer necessary + } //else { + // node.inputs.iter_mut().for_each(|x| { + // if let Some(identity_idx) = self.identity_idx.get(&x.name) { + // let input_name = &self.nodes[*identity_idx].inputs[0].name; + + // x.name = input_name.clone(); + // } + // }); + // } } } @@ -400,9 +413,6 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inner_inputs, &mut inner_outputs, &nodes); - log::info!("Finished parsing ONNX file: {}", onnx_path.display()); OnnxGraph { @@ -412,11 +422,37 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { } } -fn move_output_shape(node: &mut Node, out_arg: &Argument) { +/// Remap the unsqueeze node to a reshape node, Should only be called after +/// node renaming has been done. avoids marking rhs as passed so that it can be +/// properly deleted if nothing else uses it +fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { match node.outputs[0].ty { ArgType::Tensor(ref mut tensor_type) => { if let ArgType::Tensor(arg_tensor) = &out_arg.ty { tensor_type.shape = arg_tensor.shape.clone(); + let inner = arg_tensor + .shape + .clone() + .unwrap() + .into_iter() + .map(|x| x as i64) + .collect::>(); + let shape_len = inner.len(); + let new_rhs_value = Some(Data::Int64s(inner)); + //moving the remap to here + let rhs_arg = Argument { + name: format!("{}_generated_const", node.name), + ty: ArgType::Tensor(TensorType { + elem_type: super::ir::ElementType::Int64, + dim: 1, + shape: Some(vec![shape_len]), + }), + value: new_rhs_value, + passed: false, + }; + node.inputs[1] = rhs_arg; + node.outputs[0] = out_arg.clone(); + node.node_type = NodeType::Reshape; } } _ => {} @@ -474,34 +510,12 @@ fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { /// /// Generally, it's a good idea to remove unused inputs/outputs because it makes the /// generated code cleaner and easier to read. -fn remove_unused_graph_inputs( - inputs: &mut Vec, - outputs: &mut Vec, - nodes: &Vec, -) { +fn remove_unused_graph_inputs(inputs: &mut Vec, outputs: &mut Vec) { // Remove inputs that are not used by any node - inputs.retain(|input| { - for node in nodes.iter() { - if node - .inputs - .iter() - .any(|x| x.name == input.name && x.value.is_none()) - { - return true; - } - } - false - }); + inputs.retain(|input| input.passed); // Remove outputs that are not used by any node - outputs.retain(|output| { - for node in nodes.iter() { - if node.outputs.iter().any(|x| x.name == output.name) { - return true; - } - } - false - }); + outputs.retain(|output| output.passed); } // Define a trait for topological sorting From 82bdf19980dcf4cb6ea5d044e37a87ba370908ed Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Mon, 19 Feb 2024 11:28:48 -0600 Subject: [PATCH 12/21] `cargo xtask run-checks all` now passes --- burn-import/src/onnx/coalesce.rs | 15 +---- burn-import/src/onnx/dim_inference.rs | 53 ++++------------ burn-import/src/onnx/from_onnx.rs | 77 ++++++++++++++---------- burn-import/src/onnx/proto_conversion.rs | 2 +- 4 files changed, 58 insertions(+), 89 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 7255ba5ed7..106002b8fe 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -1,4 +1,4 @@ -use std::{cell::RefMut, iter::Peekable, slice::Iter}; +use std::{iter::Peekable, slice::Iter}; use super::{ from_onnx::OnnxGraphIO, @@ -184,16 +184,3 @@ fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { current_node.inputs.push(bias_input); current_node.outputs[0].name = bias_node.outputs[0].name.clone(); } - -/// Helper function to convert and remove the Add node -pub(crate) fn convert_node2(bias_node: &Node, mut current_node: RefMut<'_, Node>) { - let bias_input = if bias_node.inputs[0].value.is_some() { - bias_node.inputs[0].clone() - } else { - bias_node.inputs[1].clone() - }; - - // Push the bias input and update the output name - current_node.inputs.push(bias_input); - current_node.outputs[0].name = bias_node.outputs[0].name.clone(); -} diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index 4893ad5112..f864d762a5 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -4,7 +4,7 @@ use protobuf::Enum; use super::{ from_onnx::OnnxGraphIO, - ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, + ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, op_configuration::flatten_config, protos::tensor_proto::DataType, }; @@ -51,7 +51,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { NodeType::Sub => same_as_input(node), NodeType::Tanh => same_as_input(node), NodeType::Transpose => same_as_input(node), - NodeType::Unsqueeze => unsqueeze_update_output_or_node(node), + NodeType::Unsqueeze => unsqueeze_update_output(node), NodeType::Pow => same_as_input(node), // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. _ => temporary_pass_through_stub(node), @@ -235,9 +235,9 @@ fn mean_update_outputs(node: &mut Node) { } //fn __unsqueeze_shape -/// Either it Infers the shape of the output of an Unsqueeze node, or it remaps the node to a reshape if the output is static -/// providing an arg and inferring the dimensions at runtime isn't currently supported. -fn unsqueeze_update_output_or_node(node: &mut Node) { +/// Infers the shape of the output from the input and axes +/// Right now, this should only be called if the rhs is a constant +fn unsqueeze_update_output(node: &mut Node) { if node.inputs.len() != 2 { panic!("Unsqueeze: wrong number of inputs"); } @@ -256,14 +256,13 @@ fn unsqueeze_update_output_or_node(node: &mut Node) { _ => panic!("Unsqueeze: invalid input types"), }; //need output way up here to avoid borrowing issues - let (mut tensor, output_shape) = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => (tensor.clone(), tensor.shape.clone()), + let mut tensor = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), _ => panic!("Unsqueeze: invalid output types"), }; - let mut remap_node = false; - match (&axes, tensor.shape) { + match &axes { //case 1: axes is constant -> output shape is input shape with 1s inserted at the axes - (Some(dim_indices), _) => { + Some(dim_indices) => { let output_rank = (dim_indices.len() + input.dim) as i64; let mut dim_indices = dim_indices .to_vec() @@ -316,42 +315,12 @@ fn unsqueeze_update_output_or_node(node: &mut Node) { tensor.shape = Some(new_dims); node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } - //case 2: output shape isn't dynamic -> map the node to a reshape - (None, Some(_)) => { - remap_node = true; - } + //case 3: output shape is dynamic -> black magic or unsupported - (None, None) => { + None => { panic!("Unsqueeze: dynamic output shape is not currently supported"); } } - //need to move out of the match to avoid borrowing issues - if remap_node { - let mut new_node = node.clone(); - let rhs_name = node.inputs[1].name.clone(); - new_node.node_type = NodeType::Reshape; - let rhs_arg = Argument { - name: rhs_name, //need name to remain the same - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - dim: 1, - shape: Some(vec![output_shape.clone().unwrap().len()]), - }), - value: Some(Data::Int64s( - output_shape - .unwrap() - .into_iter() - .map(|ax_len| ax_len as i64) - .collect::>(), - )), - - passed: false, - }; - new_node.inputs = vec![node.inputs[0].clone(), rhs_arg]; - new_node.outputs = vec![node.outputs[0].clone()]; - reshape_update_outputs(&mut new_node); - *node = new_node; - } } fn same_as_input(node: &mut Node) { diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index aec3940035..bcd506f1c6 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -55,6 +55,11 @@ impl OnnxGraphIO { ) -> Self { let mut old_io_names = HashMap::new(); let mut in_count = 1; + let constants = initializers + .iter() + .map(|x| (x.name.clone(), Argument::from_initializer(x))) + .collect::>(); + let inputs = inputs .iter() .enumerate() @@ -62,6 +67,12 @@ impl OnnxGraphIO { let in_name = format!("input{}", in_count); old_io_names.insert(x.name.clone(), IOEntry::In(i)); let mut arg = Argument::try_from(x.clone()).unwrap(); + if let Some(initial_arg) = constants.get(&x.name) { + if arg.value.is_none() { + arg.copy_value(initial_arg); + } + } + in_count += 1; arg.name = in_name; arg @@ -116,29 +127,30 @@ impl OnnxGraphIO { ///Used to initialize the input arguments for nodes. Names need to remain the same because /// currently the old names are the key for accessing the Argument - pub fn init_in(&self, proto_str: &str) -> Argument { - match self.old_io_names.get(proto_str) { + pub fn init_in(&self, proto_str: String) -> Argument { + match self.old_io_names.get(&proto_str) { None => { - if let Some(init_arg) = self.initializers.get(proto_str) { + if let Some(init_arg) = self.initializers.get(&proto_str) { init_arg.clone() } else { - Argument::new(proto_str.to_string()) + Argument::new(proto_str) } } Some(IOEntry::In(i)) => { let mut arg = self.inputs[*i].clone(); - arg.name = proto_str.to_string(); + + arg.name = proto_str; arg.passed = true; arg } Some(IOEntry::Node(i)) => { let mut arg = self.node_out[*i].clone(); - arg.name = proto_str.to_string(); + arg.name = proto_str; arg } Some(IOEntry::Out(_)) => { - panic!("graph out {} can't be an input", &proto_str) + panic!("graph output {} can't be a Node input", &proto_str) } } } @@ -206,9 +218,13 @@ impl OnnxGraphIO { fn get_new_name(&mut self, old_name: &str) -> Option { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => { - //set the input as passed since a node is referencing it - self.inputs[*i].passed = true; - Some(self.inputs[*i].name.clone()) + if self.initializers.contains_key(old_name) { + None + } else { + //set the input as passed since a node is referencing it + self.inputs[*i].passed = true; + Some(self.inputs[*i].name.clone()) + } } Some(IOEntry::Out(_)) => { panic!( @@ -259,9 +275,9 @@ impl ONNXGraphBuilder { coalesce(&mut node, &mut node_iter, &graph_io); self.handle_node_renaming(&mut node); - self.handle_unsqueeze(&mut node, &graph_io); self.handle_identity(&mut node, and_idx); self.check_constants(&mut node, and_idx, &mut graph_io); + self.handle_unsqueeze(&mut node, &graph_io); if node.node_type != NodeType::Identity { dim_inference(&mut node, &mut graph_io); @@ -284,17 +300,15 @@ impl ONNXGraphBuilder { mut outputs, .. } = graph_io; + // Remove the graph inputs/output that are not used by any node remove_unused_graph_inputs(&mut inputs, &mut outputs); - //remove_unused_graph_inputs(&mut inputs, &mut outputs, &old_io_names); self.inputs = inputs; self.outputs = outputs; } fn handle_node_renaming(&mut self, node: &mut Node) { - if &node.node_type == &NodeType::Linear { - log::debug!("rename linear node {:?}", node); - } + log::debug!("renaming node {:?}", &node.name); self.node_name_counter .entry(node.node_type.clone()) .and_modify(|e| *e += 1) @@ -308,43 +322,43 @@ impl ONNXGraphBuilder { } fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { - if &node.node_type == &NodeType::Constant - || (&node.node_type == &NodeType::Identity && node.inputs[0].value.is_some()) + if node.node_type == NodeType::Constant + || (node.node_type == NodeType::Identity && node.inputs[0].value.is_some()) { self.constants_map.insert(node.outputs[0].name.clone(), i); } else if self.constants_types.contains(&node.node_type) { - log::debug!("lift const type match for node {:?}\n\n", &node); + log::debug!("checking node {} for constants", &node.name); for input in node.inputs.iter_mut().skip(1) { log::debug!("checking input {:?} for const", input); - log::debug!("constants map {:?}", &self.constants_map); if let Some(const_idx) = self.constants_map.get(&input.name) { - log::debug!("\nMATCH\n"); let constant = &self.nodes[*const_idx]; - log::debug!("constant node {:?}", constant); + log::debug!( + "input {} matched constant node {}", + &input.name, + &constant.name + ); if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { // The value comes from Identity inputs input.value = constant.inputs[0].value.clone(); input.ty = constant.inputs[0].ty.clone(); - //graph_io.update_value(input); } else { let arg = convert_constant_value(constant); input.value = arg.value; input.ty = arg.ty; - //graph_io.update_value(input); } self.nodes_to_remove.insert(*const_idx); - log::debug! {"\nupdated input {:?}", input}; } } } } + ///check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape + /// Needs to be called after node renaming to ensure that the rhs name is correct + /// Needs to be called after constant lifting to ensure that the rhs value exists fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { - if node.node_type == NodeType::Unsqueeze { - if node.inputs[1].value.is_none() { - if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { - remap_unsqueeze_to_reshape(node, in_arg); - } + if node.node_type == NodeType::Unsqueeze && node.inputs[1].value.is_none() { + if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { + remap_unsqueeze_to_reshape(node, in_arg); } } } @@ -404,8 +418,8 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { let ONNXGraphBuilder { nodes, - inputs: mut inner_inputs, - outputs: mut inner_outputs, + inputs: inner_inputs, + outputs: inner_outputs, .. } = builder; @@ -470,7 +484,6 @@ fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { log::debug!("checking inputs for node {:?}", &node.name); for node_input in node.inputs.iter_mut() { - log::debug!("old output names {:?}", &graph_io.old_io_names); //graph_io.add_input(&node_input.name, i); if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; diff --git a/burn-import/src/onnx/proto_conversion.rs b/burn-import/src/onnx/proto_conversion.rs index e4c324e245..2f1481281f 100644 --- a/burn-import/src/onnx/proto_conversion.rs +++ b/burn-import/src/onnx/proto_conversion.rs @@ -189,7 +189,7 @@ pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node { .input .clone() .into_iter() - .map(|x| graph_io.init_in(&x)) + .map(|x| graph_io.init_in(x)) .collect(); let outputs = node.output.clone().into_iter().map(Argument::new).collect(); From c8ed7debd9387b44b5cf4c563569f8cd9130a8d5 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Mon, 19 Feb 2024 13:17:46 -0600 Subject: [PATCH 13/21] added a fixme and a few doc strings --- burn-import/src/onnx/from_onnx.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index bcd506f1c6..6e72be00e9 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -48,6 +48,7 @@ pub(crate) struct OnnxGraphIO { } impl OnnxGraphIO { + pub(crate) fn new( inputs: &Vec, outputs: &Vec, @@ -202,7 +203,9 @@ impl OnnxGraphIO { } } - pub(crate) fn get(&self, old_name: &str) -> Option<&Argument> { + ///used by handle unsqeeze to remap the output of a node to a new name + ///expected match if it exists is either a graph input or graph output + pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => self.inputs.get(*i), Some(IOEntry::Out(i)) => self.outputs.get(*i), @@ -218,6 +221,11 @@ impl OnnxGraphIO { fn get_new_name(&mut self, old_name: &str) -> Option { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => { + //FIXME: technically in the spec, initializers are default values + //for optional inputs, but implementing that would require reworking + //the way the graph is built, and it's not clear burn users are using initializers + //in that way + // see https://github.com/onnx/onnx/issues/2660 if self.initializers.contains_key(old_name) { None } else { @@ -357,7 +365,7 @@ impl ONNXGraphBuilder { /// Needs to be called after constant lifting to ensure that the rhs value exists fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { if node.node_type == NodeType::Unsqueeze && node.inputs[1].value.is_none() { - if let Some(in_arg) = graph_io.get(&node.outputs[0].name) { + if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) { remap_unsqueeze_to_reshape(node, in_arg); } } From 8d495fcef99356cc9fd3e3312c7d073505db7842 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Tue, 20 Feb 2024 13:13:40 -0600 Subject: [PATCH 14/21] removing println and dead code --- burn-import/src/onnx/coalesce.rs | 17 +++-------- burn-import/src/onnx/dim_inference.rs | 2 -- burn-import/src/onnx/from_onnx.rs | 42 +++++++++++++-------------- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/burn-import/src/onnx/coalesce.rs b/burn-import/src/onnx/coalesce.rs index 106002b8fe..ddf102ca5e 100644 --- a/burn-import/src/onnx/coalesce.rs +++ b/burn-import/src/onnx/coalesce.rs @@ -144,32 +144,23 @@ pub(crate) fn convert_matmul_to_linear( // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { let peek_node = convert_node_proto(peek_node, graph_io).clone(); - println!("next node is {:?}", peek_node); if is_add_node_with_bias(&peek_node, node) { convert_and_remove_add_node(&peek_node, node); // You don't have to remove it if it's never stored in the first place let _ = iter_mut.next(); - println!("\n\nskipping add node\n\n"); } } } /// Helper function to check if the peeked node is an Add node with bias fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { - if peek_node.node_type == NodeType::Add && peek_node.inputs.len() == 2 { - println!("\n\ntwo matches"); - println!("peek_node.inputs[0].name: {:?}", peek_node.inputs[0].name); - println!( - "current_node.outputs[0].name: {:?}", - current_node.outputs[0].name - ); - return (peek_node.inputs[0].name == current_node.outputs[0].name + peek_node.node_type == NodeType::Add + && peek_node.inputs.len() == 2 + && ((peek_node.inputs[0].name == current_node.outputs[0].name && peek_node.inputs[1].value.is_some()) || (peek_node.inputs[1].name == current_node.outputs[0].name - && peek_node.inputs[0].value.is_some()); - } - false + && peek_node.inputs[0].value.is_some())) } /// Helper function to convert and remove the Add node diff --git a/burn-import/src/onnx/dim_inference.rs b/burn-import/src/onnx/dim_inference.rs index f864d762a5..4c07002f83 100644 --- a/burn-import/src/onnx/dim_inference.rs +++ b/burn-import/src/onnx/dim_inference.rs @@ -11,8 +11,6 @@ use super::{ /// Infer the dimension of each output tensor and update them. pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { - //graph_io.copy_to_node_inputs(node); - match node.node_type { NodeType::Add => same_as_input(node), NodeType::AveragePool2d => same_as_input(node), diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 6e72be00e9..757e44cb96 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -41,14 +41,12 @@ pub(crate) struct OnnxGraphIO { pub(crate) outputs: Vec, /// Initializers pub(crate) initializers: HashMap, - //pub(crate) initializers: Vec, ///updated names of outputs of node not stored in the graph node_out: Vec, pub(crate) old_io_names: HashMap, } impl OnnxGraphIO { - pub(crate) fn new( inputs: &Vec, outputs: &Vec, @@ -251,15 +249,15 @@ pub(crate) struct ONNXGraphBuilder { nodes: Vec, inputs: Vec, outputs: Vec, - + /// Counter for node names, used for renaming nodes node_name_counter: HashMap, - //nodes to remove + /// Nodes to remove nodes_to_remove: HashSet, + /// Map from constant node output names to indices of constant nodes constants_map: HashMap, - constants_types: HashSet, - //map from old node name to indices of identity nodes - //identity_idx: HashMap, + /// Map from identity node output names to indices of identity nodes + identity_idx: HashMap, } impl ONNXGraphBuilder { @@ -287,9 +285,7 @@ impl ONNXGraphBuilder { self.check_constants(&mut node, and_idx, &mut graph_io); self.handle_unsqueeze(&mut node, &graph_io); - if node.node_type != NodeType::Identity { - dim_inference(&mut node, &mut graph_io); - } + dim_inference(&mut node, &mut graph_io); rename_io(&mut node, &mut graph_io); @@ -360,7 +356,7 @@ impl ONNXGraphBuilder { } } - ///check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape + /// Check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape /// Needs to be called after node renaming to ensure that the rhs name is correct /// Needs to be called after constant lifting to ensure that the rhs value exists fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { @@ -374,18 +370,20 @@ impl ONNXGraphBuilder { fn handle_identity(&mut self, node: &mut Node, i: usize) { if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() { log::debug!("\nfound identity node:\n{:?}\n", &node); - //self.identity_idx.insert(node.outputs[0].name.clone(), i); + //map the output name to check for pass through values + self.identity_idx.insert(node.outputs[0].name.clone(), i); self.nodes_to_remove.insert(i); - //apparently the below is no longer necessary - } //else { - // node.inputs.iter_mut().for_each(|x| { - // if let Some(identity_idx) = self.identity_idx.get(&x.name) { - // let input_name = &self.nodes[*identity_idx].inputs[0].name; - - // x.name = input_name.clone(); - // } - // }); - // } + } else { + //NOTE: it might be possible to rework the API to handle all "per input" operations + //in a new function that operates on each input. + node.inputs.iter_mut().for_each(|x| { + if let Some(identity_idx) = self.identity_idx.get(&x.name) { + let input_name = &self.nodes[*identity_idx].inputs[0].name; + + x.name = input_name.clone(); + } + }); + } } } From b7fb819101999b876c309d15aadf513e6f07daee Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 13:51:23 -0600 Subject: [PATCH 15/21] spaces in doc strings --- burn-import/src/onnx/from_onnx.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 757e44cb96..bdf10fdfa4 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -124,7 +124,7 @@ impl OnnxGraphIO { } } - ///Used to initialize the input arguments for nodes. Names need to remain the same because + /// Used to initialize the input arguments for nodes. Names need to remain the same because /// currently the old names are the key for accessing the Argument pub fn init_in(&self, proto_str: String) -> Argument { match self.old_io_names.get(&proto_str) { @@ -173,7 +173,7 @@ impl OnnxGraphIO { self.node_out[idx].name = new_name.to_string(); } - ///iterate over the nodes output and copy them to the graph IO + /// iterate over the nodes output and copy them to the graph IO pub(crate) fn update_tensor_output(&mut self, node: &Node) { for node_output in node.outputs.iter() { match self.old_io_names.get(&node_output.name) { @@ -201,8 +201,8 @@ impl OnnxGraphIO { } } - ///used by handle unsqeeze to remap the output of a node to a new name - ///expected match if it exists is either a graph input or graph output + /// Used by handle unsqeeze to remap the output of a node to a new name + /// expected match if it exists is either a graph input or graph output pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> { match self.old_io_names.get(old_name) { Some(IOEntry::In(i)) => self.inputs.get(*i), @@ -212,9 +212,9 @@ impl OnnxGraphIO { } } - /// get the updated name of a Node Input, which obviously should be + /// Get the updated name of a Node Input, which should be /// either a graph input or a node output. - /// will return None if the it isn't a graph input or node output(like an initializer) + /// Will return None if the it isn't a graph input or node output(like an initializer) /// Will panic if it's a graph output fn get_new_name(&mut self, old_name: &str) -> Option { match self.old_io_names.get(old_name) { @@ -431,7 +431,7 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { // ONNX nodes must be topologically sorted per spec: // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + debug_assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); log::info!("Finished parsing ONNX file: {}", onnx_path.display()); From 6fb94fad11f801d6b6d232b5f0e82d36516d21db Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 13:59:28 -0600 Subject: [PATCH 16/21] altered top sort to work on node proto, moved prior to node gen --- burn-import/src/onnx/from_onnx.rs | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index bdf10fdfa4..63c4721f3f 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -4,12 +4,12 @@ use std::{ path::Path, }; -use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto}; +use crate::onnx::{self, node_remap::remap_node_type, proto_conversion::convert_node_proto}; use super::{ coalesce::coalesce, ir::{Data, OnnxGraph, TensorType}, - protos::{ModelProto, TensorProto, ValueInfoProto}, + protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto}, }; use super::dim_inference::dim_inference; @@ -410,6 +410,12 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { let onnx_model: ModelProto = Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + debug_assert!( + onnx_model.graph.node.is_top_sorted(), + "Nodes are not topologically sorted" + ); log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); @@ -429,10 +435,6 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { .. } = builder; - // ONNX nodes must be topologically sorted per spec: - // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - debug_assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); - log::info!("Finished parsing ONNX file: {}", onnx_path.display()); OnnxGraph { @@ -542,7 +544,7 @@ trait TopologicalSortable { fn is_top_sorted(&self) -> bool; } -impl TopologicalSortable for Vec { +impl TopologicalSortable for Vec { fn is_top_sorted(&self) -> bool { // Create a hashmap to store the position of each node in the vector let position: HashMap = self @@ -554,11 +556,11 @@ impl TopologicalSortable for Vec { // Iterate over each node in the vector for node in self { // Iterate over each output of the node - for output in &node.outputs { + for output in &node.output { // Iterate over each other node in the vector for other_node in self { // If the other node has an input that matches the current output - if other_node.inputs.contains(output) { + if other_node.input.contains(output) { // If the position of the current node is greater than the position of the other node if position[&node.name] > position[&other_node.name] { // The vector is not topologically sorted From ab7c9c225eabcbce6d96a66f35c6385ac1b045d1 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 15:25:36 -0600 Subject: [PATCH 17/21] Update ir.rs --- burn-import/src/onnx/ir.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-import/src/onnx/ir.rs b/burn-import/src/onnx/ir.rs index 93e1bea438..eb229ea7ce 100644 --- a/burn-import/src/onnx/ir.rs +++ b/burn-import/src/onnx/ir.rs @@ -26,7 +26,7 @@ pub struct Argument { } impl Argument { - ///Copy everything except the name from the other argument + /// Copy everything except the name from the other argument pub fn copy_value(&mut self, other_arg: &Argument) { self.ty = other_arg.ty.clone(); self.value = other_arg.value.clone(); From f740b9e12d543809f6accad232e9f70ab103fbc7 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 15:26:28 -0600 Subject: [PATCH 18/21] Update from_onnx.rs removed dead code --- burn-import/src/onnx/from_onnx.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 63c4721f3f..75e01f9113 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -492,7 +492,6 @@ fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { log::debug!("checking inputs for node {:?}", &node.name); for node_input in node.inputs.iter_mut() { - //graph_io.add_input(&node_input.name, i); if let Some(input_name) = graph_io.get_new_name(&node_input.name) { node_input.passed = true; node_input.name = input_name.clone(); From 0f33517f163b445ca8e3ec6b9f22542ca89c89ac Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 15:31:14 -0600 Subject: [PATCH 19/21] updated doc string --- burn-import/src/onnx/from_onnx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 75e01f9113..13f14e9529 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -173,7 +173,7 @@ impl OnnxGraphIO { self.node_out[idx].name = new_name.to_string(); } - /// iterate over the nodes output and copy them to the graph IO + /// Copies node outputs to graph IO. Used at the end of dim inference. pub(crate) fn update_tensor_output(&mut self, node: &Node) { for node_output in node.outputs.iter() { match self.old_io_names.get(&node_output.name) { From be57a67afa3834ea23528c479d026ec3a8b24dae Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 15:36:03 -0600 Subject: [PATCH 20/21] camalcased Onnx Graph Builder --- burn-import/src/onnx/from_onnx.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/burn-import/src/onnx/from_onnx.rs b/burn-import/src/onnx/from_onnx.rs index 13f14e9529..6743c26303 100644 --- a/burn-import/src/onnx/from_onnx.rs +++ b/burn-import/src/onnx/from_onnx.rs @@ -245,7 +245,7 @@ impl OnnxGraphIO { } #[derive(Default)] -pub(crate) struct ONNXGraphBuilder { +pub(crate) struct OnnxGraphBuilder { nodes: Vec, inputs: Vec, outputs: Vec, @@ -260,7 +260,7 @@ pub(crate) struct ONNXGraphBuilder { identity_idx: HashMap, } -impl ONNXGraphBuilder { +impl OnnxGraphBuilder { pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); @@ -425,10 +425,10 @@ pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { ); log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); - let mut builder = ONNXGraphBuilder::default(); + let mut builder = OnnxGraphBuilder::default(); builder.node_gen(&onnx_model); - let ONNXGraphBuilder { + let OnnxGraphBuilder { nodes, inputs: inner_inputs, outputs: inner_outputs, From 3e921d6e7165302a3b7c51080400eb1324e0b780 Mon Sep 17 00:00:00 2001 From: Joshua Ferguson Date: Wed, 21 Feb 2024 16:11:49 -0600 Subject: [PATCH 21/21] removed self import? --- crates/burn-import/src/onnx/from_onnx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index 6743c26303..5c7cb841e1 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -4,7 +4,7 @@ use std::{ path::Path, }; -use crate::onnx::{self, node_remap::remap_node_type, proto_conversion::convert_node_proto}; +use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto}; use super::{ coalesce::coalesce,