diff --git a/crates/burn-import/src/burn/ty.rs b/crates/burn-import/src/burn/ty.rs index f5fa41e774..963ee81abf 100644 --- a/crates/burn-import/src/burn/ty.rs +++ b/crates/burn-import/src/burn/ty.rs @@ -72,6 +72,9 @@ impl Type { impl ScalarType { pub fn new>(name: S, kind: ScalarKind) -> Self { + if name.as_ref().is_empty() { + panic!("Scalar of Type {:?} was passed with empty name", kind); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), kind, @@ -95,6 +98,12 @@ impl TensorType { kind: TensorKind, shape: Option>, ) -> Self { + if name.as_ref().is_empty() { + panic!( + "Tensor of Kind {:?} with dim shape {:?} was passed with empty name", + kind, shape + ); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), dim, @@ -141,6 +150,12 @@ impl TensorType { impl OtherType { pub fn new>(name: S, tokens: TokenStream) -> Self { + if name.as_ref().is_empty() { + panic!( + "Other type with tokens {:?} was passed with empty name", + tokens + ); + } Self { name: Ident::new(name.as_ref(), Span::call_site()), ty: tokens, diff --git a/crates/burn-import/src/onnx/coalesce.rs b/crates/burn-import/src/onnx/coalesce.rs index 623d7584f2..ddf102ca5e 100644 --- a/crates/burn-import/src/onnx/coalesce.rs +++ b/crates/burn-import/src/onnx/coalesce.rs @@ -1,32 +1,32 @@ -use std::{iter::Peekable, slice::IterMut}; - -use super::ir::{AttributeValue, Node, NodeType}; +use std::{iter::Peekable, slice::Iter}; + +use super::{ + from_onnx::OnnxGraphIO, + ir::{AttributeValue, Node, NodeType}, + proto_conversion::convert_node_proto, + protos::NodeProto, +}; use crate::onnx::ir::{ArgType, Data, TensorType}; /// The function transforms the graph into a new one where the nodes are coalesced into a single node. -pub fn coalesce(nodes: &mut Vec) { - let mut iter_mut = nodes.iter_mut().peekable(); - let mut nodes_to_remove: Vec = vec![]; - while let Some(node) = iter_mut.next() { - match node.node_type { - NodeType::Gemm => convert_gemm_to_linear(node), - NodeType::MatMul => { - convert_matmul_to_linear(node, &mut iter_mut, &mut nodes_to_remove); - } - _ => {} +pub fn coalesce( + node: &mut Node, + nodes_iter: &mut Peekable>, + graph_io: &OnnxGraphIO, +) { + match node.node_type { + NodeType::Gemm => convert_gemm_to_linear(node), + NodeType::MatMul => { + convert_matmul_to_linear(node, nodes_iter, graph_io); } - } - - // Remove nodes instructed by conversation functions - for node_to_remove in nodes_to_remove { - nodes.retain(|n| n.name != node_to_remove); + _ => {} } } /// This function converts a Gemm node into a Linear node /// /// PyTorch and other frameworks use Gemm node to represent Linear layer. -fn convert_gemm_to_linear(node: &mut Node) { +pub(crate) fn convert_gemm_to_linear(node: &mut Node) { if node.outputs.len() != 1 { panic!("Gemm node must have 1 output"); } @@ -117,10 +117,10 @@ fn transpose_flattened(matrix: Vec, rows: usize, cols: usize) -> Vec /// /// This function also converts the following Add node into a Linear node if possible. /// Add node is used to represent bias in PyTorch. -fn convert_matmul_to_linear( +pub(crate) fn convert_matmul_to_linear( node: &mut Node, - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, + iter_mut: &mut Peekable>, + graph_io: &OnnxGraphIO, ) { if node.inputs.len() != 2 { panic!("MatMul node must have 2 inputs"); @@ -143,8 +143,12 @@ fn convert_matmul_to_linear( // Check the next node for potential conversion if let Some(peek_node) = iter_mut.peek() { - if is_add_node_with_bias(peek_node, node) { - convert_and_remove_add_node(iter_mut, nodes_to_remove, node); + let peek_node = convert_node_proto(peek_node, graph_io).clone(); + if is_add_node_with_bias(&peek_node, node) { + convert_and_remove_add_node(&peek_node, node); + + // You don't have to remove it if it's never stored in the first place + let _ = iter_mut.next(); } } } @@ -160,13 +164,7 @@ fn is_add_node_with_bias(peek_node: &Node, current_node: &Node) -> bool { } /// Helper function to convert and remove the Add node -fn convert_and_remove_add_node( - iter_mut: &mut Peekable>, - nodes_to_remove: &mut Vec, - current_node: &mut Node, -) { - let bias_node = iter_mut.next().unwrap(); - +fn convert_and_remove_add_node(bias_node: &Node, current_node: &mut Node) { let bias_input = if bias_node.inputs[0].value.is_some() { bias_node.inputs[0].clone() } else { @@ -176,7 +174,4 @@ fn convert_and_remove_add_node( // Push the bias input and update the output name current_node.inputs.push(bias_input); current_node.outputs[0].name = bias_node.outputs[0].name.clone(); - - // Remove the Add node - nodes_to_remove.push(bias_node.name.clone()); } diff --git a/crates/burn-import/src/onnx/dim_inference.rs b/crates/burn-import/src/onnx/dim_inference.rs index f7d7b72663..4c07002f83 100644 --- a/crates/burn-import/src/onnx/dim_inference.rs +++ b/crates/burn-import/src/onnx/dim_inference.rs @@ -1,115 +1,61 @@ use core::panic; -use std::collections::HashMap; use protobuf::Enum; use super::{ - ir::{ArgType, Argument, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, + from_onnx::OnnxGraphIO, + ir::{ArgType, AttributeValue, Data, ElementType, Node, NodeType, TensorType}, op_configuration::flatten_config, protos::tensor_proto::DataType, }; -struct TensorDimUpdater { - arguments: HashMap, -} - -impl TensorDimUpdater { - fn new(inputs: &[Argument]) -> Self { - let mut arguments: HashMap = HashMap::with_capacity(inputs.len()); - - inputs.iter().for_each(|input| { - arguments.insert(input.name.clone(), input.clone()); - }); - - Self { arguments } - } - /// Update tensor inputs from the registered arguments and returns the number of input - /// updated. - fn update_tensor_inputs(&self, node: &mut Node) -> usize { - self.update_arguments(&mut node.inputs) - } - - /// Update the arguments struct from the node output tensors and return the number of output - /// updated. - fn update_tensor_outputs(&mut self, node: &Node) -> usize { - node.outputs - .iter() - .map(|arg| { - self.arguments.insert(arg.name.clone(), arg.clone()); - }) - .count() - } - - fn update_arguments(&self, arguments: &mut [Argument]) -> usize { - arguments - .iter_mut() - .filter_map(|input| self.arguments.get(&input.name).map(|arg| (arg, input))) - .map(|(arg, input)| { - input.ty = arg.ty.clone(); - }) - .count() - } -} - /// Infer the dimension of each output tensor and update them. -pub fn dim_inference( - nodes: &mut Vec, - graph_inputs: &Vec, - graph_outputs: &mut Vec, -) { - let mut updater = TensorDimUpdater::new(graph_inputs); - - for node in nodes.iter_mut() { - updater.update_tensor_inputs(node); - - match node.node_type { - NodeType::Add => same_as_input(node), - NodeType::AveragePool2d => same_as_input(node), - NodeType::BatchNormalization => same_as_input(node), - NodeType::Cast => cast_update_outputs(node), - NodeType::Clip => same_as_input(node), - NodeType::Concat => concat_update_outputs(node), - NodeType::Constant => constant_update_outputs(node), - NodeType::Conv1d => conv1d_update_outputs(node), - NodeType::Conv2d => conv2d_update_outputs(node), - NodeType::Cos => same_as_input(node), - NodeType::Div => same_as_input(node), - NodeType::Dropout => same_as_input(node), - NodeType::Equal => equal_update_outputs(node), - NodeType::Erf => same_as_input(node), - NodeType::Exp => same_as_input(node), - NodeType::Flatten => flatten_update_outputs(node), - NodeType::Gelu => same_as_input(node), - NodeType::GatherElements => same_as_input(node), - NodeType::GlobalAveragePool => same_as_input(node), - NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), - NodeType::Linear => linear_update_outputs(node), - NodeType::Log => same_as_input(node), - NodeType::LogSoftmax => same_as_input(node), - NodeType::MaxPool2d => same_as_input(node), - NodeType::Mul => same_as_input(node), - NodeType::Neg => same_as_input(node), - NodeType::Reciprocal => same_as_input(node), - NodeType::ReduceMean => mean_update_outputs(node), - NodeType::Relu => same_as_input(node), - NodeType::Reshape => reshape_update_outputs(node), - NodeType::Shape => shape_update_outputs(node), - NodeType::Sigmoid => same_as_input(node), - NodeType::Softmax => same_as_input(node), - NodeType::Sqrt => same_as_input(node), - NodeType::Sub => same_as_input(node), - NodeType::Tanh => same_as_input(node), - NodeType::Transpose => same_as_input(node), - NodeType::Unsqueeze => unsqueeze_update_output_or_node(node), - NodeType::Pow => same_as_input(node), - // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. - _ => temporary_pass_through_stub(node), - } - - updater.update_tensor_outputs(node); +pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) { + match node.node_type { + NodeType::Add => same_as_input(node), + NodeType::AveragePool2d => same_as_input(node), + NodeType::BatchNormalization => same_as_input(node), + NodeType::Cast => cast_update_outputs(node), + NodeType::Clip => same_as_input(node), + NodeType::Concat => concat_update_outputs(node), + NodeType::Constant => constant_update_outputs(node), + NodeType::Conv1d => conv1d_update_outputs(node), + NodeType::Conv2d => conv2d_update_outputs(node), + NodeType::Cos => same_as_input(node), + NodeType::Div => same_as_input(node), + NodeType::Dropout => same_as_input(node), + NodeType::Equal => equal_update_outputs(node), + NodeType::Erf => same_as_input(node), + NodeType::Exp => same_as_input(node), + NodeType::Flatten => flatten_update_outputs(node), + NodeType::Gelu => same_as_input(node), + NodeType::GatherElements => same_as_input(node), + NodeType::GlobalAveragePool => same_as_input(node), + NodeType::ConvTranspose2d => conv_transpose2d_update_outputs(node), + NodeType::Linear => linear_update_outputs(node), + NodeType::Log => same_as_input(node), + NodeType::LogSoftmax => same_as_input(node), + NodeType::MaxPool2d => same_as_input(node), + NodeType::Mul => same_as_input(node), + NodeType::Neg => same_as_input(node), + NodeType::Reciprocal => same_as_input(node), + NodeType::ReduceMean => mean_update_outputs(node), + NodeType::Relu => same_as_input(node), + NodeType::Reshape => reshape_update_outputs(node), + NodeType::Shape => shape_update_outputs(node), + NodeType::Sigmoid => same_as_input(node), + NodeType::Softmax => same_as_input(node), + NodeType::Sqrt => same_as_input(node), + NodeType::Sub => same_as_input(node), + NodeType::Tanh => same_as_input(node), + NodeType::Transpose => same_as_input(node), + NodeType::Unsqueeze => unsqueeze_update_output(node), + NodeType::Pow => same_as_input(node), + // Intentionally letting outputs leave unchanged but issue a warning so IR file can be generated. + _ => temporary_pass_through_stub(node), } - updater.update_arguments(graph_outputs); + graph_io.update_tensor_output(node); } fn constant_update_outputs(node: &mut Node) { @@ -287,9 +233,9 @@ fn mean_update_outputs(node: &mut Node) { } //fn __unsqueeze_shape -/// Either it Infers the shape of the output of an Unsqueeze node, or it remaps the node to a reshape if the output is static -/// providing an arg and inferring the dimensions at runtime isn't currently supported. -fn unsqueeze_update_output_or_node(node: &mut Node) { +/// Infers the shape of the output from the input and axes +/// Right now, this should only be called if the rhs is a constant +fn unsqueeze_update_output(node: &mut Node) { if node.inputs.len() != 2 { panic!("Unsqueeze: wrong number of inputs"); } @@ -308,14 +254,13 @@ fn unsqueeze_update_output_or_node(node: &mut Node) { _ => panic!("Unsqueeze: invalid input types"), }; //need output way up here to avoid borrowing issues - let (mut tensor, output_shape) = match &node.outputs[0].ty { - ArgType::Tensor(tensor) => (tensor.clone(), tensor.shape.clone()), + let mut tensor = match &node.outputs[0].ty { + ArgType::Tensor(tensor) => tensor.clone(), _ => panic!("Unsqueeze: invalid output types"), }; - let mut remap_node = false; - match (&axes, tensor.shape) { + match &axes { //case 1: axes is constant -> output shape is input shape with 1s inserted at the axes - (Some(dim_indices), _) => { + Some(dim_indices) => { let output_rank = (dim_indices.len() + input.dim) as i64; let mut dim_indices = dim_indices .to_vec() @@ -368,42 +313,12 @@ fn unsqueeze_update_output_or_node(node: &mut Node) { tensor.shape = Some(new_dims); node.outputs[0].ty = ArgType::Tensor(tensor.clone()); } - //case 2: output shape isn't dynamic -> map the node to a reshape - (None, Some(_)) => { - remap_node = true; - } + //case 3: output shape is dynamic -> black magic or unsupported - (None, None) => { + None => { panic!("Unsqueeze: dynamic output shape is not currently supported"); } } - //need to move out of the match to avoid borrowing issues - if remap_node { - let mut new_node = node.clone(); - let rhs_name = node.inputs[1].name.clone(); - new_node.node_type = NodeType::Reshape; - let rhs_arg = Argument { - name: rhs_name, //need name to remain the same - ty: ArgType::Tensor(TensorType { - elem_type: ElementType::Int64, - dim: 1, - shape: Some(vec![output_shape.clone().unwrap().len()]), - }), - value: Some(Data::Int64s( - output_shape - .unwrap() - .into_iter() - .map(|ax_len| ax_len as i64) - .collect::>(), - )), - - passed: false, - }; - new_node.inputs = vec![node.inputs[0].clone(), rhs_arg]; - new_node.outputs = vec![node.outputs[0].clone()]; - reshape_update_outputs(&mut new_node); - *node = new_node; - } } fn same_as_input(node: &mut Node) { diff --git a/crates/burn-import/src/onnx/from_onnx.rs b/crates/burn-import/src/onnx/from_onnx.rs index 9c3e9564a2..5c7cb841e1 100644 --- a/crates/burn-import/src/onnx/from_onnx.rs +++ b/crates/burn-import/src/onnx/from_onnx.rs @@ -4,14 +4,16 @@ use std::{ path::Path, }; -use crate::onnx::{ - coalesce::coalesce, ir::TensorType, node_remap::remap_node_type, - proto_conversion::convert_node_proto, +use crate::onnx::{node_remap::remap_node_type, proto_conversion::convert_node_proto}; + +use super::{ + coalesce::coalesce, + ir::{Data, OnnxGraph, TensorType}, + protos::{ModelProto, NodeProto, TensorProto, ValueInfoProto}, }; -use super::ir::{ArgType, Argument, Node, NodeType, OnnxGraph, Tensor}; -use super::protos::{ModelProto, TensorProto}; -use super::{dim_inference::dim_inference, protos::ValueInfoProto}; +use super::dim_inference::dim_inference; +use super::ir::{ArgType, Argument, Node, NodeType}; use protobuf::Message; @@ -25,339 +27,458 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 7] = [ NodeType::Unsqueeze, ]; -/// Open an onnx file and convert it to a Graph (intermediate representation) -/// -/// # Arguments -/// -/// * `onnx_path` - Path to the onnx file -/// -/// # Returns -/// -/// * `OnnxGraph` - The graph representation of the onnx file -/// -/// # Panics -/// -/// * If the file cannot be opened -/// * If the file cannot be parsed -/// * If the nodes are not topologically sorted -pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { - log::info!("Parsing ONNX file: {}", onnx_path.display()); +#[derive(Debug)] +pub(crate) enum IOEntry { + In(usize), + Out(usize), + Node(usize), +} - // Open the file - let mut file = File::open(onnx_path).expect("Unable to open file"); - let onnx_model: ModelProto = - Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); +pub(crate) struct OnnxGraphIO { + /// The inputs for the Graph + pub(crate) inputs: Vec, + /// The outputs for the Graph + pub(crate) outputs: Vec, + /// Initializers + pub(crate) initializers: HashMap, + ///updated names of outputs of node not stored in the graph + node_out: Vec, + pub(crate) old_io_names: HashMap, +} - log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); - log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); +impl OnnxGraphIO { + pub(crate) fn new( + inputs: &Vec, + outputs: &Vec, + initializers: &Vec, + ) -> Self { + let mut old_io_names = HashMap::new(); + let mut in_count = 1; + let constants = initializers + .iter() + .map(|x| (x.name.clone(), Argument::from_initializer(x))) + .collect::>(); - log::debug!( - "Number of initializers: {:?}", - onnx_model.graph.initializer.len() - ); + let inputs = inputs + .iter() + .enumerate() + .map(|(i, x)| { + let in_name = format!("input{}", in_count); + old_io_names.insert(x.name.clone(), IOEntry::In(i)); + let mut arg = Argument::try_from(x.clone()).unwrap(); + if let Some(initial_arg) = constants.get(&x.name) { + if arg.value.is_none() { + arg.copy_value(initial_arg); + } + } - log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + in_count += 1; + arg.name = in_name; + arg + }) + .collect::>(); - // Convert the nodes - let mut nodes: Vec = vec![]; - for onnx_node in onnx_model.graph.node.iter() { - let mut node = convert_node_proto(onnx_node); - if onnx_node.op_type == "Unsqueeze" { - move_output_for_unsqueeze( - &mut node, - onnx_model.graph.input.clone(), - onnx_model.graph.output.clone(), - ); + let outputs = outputs + .iter() + .enumerate() + .map(|(i, x)| { + old_io_names.insert(x.name.clone(), IOEntry::Out(i)); + Argument::try_from(x.clone()).unwrap() + }) + .collect::>(); + + let constants = initializers + .iter() + .map(|x| (x.name.clone(), Argument::from_initializer(x))) + .collect::>(); + + Self { + inputs, + outputs, + initializers: constants, + node_out: Vec::new(), + old_io_names, } - remap_node_type(&mut node); - nodes.push(node); } - // ONNX nodes must be topologically sorted per spec: - // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs - assert!(nodes.is_top_sorted(), "Nodes are not topologically sorted"); + fn update_name(&mut self, arg: &Argument, new_name: &str) { + match self.old_io_names.get(&arg.name) { + Some(IOEntry::In(_)) => { + panic!("input names are set from the beginning"); + } + Some(IOEntry::Out(i)) => { + let arg = self.outputs.get_mut(*i).unwrap(); + arg.name = new_name.to_string(); + } + Some(IOEntry::Node(i)) => { + let arg = self.node_out.get_mut(*i).unwrap(); + arg.name = new_name.to_string(); + } + None => { + //Constants, Casts wound up here before API changes + panic!( + "Tried to update the name of {} to {} but entry doesn't exist in the map", + arg.name, new_name + ) + } + } + } + + /// Used to initialize the input arguments for nodes. Names need to remain the same because + /// currently the old names are the key for accessing the Argument + pub fn init_in(&self, proto_str: String) -> Argument { + match self.old_io_names.get(&proto_str) { + None => { + if let Some(init_arg) = self.initializers.get(&proto_str) { + init_arg.clone() + } else { + Argument::new(proto_str) + } + } - // Move inputs with initializers to states - move_inputs_to_state(&mut nodes, &onnx_model.graph.initializer); + Some(IOEntry::In(i)) => { + let mut arg = self.inputs[*i].clone(); - // Handle Identity nodes (expects inputs to be moved to states) - handle_identity(&mut nodes); + arg.name = proto_str; + arg.passed = true; + arg + } + Some(IOEntry::Node(i)) => { + let mut arg = self.node_out[*i].clone(); + arg.name = proto_str; + arg + } + Some(IOEntry::Out(_)) => { + panic!("graph output {} can't be a Node input", &proto_str) + } + } + } - // Lift constants to initializers (expects inputs to be moved to states) - lift_constants(&mut nodes); + fn insert(&mut self, arg: &Argument, new_name: &str) { + if let Some(idx) = self.old_io_names.get(&arg.name) { + if let IOEntry::Node(idx) = idx { + if self.node_out[*idx].name == arg.name { + self.node_out[*idx].name = new_name.to_string(); + return; + } + } else { + panic!("arg entry with old name {} is a graph IO", &arg.name); + } + } - // Coalesce and transform nodes - coalesce(&mut nodes); + let idx = self.node_out.len(); + self.old_io_names + .insert(arg.name.clone(), IOEntry::Node(idx)); + self.node_out.push(arg.clone()); + self.node_out[idx].name = new_name.to_string(); + } - // Rename nodes and inputs, save the mapping for later - let old_node_names = rename_nodes(&mut nodes); + /// Copies node outputs to graph IO. Used at the end of dim inference. + pub(crate) fn update_tensor_output(&mut self, node: &Node) { + for node_output in node.outputs.iter() { + match self.old_io_names.get(&node_output.name) { + Some(IOEntry::In(i)) => { + let arg = self.inputs.get_mut(*i).unwrap(); + arg.copy_value(node_output); + } + Some(IOEntry::Out(i)) => { + let arg = self.outputs.get_mut(*i).unwrap(); + arg.copy_value(node_output); + //Set the output to passed since it's been altered by a Node + arg.passed = true; + } + Some(IOEntry::Node(_)) => { + panic!("This output is from another node"); + } + None => { + log::debug!("inserting with name {:?}", &node_output.name); + let idx = self.node_out.len(); + self.old_io_names + .insert(node_output.name.clone(), IOEntry::Node(idx)); + self.node_out.push(node_output.clone()); + } + } + } + } - // This function collects the inputs of an ONNX model and returns them as a vector of Arguments. - let mut inputs = onnx_model - .graph - .input - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); + /// Used by handle unsqeeze to remap the output of a node to a new name + /// expected match if it exists is either a graph input or graph output + pub(crate) fn get_node_output(&self, old_name: &str) -> Option<&Argument> { + match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => self.inputs.get(*i), + Some(IOEntry::Out(i)) => self.outputs.get(*i), + Some(IOEntry::Node(_)) => panic!("This is a node output"), + None => None, + } + } - // Map each output in the model's graph to an Argument and collect them into a vector. - let mut outputs = onnx_model - .graph - .output - .iter() - .map(|x| Argument::try_from(x.clone()).unwrap()) - .collect(); + /// Get the updated name of a Node Input, which should be + /// either a graph input or a node output. + /// Will return None if the it isn't a graph input or node output(like an initializer) + /// Will panic if it's a graph output + fn get_new_name(&mut self, old_name: &str) -> Option { + match self.old_io_names.get(old_name) { + Some(IOEntry::In(i)) => { + //FIXME: technically in the spec, initializers are default values + //for optional inputs, but implementing that would require reworking + //the way the graph is built, and it's not clear burn users are using initializers + //in that way + // see https://github.com/onnx/onnx/issues/2660 + if self.initializers.contains_key(old_name) { + None + } else { + //set the input as passed since a node is referencing it + self.inputs[*i].passed = true; + Some(self.inputs[*i].name.clone()) + } + } + Some(IOEntry::Out(_)) => { + panic!( + "you just tried to get an updated name on a graph output: {}", + old_name + ) + } + Some(IOEntry::Node(i)) => Some(self.node_out[*i].name.clone()), + None => None, + } + } +} - let old_input_names = rename_inputs(&mut nodes, &mut inputs, &mut outputs); +#[derive(Default)] +pub(crate) struct OnnxGraphBuilder { + nodes: Vec, + inputs: Vec, + outputs: Vec, + /// Counter for node names, used for renaming nodes + node_name_counter: HashMap, + /// Nodes to remove + nodes_to_remove: HashSet, + /// Map from constant node output names to indices of constant nodes + constants_map: HashMap, + constants_types: HashSet, + /// Map from identity node output names to indices of identity nodes + identity_idx: HashMap, +} - // Infer shapes and update the inputs and outputs - dim_inference(&mut nodes, &inputs, &mut outputs); +impl OnnxGraphBuilder { + pub(crate) fn node_gen(&mut self, model_proto: &ModelProto) { + self.constants_types = LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - // Remove the graph inputs/output that are not used by any node - remove_unused_graph_inputs(&mut inputs, &mut outputs, &nodes); + let mut graph_io = OnnxGraphIO::new( + &model_proto.graph.input, + &model_proto.graph.output, + &model_proto.graph.initializer, + ); - log::info!("Finished parsing ONNX file: {}", onnx_path.display()); + self.nodes = Vec::with_capacity(model_proto.graph.node.len()); + let mut and_idx = 0; + let mut node_iter = model_proto.graph.node.iter().peekable(); - OnnxGraph { - nodes, - inputs, - outputs, - old_node_names, - old_input_names, - } -} + while let Some(node_proto) = node_iter.next() { + let mut node = convert_node_proto(node_proto, &graph_io); -/// This function moves inputs that are also present -/// in the initializer to the node's states vector. -/// It also removes inputs that are already present in the states vector. -/// -/// # Arguments -/// -/// * `nodes` - A mutable reference to a vector of nodes -/// * `initializers` - A vector of TensorProto -fn move_inputs_to_state(nodes: &mut Vec, initializers: &[TensorProto]) { - // Convert initializers to hashmap for faster lookup - let initializers = initializers - .iter() - .map(|x| (x.name.clone(), x.clone())) - .collect::>(); - - // Iterate over each node in the graph - nodes.iter_mut().for_each(|node| { - for input in node.inputs.iter_mut() { - // If there is a corresponding initializer for the input, then move the data to the input value - if let Some(initializer) = initializers.get(&input.name) { - move_initializer_data(initializer, input); - } - } - }); -} + remap_node_type(&mut node); -fn move_initializer_data(initializer: &TensorProto, input: &mut Argument) { - // If the input name matches the tensor name in the initializer - // Convert the initializer to a tensor - let tensor = Tensor::try_from(initializer.clone()).expect("Invalid tensor"); + coalesce(&mut node, &mut node_iter, &graph_io); + self.handle_node_renaming(&mut node); + self.handle_identity(&mut node, and_idx); + self.check_constants(&mut node, and_idx, &mut graph_io); + self.handle_unsqueeze(&mut node, &graph_io); - if tensor.dim == 0 { - // Convert zero dim tensor to scalar - if let Some(data) = tensor.data { - input.value = Some(data.into_scalar()); - } else { - input.value = None; + dim_inference(&mut node, &mut graph_io); + + rename_io(&mut node, &mut graph_io); + + self.nodes.push(node); + and_idx += 1; } - // Update the input type - input.ty = ArgType::Scalar(tensor.elem_type); - } else { - // Move the tensor data to the input value - input.value = tensor.data.clone(); - - // Update the input type - input.ty = ArgType::Tensor(TensorType { - dim: tensor.dim, - elem_type: tensor.elem_type, - shape: tensor.shape, + let mut i = 0; + self.nodes.retain(|_x| { + let res = !self.nodes_to_remove.contains(&i); + i += 1; + res }); + let OnnxGraphIO { + mut inputs, + mut outputs, + .. + } = graph_io; + + // Remove the graph inputs/output that are not used by any node + remove_unused_graph_inputs(&mut inputs, &mut outputs); + self.inputs = inputs; + self.outputs = outputs; } -} -/// Stores the output shape in the unsqueeze node for the situation where -/// the axes value isn't constant. -fn move_output_for_unsqueeze( - node: &mut Node, - outputs: Vec, - inputs: Vec, -) { - let output_name = node.outputs[0].name.clone(); - // check outputs first, as it's shorter - for output in outputs.iter() { - if output.name == output_name { - //copy the shape - match node.outputs[0].ty { - ArgType::Tensor(ref mut tensor_type) => { - if let Some(shape) = output.type_.as_ref().unwrap().tensor_type().shape.as_ref() - { - tensor_type.shape = - Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); - return; + fn handle_node_renaming(&mut self, node: &mut Node) { + log::debug!("renaming node {:?}", &node.name); + self.node_name_counter + .entry(node.node_type.clone()) + .and_modify(|e| *e += 1) + .or_insert(1); + let new_name = format!( + "{}{}", + node.node_type, self.node_name_counter[&node.node_type] + ) + .to_lowercase(); + node.name = new_name.clone(); + } + + fn check_constants(&mut self, node: &mut Node, i: usize, _graph_io: &mut OnnxGraphIO) { + if node.node_type == NodeType::Constant + || (node.node_type == NodeType::Identity && node.inputs[0].value.is_some()) + { + self.constants_map.insert(node.outputs[0].name.clone(), i); + } else if self.constants_types.contains(&node.node_type) { + log::debug!("checking node {} for constants", &node.name); + for input in node.inputs.iter_mut().skip(1) { + log::debug!("checking input {:?} for const", input); + if let Some(const_idx) = self.constants_map.get(&input.name) { + let constant = &self.nodes[*const_idx]; + log::debug!( + "input {} matched constant node {}", + &input.name, + &constant.name + ); + if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { + // The value comes from Identity inputs + input.value = constant.inputs[0].value.clone(); + input.ty = constant.inputs[0].ty.clone(); + } else { + let arg = convert_constant_value(constant); + input.value = arg.value; + input.ty = arg.ty; } + self.nodes_to_remove.insert(*const_idx); } - _ => return, } } } - for input in inputs.iter() { - if input.name == output_name { - //copy the shape - match node.outputs[0].ty { - ArgType::Tensor(ref mut tensor_type) => { - if let Some(shape) = input.type_.as_ref().unwrap().tensor_type().shape.as_ref() - { - tensor_type.shape = - Some(shape.dim.iter().map(|x| x.dim_value() as usize).collect()); - return; - } - } - _ => return, + + /// Check if the unsqueeze node has a rhs value (rhs is constant) and if not remap it to a reshape + /// Needs to be called after node renaming to ensure that the rhs name is correct + /// Needs to be called after constant lifting to ensure that the rhs value exists + fn handle_unsqueeze(&mut self, node: &mut Node, graph_io: &OnnxGraphIO) { + if node.node_type == NodeType::Unsqueeze && node.inputs[1].value.is_none() { + if let Some(in_arg) = graph_io.get_node_output(&node.outputs[0].name) { + remap_unsqueeze_to_reshape(node, in_arg); } } } + + fn handle_identity(&mut self, node: &mut Node, i: usize) { + if node.node_type == NodeType::Identity && node.inputs[0].value.is_none() { + log::debug!("\nfound identity node:\n{:?}\n", &node); + //map the output name to check for pass through values + self.identity_idx.insert(node.outputs[0].name.clone(), i); + self.nodes_to_remove.insert(i); + } else { + //NOTE: it might be possible to rework the API to handle all "per input" operations + //in a new function that operates on each input. + node.inputs.iter_mut().for_each(|x| { + if let Some(identity_idx) = self.identity_idx.get(&x.name) { + let input_name = &self.nodes[*identity_idx].inputs[0].name; + + x.name = input_name.clone(); + } + }); + } + } } -/// Lift constants from the graph into the states vector for known node types. -/// -/// The primary reason to move constants into the states vector is to reduce the number of nodes in the graph, -/// and consistently utilize the same interface for all nodes (constant inputs and inputs with initializers are -/// treated the same way). This simplification aids code generation. -/// -/// For example, if we have a graph ([Const1, Const2, Conv2d1]) where the Conv2d node has 3 inputs -/// (graph_input, const2_out1, const_out2), we can lift the constants into the states of the Conv2d node. -/// const2_out1 and const_out2 are used for the weights and bias of the Conv2d node. -/// After lifting, we will have a graph ([Conv2d1]) where the Conv2d node has 1 input (graph_input) and 2 states. +/// Open an onnx file and convert it to a Graph (intermediate representation) /// -/// Also note that often times, Conv2d node's inputs are not constants, but they are initializers. Initializers -/// move to the states vector as well, using the `move_inputs_to_state` function. +/// # Arguments /// +/// * `onnx_path` - Path to the onnx file /// -/// # Arguments +/// # Returns /// -/// * `nodes` - A mutable reference to a vector of nodes +/// * `OnnxGraph` - The graph representation of the onnx file /// /// # Panics /// -/// Panics if the node's output is not a constant. -fn lift_constants(nodes: &mut Vec) { - log::info!("Lifting constants into the states"); - - // create a set to hold the node types to process - let node_types_to_process: HashSet = - LIFT_CONSTANTS_FOR_NODE_TYPES.into_iter().collect(); - - // create a new vector to hold the graph's constants (index by the node's name) - let constants = nodes - .iter() - .filter(|node| node.node_type == NodeType::Constant || node.node_type == NodeType::Identity) - .map(|node| (node.outputs[0].name.clone(), node.clone())) - .collect::>(); - - // create a set to hold the IDs of constants to be removed - let mut constant_to_removed = HashSet::::new(); - - for node in nodes.iter_mut() { - // Skip the node if it is not in the set of node types to process - if !node_types_to_process.contains(&node.node_type) { - continue; - } - - // Skip the first input because it is the node's true input and not a constant/state - node.inputs - .iter_mut() - .skip(1) // TODO make configurable - .for_each(|input| { - if let Some(constant) = constants.get(&input.name) { - if !constant.inputs.is_empty() && constant.inputs[0].value.is_some() { - // The value comes from Identity inputs - if let Some(constant_input) = constant.inputs.first() { - input.ty = constant_input.ty.clone(); - input.value = constant_input.value.clone(); - } - } else { - // The value comes from an attribute - let arg = convert_constant_value(constant); // get the value of the constant +/// * If the file cannot be opened +/// * If the file cannot be parsed +/// * If the nodes are not topologically sorted +pub fn parse_onnx(onnx_path: &Path) -> OnnxGraph { + log::info!("Parsing ONNX file: {}", onnx_path.display()); - input.value = arg.value; // set the input's value to the constant's value - input.ty = arg.ty; // set the input's type to the constant's type - // remove the constant from the graph - } - constant_to_removed.insert(constant.name.clone()); - } - }); - } + // Open the file + let mut file = File::open(onnx_path).expect("Unable to open file"); + let onnx_model: ModelProto = + Message::parse_from_reader(&mut file).expect("Unable to parse ONNX file"); - // remove the constants that were moved to the states vector - nodes.retain(|node| !constant_to_removed.contains(&node.name)); + // ONNX nodes must be topologically sorted per spec: + // https://github.com/onnx/onnx/blob/main/docs/IR.md#graphs + debug_assert!( + onnx_model.graph.node.is_top_sorted(), + "Nodes are not topologically sorted" + ); + log::debug!("Number of nodes: {:?}", onnx_model.graph.node.len()); + log::debug!("Number of inputs: {:?}", onnx_model.graph.input.len()); log::debug!( - "The number of constants lifted: {}", - constant_to_removed.len() + "Number of initializers: {:?}", + onnx_model.graph.initializer.len() ); -} - -fn handle_identity(nodes: &mut Vec) { - log::info!("Handling identity nodes"); - let mut nodes_to_remove = HashSet::new(); + log::debug!("Number of outputs: {:?}", onnx_model.graph.output.len()); + let mut builder = OnnxGraphBuilder::default(); + builder.node_gen(&onnx_model); - let identity_nodes = nodes - .iter() - .filter(|node| node.node_type == NodeType::Identity) - .cloned() - .collect::>(); + let OnnxGraphBuilder { + nodes, + inputs: inner_inputs, + outputs: inner_outputs, + .. + } = builder; - // Handle pass-through nodes. - for identity_node in identity_nodes { - if identity_node.node_type == NodeType::Identity && identity_node.inputs[0].value.is_none() - { - let input_name = &identity_node.inputs[0].name; - let output_name = &identity_node.outputs[0].name; - - // Replace the identity node's output with its input in the connected nodes. - for node in nodes.iter_mut() { - if let Some(matched_input) = node.inputs.iter_mut().find(|x| x.name == *output_name) - { - matched_input.name = input_name.clone(); - } - } + log::info!("Finished parsing ONNX file: {}", onnx_path.display()); - nodes_to_remove.insert(identity_node); - } + OnnxGraph { + nodes, + inputs: inner_inputs, + outputs: inner_outputs, } - - // Remove the identity nodes. - nodes.retain(|node| !nodes_to_remove.contains(node)); } -/// Rename the nodes in the graph to be unique and return a map of the old names to the new names. -fn rename_nodes(nodes: &mut Vec) -> HashMap { - let mut old_names = HashMap::new(); - let mut counter: HashMap = HashMap::new(); - - for node in nodes.iter_mut() { - // keep track of the number of nodes of each type - counter - .entry(node.node_type.clone()) - .and_modify(|e| *e += 1) - .or_insert(1); - - let old_name = node.name.clone(); - let new_name = format!("{}{}", node.node_type, counter[&node.node_type]).to_lowercase(); - - node.name = new_name.clone(); - - old_names.insert(old_name, new_name); +/// Remap the unsqueeze node to a reshape node, Should only be called after +/// node renaming has been done. avoids marking rhs as passed so that it can be +/// properly deleted if nothing else uses it +fn remap_unsqueeze_to_reshape(node: &mut Node, out_arg: &Argument) { + match node.outputs[0].ty { + ArgType::Tensor(ref mut tensor_type) => { + if let ArgType::Tensor(arg_tensor) = &out_arg.ty { + tensor_type.shape = arg_tensor.shape.clone(); + let inner = arg_tensor + .shape + .clone() + .unwrap() + .into_iter() + .map(|x| x as i64) + .collect::>(); + let shape_len = inner.len(); + let new_rhs_value = Some(Data::Int64s(inner)); + //moving the remap to here + let rhs_arg = Argument { + name: format!("{}_generated_const", node.name), + ty: ArgType::Tensor(TensorType { + elem_type: super::ir::ElementType::Int64, + dim: 1, + shape: Some(vec![shape_len]), + }), + value: new_rhs_value, + passed: false, + }; + node.inputs[1] = rhs_arg; + node.outputs[0] = out_arg.clone(); + node.node_type = NodeType::Reshape; + } + } + _ => {} } - - old_names } /// Rename the inputs and output in the graph and return a map of @@ -366,60 +487,38 @@ fn rename_nodes(nodes: &mut Vec) -> HashMap { /// The inputs are renamed to be unique and to be in the format of /// conv2_in1, conv2_in2, etc. This is done to be consistent with /// the naming convention of the nodes and allow to be used as rust identifiers. -fn rename_inputs( - nodes: &mut Vec, - inputs: &mut Vec, - outputs: &mut Vec, -) -> HashMap { - let mut old_names = HashMap::new(); - - // rename all graph input names to follow input1, input2, input3, etc. - // (assumes the input names are already unique) - let mut counter = 1; - for input in inputs.iter_mut() { - let old_name = input.name.clone(); - let new_name = format!("input{}", counter); - input.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; - } - - for node in nodes.iter_mut() { - let mut counter = 1; - // loop through node outputs and rename them and store the new name <-> old name mapping - for output in node.outputs.iter_mut() { - let old_name = output.name.clone(); - let new_name = format!("{}_out{}", node.name, counter); - output.name = new_name.clone(); - old_names.insert(old_name, new_name); - counter += 1; +/// Rename the inputs and output in the graph and return a map of +/// the old names to the new names. +fn rename_io(node: &mut Node, graph_io: &mut OnnxGraphIO) { + log::debug!("checking inputs for node {:?}", &node.name); + for node_input in node.inputs.iter_mut() { + if let Some(input_name) = graph_io.get_new_name(&node_input.name) { + node_input.passed = true; + node_input.name = input_name.clone(); + } else { + node_input.name = "".to_string(); + node_input.passed = false; } } + log::debug!("\n\nchecking outputs"); + let mut out_count = 1; + if node.node_type == NodeType::Constant || node.node_type == NodeType::Identity { + log::debug!("it's a constant"); + let new_name = format!("{}_out{}", node.name, out_count); + graph_io.insert(&node.outputs[0], &new_name); + node.outputs[0].name = new_name; + } else { + for output in node.outputs.iter_mut() { + log::debug!("output name: {}", &output.name); - for node in nodes.iter_mut() { - // loop through node inputs and rename them with previously replaced names - // and mark them as passed if they are in the old_names map (i.e. they are node outputs) - for input in node.inputs.iter_mut() { - if let Some(new_name) = old_names.get(&input.name) { - input.name = new_name.clone(); - input.passed = true; - } else { - input.name = "".to_string(); // Rename to a placeholder - input.passed = false; - } - } - } + let new_name = format!("{}_out{}", node.name, out_count); + + graph_io.update_name(output, &new_name); - // Rename the graph outputs - for output in outputs.iter_mut() { - if let Some(new_name) = old_names.get(&output.name) { output.name = new_name.clone(); - } else { - log::warn!("Output {:?} not found in old_names", output.name); + out_count += 1; } } - - old_names } /// Removes the graph inputs/output that are not used by any node. @@ -431,34 +530,12 @@ fn rename_inputs( /// /// Generally, it's a good idea to remove unused inputs/outputs because it makes the /// generated code cleaner and easier to read. -fn remove_unused_graph_inputs( - inputs: &mut Vec, - outputs: &mut Vec, - nodes: &Vec, -) { +fn remove_unused_graph_inputs(inputs: &mut Vec, outputs: &mut Vec) { // Remove inputs that are not used by any node - inputs.retain(|input| { - for node in nodes.iter() { - if node - .inputs - .iter() - .any(|x| x.name == input.name && x.value.is_none()) - { - return true; - } - } - false - }); + inputs.retain(|input| input.passed); // Remove outputs that are not used by any node - outputs.retain(|output| { - for node in nodes.iter() { - if node.outputs.iter().any(|x| x.name == output.name) { - return true; - } - } - false - }); + outputs.retain(|output| output.passed); } // Define a trait for topological sorting @@ -466,7 +543,7 @@ trait TopologicalSortable { fn is_top_sorted(&self) -> bool; } -impl TopologicalSortable for Vec { +impl TopologicalSortable for Vec { fn is_top_sorted(&self) -> bool { // Create a hashmap to store the position of each node in the vector let position: HashMap = self @@ -478,11 +555,11 @@ impl TopologicalSortable for Vec { // Iterate over each node in the vector for node in self { // Iterate over each output of the node - for output in &node.outputs { + for output in &node.output { // Iterate over each other node in the vector for other_node in self { // If the other node has an input that matches the current output - if other_node.inputs.contains(output) { + if other_node.input.contains(output) { // If the position of the current node is greater than the position of the other node if position[&node.name] > position[&other_node.name] { // The vector is not topologically sorted diff --git a/crates/burn-import/src/onnx/ir.rs b/crates/burn-import/src/onnx/ir.rs index be27a70dd7..eb229ea7ce 100644 --- a/crates/burn-import/src/onnx/ir.rs +++ b/crates/burn-import/src/onnx/ir.rs @@ -3,6 +3,8 @@ use half::f16; use std::{collections::HashMap, fmt::Formatter}; use strum_macros::{Display, EnumString}; +use super::protos::TensorProto; + pub type Dim = usize; pub type Shape = Vec; @@ -23,6 +25,48 @@ pub struct Argument { pub passed: bool, } +impl Argument { + /// Copy everything except the name from the other argument + pub fn copy_value(&mut self, other_arg: &Argument) { + self.ty = other_arg.ty.clone(); + self.value = other_arg.value.clone(); + } + + pub fn from_initializer(initializer: &TensorProto) -> Argument { + let name = initializer.name.clone(); + let tensor = Tensor::try_from(initializer.clone()) + .unwrap_or_else(|_| panic!("invalid tensor {}", &initializer.name)); + + if tensor.dim == 0 { + // Convert zero dim tensor to scalar + let value = if tensor.data.is_some() { + Some(tensor.data.clone().unwrap().into_scalar()) + } else { + None + }; + let ty = ArgType::Scalar(tensor.elem_type); + + Self { + name, + ty, + value, + passed: false, + } + } else { + Self { + name, + ty: ArgType::Tensor(TensorType { + elem_type: tensor.elem_type, + dim: tensor.dim, + shape: tensor.shape, + }), + value: tensor.data.clone(), + passed: false, + } + } + } +} + /// The type of an argument. #[derive(Debug, Clone)] pub enum ArgType { @@ -138,12 +182,6 @@ pub struct OnnxGraph { /// The outputs of the graph. pub outputs: Vec, - - /// The original node names. - pub old_node_names: HashMap, - - /// The original input names. - pub old_input_names: HashMap, } /// Nodes produced by the ONNX parser diff --git a/crates/burn-import/src/onnx/proto_conversion.rs b/crates/burn-import/src/onnx/proto_conversion.rs index 1fe5aac54a..2f1481281f 100644 --- a/crates/burn-import/src/onnx/proto_conversion.rs +++ b/crates/burn-import/src/onnx/proto_conversion.rs @@ -2,6 +2,7 @@ use std::str::{from_utf8, FromStr}; use crate::onnx::ir::TensorType; +use super::from_onnx::OnnxGraphIO; use super::ir::Dim; use super::ir::{ ArgType, Argument, AttributeValue, Attributes, Data, ElementType, Node, NodeType, Tensor, @@ -179,12 +180,17 @@ pub fn convert_vec_attrs_proto(attrs: Vec) -> Attributes { result } -pub fn convert_node_proto(node: &NodeProto) -> Node { +pub fn convert_node_proto(node: &NodeProto, graph_io: &OnnxGraphIO) -> Node { let name = node.name.clone(); log::debug!("Converting ONNX node with type {:?}", node.op_type.as_str()); - let inputs = node.input.clone().into_iter().map(Argument::new).collect(); + let inputs = node + .input + .clone() + .into_iter() + .map(|x| graph_io.init_in(x)) + .collect(); let outputs = node.output.clone().into_iter().map(Argument::new).collect();