From 8ab1b5dbf0360c71895dc6ce24875ea156b79999 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:58:19 +0100 Subject: [PATCH 01/21] feat!: Allow generic Nodes in HugrMut insert operations (#2075) `insert_hugr`, `insert_from_view`, and `insert_subgraph` were written before we made `Node` a type generic, and incorrectly assumed the return type maps were always `hugr::Node`s. The methods were either unusable or incorrect when using generic `HugrView`s source/targets with non-base node types. This PR fixes that, and additionally allows us us to have `SiblingSubgraph::extract_subgraph` work for generic `HugrViews`. BREAKING CHANGE: Added Node type parameters to extraction operations in `HugrMut`. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 114 ++++++++++++------- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index f1613895d..e17d172ca 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -119,7 +119,7 @@ pub trait Container { } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult { + fn add_hugr_view(&mut self, child: &H) -> InsertionResult { let parent = self.container_node(); self.hugr_mut().insert_from_view(parent, child) } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f3ef094be..38eb59222 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,13 +1,15 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, PortMut, PortView, SecondaryMap}; +use crate::core::HugrNode; use crate::extension::ExtensionRegistry; +use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; @@ -162,10 +164,10 @@ pub trait HugrMut: HugrMutInternals { /// correspondingly for `Dom` edges) fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { panic_invalid_node(self, root); panic_invalid_node(self, new_parent); self.hugr_mut().copy_descendants(root, new_parent, subst) @@ -225,7 +227,7 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult { + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_hugr(root, other) } @@ -236,7 +238,11 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_from_view(root, other) } @@ -255,12 +261,12 @@ pub trait HugrMut: HugrMutInternals { // TODO: Try to preserve the order when possible? We cannot always ensure // it, since the subgraph may have arbitrary nodes without including their // parent. - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { panic_invalid_node(self, root); self.hugr_mut().insert_subgraph(root, other, subgraph) } @@ -307,20 +313,32 @@ pub trait HugrMut: HugrMutInternals { /// Records the result of inserting a Hugr or view /// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]. -pub struct InsertionResult { +/// +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. +pub struct InsertionResult { /// The node, after insertion, that was the root of the inserted Hugr. /// /// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root] - pub new_root: Node, + pub new_root: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. - pub node_map: HashMap, + pub node_map: HashMap, } -fn translate_indices( +/// Translate a portgraph node index map into a map from nodes in the source +/// HUGR to nodes in the target HUGR. +/// +/// This is as a helper in `insert_hugr` and `insert_subgraph`, where the source +/// HUGR may be an arbitrary `HugrView` with generic node types. +fn translate_indices( + mut source_node: impl FnMut(portgraph::NodeIndex) -> N, + mut target_node: impl FnMut(portgraph::NodeIndex) -> Node, node_map: HashMap, -) -> impl Iterator { - node_map.into_iter().map(|(k, v)| (k.into(), v.into())) +) -> impl Iterator { + node_map + .into_iter() + .map(move |(k, v)| (source_node(k), target_node(v))) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -406,7 +424,11 @@ impl + AsMut> HugrMut for T (src_port, dst_port) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { + fn insert_hugr( + &mut self, + root: Self::Node, + mut other: Hugr, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); // Update the optypes and metadata, taking them from the other graph. // @@ -423,11 +445,16 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); // Update the optypes and metadata, copying them from the other graph. // @@ -444,22 +471,28 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { // Create a portgraph view with the explicit list of nodes defined by the subgraph. - let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> = + let context: HashSet = subgraph + .nodes() + .iter() + .map(|&n| other.get_pg_index(n)) + .collect(); + let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( other.portgraph(), - |node, ctx| ctx.contains(&node.into()), - subgraph.nodes(), + |node, ctx| ctx.contains(&node), + context, ); let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. @@ -473,25 +506,24 @@ impl + AsMut> HugrMut for T self.use_extensions(exts); } } - translate_indices(node_map).collect() + translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() } fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = translate_indices( - portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) - .copy_in_parent() - .expect("Is a MultiPortGraph"), - ) - .collect::>(); + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"); + let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -563,10 +595,10 @@ fn insert_hugr_internal( /// sibling order in the hierarchy. This is due to the subgraph not necessarily /// having a single root, so the logic for reconstructing the hierarchy is not /// able to just do a BFS. -fn insert_subgraph_internal( +fn insert_subgraph_internal( hugr: &mut Hugr, root: Node, - other: &impl HugrView, + other: &impl HugrView, portgraph: &impl portgraph::LinkView, ) -> HashMap { let node_map = hugr diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index a0bf1a3da..c681fafc9 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -446,16 +446,14 @@ impl SiblingSubgraph { nu_out, )) } -} -impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, - hugr: &impl HugrView, + hugr: &impl HugrView, name: impl Into, ) -> Hugr { let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap(); From 7d9c650d94cabaa93a455b376133be2c1cdce4ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:02:17 +0100 Subject: [PATCH 02/21] fix!: Don't expose `HugrMutInternals` (#2071) `HugrMutInternals` is part of the semi-private traits defined in `hugr-core`. While most things get re-exported in `hugr`, we `*Internal` traits require you to explicitly declare a dependency on the `-core` package (as we don't want most users to have to interact with them). For some reason there was a public re-export of the trait in a re-exported module, so it ended up appearing in `hugr` anyways. BREAKING CHANGE: Removed public re-export of `HugrMutInternals` from `hugr`. --- hugr-core/src/hugr/rewrite/simple_replace.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index cf7f2922a..b4ec37db1 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; -pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; From 4818a02177bd821a001ceb8c88313fdcbeb3b8cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:10:32 +0100 Subject: [PATCH 03/21] feat!: Mark all Error enums as non_exhaustive (#2056) #2027 ended up being breaking due to adding a new variant to an error enum missing the `non_exhaustive` marker. This (breaking) PR makes sure all error enums have the flag. BREAKING CHANGE: Marked all Error enums as `non_exhaustive` --- hugr-core/src/extension.rs | 1 + hugr-core/src/hugr/serialize/upgrade.rs | 1 + hugr-core/src/import.rs | 2 ++ hugr-model/src/v0/ast/resolve.rs | 1 + hugr-model/src/v0/table/mod.rs | 1 + hugr-passes/src/lower.rs | 1 + hugr-passes/src/non_local.rs | 1 + hugr-passes/src/replace_types/linearize.rs | 1 + hugr-passes/src/validation.rs | 1 + 9 files changed, 10 insertions(+) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..b6e059050 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -378,6 +378,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] diff --git a/hugr-core/src/hugr/serialize/upgrade.rs b/hugr-core/src/hugr/serialize/upgrade.rs index 2741b6175..ac1ac1eea 100644 --- a/hugr-core/src/hugr/serialize/upgrade.rs +++ b/hugr-core/src/hugr/serialize/upgrade.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Debug, Error)] +#[non_exhaustive] pub enum UpgradeError { #[error(transparent)] Deserialize(#[from] serde_json::Error), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 642c84c41..899deb17d 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -35,6 +35,7 @@ use thiserror::Error; /// Error during import. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and @@ -75,6 +76,7 @@ pub enum ImportError { /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..c9be8896b 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -362,6 +362,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..55a4b9889 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -456,6 +456,7 @@ pub struct VarId(pub NodeId, pub VarIndex); /// Errors that can occur when traversing and interpreting the model. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ModelError { /// There is a reference to a node that does not exist. #[error("node not found: {0}")] diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 09e02c41d..8f8920967 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -35,6 +35,7 @@ pub fn replace_many_ops>( /// Errors produced by the [`lower_ops`] function. #[derive(Debug, Error)] #[error(transparent)] +#[non_exhaustive] pub enum LowerError { /// Invalid subgraph. #[error("Subgraph formed by node is invalid: {0}")] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fca74657b..180e9d6fc 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -23,6 +23,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator { #[error("Found {} nonlocal edges", .0.len())] Edges(Vec<(N, IncomingPort)>), diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 7b83717d0..5b4da7184 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -135,6 +135,7 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 5f53f403c..6c3e61fb4 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -25,6 +25,7 @@ pub enum ValidationLevel { #[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] InputError { From d4747cead11184cb4b1974f0929fa98ec081e59f Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 11:11:13 +0100 Subject: [PATCH 04/21] feat!: Handle CallIndirect in Dataflow Analysis (#2059) * PartialValue now has a LoadedFunction variant, created by LoadFunction nodes (only, although other analyses are able to create PartialValues if they want) * This requires adding a type parameter to PartialValue for the type of Node, which gets everywhere :-(. * Use this to handle CallIndirects *with known targets* (it'll be a single known target or none at all) just like other Calls to the same function * deprecate (and ignore) `value_from_function` * Add a new trait `AsConcrete` for the result type of `PartialValue::try_into_concrete` and `PartialSum::try_into_sum` Note almost no change to constant folding (only to drop impl of `value_from_function`) BREAKING CHANGE: in dataflow framework, PartialValue now has additional variant; `try_into_concrete` requires the target type to implement AsConcrete. --- hugr-passes/src/const_fold.rs | 63 ++--- hugr-passes/src/const_fold/test.rs | 6 +- hugr-passes/src/const_fold/value_handle.rs | 23 +- hugr-passes/src/dataflow.rs | 17 +- hugr-passes/src/dataflow/datalog.rs | 171 +++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 267 +++++++++++++-------- hugr-passes/src/dataflow/results.rs | 22 +- hugr-passes/src/dataflow/test.rs | 108 ++++++++- hugr-passes/src/dataflow/value_row.rs | 38 +-- 9 files changed, 492 insertions(+), 223 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..e73e3cd0e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -102,7 +98,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -112,7 +108,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -131,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) @@ -205,60 +201,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); - -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +121,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,19 +3,22 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; + +type NodeInputs = Vec<(IncomingPort, PV)>; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -25,10 +28,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -135,10 +135,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +157,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -322,6 +324,37 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + let tgt = load_func(v); + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + input_child(func, inp), + in_wire_value(call, p, v) + if p.index() > 0; + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + output_child(func, outp), + in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value @@ -337,13 +370,58 @@ pub(super) fn run_datalog( } } +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. @@ -362,8 +440,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -380,10 +457,10 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { @@ -401,6 +478,54 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive + } +} + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..240f4f2d6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +51,25 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded + pub func_node: N, + /// The type arguments provided when loading + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +85,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +163,33 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,11 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - { + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( + self, + typ: &Type, + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -185,22 +228,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +245,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +255,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +292,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +311,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +352,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,47 +414,27 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> - where - V: TryInto, - Sum: TryInto, - { + pub fn try_into_concrete>( + self, + typ: &Type, + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +446,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also join the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +473,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also meet the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +500,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +511,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -509,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -537,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -605,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +23,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,13 +86,11 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + #[allow(clippy::type_complexity)] + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom + TryFrom, Error = SE>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr @@ -116,7 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..1c4b4e439 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,10 +1,12 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; -use hugr_core::types::TypeRow; +use hugr_core::ops::{CallIndirect, TailLoop}; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -19,7 +21,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +40,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -295,9 +312,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); @@ -547,3 +562,78 @@ fn test_module() { ); } } + +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + let run = |which| { + Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], + ) + }; + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + + // 2. Test with `which` selecting second function -> both passthrough + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); + + //3. Test with `which` selecting first function -> alias + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) From 7e0444d336af9f216b444630675d4b298a523729 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:52:21 +0100 Subject: [PATCH 05/21] feat: Make NodeHandle generic (#2092) Adds a generic node type to the `NodeHandle` type. This is a required change for #2029. drive-by: Implement the "Link the NodeHandles to the OpType" TODO --- hugr-core/src/ops.rs | 16 +++++++- hugr-core/src/ops/handle.rs | 73 ++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 0c7d3bb3f..ce0d44de0 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,6 +9,7 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::core::HugrNode; use crate::extension::resolution::{ collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, }; @@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; +use handle::NodeHandle; use paste::paste; use portgraph::NodeIndex; @@ -41,7 +43,6 @@ pub use tag::OpTag; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] /// The concrete operation types for a node in the HUGR. -// TODO: Link the NodeHandles to the OpType. #[non_exhaustive] #[allow(missing_docs)] #[serde(tag = "op")] @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone { /// Tag identifying the operation. fn tag(&self) -> OpTag; + /// Tries to create a specific [`NodeHandle`] for a node with this operation + /// type. + /// + /// Fails if the operation's [`OpTrait::tag`] does not match the + /// [`NodeHandle::TAG`] of the requested handle. + fn try_node_handle(&self, node: N) -> Option + where + N: HugrNode, + H: NodeHandle + From, + { + H::TAG.is_superset(self.tag()).then(|| node.into()) + } + /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..a5a3c294a 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,4 +1,5 @@ //! Handles to nodes in HUGR. +use crate::core::HugrNode; use crate::types::{Type, TypeBound}; use crate::Node; @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag}; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. -pub trait NodeHandle: Clone { +pub trait NodeHandle: Clone { /// The most specific operation tag associated with the handle. const TAG: OpTag; /// Index of underlying node. - fn node(&self) -> Node; + fn node(&self) -> N; /// Operation tag for the handle. #[inline] @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone { } /// Cast the handle to a different more general tag. - fn try_cast>(&self) -> Option { + fn try_cast + From>(&self) -> Option { T::TAG.is_superset(Self::TAG).then(|| self.node().into()) } @@ -36,30 +37,30 @@ pub trait NodeHandle: Clone { /// Trait for handles that contain children. /// /// The allowed children handles are defined by the associated type. -pub trait ContainerHandle: NodeHandle { +pub trait ContainerHandle: NodeHandle { /// Handle type for the children of this node. - type ChildrenHandle: NodeHandle; + type ChildrenHandle: NodeHandle; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowOp](crate::ops::dataflow). -pub struct DataflowOpID(Node); +pub struct DataflowOpID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DFG](crate::ops::DFG) node. -pub struct DfgID(Node); +pub struct DfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [CFG](crate::ops::CFG) node. -pub struct CfgID(Node); +pub struct CfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a module [Module](crate::ops::Module) node. -pub struct ModuleRootID(Node); +pub struct ModuleRootID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [module op](crate::ops::module) node. -pub struct ModuleID(Node); +pub struct ModuleID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [def](crate::ops::OpType::FuncDefn) @@ -67,7 +68,7 @@ pub struct ModuleID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct FuncID(Node); +pub struct FuncID(N); #[derive(Debug, Clone, PartialEq, Eq)] /// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn) @@ -75,15 +76,15 @@ pub struct FuncID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct AliasID { - node: Node, +pub struct AliasID { + node: N, name: SmolStr, bound: TypeBound, } -impl AliasID { +impl AliasID { /// Construct new AliasID - pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self { + pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { Self { node, name, bound } } @@ -99,27 +100,27 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node); +pub struct ConstID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. -pub struct BasicBlockID(Node); +pub struct BasicBlockID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Case](crate::ops::Case) node. -pub struct CaseID(Node); +pub struct CaseID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [TailLoop](crate::ops::TailLoop) node. -pub struct TailLoopID(Node); +pub struct TailLoopID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Conditional](crate::ops::Conditional) node. -pub struct ConditionalID(Node); +pub struct ConditionalID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a dataflow container node. -pub struct DataflowParentID(Node); +pub struct DataflowParentID(N); /// Implements the `NodeHandle` trait for a tuple struct that contains just a /// NodeIndex. Takes the name of the struct, and the corresponding OpTag. @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle { impl_nodehandle!($name, $tag, 0); }; ($name:ident, $tag:expr, $node_attr:tt) => { - impl NodeHandle for $name { + impl NodeHandle for $name { const TAG: OpTag = $tag; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.$node_attr } } @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const); impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock); -impl NodeHandle for FuncID { +impl NodeHandle for FuncID { const TAG: OpTag = OpTag::Function; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.0 } } -impl NodeHandle for AliasID { +impl NodeHandle for AliasID { const TAG: OpTag = OpTag::Alias; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.node } } -impl NodeHandle for Node { +impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { *self } } /// Implements the `ContainerHandle` trait, with the given child handle type. macro_rules! impl_containerHandle { - ($name:path, $children:ident) => { - impl ContainerHandle for $name { - type ChildrenHandle = $children; + ($name:ident, $children:ident) => { + impl ContainerHandle for $name { + type ChildrenHandle = $children; } }; } @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID); impl_containerHandle!(ModuleRootID, ModuleID); impl_containerHandle!(CfgID, BasicBlockID); impl_containerHandle!(BasicBlockID, DataflowOpID); -impl_containerHandle!(FuncID, DataflowOpID); -impl_containerHandle!(AliasID, DataflowOpID); +impl ContainerHandle for FuncID { + type ChildrenHandle = DataflowOpID; +} +impl ContainerHandle for AliasID { + type ChildrenHandle = DataflowOpID; +} From 70881b7c5a55613f0304f41ee7cae8236a8bd668 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 17 Apr 2025 10:38:12 +0100 Subject: [PATCH 06/21] feat!: remove ExtensionValue (#2093) Closes #1595 BREAKING CHANGE: `values` field in `Extension` and `ExtensionValue` struct/class removed in rust and python. Use 0-input ops that return constant values. --- hugr-core/src/extension.rs | 64 +------------------ .../src/extension/resolution/extension.rs | 11 +--- hugr-core/src/hugr/validate/test.rs | 8 +-- hugr-core/src/std_extensions/logic.rs | 26 +------- hugr-py/src/hugr/_serialization/extension.py | 21 ------ hugr-py/src/hugr/ext.py | 42 +----------- .../_json_defs/arithmetic/conversions.json | 1 - .../hugr/std/_json_defs/arithmetic/float.json | 1 - .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 1 - .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 1 - .../hugr/std/_json_defs/collections/list.json | 1 - .../_json_defs/collections/static_array.json | 1 - hugr-py/src/hugr/std/_json_defs/logic.json | 28 -------- hugr-py/src/hugr/std/_json_defs/prelude.json | 1 - hugr-py/src/hugr/std/_json_defs/ptr.json | 1 - specification/schema/hugr_schema_live.json | 30 --------- .../schema/hugr_schema_strict_live.json | 30 --------- .../schema/testing_hugr_schema_live.json | 30 --------- .../testing_hugr_schema_strict_live.json | 30 --------- .../arithmetic/conversions.json | 1 - .../std_extensions/arithmetic/float.json | 1 - .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 1 - .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 1 - .../std_extensions/collections/list.json | 1 - .../collections/static_array.json | 1 - specification/std_extensions/logic.json | 28 -------- specification/std_extensions/prelude.json | 1 - specification/std_extensions/ptr.json | 1 - 32 files changed, 7 insertions(+), 361 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index b6e059050..23238ccfd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,9 +19,8 @@ use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; -use crate::ops::constant::{ValueName, ValueNameRef}; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::ops::{self, OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; @@ -497,37 +496,6 @@ impl CustomConcrete for CustomType { } } -/// A constant value provided by a extension. -/// Must be an instance of a type available to the extension. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct ExtensionValue { - extension: ExtensionId, - name: ValueName, - typed_value: ops::Value, -} - -impl ExtensionValue { - /// Returns a reference to the typed value of this [`ExtensionValue`]. - pub fn typed_value(&self) -> &ops::Value { - &self.typed_value - } - - /// Returns a mutable reference to the typed value of this [`ExtensionValue`]. - pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value { - &mut self.typed_value - } - - /// Returns a reference to the name of this [`ExtensionValue`]. - pub fn name(&self) -> &str { - self.name.as_str() - } - - /// Returns a reference to the extension this [`ExtensionValue`] belongs to. - pub fn extension(&self) -> &ExtensionId { - &self.extension - } -} - /// A unique identifier for a extension. /// /// The actual [`Extension`] is stored externally. @@ -583,8 +551,6 @@ pub struct Extension { pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, - /// Static values defined by this extension. - values: BTreeMap, /// Operation declarations with serializable definitions. // Note: serde will serialize this because we configure with `features=["rc"]`. // That will clone anything that has multiple references, but each @@ -608,7 +574,6 @@ impl Extension { version, runtime_reqs: Default::default(), types: Default::default(), - values: Default::default(), operations: Default::default(), } } @@ -680,11 +645,6 @@ impl Extension { self.types.get(type_name) } - /// Allows read-only access to the values in this Extension - pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> { - self.values.get(value_name) - } - /// Returns the name of the extension. pub fn name(&self) -> &ExtensionId { &self.name @@ -705,25 +665,6 @@ impl Extension { self.types.iter() } - /// Add a named static value to the extension. - pub fn add_value( - &mut self, - name: impl Into, - typed_value: ops::Value, - ) -> Result<&mut ExtensionValue, ExtensionBuildError> { - let extension_value = ExtensionValue { - extension: self.name.clone(), - name: name.into(), - typed_value, - }; - match self.values.entry(extension_value.name.clone()) { - btree_map::Entry::Occupied(_) => { - Err(ExtensionBuildError::ValueExists(extension_value.name)) - } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), - } - } - /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension. pub fn instantiate_extension_op( &self, @@ -784,9 +725,6 @@ pub enum ExtensionBuildError { /// Existing [`TypeDef`] #[error("Extension already has an type called {0}.")] TypeDefExists(TypeName), - /// Existing [`ExtensionValue`] - #[error("Extension already has an extension value called {0}.")] - ValueExists(ValueName), } /// A set of extensions identified by their unique [`ExtensionId`]. diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 61adc1dea..05c0faf69 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::{resolve_signature_exts, resolve_value_exts}; +use super::types_mut::resolve_signature_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -59,14 +59,7 @@ impl Extension { for type_def in self.types.values_mut() { resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?; } - for val in self.values.values_mut() { - resolve_value_exts( - None, - val.typed_value_mut(), - extensions, - &mut used_extensions, - )?; - } + let ops = mem::take(&mut self.operations); for (op_id, mut op_def) in ops { // TODO: We should be able to clone the definition if needed by using `make_mut`, diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ecb417ec5..37157020d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -20,7 +20,6 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::LogicOp; -use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -307,12 +306,7 @@ fn test_local_const() { port_kind: EdgeKind::Value(bool_t()) }) ); - let const_op: ops::Const = logic::EXTENSION - .get_value(&logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone() - .into(); + let const_op: ops::Const = ops::Value::from_bool(true).into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..20977cb51 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -124,13 +124,6 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); fn extension() -> Arc { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { LogicOp::load_all_ops(extension, extension_ref).unwrap(); - - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); }) } @@ -172,12 +165,9 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use std::sync::Arc; - use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp}; use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, ops::{NamedOp, Value}, Extension, }; @@ -207,18 +197,6 @@ pub(crate) mod test { } } - #[test] - fn test_values() { - let r: Arc = extension(); - let false_val = r.get_value(&FALSE_NAME).unwrap(); - let true_val = r.get_value(&TRUE_NAME).unwrap(); - - for v in [false_val, true_val] { - let simpl = v.typed_value().get_type(); - assert_eq!(simpl, bool_t()); - } - } - /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 429bdd785..95e59754e 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -8,7 +8,6 @@ from hugr.hugr.base import Hugr from hugr.utils import deser_it -from .ops import Value from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, @@ -20,7 +19,6 @@ ) if TYPE_CHECKING: - from .ops import Value from .serial_hugr import SerialHugr @@ -62,20 +60,6 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: ) -class ExtensionValue(ConfiguredBaseModel): - extension: ExtensionId - name: str - typed_value: Value - - def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: - return extension.add_extension_value( - ext.ExtensionValue( - name=self.name, - val=self.typed_value.deserialize(), - ) - ) - - # -------------------------------------- # --------------- OpDef ---------------- # -------------------------------------- @@ -124,7 +108,6 @@ class Extension(ConfiguredBaseModel): name: ExtensionId runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] - values: dict[str, ExtensionValue] operations: dict[str, OpDef] @classmethod @@ -146,10 +129,6 @@ def deserialize(self) -> ext.Extension: assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for k, v in self.values.items(): - assert k == v.name, "Value name must match key" - e.add_extension_value(v.deserialize(e)) - return e diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 494ea3c69..7bd02f982 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -8,7 +8,7 @@ from semver import Version import hugr._serialization.extension as ext_s -from hugr import ops, tys, val +from hugr import ops, tys from hugr.utils import ser_it __all__ = [ @@ -18,7 +18,6 @@ "FixedHugr", "OpDefSig", "OpDef", - "ExtensionValue", "Extension", "Version", ] @@ -246,23 +245,6 @@ def instantiate( return ops.ExtOp(self, concrete_signature, list(args or [])) -@dataclass -class ExtensionValue(ExtensionObject): - """A value defined in an :class:`Extension`.""" - - #: The name of the value. - name: str - #: Value payload. - val: val.Value - - def _to_serial(self) -> ext_s.ExtensionValue: - return ext_s.ExtensionValue( - extension=self.get_extension().name, - name=self.name, - typed_value=self.val._to_serial_root(), - ) - - T = TypeVar("T", bound=ops.RegisteredOp) @@ -278,8 +260,6 @@ class Extension: runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) - #: Values defined in the extension. - values: dict[str, ExtensionValue] = field(default_factory=dict) #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @@ -295,7 +275,6 @@ def _to_serial(self) -> ext_s.Extension: version=self.version, # type: ignore[arg-type] runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, - values={k: v._to_serial() for k, v in self.values.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -347,19 +326,6 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef: self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: - """Add a value to the extension. - - Args: - extension_value: The value to add. - - Returns: - The added value, now associated with the extension. - """ - extension_value._extension = self - self.values[extension_value.name] = extension_value - return self.values[extension_value.name] - @dataclass class OperationNotFound(NotFound): """Operation not found in extension.""" @@ -406,12 +372,6 @@ def get_type(self, name: str) -> TypeDef: class ValueNotFound(NotFound): """Value not found in extension.""" - def get_value(self, name: str) -> ExtensionValue: - try: - return self.values[name] - except KeyError as e: - raise self.ValueNotFound(name) from e - T = TypeVar("T", bound=ops.RegisteredOp) def register_op( diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 21e405151..375e13c72 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ad9f02019..ff29d2c21 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index e11ba2388..ec392b155 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9e7d8c40c..ea08dff5b 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 6f436f969..8b65bae94 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index bc067d40e..91b121da6 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 47c9778d3..eae6a13a7 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 21e405151..375e13c72 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ad9f02019..ff29d2c21 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index e11ba2388..ec392b155 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", From 1121fb0fd5ef808fe7eb1ae7b3b13f143024d243 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Apr 2025 17:16:31 +0100 Subject: [PATCH 07/21] feat!: ComposablePass trait allowing sequencing and validation (#1895) Currently We have several "passes": monomorphization, dead function removal, constant folding. Each has its own code to allow setting a validation level (before and after that pass). This PR adds the ability chain (sequence) passes;, and to add validation before+after any pass or sequence; and commons up validation code. The top-level `constant_fold_pass` (etc.) functions are left as wrappers that do a single pass with validation only in test. I've left ConstFoldPass as always including DCE, but an alternative could be to return a sequence of the two - ATM that means a tuple `(ConstFoldPass, DeadCodeElimPass)`. I also wondered about including a method `add_entry_point` in ComposablePass (e.g. for ConstFoldPass, that means `with_inputs` but no inputs, i.e. all Top). I feel this is not applicable to *all* passes, but near enough. This could be done in a later PR but `add_entry_point` would need a no-op default for that to be a non-breaking change. So if we wouldn't be happy with the no-op default then I could just add it here... Finally...docs are extremely minimal ATM (this is hugr-passes), I am hoping that most of this is reasonably obvious (it doesn't really do a lot!), but please flag anything you think is particularly in need of a doc comment! BREAKING CHANGE: quite a lot of calls to current pass routines will break, specific cases include (a) `with_validation_level` should be done by wrapping a ValidatingPass around the receiver; (b) XXXPass::run() requires `use ...ComposablePass` (however, such calls will cease to do any validation). closes #1832 --- hugr-passes/src/composable.rs | 361 +++++++++++++++++++++ hugr-passes/src/const_fold.rs | 45 +-- hugr-passes/src/const_fold/test.rs | 1 + hugr-passes/src/dead_code.rs | 50 ++- hugr-passes/src/dead_funcs.rs | 77 ++--- hugr-passes/src/lib.rs | 12 +- hugr-passes/src/monomorphize.rs | 92 ++---- hugr-passes/src/replace_types.rs | 105 +++--- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 70 ++-- 10 files changed, 550 insertions(+), 265 deletions(-) create mode 100644 hugr-passes/src/composable.rs diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..fb3319155 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,361 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Error = E; + + type Result = (P1::Result, P2::Result); + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl, B: ComposablePass, E: ErrorCombiner> + IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl, B: ComposablePass, E: ErrorCombiner> + ComposablePass for IfThen +{ + type Error = E; + + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, + }; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e73e3cd0e..99ccc180c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -21,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -34,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -49,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -86,9 +76,19 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -164,23 +164,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -198,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58e69c568..ff5cd93a5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..899e30243 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,22 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +184,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..7071d5335 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,32 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +127,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +186,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da47..83ff71b67 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..875ee9355 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,25 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +357,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +380,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +442,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +463,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +571,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +632,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +689,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..e81a640e3 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, HugrView, Node, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -143,7 +143,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,8 +183,6 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), @@ -203,16 +200,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,36 +313,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -472,11 +432,40 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -532,29 +521,26 @@ mod test { use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - - use hugr_core::ops::constant::OpaqueValue; - use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::{ + constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, + }; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -979,13 +965,16 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da7184..bc508bd53 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -377,7 +377,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd1..874fd9ec3 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Error = UntupleError; + + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_rewrite(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -421,7 +410,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); From 4660a1172d11cff0809de058528b322aa0bc9736 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 10:55:43 +0100 Subject: [PATCH 08/21] feat!: ReplaceTypes: allow lowering ops into a Call to a function already in the Hugr (#2094) There are two issues: * Errors. The previous NodeTemplates still always work, but the Call one can fail if the Hugr doesn't contain the target function node. ATM there is no channel for reporting that error so I've had to panic. Otherwise it's an even-more-breaking change to add an error type to `NodeTemplate::add()` and `NodeTemplate::add_hugr()`. Should we? (I note `HugrMut::connect` panics if the node isn't there, but could make the `NodeTemplate::add` builder method return a BuildError...and propagate that everywhere of course) * There's a big limitation in `linearize_array` that it'll break if the *element* says it should be copied/discarded via a NodeTemplate::Call, as `linearize_array` puts the elementwise copy/discard function into a *nested Hugr* (`Value::Function`) that won't contain the function. This could be fixed via lifting those to toplevel FuncDefns with name-mangling, but I'd rather leave that for #2086 .... BREAKING CHANGE: Add new variant NodeTemplate::Call; LinearizeError no longer derives Eq. --- hugr-passes/src/replace_types.rs | 234 ++++++++++++++++----- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 104 +++++++-- 3 files changed, 268 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e81a640e3..df4c14075 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,16 +15,17 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::ComposablePass; @@ -45,21 +46,37 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to an existing function. + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args)?; + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + Ok(n) + } } } @@ -72,10 +89,15 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -88,19 +110,57 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args)?; + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } + }; + Ok(Call::try_new(func_sig, type_args)?) +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -186,6 +246,8 @@ pub enum ReplaceTypesError { ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -370,8 +432,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -382,7 +447,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { @@ -515,24 +582,22 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; - use hugr_core::ops::{ - constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, - }; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::ops::constant::OpaqueValue; + use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, - }; - use hugr_core::std_extensions::collections::list::{ - list_type, list_type_def, ListOp, ListValue, + use hugr_core::std_extensions::collections::{ + array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{type_row, Extension, HugrView}; @@ -601,30 +666,37 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(Signature::new( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.set_outputs([res]).unwrap(); + dfb + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -640,7 +712,13 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) + }); lw } @@ -977,4 +1055,52 @@ mod test { repl.run(&mut h).unwrap(); h.validate_no_extensions().unwrap(); } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); + } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..b6e6e6780 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -92,7 +92,7 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -162,7 +162,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc508bd53..5c4a4a707 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,10 +1,9 @@ -use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -76,9 +75,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -133,7 +134,7 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] pub enum LinearizeError { @@ -163,6 +164,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -185,8 +190,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -230,18 +237,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { @@ -353,7 +354,10 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -768,4 +772,68 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function( + "drop", + Signature::new(lin_t.clone(), type_row![]) + .with_extension_delta(e.name().clone()), + ) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::UnexpectedType { node, .. } + ) + )) if nested_t == lin_t && node == discard_fn + )); + } } From a6c52548107c4edfee2e22e4ba5a89caeeb0cc46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 23 Apr 2025 17:32:27 +0100 Subject: [PATCH 09/21] feat!: Hugrmut on generic nodes (#2111) - Allows `HugrMut` to be implemented for `HugrView`s with arbitrary node types - Removes `HugrMutInternals::hugr_mut(&mut self) -> &mut Hugr`, it can be implemented for more complex types. This is required for #1926, but I haven't touched the read-only side yet. - Added a `Node` associated type to `Rewrite`. All existing rewrites only implement `Rewrite` for now, expanding their type is left for a separate PR. drive-by: Fix a couple bugs in rewrite implementations that assumed that `SiblingMut` contained transitive children. BREAKING CHANGE: `HugrMut` is now implemented generically for any `HugrView::Node` type. BREAKING CHANGE: `SiblingMut` has a new type parameter for the wrapped hugr type. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 212 ++++-------- hugr-core/src/hugr/internal.rs | 320 +++++++----------- hugr-core/src/hugr/rewrite.rs | 21 +- hugr-core/src/hugr/rewrite/consts.rs | 6 +- hugr-core/src/hugr/rewrite/inline_call.rs | 3 +- hugr-core/src/hugr/rewrite/inline_dfg.rs | 6 +- hugr-core/src/hugr/rewrite/insert_identity.rs | 6 +- hugr-core/src/hugr/rewrite/outline_cfg.rs | 38 +-- hugr-core/src/hugr/rewrite/replace.rs | 5 +- hugr-core/src/hugr/rewrite/simple_replace.rs | 3 +- hugr-core/src/hugr/views.rs | 17 +- hugr-core/src/hugr/views/descendants.rs | 6 +- hugr-core/src/hugr/views/impls.rs | 266 ++++++++++++--- hugr-core/src/hugr/views/root_checked.rs | 73 ++-- hugr-core/src/hugr/views/sibling.rs | 134 +++++--- .../src/utils/inline_constant_functions.rs | 4 +- hugr-passes/src/composable.rs | 45 ++- hugr-passes/src/const_fold.rs | 5 +- hugr-passes/src/dataflow/partial_value.rs | 2 +- hugr-passes/src/dead_code.rs | 3 +- hugr-passes/src/dead_funcs.rs | 5 +- hugr-passes/src/force_order.rs | 4 +- hugr-passes/src/lower.rs | 4 +- hugr-passes/src/merge_bbs.rs | 17 +- hugr-passes/src/monomorphize.rs | 13 +- hugr-passes/src/nest_cfgs.rs | 8 +- hugr-passes/src/replace_types.rs | 17 +- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 4 +- 30 files changed, 676 insertions(+), 575 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index e17d172ca..58c15c54a 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -153,7 +153,7 @@ pub trait Container { where ExtensionRegistry: Extend, { - self.hugr_mut().extensions_mut().extend(registry); + self.hugr_mut().use_extensions(registry); } } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 38eb59222..bf9a4cad0 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,7 +1,7 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; @@ -11,14 +11,13 @@ use crate::core::HugrNode; use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrView, Node, OpType, RootTagged}; +use crate::hugr::{HugrView, Node, OpType}; use crate::hugr::{NodeMetadata, Rewrite}; use crate::ops::OpTrait; use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; -use super::NodeMetadataMap; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { @@ -27,14 +26,9 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn get_metadata_mut(&mut self, node: Node, key: impl AsRef) -> &mut NodeMetadata { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { panic_invalid_node(self, node); - let node_meta = self - .hugr_mut() - .metadata - .get_mut(node.pg_index()) - .get_or_insert_with(Default::default); - node_meta + self.node_metadata_map_mut(node) .entry(key.as_ref()) .or_insert(serde_json::Value::Null) } @@ -46,7 +40,7 @@ pub trait HugrMut: HugrMutInternals { /// If the node is not in the graph. fn set_metadata( &mut self, - node: Node, + node: Self::Node, key: impl AsRef, metadata: impl Into, ) { @@ -59,30 +53,10 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn remove_metadata(&mut self, node: Node, key: impl AsRef) { + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { panic_invalid_node(self, node); - let node_meta = self.hugr_mut().metadata.get_mut(node.pg_index()); - if let Some(node_meta) = node_meta { - node_meta.remove(key.as_ref()); - } - } - - /// Retrieve the complete metadata map for a node. - fn take_node_metadata(&mut self, node: Self::Node) -> Option { - if !self.valid_node(node) { - return None; - } - self.hugr_mut().metadata.take(node.pg_index()) - } - - /// Overwrite the complete metadata map for a node. - /// - /// # Panics - /// - /// If the node is not in the graph. - fn overwrite_node_metadata(&mut self, node: Node, metadata: Option) { - panic_invalid_node(self, node); - self.hugr_mut().metadata.set(node.pg_index(), metadata); + let node_meta = self.node_metadata_map_mut(node); + node_meta.remove(key.as_ref()); } /// Add a node to the graph with a parent in the hierarchy. @@ -92,11 +66,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the parent is not in the graph. - #[inline] - fn add_node_with_parent(&mut self, parent: Node, op: impl Into) -> Node { - panic_invalid_node(self, parent); - self.hugr_mut().add_node_with_parent(parent, op) - } + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; /// Add a node to the graph as the previous sibling of another node. /// @@ -105,11 +75,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_before(sibling, nodetype) - } + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; /// Add a node to the graph as the next sibling of another node. /// @@ -118,11 +84,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_after(sibling, op) - } + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; /// Remove a node from the graph and return the node weight. /// Note that if the node has children, they are not removed; this leaves @@ -131,24 +93,14 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the node is the root node. - #[inline] - fn remove_node(&mut self, node: Node) -> OpType { - panic_invalid_non_root(self, node); - self.hugr_mut().remove_node(node) - } + fn remove_node(&mut self, node: Self::Node) -> OpType; /// Remove a node from the graph, along with all its descendants in the hierarchy. /// /// # Panics /// /// If the node is not in the graph, or is the root (this would leave an empty Hugr). - fn remove_subtree(&mut self, node: Node) { - panic_invalid_non_root(self, node); - while let Some(ch) = self.first_child(node) { - self.remove_subtree(ch) - } - self.hugr_mut().remove_node(node); - } + fn remove_subtree(&mut self, node: Self::Node); /// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a /// [Substitution] to the [OpType]s of the copied nodes. @@ -167,29 +119,20 @@ pub trait HugrMut: HugrMutInternals { root: Self::Node, new_parent: Self::Node, subst: Option, - ) -> BTreeMap { - panic_invalid_node(self, root); - panic_invalid_node(self, new_parent); - self.hugr_mut().copy_descendants(root, new_parent, subst) - } + ) -> BTreeMap; /// Connect two nodes at the given ports. /// /// # Panics /// /// If either node is not in the graph or if the ports are invalid. - #[inline] fn connect( &mut self, - src: Node, + src: Self::Node, src_port: impl Into, - dst: Node, + dst: Self::Node, dst_port: impl Into, - ) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().connect(src, src_port, dst, dst_port); - } + ); /// Disconnects all edges from the given port. /// @@ -198,11 +141,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - #[inline] - fn disconnect(&mut self, node: Node, port: impl Into) { - panic_invalid_node(self, node); - self.hugr_mut().disconnect(node, port); - } + fn disconnect(&mut self, node: Self::Node, port: impl Into); /// Adds a non-dataflow edge between two nodes. The kind is given by the /// operation's [`OpTrait::other_input`] or [`OpTrait::other_output`]. @@ -215,37 +154,25 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().add_other_edge(src, dst) - } + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort); /// Insert another hugr into this one, under a given root node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] - fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_hugr(root, other) - } + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; /// Copy another hugr into this one, under a given root node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] fn insert_from_view( &mut self, root: Self::Node, other: &H, - ) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_from_view(root, other) - } + ) -> InsertionResult; /// Copy a subgraph from another hugr into this one, under a given root node. /// @@ -266,13 +193,13 @@ pub trait HugrMut: HugrMutInternals { root: Self::Node, other: &H, subgraph: &SiblingSubgraph, - ) -> HashMap { - panic_invalid_node(self, root); - self.hugr_mut().insert_subgraph(root, other, subgraph) - } + ) -> HashMap; /// Applies a rewrite to the graph. - fn apply_rewrite(&mut self, rw: impl Rewrite) -> Result + fn apply_rewrite( + &mut self, + rw: impl Rewrite, + ) -> Result where Self: Sized, { @@ -286,7 +213,7 @@ pub trait HugrMut: HugrMutInternals { /// /// See [`ExtensionRegistry::register_updated`] for more information. fn use_extension(&mut self, extension: impl Into>) { - self.hugr_mut().extensions.register_updated(extension); + self.extensions_mut().register_updated(extension); } /// Extend the set of extensions used by the hugr with the extensions in the @@ -302,12 +229,7 @@ pub trait HugrMut: HugrMutInternals { where ExtensionRegistry: Extend, { - self.hugr_mut().extensions.extend(registry); - } - - /// Returns a mutable reference to the extension registry for this hugr. - fn extensions_mut(&mut self) -> &mut ExtensionRegistry { - &mut self.hugr_mut().extensions + self.extensions_mut().extend(registry); } } @@ -342,11 +264,10 @@ fn translate_indices( } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMut for T { +impl HugrMut for Hugr { fn add_node_with_parent(&mut self, parent: Node, node: impl Into) -> Node { let node = self.as_mut().add_node(node.into()); - self.as_mut() - .hierarchy + self.hierarchy .push_child(node.pg_index(), parent.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -354,8 +275,7 @@ impl + AsMut> HugrMut for T fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { let node = self.as_mut().add_node(nodetype.into()); - self.as_mut() - .hierarchy + self.hierarchy .insert_before(node.pg_index(), sibling.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -363,8 +283,7 @@ impl + AsMut> HugrMut for T fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { let node = self.as_mut().add_node(op.into()); - self.as_mut() - .hierarchy + self.hierarchy .insert_after(node.pg_index(), sibling.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node @@ -372,9 +291,19 @@ impl + AsMut> HugrMut for T fn remove_node(&mut self, node: Node) -> OpType { panic_invalid_non_root(self, node); - self.as_mut().hierarchy.remove(node.pg_index()); - self.as_mut().graph.remove_node(node.pg_index()); - self.as_mut().op_types.take(node.pg_index()) + self.hierarchy.remove(node.pg_index()); + self.graph.remove_node(node.pg_index()); + self.op_types.take(node.pg_index()) + } + + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + let mut queue = VecDeque::new(); + queue.push_back(node); + while let Some(n) = queue.pop_front() { + queue.extend(self.children(n)); + self.remove_node(n); + } } fn connect( @@ -388,8 +317,7 @@ impl + AsMut> HugrMut for T let dst_port = dst_port.into(); panic_invalid_port(self, src, src_port); panic_invalid_port(self, dst, dst_port); - self.as_mut() - .graph + self.graph .link_nodes( src.pg_index(), src_port.index(), @@ -404,11 +332,10 @@ impl + AsMut> HugrMut for T let offset = port.pg_offset(); panic_invalid_port(self, node, port); let port = self - .as_mut() .graph .port_index(node.pg_index(), offset) .expect("The port should exist at this point."); - self.as_mut().graph.unlink_port(port); + self.graph.unlink_port(port); } fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { @@ -429,15 +356,15 @@ impl + AsMut> HugrMut for T root: Self::Node, mut other: Hugr, ) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); + let (new_root, node_map) = insert_hugr_internal(self, root, &other); // Update the optypes and metadata, taking them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let optype = other.op_types.take(node); - self.as_mut().op_types.set(new_node, optype); + self.op_types.set(new_node, optype); let meta = other.metadata.take(node); - self.as_mut().metadata.set(new_node, meta); + self.metadata.set(new_node, meta); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -455,15 +382,15 @@ impl + AsMut> HugrMut for T root: Self::Node, other: &H, ) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); + let (new_root, node_map) = insert_hugr_internal(self, root, other); // Update the optypes and metadata, copying them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); + self.op_types.set(new_node, nodetype.clone()); let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + self.metadata.set(new_node, meta.clone()); } debug_assert_eq!( Some(&new_root.pg_index()), @@ -494,13 +421,13 @@ impl + AsMut> HugrMut for T |node, ctx| ctx.contains(&node), context, ); - let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); + let node_map = insert_subgraph_internal(self, root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); + self.op_types.set(new_node, nodetype.clone()); let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + self.metadata.set(new_node, meta.clone()); // Add the required extensions to the registry. if let Ok(exts) = nodetype.used_extensions() { self.use_extensions(exts); @@ -519,7 +446,7 @@ impl + AsMut> HugrMut for T let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes) .copy_in_parent() .expect("Is a MultiPortGraph"); let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) @@ -538,9 +465,9 @@ impl + AsMut> HugrMut for T (None, op) => op.clone(), (Some(subst), op) => op.substitute(subst), }; - self.as_mut().op_types.set(new_node.pg_index(), new_optype); + self.op_types.set(new_node.pg_index(), new_optype); let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); - self.as_mut().metadata.set(new_node.pg_index(), meta); + self.metadata.set(new_node.pg_index(), meta); } node_map } @@ -624,22 +551,20 @@ fn insert_subgraph_internal( /// Panic if [`HugrView::valid_node`] fails. #[track_caller] pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if !hugr.valid_node(node) { - panic!( - "Received an invalid node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid node {node} while mutating a HUGR.",); } } /// Panic if [`HugrView::valid_non_root`] fails. #[track_caller] pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if !hugr.valid_non_root(node) { - panic!( - "Received an invalid non-root node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid non-root node {node} while mutating a HUGR.",); } } @@ -651,15 +576,14 @@ pub(super) fn panic_invalid_port( port: impl Into, ) { let port = port.into(); + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. if hugr .portgraph() .port_index(node.pg_index(), port.pg_offset()) .is_none() { - panic!( - "Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); + panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 6dab3adc0..8892c3b11 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -2,19 +2,17 @@ use std::borrow::Cow; use std::ops::Range; -use std::rc::Rc; -use std::sync::Arc; +use std::sync::OnceLock; -use delegate::delegate; use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; +use crate::extension::ExtensionRegistry; use crate::ops::handle::NodeHandle; -use crate::ops::{OpTag, OpTrait}; use crate::{Direction, Hugr, Node}; use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrError, OpType, RootTagged}; +use super::{HugrError, NodeMetadataMap, OpType, RootTagged}; /// Trait for accessing the internals of a Hugr(View). /// @@ -46,10 +44,17 @@ pub trait HugrInternals { fn root_node(&self) -> Self::Node; /// Convert a node to a portgraph node index. - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex; /// Convert a portgraph node index to a node. fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap; } impl HugrInternals for Hugr { @@ -80,145 +85,41 @@ impl HugrInternals for Hugr { self.root.into() } - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex { - node.pg_index() + #[inline] + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + node.node().pg_index() } + #[inline] fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { index.into() } -} - -impl HugrInternals for &T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} -impl HugrInternals for &mut T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Rc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Arc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Box { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + static EMPTY: OnceLock = OnceLock::new(); + panic_invalid_node(self, node); + let map = self.metadata.get(node.pg_index()).as_ref(); + map.unwrap_or(EMPTY.get_or_init(Default::default)) } } -impl HugrInternals for Cow<'_, T> { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} /// Trait for accessing the mutable internals of a Hugr(Mut). /// /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. -pub trait HugrMutInternals: RootTagged { - /// Returns the Hugr at the base of a chain of views. - fn hugr_mut(&mut self) -> &mut Hugr; +pub trait HugrMutInternals: RootTagged { + /// Set root node of the HUGR. + /// + /// This should be an existing node in the HUGR. Most operations use the + /// root node as a starting point for traversal. + fn set_root(&mut self, root: Self::Node); /// Set the number of ports on a node. This may invalidate the node's `PortIndex`. /// /// # Panics /// /// If the node is not in the graph. - fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - panic_invalid_node(self, node); - self.hugr_mut().set_num_ports(node, incoming, outgoing) - } + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); /// Alter the number of ports on a node and returns a range with the new /// port offsets, if any. This may invalidate the node's `PortIndex`. @@ -231,10 +132,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If the node is not in the graph. - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().add_ports(node, direction, amount) - } + fn add_ports(&mut self, node: Self::Node, direction: Direction, amount: isize) -> Range; /// Insert `amount` new ports for a node, starting at `index`. The /// `direction` parameter specifies whether to add ports to the incoming or @@ -247,14 +145,11 @@ pub trait HugrMutInternals: RootTagged { /// If the node is not in the graph. fn insert_ports( &mut self, - node: Node, + node: Self::Node, direction: Direction, index: usize, amount: usize, - ) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().insert_ports(node, direction, index, amount) - } + ) -> Range; /// Sets the parent of a node. /// @@ -263,11 +158,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either the node or the parent is not in the graph. - fn set_parent(&mut self, node: Node, parent: Node) { - panic_invalid_node(self, parent); - panic_invalid_non_root(self, node); - self.hugr_mut().set_parent(node, parent); - } + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); /// Move a node in the hierarchy to be the subsequent sibling of another /// node. @@ -279,11 +170,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_after_sibling(&mut self, node: Node, after: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, after); - self.hugr_mut().move_after_sibling(node, after); - } + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); /// Move a node in the hierarchy to be the prior sibling of another node. /// @@ -294,11 +181,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_before_sibling(&mut self, node: Node, before: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, before); - self.hugr_mut().move_before_sibling(node, before) - } + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); /// Replace the OpType at node and return the old OpType. /// In general this invalidates the ports, which may need to be resized to @@ -306,7 +189,8 @@ pub trait HugrMutInternals: RootTagged { /// /// Returns the old OpType. /// - /// TODO: Add a version which ignores input extensions + /// If the module root is set to a non-module operation the hugr will + /// become invalid. /// /// # Errors /// @@ -316,48 +200,68 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If the node is not in the graph. - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { - panic_invalid_node(self, node); - let op = op.into(); - if node == self.root() && !Self::RootHandle::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { - required: Self::RootHandle::TAG, - actual: op.tag(), - }); - } - self.hugr_mut().replace_op(node, op) - } + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; /// Gets a mutable reference to the optype. /// /// Changing this may invalidate the ports, which may need to be resized to /// match the OpType signature. /// - /// Will panic for the root node unless [`Self::RootHandle`](RootTagged::RootHandle) - /// is [OpTag::Any], as mutation could invalidate the bound. - fn optype_mut(&mut self, node: Node) -> &mut OpType { - if Self::RootHandle::TAG.is_superset(OpTag::Any) { - panic_invalid_node(self, node); - } else { - panic_invalid_non_root(self, node); - } - self.hugr_mut().op_types.get_mut(node.pg_index()) - } + /// Mutating the root node operation may invalidate the root tag. + /// + /// Mutating the module root into a non-module operation will invalidate the hugr. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap; + + /// Returns a mutable reference to the extension registry for this hugr, + /// containing all extensions required to define the operations and types in + /// the hugr. + fn extensions_mut(&mut self) -> &mut ExtensionRegistry; } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMutInternals for T { - fn hugr_mut(&mut self) -> &mut Hugr { - self.as_mut() +impl HugrMutInternals for Hugr { + fn set_root(&mut self, root: Node) { + panic_invalid_node(self, root); + self.root = self.get_pg_index(root); } #[inline] fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - self.hugr_mut() - .graph + panic_invalid_node(self, node); + self.graph .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) } + fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { + panic_invalid_node(self, node); + let mut incoming = self.graph.num_inputs(node.pg_index()); + let mut outgoing = self.graph.num_outputs(node.pg_index()); + let increment = |num: &mut usize| { + let new = num.saturating_add_signed(amount); + let range = *num..new; + *num = new; + range + }; + let range = match direction { + Direction::Incoming => increment(&mut incoming), + Direction::Outgoing => increment(&mut outgoing), + }; + self.graph + .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); + range + } + fn insert_ports( &mut self, node: Node, @@ -365,6 +269,7 @@ impl + AsMut> HugrMutInterna index: usize, amount: usize, ) -> Range { + panic_invalid_node(self, node); let old_num_ports = self.base_hugr().graph.num_ports(node.pg_index(), direction); self.add_ports(node, direction, amount as isize); @@ -383,10 +288,9 @@ impl + AsMut> HugrMutInterna .port_links(from_port_index) .map(|(_, to_subport)| to_subport.port()) .collect_vec(); - self.hugr_mut().graph.unlink_port(from_port_index); + self.graph.unlink_port(from_port_index); for linked_port_index in linked_ports { let _ = self - .hugr_mut() .graph .link_ports(to_port_index, linked_port_index) .expect("Ports exist"); @@ -395,53 +299,55 @@ impl + AsMut> HugrMutInterna index..index + amount } - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - let mut incoming = self.hugr_mut().graph.num_inputs(node.pg_index()); - let mut outgoing = self.hugr_mut().graph.num_outputs(node.pg_index()); - let increment = |num: &mut usize| { - let new = num.saturating_add_signed(amount); - let range = *num..new; - *num = new; - range - }; - let range = match direction { - Direction::Incoming => increment(&mut incoming), - Direction::Outgoing => increment(&mut outgoing), - }; - self.hugr_mut() - .graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); - range - } - fn set_parent(&mut self, node: Node, parent: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_node(self, parent); + panic_invalid_node(self, node); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .push_child(node.pg_index(), parent.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_after_sibling(&mut self, node: Node, after: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, after); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .insert_after(node.pg_index(), after.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_before_sibling(&mut self, node: Node, before: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, before); + self.hierarchy.detach(node.pg_index()); + self.hierarchy .insert_before(node.pg_index(), before.pg_index()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn replace_op(&mut self, node: Node, op: impl Into) -> Result { + panic_invalid_node(self, node); // We know RootHandle=Node here so no need to check Ok(std::mem::replace(self.optype_mut(node), op.into())) } + + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { + panic_invalid_node(self, node); + let node = self.get_pg_index(node); + self.op_types.get_mut(node) + } + + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap { + panic_invalid_node(self, node); + self.metadata + .get_mut(node.pg_index()) + .get_or_insert_with(Default::default) + } + + fn extensions_mut(&mut self) -> &mut ExtensionRegistry { + &mut self.extensions + } } #[cfg(test)] diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index 7c4374b65..d2b0fe14d 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -9,7 +9,8 @@ mod port_types; pub mod replace; pub mod simple_replace; -use crate::{Hugr, HugrView, Node}; +use crate::core::HugrNode; +use crate::{Hugr, HugrView}; pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; @@ -17,6 +18,8 @@ use super::HugrMut; /// An operation that can be applied to mutate a Hugr pub trait Rewrite { + /// The node type used by the target Hugr. + type Node: HugrNode; /// The type of Error with which this Rewrite may fail type Error: std::error::Error; /// The type returned on successful application of the rewrite. @@ -29,7 +32,7 @@ pub trait Rewrite { /// Checks whether the rewrite would succeed on the specified Hugr. /// If this call succeeds, [self.apply] should also succeed on the same `h` /// If this calls fails, [self.apply] would fail with the same error. - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; /// Mutate the specified Hugr, or fail with an error. /// Returns [`Self::ApplyResult`] if successful. @@ -39,14 +42,17 @@ pub trait Rewrite { /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` /// being preferred. - fn apply(self, h: &mut impl HugrMut) -> Result; + fn apply( + self, + h: &mut impl HugrMut, + ) -> Result; /// Returns a set of nodes referenced by the rewrite. Modifying any of these /// nodes will invalidate it. /// /// Two `impl Rewrite`s can be composed if their invalidation sets are /// disjoint. - fn invalidation_set(&self) -> impl Iterator; + fn invalidation_set(&self) -> impl Iterator; } /// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) @@ -57,15 +63,16 @@ pub struct Transactional { // Note we might like to constrain R to Rewrite but this // is not yet supported, https://github.com/rust-lang/rust/issues/92827 impl Rewrite for Transactional { + type Node = R::Node; type Error = R::Error; type ApplyResult = R::ApplyResult; const UNCHANGED_ON_FAILURE: bool = true; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { self.underlying.verify(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { if R::UNCHANGED_ON_FAILURE { return self.underlying.apply(h); } @@ -86,7 +93,7 @@ impl Rewrite for Transactional { } #[inline] - fn invalidation_set(&self) -> impl Iterator { + fn invalidation_set(&self) -> impl Iterator { self.underlying.invalidation_set() } } diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/rewrite/consts.rs index c112dfc57..ac657bf91 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/rewrite/consts.rs @@ -25,6 +25,7 @@ pub enum RemoveError { } impl Rewrite for RemoveLoadConstant { + type Node = Node; type Error = RemoveError; // The Const node the LoadConstant was connected to. @@ -50,7 +51,7 @@ impl Rewrite for RemoveLoadConstant { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let source = h @@ -73,6 +74,7 @@ impl Rewrite for RemoveLoadConstant { pub struct RemoveConst(pub Node); impl Rewrite for RemoveConst { + type Node = Node; type Error = RemoveError; // The parent of the Const node. @@ -94,7 +96,7 @@ impl Rewrite for RemoveConst { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let parent = h diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/rewrite/inline_call.rs index 9af9cd70a..6b1e7a958 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/rewrite/inline_call.rs @@ -33,6 +33,7 @@ impl InlineCall { } impl Rewrite for InlineCall { + type Node = Node; type ApplyResult = (); type Error = InlineCallError; fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { @@ -51,7 +52,7 @@ impl Rewrite for InlineCall { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { self.verify(h)?; // Now we know we have a Call to a FuncDefn. let orig_func = h.static_source(self.0).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index a8a09e0cc..8988df170 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -23,6 +23,7 @@ pub enum InlineDFGError { impl Rewrite for InlineDFG { /// Returns the removed nodes: the DFG, and its Input and Output children. + type Node = Node; type ApplyResult = [Node; 3]; type Error = InlineDFGError; @@ -39,7 +40,10 @@ impl Rewrite for InlineDFG { Ok(()) } - fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result { + fn apply( + self, + h: &mut impl crate::hugr::HugrMut, + ) -> Result { self.verify(h)?; let n = self.0.node(); let (oth_in, oth_out) = { diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/rewrite/insert_identity.rs index 2114be8fd..bde43413b 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/rewrite/insert_identity.rs @@ -48,6 +48,7 @@ pub enum IdentityInsertionError { } impl Rewrite for IdentityInsertion { + type Node = Node; type Error = IdentityInsertionError; /// The inserted node. type ApplyResult = Node; @@ -65,7 +66,10 @@ impl Rewrite for IdentityInsertion { unimplemented!() } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply( + self, + h: &mut impl HugrMut, + ) -> Result { let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/rewrite/outline_cfg.rs index 7294bfcad..a76dbc6ee 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/rewrite/outline_cfg.rs @@ -6,14 +6,12 @@ use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::extension::ExtensionSet; -use crate::hugr::internal::HugrMutInternals; use crate::hugr::rewrite::Rewrite; -use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; use crate::ops::dataflow::DataflowOpTrait; -use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; +use crate::ops::handle::NodeHandle; use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; @@ -95,6 +93,7 @@ impl OutlineCfg { } impl Rewrite for OutlineCfg { + type Node = Node; type Error = OutlineCfgError; /// The newly-created basic block, and the [CFG] node inside it /// @@ -185,8 +184,19 @@ impl Rewrite for OutlineCfg { let inner_exit = { // These operations do not fit within any CSG/SiblingMut // so we need to access the Hugr directly. - let h = h.hugr_mut(); - let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // + // TODO: This is a temporary hack that won't be needed once Hugr Root Pointers get implemented. + // The commented line below are the correct ones, but they don't work yet. + // https://github.com/CQCL/hugr/issues/2029 + let hierarchy = h.hierarchy(); + let inner_exit = hierarchy + .children(h.get_pg_index(cfg_node)) + .exactly_one() + .ok() + .unwrap(); + let inner_exit = h.get_node(inner_exit); + //let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // Entry node must be first h.move_before_sibling(entry, inner_exit); // And remaining nodes @@ -200,12 +210,7 @@ impl Rewrite for OutlineCfg { }; // 4(b). Reconnect exit edge to the new exit node within the inner CFG - // Use nested SiblingMut's in case the outer `h` is only a SiblingMut itself. - let mut in_bb_view: SiblingMut<'_, BasicBlockID> = - SiblingMut::try_new(h, new_block).unwrap(); - let mut in_cfg_view: SiblingMut<'_, CfgID> = - SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap(); - in_cfg_view.connect(exit, exit_port, inner_exit, 0); + h.connect(exit, exit_port, inner_exit, 0); Ok((new_block, cfg_node)) } @@ -252,10 +257,9 @@ mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; - use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle}; + use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use crate::types::Signature; use crate::{Hugr, HugrView, Node}; use cool_asserts::assert_matches; @@ -457,11 +461,7 @@ mod test { h.output_neighbours(tail).collect::>(), HashSet::from([head, exit_node]) ); - outline_cfg_check_parents( - &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(), - cfg, - vec![head, tail], - ); + outline_cfg_check_parents(&mut h, cfg, vec![head, tail]); h.validate().unwrap(); } @@ -491,7 +491,7 @@ mod test { } fn outline_cfg_check_parents( - h: &mut impl HugrMut, + h: &mut impl HugrMut, cfg: Node, blocks: Vec, ) -> (Node, Node, Node) { diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 55c07d680..c2659cc5a 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -222,6 +222,7 @@ impl Replacement { } } impl Rewrite for Replacement { + type Node = Node; type Error = ReplaceError; /// Map from Node in replacement to corresponding Node in the result Hugr @@ -282,7 +283,7 @@ impl Rewrite for Replacement { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) @@ -343,7 +344,7 @@ impl Rewrite for Replacement { } fn transfer_edges<'a>( - h: &mut impl HugrMut, + h: &mut impl HugrMut, edges: impl Iterator, trans_src: impl Fn(Node) -> Result, trans_tgt: impl Fn(Node) -> Result, diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index b4ec37db1..5d3716dc0 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -275,6 +275,7 @@ impl SimpleReplacement { } impl Rewrite for SimpleReplacement { + type Node = Node; type Error = SimpleReplacementError; type ApplyResult = Vec<(Node, OpType)>; const UNCHANGED_ON_FAILURE: bool = true; @@ -283,7 +284,7 @@ impl Rewrite for SimpleReplacement { self.is_valid_rewrite(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn apply(self, h: &mut impl HugrMut) -> Result { self.is_valid_rewrite(h)?; let parent = self.subgraph.get_parent(h); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 09805d1f8..eb8059577 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -485,7 +485,7 @@ pub trait RootTagged: HugrView { /// /// The handle is guaranteed to be able to contain the operation returned by /// [`HugrView::root_type`]. - type RootHandle: NodeHandle; + type RootHandle: NodeHandle; } /// A common trait for views of a HUGR hierarchical subgraph. @@ -515,7 +515,8 @@ pub trait ExtractHugr: HugrView + Sized { } } -fn check_tag( +/// Check that the node in a HUGR can be represented by the required tag. +fn check_tag, N>( hugr: &impl HugrView, node: N, ) -> Result<(), HugrError> { @@ -527,18 +528,6 @@ fn check_tag( Ok(()) } -impl RootTagged for Hugr { - type RootHandle = Node; -} - -impl RootTagged for &Hugr { - type RootHandle = Node; -} - -impl RootTagged for &mut Hugr { - type RootHandle = Node; -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 6f87027ef..28a7d9f2d 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -179,7 +179,7 @@ where } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } @@ -187,6 +187,10 @@ where fn get_node(&self, index: portgraph::NodeIndex) -> Node { self.hugr.get_node(index) } + + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } #[cfg(test)] diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 2cfc70104..928acba20 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -1,119 +1,285 @@ +//! Implementation of the core hugr traits for different wrappers of a `Hugr`. + use std::{borrow::Cow, rc::Rc, sync::Arc}; -use delegate::delegate; -use itertools::Either; +use super::HugrView; +use super::RootTagged; +use crate::hugr::internal::{HugrInternals, HugrMutInternals}; +use crate::hugr::HugrMut; +use crate::Hugr; +use crate::Node; -use super::{render::RenderConfig, HugrView, RootChecked}; -use crate::{ - extension::ExtensionRegistry, - hugr::{NodeMetadata, NodeMetadataMap, ValidationError}, - ops::OpType, - types::{PolyFuncType, Signature, Type}, - Direction, Hugr, IncomingPort, OutgoingPort, Port, -}; +macro_rules! hugr_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; + fn base_hugr(&self) -> &crate::Hugr; + fn root_node(&self) -> Self::Node; + fn get_pg_index(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; + fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap; + } + } + }; +} +pub(crate) use hugr_internal_methods; macro_rules! hugr_view_methods { // The extra ident here is because invocations of the macro cannot pass `self` as argument ($arg:ident, $e:expr) => { - delegate! { + delegate::delegate! { to ({let $arg=self; $e}) { fn root(&self) -> Self::Node; - fn root_type(&self) -> &OpType; + fn root_type(&self) -> &crate::ops::OpType; fn contains_node(&self, node: Self::Node) -> bool; fn valid_node(&self, node: Self::Node) -> bool; fn valid_non_root(&self, node: Self::Node) -> bool; fn get_parent(&self, node: Self::Node) -> Option; - fn get_optype(&self, node: Self::Node) -> &OpType; - fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata>; - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap>; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&crate::hugr::NodeMetadata>; + fn get_node_metadata(&self, node: Self::Node) -> Option<&crate::hugr::NodeMetadataMap>; fn node_count(&self) -> usize; fn edge_count(&self) -> usize; fn nodes(&self) -> impl Iterator + Clone; - fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; - fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; + fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; + fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; fn linked_ports( &self, node: Self::Node, - port: impl Into, - ) -> impl Iterator + Clone; + port: impl Into, + ) -> impl Iterator + Clone; fn all_linked_ports( &self, node: Self::Node, - dir: Direction, - ) -> Either< - impl Iterator, - impl Iterator, + dir: crate::Direction, + ) -> itertools::Either< + impl Iterator, + impl Iterator, >; - fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; - fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; - fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, Port)>; - fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, OutgoingPort)>; - fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, IncomingPort)>; - fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; - fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; + fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; + fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; + fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::Port)>; + fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::OutgoingPort)>; + fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::IncomingPort)>; + fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; + fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; + fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; fn num_inputs(&self, node: Self::Node) -> usize; fn num_outputs(&self, node: Self::Node) -> usize; fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; fn first_child(&self, node: Self::Node) -> Option; - fn neighbours(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; + fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn input_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn output_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>; - fn inner_function_type(&self) -> Option>; - fn poly_func_type(&self) -> Option; + fn inner_function_type(&self) -> Option>; + fn poly_func_type(&self) -> Option; // TODO: cannot use delegate here. `PetgraphWrapper` is a thin // wrapper around `Self`, so falling back to the default impl // should be harmless. // fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>; fn mermaid_string(&self) -> String; - fn mermaid_string_with_config(&self, config: RenderConfig) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; fn dot_string(&self) -> String; fn static_source(&self, node: Self::Node) -> Option; - fn static_targets(&self, node: Self::Node) -> Option>; - fn signature(&self, node: Self::Node) -> Option>; - fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator; - fn in_value_types(&self, node: Self::Node) -> impl Iterator; - fn out_value_types(&self, node: Self::Node) -> impl Iterator; - fn extensions(&self) -> &ExtensionRegistry; - fn validate(&self) -> Result<(), ValidationError>; - fn validate_no_extensions(&self) -> Result<(), ValidationError>; + fn static_targets(&self, node: Self::Node) -> Option>; + fn signature(&self, node: Self::Node) -> Option>; + fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; + fn in_value_types(&self, node: Self::Node) -> impl Iterator; + fn out_value_types(&self, node: Self::Node) -> impl Iterator; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + fn validate(&self) -> Result<(), crate::hugr::ValidationError>; + fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; } } } } +pub(crate) use hugr_view_methods; + +macro_rules! hugr_mut_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } + } + }; +} +pub(crate) use hugr_mut_internal_methods; + +macro_rules! hugr_mut_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; + fn remove_node(&mut self, node: Self::Node) -> crate::ops::OpType; + fn remove_subtree(&mut self, node: Self::Node); + fn copy_descendants(&mut self, root: Self::Node, new_parent: Self::Node, subst: Option) -> std::collections::BTreeMap; + fn connect(&mut self, src: Self::Node, src_port: impl Into, dst: Self::Node, dst_port: impl Into); + fn disconnect(&mut self, node: Self::Node, port: impl Into); + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); + fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; + fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; + fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; + } + } + }; +} +pub(crate) use hugr_mut_methods; + +// -------- Base Hugr implementation +impl RootTagged for Hugr { + type RootHandle = Node; +} + +// -------- Immutable borrow +impl HugrInternals for &T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, *this} +} impl HugrView for &T { hugr_view_methods! {this, *this} } +impl RootTagged for &T { + type RootHandle = T::RootHandle; +} + +// -------- Mutable borrow +impl HugrInternals for &mut T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, &**this} +} impl HugrView for &mut T { hugr_view_methods! {this, &**this} } +impl RootTagged for &mut T { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for &mut T { + hugr_mut_internal_methods! {this, &mut **this} +} +impl HugrMut for &mut T { + hugr_mut_methods! {this, &mut **this} +} + +// -------- Rc +impl HugrInternals for Rc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Rc { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Rc { + type RootHandle = T::RootHandle; +} + +// -------- Arc +impl HugrInternals for Arc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Arc { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Arc { + type RootHandle = T::RootHandle; +} + +// -------- Box +impl HugrInternals for Box { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Box { hugr_view_methods! {this, this.as_ref()} } +impl RootTagged for Box { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for Box { + hugr_mut_internal_methods! {this, this.as_mut()} +} +impl HugrMut for Box { + hugr_mut_methods! {this, this.as_mut()} +} +// -------- Cow +impl HugrInternals for Cow<'_, T> { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Cow<'_, T> { hugr_view_methods! {this, this.as_ref()} } - -impl, Root> HugrView for RootChecked { - hugr_view_methods! {this, this.as_ref()} +impl RootTagged for Cow<'_, T> { + type RootHandle = T::RootHandle; +} +impl HugrMutInternals for Cow<'_, T> +where + T: HugrMutInternals + ToOwned, + ::Owned: HugrMutInternals, +{ + hugr_mut_internal_methods! {this, this.to_mut()} +} +impl HugrMut for Cow<'_, T> +where + T: HugrMut + ToOwned, + ::Owned: HugrMut, +{ + hugr_mut_methods! {this, this.to_mut()} } #[cfg(test)] diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index ba214241a..e0dcf3eb7 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,22 +1,20 @@ use std::borrow::Cow; use std::marker::PhantomData; -use delegate::delegate; -use portgraph::MultiPortGraph; - use crate::hugr::internal::{HugrInternals, HugrMutInternals}; use crate::hugr::{HugrError, HugrMut}; use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; use crate::{Hugr, Node}; -use super::{check_tag, RootTagged}; +use super::{check_tag, HugrView, RootTagged}; /// A view of the whole Hugr. /// (Just provides static checking of the type of the root node) #[derive(Clone)] pub struct RootChecked(H, PhantomData); -impl, Root: NodeHandle> RootChecked { +impl> RootChecked { /// Create a hierarchical view of a whole HUGR /// /// # Errors @@ -49,26 +47,21 @@ impl RootChecked<&mut Hugr, Root> { } } -impl, Root> HugrInternals for RootChecked { +impl HugrInternals for RootChecked { type Portgraph<'p> - = &'p MultiPortGraph + = H::Portgraph<'p> where Self: 'p; - type Node = Node; - - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Node; - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Node; - } - } + type Node = H::Node; + + super::impls::hugr_internal_methods! {this, &this.0} } -impl, Root: NodeHandle> RootTagged for RootChecked { +impl HugrView for RootChecked { + super::impls::hugr_view_methods! {this, &this.0} +} + +impl> RootTagged for RootChecked { type RootHandle = Root; } @@ -78,17 +71,41 @@ impl, Root> AsRef for RootChecked { } } -impl, Root> HugrMutInternals for RootChecked -where - Root: NodeHandle, -{ - #[inline(always)] - fn hugr_mut(&mut self) -> &mut Hugr { - self.0.hugr_mut() +impl> HugrMutInternals for RootChecked { + fn replace_op( + &mut self, + node: Self::Node, + op: impl Into, + ) -> Result { + let op = op.into(); + if node == self.root() && !Root::TAG.is_superset(op.tag()) { + return Err(HugrError::InvalidTag { + required: Root::TAG, + actual: op.tag(), + }); + } + self.0.replace_op(node, op) + } + + delegate::delegate! { + to (&mut self.0) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } } } -impl, Root: NodeHandle> HugrMut for RootChecked {} +impl> HugrMut for RootChecked { + super::impls::hugr_mut_methods! {this, &mut this.0} +} #[cfg(test)] mod test { diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index f93b14cb4..4d15a9c48 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -6,8 +6,9 @@ use itertools::{Either, Itertools}; use portgraph::{LinkView, MultiPortGraph, PortView}; use crate::hugr::internal::HugrMutInternals; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::{HugrError, HugrMut, NodeMetadataMap}; use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; @@ -212,7 +213,7 @@ where } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } @@ -220,6 +221,11 @@ where fn get_node(&self, index: portgraph::NodeIndex) -> Node { self.hugr.get_node(index) } + + #[inline] + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } /// Mutable view onto a HUGR sibling graph. @@ -233,101 +239,113 @@ where /// [HugrView] methods may be slower than for an immutable [SiblingGraph] /// as the latter may cache information about the graph connectivity, /// whereas (in order to ease mutation) this does not. -pub struct SiblingMut<'g, Root = Node> { +pub struct SiblingMut<'g, H: HugrView, Root = Node> { /// The chosen root node. - root: Node, + root: H::Node, /// The rest of the HUGR. - hugr: &'g mut Hugr, + hugr: &'g mut H, /// The operation type of the root node. _phantom: std::marker::PhantomData, } -impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> SiblingMut<'g, H, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. - pub fn try_new(hugr: &'g mut Base, root: Node) -> Result { - if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) { + pub fn try_new(hugr: &'g mut H, root: H::Node) -> Result { + if root == hugr.root() && !H::RootHandle::TAG.is_superset(Root::TAG) { return Err(HugrError::InvalidTag { - required: Base::RootHandle::TAG, + required: H::RootHandle::TAG, actual: Root::TAG, }); } check_tag::(hugr, root)?; Ok(Self { - hugr: hugr.hugr_mut(), + hugr, root, _phantom: std::marker::PhantomData, }) } } -impl ExtractHugr for SiblingMut<'_, Root> {} +impl> ExtractHugr for SiblingMut<'_, H, Root> {} -impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> HugrInternals for SiblingMut<'g, H, Root> { type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p; - type Node = Node; + type Node = H::Node; + #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { FlatRegionGraph::new( &self.base_hugr().graph, &self.base_hugr().hierarchy, - self.root.pg_index(), + self.get_pg_index(self.root), ) } + #[inline] fn base_hugr(&self) -> &Hugr { - self.hugr + self.hugr.base_hugr() } - fn root_node(&self) -> Node { + #[inline] + fn root_node(&self) -> Self::Node { self.root } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { + fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { self.hugr.get_pg_index(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { + fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { self.hugr.get_node(index) } + + #[inline] + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) + } } -impl HugrView for SiblingMut<'_, Root> { +impl> HugrView for SiblingMut<'_, H, Root> { impl_base_members! {} - fn contains_node(&self, node: Node) -> bool { + fn contains_node(&self, node: H::Node) -> bool { // Don't call self.get_parent(). That requires valid_node(node) // which infinitely-recurses back here. - node == self.root || self.base_hugr().get_parent(node) == Some(self.root) + node == self.root || self.hugr.get_parent(node) == Some(self.root) } - fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.base_hugr().node_ports(node, dir) + fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone { + self.hugr.node_ports(node, dir) } - fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.base_hugr().all_node_ports(node) + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone { + self.hugr.all_node_ports(node) } fn linked_ports( &self, - node: Node, + node: Self::Node, port: impl Into, - ) -> impl Iterator + Clone { + ) -> impl Iterator + Clone { self.hugr .linked_ports(node, port) .filter(|(n, _)| self.contains_node(*n)) } - fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { + fn node_connections( + &self, + node: Self::Node, + other: Self::Node, + ) -> impl Iterator + Clone { match self.contains_node(node) && self.contains_node(other) { // The nodes are not in the sibling graph false => Either::Left(iter::empty()), @@ -336,34 +354,66 @@ impl HugrView for SiblingMut<'_, Root> { } } - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.base_hugr().num_ports(node, dir) + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.hugr.num_ports(node, dir) } - fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + fn neighbours( + &self, + node: Self::Node, + dir: Direction, + ) -> impl Iterator + Clone { self.hugr .neighbours(node, dir) .filter(|n| self.contains_node(*n)) } - fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { + fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone { self.hugr .all_neighbours(node) .filter(|n| self.contains_node(*n)) } } -impl RootTagged for SiblingMut<'_, Root> { +impl> RootTagged for SiblingMut<'_, H, Root> { type RootHandle = Root; } -impl HugrMutInternals for SiblingMut<'_, Root> { - fn hugr_mut(&mut self) -> &mut Hugr { - self.hugr +impl> HugrMutInternals for SiblingMut<'_, H, Root> { + fn replace_op( + &mut self, + node: Self::Node, + op: impl Into, + ) -> Result { + let op = op.into(); + if node == self.root() && !Root::TAG.is_superset(op.tag()) { + return Err(HugrError::InvalidTag { + required: Root::TAG, + actual: op.tag(), + }); + } + self.hugr.replace_op(node, op) + } + + delegate::delegate! { + to (&mut *self.hugr) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } } } -impl HugrMut for SiblingMut<'_, Root> {} +impl> HugrMut for SiblingMut<'_, H, Root> { + super::impls::hugr_mut_methods! {this, &mut *this.hugr} +} #[cfg(test)] mod test { @@ -475,7 +525,7 @@ mod test { let mut def_region_hugr = hugr.clone(); let mut inner_region_hugr = hugr.clone(); - test_properties::( + test_properties::>( &hugr, def, inner, @@ -526,7 +576,7 @@ mod test { let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().into_owned(); - let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); + let sib_mut = SiblingMut::<_, CfgID>::try_new(&mut simple_dfg_hugr, root); assert_eq!( sib_mut.err(), Some(HugrError::InvalidTag { @@ -535,7 +585,7 @@ mod test { }) ); - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); let bad_nodetype: OpType = crate::ops::CFG { signature }.into(); assert_eq!( sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), @@ -560,7 +610,7 @@ mod test { .unwrap() .into_owned(), }; - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); + let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); // As expected, we cannot replace the root with a Case assert_eq!( sib_mut.replace_op(root, case_nodetype), @@ -570,7 +620,7 @@ mod test { }) ); - let nested_sib_mut = SiblingMut::::try_new(&mut sib_mut, root); + let nested_sib_mut = SiblingMut::<_, DataflowParentID>::try_new(&mut sib_mut, root); assert!(nested_sib_mut.is_err()); } diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 28e664b97..a55072a99 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -11,12 +11,12 @@ fn const_fn_name(konst_n: Node) -> String { format!("const_fun_{}", konst_n.index()) } -pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { +pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { while inline_constant_functions_impl(hugr)? {} Ok(()) } -fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { +fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { let mut const_funs = vec![]; for n in hugr.nodes() { diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index fb3319155..ad8ff1ec0 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -2,6 +2,7 @@ use std::{error::Error, marker::PhantomData}; +use hugr_core::core::HugrNode; use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; use hugr_core::HugrView; use itertools::Either; @@ -9,36 +10,40 @@ use itertools::Either; /// An optimization pass that can be sequenced with another and/or wrapped /// e.g. by [ValidatingPass] pub trait ComposablePass: Sized { + type Node: HugrNode; type Error: Error; type Result; // Would like to default to () but currently unstable - fn run(&self, hugr: &mut impl HugrMut) -> Result; + fn run(&self, hugr: &mut impl HugrMut) -> Result; fn map_err( self, f: impl Fn(Self::Error) -> E2, - ) -> impl ComposablePass { + ) -> impl ComposablePass { ErrMapper::new(self, f) } /// Returns a [ComposablePass] that does "`self` then `other`", so long as /// `other::Err` can be combined with ours. - fn then>( + fn then, E: ErrorCombiner>( self, other: P, - ) -> impl ComposablePass { + ) -> impl ComposablePass { struct Sequence(P1, P2, PhantomData); impl ComposablePass for Sequence where P1: ComposablePass, - P2: ComposablePass, + P2: ComposablePass, E: ErrorCombiner, { + type Node = P1::Node; type Error = E; - type Result = (P1::Result, P2::Result); - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run( + &self, + hugr: &mut impl HugrMut, + ) -> Result { let res1 = self.0.run(hugr).map_err(E::from_first)?; let res2 = self.1.run(hugr).map_err(E::from_second)?; Ok((res1, res2)) @@ -95,10 +100,11 @@ impl E> ErrMapper { } impl E> ComposablePass for ErrMapper { + type Node = P::Node; type Error = E; type Result = P::Result; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { self.0.run(hugr).map_err(&self.1) } } @@ -157,10 +163,11 @@ impl ValidatingPass

{ } impl ComposablePass for ValidatingPass

{ + type Node = P::Node; type Error = ValidatePassError; type Result = P::Result; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { err, pretty_hugr, @@ -180,8 +187,11 @@ impl ComposablePass for ValidatingPass

{ /// executes a second pass pub struct IfThen(A, B, PhantomData); -impl, B: ComposablePass, E: ErrorCombiner> - IfThen +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > IfThen { /// Make a new instance given the [ComposablePass] to run first /// and (maybe) second @@ -190,14 +200,17 @@ impl, B: ComposablePass, E: ErrorCombiner, B: ComposablePass, E: ErrorCombiner> - ComposablePass for IfThen +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > ComposablePass for IfThen { + type Node = A::Node; type Error = E; - type Result = Option; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) .transpose() @@ -206,7 +219,7 @@ impl, B: ComposablePass, E: ErrorCombiner( pass: P, - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, ) -> Result> { if cfg!(test) { ValidatingPass::new_default(pass).run(hugr) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 99ccc180c..b406ae894 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -79,6 +79,7 @@ impl ConstantFoldPass { } impl ComposablePass for ConstantFoldPass { + type Node = Node; type Error = ConstFoldError; type Result = (); @@ -88,7 +89,7 @@ impl ComposablePass for ConstantFoldPass { /// /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] /// was of an invalid [OpType] - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -175,7 +176,7 @@ impl ComposablePass for ConstantFoldPass { /// /// [FuncDefn]: hugr_core::ops::OpType::FuncDefn /// [Module]: hugr_core::ops::OpType::Module -pub fn constant_fold_pass(h: &mut H) { +pub fn constant_fold_pass>(h: &mut H) { let c = ConstantFoldPass::default(); let c = if h.get_optype(h.root()).is_module() { let no_inputs: [(IncomingPort, _); 0] = []; diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index 240f4f2d6..f7b8a171c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -52,7 +52,7 @@ pub struct Sum { } /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" -/// to a function at a specific node, instantiated with the provided type-args. +/// to a function at a specific node, instantiated with the provided type-args. #[derive(Clone, Debug, Hash, PartialEq, Eq)] pub struct LoadedFunction { /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 899e30243..d92fed134 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -158,10 +158,11 @@ impl DeadCodeElimPass { } impl ComposablePass for DeadCodeElimPass { + type Node = Node; type Error = Infallible; type Result = (); - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { let needed = self.find_needed_nodes(&*hugr); let remove = hugr .nodes() diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index 7071d5335..d1714eac9 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -83,9 +83,10 @@ impl RemoveDeadFuncsPass { } impl ComposablePass for RemoveDeadFuncsPass { + type Node = Node; type Error = RemoveDeadFuncsError; type Result = (); - fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { let reachable = reachable_funcs( &CallGraph::new(hugr), hugr, @@ -125,7 +126,7 @@ impl ComposablePass for RemoveDeadFuncsPass { /// [LoadFunction]: hugr_core::ops::OpType::LoadFunction /// [Module]: hugr_core::ops::OpType::Module pub fn remove_dead_funcs( - h: &mut impl HugrMut, + h: &mut impl HugrMut, entry_points: impl IntoIterator, ) -> Result<(), ValidatePassError> { validate_if_test( diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index 689479b95..ad40e2164 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -36,7 +36,7 @@ use petgraph::{ /// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`). /// Nodes of equal rank will be ordered arbitrarily, although that arbitrary /// order is deterministic. -pub fn force_order( +pub fn force_order>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> i64, @@ -46,7 +46,7 @@ pub fn force_order( /// As [force_order], but allows a generic [Ord] choice for the result of the /// `rank` function. -pub fn force_order_by_key( +pub fn force_order_by_key, K: Ord>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> K, diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 8f8920967..8de6c00a2 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -15,7 +15,7 @@ use thiserror::Error; /// /// Returns a [`HugrError`] if any replacement fails. pub fn replace_many_ops>( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, mapping: impl Fn(&OpType) -> Option, ) -> Result, HugrError> { let replacements = hugr @@ -54,7 +54,7 @@ pub enum LowerError { /// /// Returns a [`LowerError`] if the lowered HUGR is invalid or if any rewrite fails. pub fn lower_ops( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, lowering: impl Fn(&OpType) -> Option, ) -> Result, LowerError> { let replacements = hugr diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index aeabc26ce..d1731107d 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -16,8 +16,8 @@ use hugr_core::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor /// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { - let mut worklist = cfg.nodes().collect::>(); +pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { + let mut worklist = cfg.children(cfg.root()).collect::>(); while let Some(n) = worklist.pop() { // Consider merging n with its successor let Ok(succ) = cfg.output_neighbours(n).exactly_one() else { @@ -33,13 +33,11 @@ pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { continue; }; let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.hugr_mut().apply_rewrite(rep).unwrap(); + let node_map = cfg.apply_rewrite(rep).unwrap(); let merged_bb = *node_map.get(&merge_bb).unwrap(); for dfg_id in dfgs { let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.hugr_mut() - .apply_rewrite(InlineDFG(n_id.into())) - .unwrap(); + cfg.apply_rewrite(InlineDFG(n_id.into())).unwrap(); } worklist.push(merged_bb); } @@ -160,12 +158,12 @@ mod test { use std::sync::Arc; use hugr_core::extension::prelude::PRELUDE_ID; + use hugr_core::hugr::views::RootChecked; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; - use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{LoadConstant, OpTrait, OpType}; @@ -254,7 +252,7 @@ mod test { let mut h = h.finish_hugr()?; let r = h.root(); - merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); + merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); @@ -348,8 +346,7 @@ mod test { h.branch(&bb3, 0, &h.exit_block())?; let mut h = h.finish_hugr()?; - let root = h.root(); - merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); + merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); h.validate()?; // Should only be one BB left diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 875ee9355..3164702d8 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -33,7 +33,9 @@ use crate::ComposablePass; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { +pub fn monomorphize( + hugr: &mut impl HugrMut, +) -> Result<(), ValidatePassError> { validate_if_test(MonomorphizePass, hugr) } @@ -56,7 +58,7 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { since = "0.14.1", note = "Use hugr_passes::RemoveDeadFuncsPass instead" )] -fn remove_polyfuncs_ref(h: &mut impl HugrMut) { +fn remove_polyfuncs_ref(h: &mut impl HugrMut) { let mut pfs_to_delete = Vec::new(); let mut to_scan = Vec::from_iter(h.children(h.root())); while let Some(n) = to_scan.pop() { @@ -92,7 +94,7 @@ type Instantiations = HashMap, Node>>; /// Optionally copies the subtree into a new location whilst applying a substitution. /// The subtree should be monomorphic after the substitution (if provided) has been applied. fn mono_scan( - h: &mut impl HugrMut, + h: &mut impl HugrMut, parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, @@ -160,7 +162,7 @@ fn mono_scan( } fn instantiate( - h: &mut impl HugrMut, + h: &mut impl HugrMut, poly_func: Node, type_args: Vec, mono_sig: Signature, @@ -258,10 +260,11 @@ fn instantiate( pub struct MonomorphizePass; impl ComposablePass for MonomorphizePass { + type Node = Node; type Error = Infallible; type Result = (); - fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { let root = h.root(); // If the root is a polymorphic function, then there are no external calls, so nothing to do if !is_polymorphic_funcdefn(h.get_optype(root)) { diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 9baf250f9..1c4928e12 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -51,7 +51,7 @@ use hugr_core::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; -use hugr_core::{Direction, Hugr}; +use hugr_core::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into /// multiple blocks in the view (or merged together). @@ -155,7 +155,7 @@ pub fn transform_cfg_to_nested( pub fn transform_all_cfgs(h: &mut Hugr) { let mut node_stack = Vec::from([h.root()]); while let Some(n) = node_stack.pop() { - if let Ok(s) = SiblingMut::::try_new(h, n) { + if let Ok(s) = SiblingMut::<_, CfgID>::try_new(h, n) { transform_cfg_to_nested(&mut IdentityCfgMap::new(s)); } node_stack.extend(h.children(n)) @@ -246,7 +246,7 @@ impl CfgNodeMap for IdentityCfgMap { } } -impl CfgNester for IdentityCfgMap { +impl> CfgNester for IdentityCfgMap { fn nest_sese_region( &mut self, entry_edge: (H::Node, H::Node), @@ -760,7 +760,7 @@ pub(crate) mod test { // Again, there's no need for a view of a region here, but check that the // transformation still works when we can only directly mutate the top level let root = h.root(); - let m = SiblingMut::::try_new(&mut h, root).unwrap(); + let m = SiblingMut::<_, CfgID>::try_new(&mut h, root).unwrap(); transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); h.validate().unwrap(); assert_eq!(1, depth(&h, entry)); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index df4c14075..d33234126 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -66,7 +66,11 @@ impl NodeTemplate { /// * has a [`signature`] which the type-args of the [Self::Call] do not match /// /// [`signature`]: hugr_core::types::PolyFuncType - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { + pub fn add_hugr( + self, + hugr: &mut impl HugrMut, + parent: Node, + ) -> Result { match self { NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), @@ -97,7 +101,7 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -375,7 +379,11 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + fn change_node( + &self, + hugr: &mut impl HugrMut, + n: Node, + ) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), @@ -505,10 +513,11 @@ impl ReplaceTypes { } impl ComposablePass for ReplaceTypes { + type Node = Node; type Error = ReplaceTypesError; type Result = bool; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5c4a4a707..321ec194f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -49,7 +49,7 @@ pub trait Linearizer { /// if `src` is not a valid Wire (does not identify a dataflow out-port) fn insert_copy_discard( &self, - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 874fd9ec3..d074bed0f 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -122,11 +122,11 @@ impl UntuplePass { } impl ComposablePass for UntuplePass { + type Node = Node; type Error = UntupleError; - type Result = UntupleResult; - fn run(&self, hugr: &mut impl HugrMut) -> Result { + fn run(&self, hugr: &mut impl HugrMut) -> Result { let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); let rewrites_applied = rewrites.len(); // The rewrites are independent, so we can always apply them all. From 2e32c58e52c9de2f5b78a6e975e82c02d04c755d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 28 Apr 2025 13:55:33 +0100 Subject: [PATCH 10/21] feat!: Removed model_unstable feature flag (#2120) Re-created from #2113, targeting the release branch instead. BREAKING CHANGE: Downstream crates need to remove the model_unstable feature flag when referencing hugr or hugr-core. --------- Co-authored-by: Lukas Heidemann --- .github/workflows/ci-rs.yml | 43 +++++++++---------- hugr-core/Cargo.toml | 4 +- hugr-core/README.md | 20 ++++----- hugr-core/src/envelope.rs | 33 +------------- hugr-core/src/lib.rs | 2 - .../std_extensions/arithmetic/float_types.rs | 1 - .../std_extensions/arithmetic/int_types.rs | 1 - .../src/std_extensions/collections/array.rs | 1 - hugr/Cargo.toml | 3 +- hugr/benches/benchmarks/hugr.rs | 23 ++++------ release-plz.toml | 4 +- 11 files changed, 42 insertions(+), 93 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 824291a8b..56c093e8b 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '**' + - "**" merge_group: types: [checks_requested] workflow_dispatch: {} @@ -25,7 +25,6 @@ env: LLVM_VERSION: "14.0" LLVM_FEATURE_NAME: "14-0" - jobs: # Check if changes were made to the relevant files. # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. @@ -43,25 +42,25 @@ jobs: model: ${{ steps.filter.outputs.model == 'true' || steps.override.outputs.out == 'true' }} llvm: ${{ steps.filter.outputs.llvm == 'true' || steps.override.outputs.out == 'true' }} steps: - - uses: actions/checkout@v4 - - name: Override label - id: override - run: | - echo "Label contains run-ci-checks: $OVERRIDE_LABEL" - if [ "$OVERRIDE_LABEL" == "true" ]; then - echo "Overriding due to label 'run-ci-checks'" - echo "out=true" >> $GITHUB_OUTPUT - elif [ "$DEFAULT_BRANCH" == "true" ]; then - echo "Overriding due to running on the default branch" - echo "out=true" >> $GITHUB_OUTPUT - fi - env: - OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} - DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - name: Override label + id: override + run: | + echo "Label contains run-ci-checks: $OVERRIDE_LABEL" + if [ "$OVERRIDE_LABEL" == "true" ]; then + echo "Overriding due to label 'run-ci-checks'" + echo "out=true" >> $GITHUB_OUTPUT + elif [ "$DEFAULT_BRANCH" == "true" ]; then + echo "Overriding due to running on the default branch" + echo "out=true" >> $GITHUB_OUTPUT + fi + env: + OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} + DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -109,7 +108,7 @@ jobs: - name: Override criterion with the CodSpeed harness run: cargo add --dev codspeed-criterion-compat --rename criterion --package hugr - name: Build benchmarks - run: cargo codspeed build --profile bench --features extension_inference,declarative,model_unstable,llvm,llvm-test + run: cargo codspeed build --profile bench --features extension_inference,declarative,llvm,llvm-test - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index b5a7908ef..49417097a 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -19,7 +19,6 @@ workspace = true [features] extension_inference = [] declarative = ["serde_yaml"] -model_unstable = ["hugr-model"] zstd = ["dep:zstd"] [lib] @@ -27,10 +26,9 @@ bench = false [[test]] name = "model" -required-features = ["model_unstable"] [dependencies] -hugr-model = { version = "0.19.1", path = "../hugr-model", optional = true } +hugr-model = { version = "0.19.1", path = "../hugr-model" } cgmath = { workspace = true, features = ["serde"] } delegate = { workspace = true } diff --git a/hugr-core/README.md b/hugr-core/README.md index 46cafe16f..379041a5b 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-core -=============== +# hugr-core [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-core) @@ -21,9 +20,6 @@ Please read the [API documentation here][]. Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. -- `model_unstable` - Import and export from the representation defined in the `hugr-model` crate. - Unstable and subject to change. Not enabled by default. ## Recent Changes @@ -38,10 +34,10 @@ See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - [API documentation here]: https://docs.rs/hugr-core/ - [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg - [crates]: https://img.shields.io/crates/v/hugr-core - [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov - [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md +[API documentation here]: https://docs.rs/hugr-core/ +[build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main +[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg +[crates]: https://img.shields.io/crates/v/hugr-core +[codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov +[LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE +[CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md diff --git a/hugr-core/src/envelope.rs b/hugr-core/src/envelope.rs index 35ea9c85f..24c348b78 100644 --- a/hugr-core/src/envelope.rs +++ b/hugr-core/src/envelope.rs @@ -55,7 +55,6 @@ use std::io::Write; #[allow(unused_imports)] use itertools::Itertools as _; -#[cfg(feature = "model_unstable")] use crate::import::ImportError; /// Read a HUGR envelope from a reader. @@ -197,19 +196,16 @@ pub enum EnvelopeError { source: PackageEncodingError, }, /// Error importing a HUGR from a hugr-model payload. - #[cfg(feature = "model_unstable")] ModelImport { /// The source error. source: ImportError, }, /// Error reading a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelRead { /// The source error. source: hugr_model::v0::binary::ReadError, }, /// Error writing a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelWrite { /// The source error. source: hugr_model::v0::binary::WriteError, @@ -225,17 +221,9 @@ fn read_impl( match header.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => Ok(Package::from_json_reader(payload, registry)?), - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { decode_model(payload, registry, header.format) } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - Err(EnvelopeError::FormatUnsupported { - format: header.format, - feature: Some("model_unstable"), - }) - } } } @@ -246,7 +234,6 @@ fn read_impl( /// - `extension_registry`: An extension registry with additional extensions to use when /// decoding the HUGR, if they are not already included in the package. /// - `format`: The format of the payload. -#[cfg(feature = "model_unstable")] fn decode_model( mut stream: impl BufRead, extension_registry: &ExtensionRegistry, @@ -286,22 +273,13 @@ fn write_impl( match config.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => package.to_json_writer(writer)?, - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { encode_model(writer, package, config.format)? } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - return Err(EnvelopeError::FormatUnsupported { - format: config.format, - feature: Some("model_unstable"), - }) - } } Ok(()) } -#[cfg(feature = "model_unstable")] fn encode_model( mut writer: impl Write, package: &Package, @@ -391,7 +369,6 @@ mod tests { //#[case::empty(Package::default())] // Not currently supported #[case::simple(simple_package())] //#[case::multi(multi_module_package())] // Not currently supported - #[cfg(feature = "model_unstable")] fn module_exts_roundtrip(#[case] package: Package) { let mut buffer = Vec::new(); let config = EnvelopeConfig { @@ -417,15 +394,7 @@ mod tests { format: EnvelopeFormat::Model, zstd: None, }; - let res = package.store(&mut buffer, config); - - match cfg!(feature = "model_unstable") { - true => res.unwrap(), - false => { - assert_matches!(res, Err(EnvelopeError::FormatUnsupported { .. })); - return; - } - } + package.store(&mut buffer, config).unwrap(); let (decoded_config, new_package) = read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index e32b623f2..e5f57d2a8 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -12,11 +12,9 @@ pub mod builder; pub mod core; pub mod envelope; -#[cfg(feature = "model_unstable")] pub mod export; pub mod extension; pub mod hugr; -#[cfg(feature = "model_unstable")] pub mod import; pub mod macros; pub mod ops; diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 3122bf30f..200e9dcbf 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -65,7 +65,6 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const_f64"; /// Create a new [`ConstF64`] diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index e5d625695..1342dd932 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -105,7 +105,6 @@ pub struct ConstInt { impl ConstInt { /// Name of the constructor for creating constant integers. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const"; /// Create a new [`ConstInt`] with a given width and unsigned value diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 0332ff351..fac12b1bf 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -45,7 +45,6 @@ pub struct ArrayValue { impl ArrayValue { /// Name of the constructor for creating constant arrays. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "collections.array.const"; /// Create a new [CustomConst] for an array of values of type `typ`. diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index a3f347ad9..d96ac8fe5 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -26,13 +26,12 @@ default = ["zstd"] extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] -model_unstable = ["hugr-core/model_unstable", "hugr-model"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] zstd = ["hugr-core/zstd"] [dependencies] -hugr-model = { path = "../hugr-model", optional = true, version = "0.19.1" } +hugr-model = { path = "../hugr-model", version = "0.19.1" } hugr-core = { path = "../hugr-core", version = "0.15.4" } hugr-passes = { path = "../hugr-passes", version = "0.15.4" } hugr-llvm = { path = "../hugr-llvm", version = "0.15.4", optional = true } diff --git a/hugr/benches/benchmarks/hugr.rs b/hugr/benches/benchmarks/hugr.rs index 49d73d58e..3635c8d09 100644 --- a/hugr/benches/benchmarks/hugr.rs +++ b/hugr/benches/benchmarks/hugr.rs @@ -24,10 +24,8 @@ impl Serializer for JsonSer { } } -#[cfg(feature = "model_unstable")] struct CapnpSer; -#[cfg(feature = "model_unstable")] impl Serializer for CapnpSer { fn serialize(&self, hugr: &Hugr) -> Vec { let bump = bumpalo::Bump::new(); @@ -90,20 +88,17 @@ fn bench_serialization(c: &mut Criterion) { } group.finish(); - #[cfg(feature = "model_unstable")] - { - let mut group = c.benchmark_group("circuit_roundtrip/capnp"); - group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - for size in [0, 1, 10, 100, 1000].iter() { - group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { - let h = circuit(size).0; - b.iter(|| { - black_box(roundtrip(&h, CapnpSer)); - }); + let mut group = c.benchmark_group("circuit_roundtrip/capnp"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for size in [0, 1, 10, 100, 1000].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let h = circuit(size).0; + b.iter(|| { + black_box(roundtrip(&h, CapnpSer)); }); - } - group.finish(); + }); } + group.finish(); } criterion_group! { diff --git a/release-plz.toml b/release-plz.toml index 091ca3795..4bc9f7104 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -63,9 +63,7 @@ version_group = "hugr" [[package]] name = "hugr-model" release = true -# Use a separate version group while the dependency is `-unstable`, -# to avoid breaking releases of the main package. -version_group = "hugr-model" +version_group = "hugr" [[package]] name = "hugr-llvm" From fd522c782ee59797f49e436dc874ba3082ea1a40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Mon, 28 Apr 2025 14:06:39 +0100 Subject: [PATCH 11/21] feat!: Remove `RootTagged` from the hugr view trait hierarchy (#2122) Closes #2077 BREAKING CHANGE: Removed `RootTagged` trait. Now `RootChecked` is a non-hugrview wrapper only used to verify inputs where appropriate. --- hugr-core/src/builder/dataflow.rs | 24 +-- hugr-core/src/builder/module.rs | 3 +- hugr-core/src/hugr.rs | 9 +- hugr-core/src/hugr/hugrmut.rs | 6 +- hugr-core/src/hugr/internal.rs | 26 +-- hugr-core/src/hugr/rewrite.rs | 3 +- hugr-core/src/hugr/rewrite/inline_call.rs | 9 +- hugr-core/src/hugr/rewrite/replace.rs | 3 +- hugr-core/src/hugr/validate/test.rs | 26 +-- hugr-core/src/hugr/views.rs | 27 +-- hugr-core/src/hugr/views/descendants.rs | 7 +- hugr-core/src/hugr/views/impls.rs | 28 +-- hugr-core/src/hugr/views/root_checked.rs | 194 +++++++----------- hugr-core/src/hugr/views/sibling.rs | 63 +----- hugr-core/src/hugr/views/sibling_subgraph.rs | 34 ++- hugr-core/src/package.rs | 3 +- .../src/utils/inline_constant_functions.rs | 2 +- hugr-passes/src/half_node.rs | 18 +- hugr-passes/src/lower.rs | 16 +- hugr-passes/src/merge_bbs.rs | 18 +- hugr-passes/src/monomorphize.rs | 4 +- hugr-passes/src/nest_cfgs.rs | 15 +- hugr/src/hugr.rs | 2 +- uv.lock | 2 +- 24 files changed, 191 insertions(+), 351 deletions(-) diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index ebad52085..64c5f5c84 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -174,9 +174,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().input.clone(); - self.hugr_mut() - .replace_op(inp_node, Input { types }) - .unwrap(); + self.hugr_mut().replace_op(inp_node, Input { types }); let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1); let new_port = new_port.next().unwrap(); @@ -211,9 +209,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().output.clone(); - self.hugr_mut() - .replace_op(out_node, Output { types }) - .unwrap(); + self.hugr_mut().replace_op(out_node, Output { types }); let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1); let new_port = new_port.next().unwrap(); @@ -250,15 +246,13 @@ impl FunctionBuilder { .expect("FunctionBuilder node must be a FuncDefn"); let signature = old_optype.inner_signature().into_owned(); let name = old_optype.name.clone(); - self.hugr_mut() - .replace_op( - parent, - ops::FuncDefn { - signature: f(signature).into(), - name, - }, - ) - .expect("Could not replace FunctionBuilder operation"); + self.hugr_mut().replace_op( + parent, + ops::FuncDefn { + signature: f(signature).into(), + name, + }, + ); self.hugr().get_optype(parent).as_func_defn().unwrap() } diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 18390926e..1387c1ec5 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -83,8 +83,7 @@ impl + AsRef> ModuleBuilder { .clone(); let body = signature.body().clone(); self.hugr_mut() - .replace_op(f_node, ops::FuncDefn { name, signature }) - .expect("Replacing a FuncDecl node with a FuncDefn should always be valid"); + .replace_op(f_node, ops::FuncDefn { name, signature }); let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 2789ae056..d708f15cd 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -23,7 +23,7 @@ use portgraph::multiportgraph::MultiPortGraph; use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; use thiserror::Error; -pub use self::views::{HugrView, RootTagged}; +pub use self::views::HugrView; use crate::core::NodeIndex; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, @@ -367,13 +367,10 @@ pub struct ExtensionError { } /// Errors that can occur while manipulating a Hugr. -/// -/// TODO: Better descriptions, not just re-exporting portgraph errors. #[derive(Debug, Clone, PartialEq, Eq, Error)] #[non_exhaustive] pub enum HugrError { /// The node was not of the required [OpTag] - /// (e.g. to conform to the [RootTagged::RootHandle] of a [HugrView]) #[error("Invalid tag: required a tag in {required} but found {actual}")] #[allow(missing_docs)] InvalidTag { required: OpTag, actual: OpTag }, @@ -671,12 +668,12 @@ mod test { signature: Signature::new_endo(ty).with_extension_delta(result.clone()), }; let mut expected = backup; - expected.replace_op(p, expected_p).unwrap(); + expected.replace_op(p, expected_p); let expected_gp = ops::Conditional { extension_delta: result, ..root_ty }; - expected.replace_op(h.root(), expected_gp).unwrap(); + expected.replace_op(h.root(), expected_gp); assert_eq!(h, expected); } else { diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index bf9a4cad0..51e92f342 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -156,14 +156,14 @@ pub trait HugrMut: HugrMutInternals { /// If the node is not in the graph, or if the port is invalid. fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort); - /// Insert another hugr into this one, under a given root node. + /// Insert another hugr into this one, under a given parent node. /// /// # Panics /// /// If the root node is not in the graph. fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; - /// Copy another hugr into this one, under a given root node. + /// Copy another hugr into this one, under a given parent node. /// /// # Panics /// @@ -174,7 +174,7 @@ pub trait HugrMut: HugrMutInternals { other: &H, ) -> InsertionResult; - /// Copy a subgraph from another hugr into this one, under a given root node. + /// Copy a subgraph from another hugr into this one, under a given parent node. /// /// Sibling order is not preserved. /// diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 8892c3b11..58ce066c0 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -12,7 +12,7 @@ use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node}; use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrError, NodeMetadataMap, OpType, RootTagged}; +use super::{HugrView, NodeMetadataMap, OpType}; /// Trait for accessing the internals of a Hugr(View). /// @@ -107,7 +107,7 @@ impl HugrInternals for Hugr { /// /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. -pub trait HugrMutInternals: RootTagged { +pub trait HugrMutInternals: HugrView { /// Set root node of the HUGR. /// /// This should be an existing node in the HUGR. Most operations use the @@ -189,18 +189,10 @@ pub trait HugrMutInternals: RootTagged { /// /// Returns the old OpType. /// - /// If the module root is set to a non-module operation the hugr will - /// become invalid. - /// - /// # Errors - /// - /// Returns a [`HugrError::InvalidTag`] if this would break the bound - /// (`Self::RootHandle`) on the root node's OpTag. - /// /// # Panics /// /// If the node is not in the graph. - fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> OpType; /// Gets a mutable reference to the optype. /// @@ -223,9 +215,10 @@ pub trait HugrMutInternals: RootTagged { /// If the node is not in the graph. fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap; - /// Returns a mutable reference to the extension registry for this hugr, - /// containing all extensions required to define the operations and types in - /// the hugr. + /// Returns a mutable reference to the extension registry for this HUGR. + /// + /// This set contains all extensions required to define the operations and + /// types in the HUGR. fn extensions_mut(&mut self) -> &mut ExtensionRegistry; } @@ -326,10 +319,9 @@ impl HugrMutInternals for Hugr { .expect("Inserting a newly-created node into the hierarchy should never fail."); } - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { + fn replace_op(&mut self, node: Node, op: impl Into) -> OpType { panic_invalid_node(self, node); - // We know RootHandle=Node here so no need to check - Ok(std::mem::replace(self.optype_mut(node), op.into())) + std::mem::replace(self.optype_mut(node), op.into()) } fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index d2b0fe14d..e220864a7 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -82,8 +82,7 @@ impl Rewrite for Transactional { let r = self.underlying.apply(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()) - .expect("The root replacement should always match the old root type"); + h.replace_op(h.root(), backup.root_type().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/rewrite/inline_call.rs index 6b1e7a958..e32373507 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/rewrite/inline_call.rs @@ -76,7 +76,6 @@ impl Rewrite for InlineCall { let ty_args = h .replace_op(self.0, new_op) - .unwrap() .as_call() .unwrap() .type_args @@ -117,8 +116,7 @@ mod test { ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::RootChecked; - use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle}; + use crate::ops::handle::{FuncID, NodeHandle}; use crate::ops::{Input, OpType, Value}; use crate::std_extensions::arithmetic::{ int_ops::{self, IntOpDef}, @@ -179,10 +177,7 @@ mod test { .count(), 1 ); - RootChecked::<_, ModuleRootID>::try_new(&mut hugr) - .unwrap() - .apply_rewrite(InlineCall(call1.node())) - .unwrap(); + hugr.apply_rewrite(InlineCall(call1.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); assert_eq!(calls(&hugr), [call2]); diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index c2659cc5a..0316f9d5b 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -732,8 +732,7 @@ mod test { // Root node type needs to be that of common parent of the removed nodes: let mut rep2 = rep.clone(); rep2.replacement - .replace_op(rep2.replacement.root(), h.root_type().clone()) - .unwrap(); + .replace_op(rep2.replacement.root(), h.root_type().clone()); assert_eq!( check_same_errors(rep2), ReplaceError::WrongRootNodeTag { diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 37157020d..a66296c35 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -191,26 +191,23 @@ fn df_children_restrictions() { .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, Noop(usize_t())).unwrap(); + b.replace_op(output, Noop(usize_t())); assert_matches!( b.validate(), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, ops::Output::new(vec![bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} ); - b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])); // After fixing the output back, replace the copy with an output op - b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) @@ -806,8 +803,7 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![bool_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::ContainerWithoutChildren { .. }) @@ -869,8 +865,7 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![qb_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); b.replace_op( block, ops::DataflowBlock { @@ -879,18 +874,15 @@ fn cfg_children_restrictions() { other_outputs: vec![qb_t()].into(), extension_delta: ExtensionSet::new(), }, - ) - .unwrap(); + ); let mut block_children = b.hierarchy.children(block.pg_index()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, ops::Input::new(vec![qb_t()])) - .unwrap(); + b.replace_op(block_input, ops::Input::new(vec![qb_t()])); b.replace_op( block_output, ops::Output::new(vec![Type::new_unit_sum(1), qb_t()]), - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index eb8059577..a154a956f 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; pub use self::petgraph::PetgraphWrapper; use self::render::RenderConfig; pub use descendants::DescendantsGraph; -pub use root_checked::RootChecked; +pub use root_checked::{check_tag, RootCheckable, RootChecked}; pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; @@ -29,7 +29,6 @@ use super::{ Hugr, HugrError, HugrMut, Node, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, }; use crate::extension::ExtensionRegistry; -use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; use crate::types::{EdgeKind, PolyFuncType, Signature, Type}; @@ -479,17 +478,8 @@ pub trait HugrView: HugrInternals { } } -/// Trait for views that provides a guaranteed bound on the type of the root node. -pub trait RootTagged: HugrView { - /// The kind of handle that can be used to refer to the root node. - /// - /// The handle is guaranteed to be able to contain the operation returned by - /// [`HugrView::root_type`]. - type RootHandle: NodeHandle; -} - /// A common trait for views of a HUGR hierarchical subgraph. -pub trait HierarchyView<'a>: RootTagged + Sized { +pub trait HierarchyView<'a>: HugrView + Sized { /// Create a hierarchical view of a HUGR given a root node. /// /// # Errors @@ -515,19 +505,6 @@ pub trait ExtractHugr: HugrView + Sized { } } -/// Check that the node in a HUGR can be represented by the required tag. -fn check_tag, N>( - hugr: &impl HugrView, - node: N, -) -> Result<(), HugrError> { - let actual = hugr.get_optype(node).tag(); - let required = Required::TAG; - if !required.is_superset(actual) { - return Err(HugrError::InvalidTag { required, actual }); - } - Ok(()) -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 28a7d9f2d..906dea3e4 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -8,7 +8,7 @@ use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; @@ -131,16 +131,13 @@ impl HugrView for DescendantsGraph<'_, Root> { .map(|index| self.get_node(index)) } } -impl RootTagged for DescendantsGraph<'_, Root> { - type RootHandle = Root; -} impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root> where Root: NodeHandle, { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { - check_tag::(hugr, root)?; + check_tag::(hugr, root)?; let hugr = hugr.base_hugr(); Ok(Self { root, diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 928acba20..440df9480 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -3,11 +3,8 @@ use std::{borrow::Cow, rc::Rc, sync::Arc}; use super::HugrView; -use super::RootTagged; use crate::hugr::internal::{HugrInternals, HugrMutInternals}; use crate::hugr::HugrMut; -use crate::Hugr; -use crate::Node; macro_rules! hugr_internal_methods { // The extra ident here is because invocations of the macro cannot pass `self` as argument @@ -116,7 +113,7 @@ macro_rules! hugr_mut_internal_methods { fn set_parent(&mut self, node: Self::Node, parent: Self::Node); fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); - fn replace_op(&mut self, node: Self::Node, op: impl Into) -> Result; + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> crate::ops::OpType; fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; @@ -149,11 +146,6 @@ macro_rules! hugr_mut_methods { } pub(crate) use hugr_mut_methods; -// -------- Base Hugr implementation -impl RootTagged for Hugr { - type RootHandle = Node; -} - // -------- Immutable borrow impl HugrInternals for &T { type Portgraph<'p> @@ -167,9 +159,6 @@ impl HugrInternals for &T { impl HugrView for &T { hugr_view_methods! {this, *this} } -impl RootTagged for &T { - type RootHandle = T::RootHandle; -} // -------- Mutable borrow impl HugrInternals for &mut T { @@ -184,9 +173,6 @@ impl HugrInternals for &mut T { impl HugrView for &mut T { hugr_view_methods! {this, &**this} } -impl RootTagged for &mut T { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for &mut T { hugr_mut_internal_methods! {this, &mut **this} } @@ -207,9 +193,6 @@ impl HugrInternals for Rc { impl HugrView for Rc { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Rc { - type RootHandle = T::RootHandle; -} // -------- Arc impl HugrInternals for Arc { @@ -224,9 +207,6 @@ impl HugrInternals for Arc { impl HugrView for Arc { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Arc { - type RootHandle = T::RootHandle; -} // -------- Box impl HugrInternals for Box { @@ -241,9 +221,6 @@ impl HugrInternals for Box { impl HugrView for Box { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Box { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for Box { hugr_mut_internal_methods! {this, this.as_mut()} } @@ -264,9 +241,6 @@ impl HugrInternals for Cow<'_, T> { impl HugrView for Cow<'_, T> { hugr_view_methods! {this, this.as_ref()} } -impl RootTagged for Cow<'_, T> { - type RootHandle = T::RootHandle; -} impl HugrMutInternals for Cow<'_, T> where T: HugrMutInternals + ToOwned, diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index e0dcf3eb7..50c9bcf44 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,20 +1,28 @@ -use std::borrow::Cow; use std::marker::PhantomData; -use crate::hugr::internal::{HugrInternals, HugrMutInternals}; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; -use crate::ops::OpTrait; +use crate::ops::{OpTag, OpTrait}; use crate::{Hugr, Node}; -use super::{check_tag, HugrView, RootTagged}; +use super::HugrView; -/// A view of the whole Hugr. -/// (Just provides static checking of the type of the root node) +/// A wrapper over a Hugr that ensures the root node optype is of the required +/// [`OpTag`]. #[derive(Clone)] -pub struct RootChecked(H, PhantomData); +pub struct RootChecked(H, PhantomData); + +impl> RootChecked { + /// A tag that can contain the operation of the hugr root node. + const TAG: OpTag = Handle::TAG; + + /// Returns the most specific tag that can be applied to the root node. + pub fn tag(&self) -> OpTag { + let tag = self.0.get_optype(self.0.root()).tag(); + debug_assert!(Self::TAG.is_superset(tag)); + tag + } -impl> RootChecked { /// Create a hierarchical view of a whole HUGR /// /// # Errors @@ -22,101 +30,80 @@ impl> RootChecked { /// /// [`OpTag`]: crate::ops::OpTag pub fn try_new(hugr: H) -> Result { - if !H::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: H::RootHandle::TAG, - actual: Root::TAG, - }); - } - check_tag::(&hugr, hugr.root())?; + Self::check(&hugr)?; Ok(Self(hugr, PhantomData)) } -} -impl RootChecked { - /// Extracts the underlying (owned) Hugr - pub fn into_hugr(self) -> Hugr { - self.0 + /// Check if a Hugr is valid for the given [`OpTag`]. + /// + /// To check arbitrary nodes, use [`check_tag`]. + pub fn check(hugr: &H) -> Result<(), HugrError> { + check_tag::(hugr, hugr.root())?; + Ok(()) } -} -impl RootChecked<&mut Hugr, Root> { - /// Allows immutably borrowing the underlying mutable reference - pub fn borrow(&self) -> RootChecked<&Hugr, Root> { - RootChecked(&*self.0, PhantomData) + /// Returns a reference to the underlying Hugr. + pub fn hugr(&self) -> &H { + &self.0 } -} -impl HugrInternals for RootChecked { - type Portgraph<'p> - = H::Portgraph<'p> - where - Self: 'p; - type Node = H::Node; - - super::impls::hugr_internal_methods! {this, &this.0} -} - -impl HugrView for RootChecked { - super::impls::hugr_view_methods! {this, &this.0} -} + /// Extracts the underlying Hugr + pub fn into_hugr(self) -> H { + self.0 + } -impl> RootTagged for RootChecked { - type RootHandle = Root; + /// Returns a wrapper over a reference to the underlying Hugr. + pub fn as_ref(&self) -> RootChecked<&H, Handle> { + RootChecked(&self.0, PhantomData) + } } -impl, Root> AsRef for RootChecked { +impl, Handle> AsRef for RootChecked { fn as_ref(&self) -> &Hugr { self.0.as_ref() } } -impl> HugrMutInternals for RootChecked { - fn replace_op( - &mut self, - node: Self::Node, - op: impl Into, - ) -> Result { - let op = op.into(); - if node == self.root() && !Root::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { - required: Root::TAG, - actual: op.tag(), - }); - } - self.0.replace_op(node, op) +/// A trait for types that can be checked for a specific [`OpTag`] at their root node. +/// +/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`]. +pub trait RootCheckable>: Sized { + /// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`]. + /// + /// If `Self` is already a [`RootChecked`], it is a no-op. + fn try_into_checked(self) -> Result, HugrError>; +} +impl> RootCheckable for H { + fn try_into_checked(self) -> Result, HugrError> { + RootChecked::try_new(self) } - - delegate::delegate! { - to (&mut self.0) { - fn set_root(&mut self, root: Self::Node); - fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); - fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; - fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; - fn set_parent(&mut self, node: Self::Node, parent: Self::Node); - fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); - fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); - fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; - fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; - fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; - } +} +impl> RootCheckable for RootChecked { + fn try_into_checked(self) -> Result, HugrError> { + Ok(self) } } -impl> HugrMut for RootChecked { - super::impls::hugr_mut_methods! {this, &mut this.0} +/// Check that the node in a HUGR can be represented by the required tag. +pub fn check_tag, N>( + hugr: &impl HugrView, + node: N, +) -> Result<(), HugrError> { + let actual = hugr.get_optype(node).tag(); + let required = Required::TAG; + if !required.is_superset(actual) { + return Err(HugrError::InvalidTag { required, actual }); + } + Ok(()) } #[cfg(test)] mod test { use super::RootChecked; - use crate::extension::prelude::MakeTuple; - use crate::extension::ExtensionSet; - use crate::hugr::internal::HugrMutInternals; - use crate::hugr::{HugrError, HugrMut}; - use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID}; - use crate::ops::{DataflowBlock, OpTag, OpType}; - use crate::{ops, type_row, types::Signature, Hugr, HugrView}; + use crate::hugr::HugrError; + use crate::ops::handle::{CfgID, DfgID}; + use crate::ops::{OpTag, OpType}; + use crate::{ops, types::Signature, Hugr}; #[test] fn root_checked() { @@ -125,7 +112,7 @@ mod test { } .into(); let mut h = Hugr::new(root_type.clone()); - let cfg_v = RootChecked::<&Hugr, CfgID>::try_new(&h); + let cfg_v = RootChecked::<_, CfgID>::check(&h); assert_eq!( cfg_v.err(), Some(HugrError::InvalidTag { @@ -133,46 +120,9 @@ mod test { actual: OpTag::Dfg }) ); - let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); - // That is a HugrMutInternal, so we can try: - let root = dfg_v.root(); - let bb: OpType = DataflowBlock { - inputs: type_row![], - other_outputs: type_row![], - sum_rows: vec![type_row![]], - extension_delta: ExtensionSet::new(), - } - .into(); - let r = dfg_v.replace_op(root, bb.clone()); - assert_eq!( - r, - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: ops::OpTag::DataflowBlock - }) - ); - // That didn't do anything: - assert_eq!(dfg_v.get_optype(root), &root_type); - - // Make a RootChecked that allows any DataflowParent - // We won't be able to do this by widening the bound: - assert_eq!( - RootChecked::<_, DataflowParentID>::try_new(dfg_v).err(), - Some(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::DataflowParent - }) - ); - - let mut dfp_v = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut h).unwrap(); - let r = dfp_v.replace_op(root, bb.clone()); - assert_eq!(r, Ok(root_type)); - assert_eq!(dfp_v.get_optype(root), &bb); - // Just check we can create a nested instance (narrowing the bound) - let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap(); - - // And it's a HugrMut: - let nodetype = MakeTuple(type_row![]); - bb_v.add_node_with_parent(bb_v.root(), nodetype); + // This should succeed + let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); + assert!(OpTag::Dfg.is_superset(dfg_v.tag())); + assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag()); } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 4d15a9c48..ac31d2695 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -11,7 +11,7 @@ use crate::ops::handle::NodeHandle; use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; @@ -154,9 +154,6 @@ impl HugrView for SiblingGraph<'_, Root> { .map(|n| self.get_node(n)) } } -impl RootTagged for SiblingGraph<'_, Root> { - type RootHandle = Root; -} impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> { fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self { @@ -254,12 +251,6 @@ impl<'g, H: HugrMut, Root: NodeHandle> SiblingMut<'g, H, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. pub fn try_new(hugr: &'g mut H, root: H::Node) -> Result { - if root == hugr.root() && !H::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: H::RootHandle::TAG, - actual: Root::TAG, - }); - } check_tag::(hugr, root)?; Ok(Self { hugr, @@ -375,22 +366,20 @@ impl> HugrView for SiblingMut<'_, H, Root> } } -impl> RootTagged for SiblingMut<'_, H, Root> { - type RootHandle = Root; -} - impl> HugrMutInternals for SiblingMut<'_, H, Root> { fn replace_op( &mut self, node: Self::Node, op: impl Into, - ) -> Result { + ) -> crate::ops::OpType { let op = op.into(); + // Note: `SiblingMut` will be removed in a subsequent PR, so we just panic here for now. if node == self.root() && !Root::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { + let err = HugrError::InvalidTag { required: Root::TAG, actual: op.tag(), - }); + }; + panic!("{err}"); } self.hugr.replace_op(node, op) } @@ -424,9 +413,9 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; + use crate::ops::handle::{CfgID, DfgID, FuncID}; + use crate::ops::OpType; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; - use crate::ops::{OpTrait, OpType}; use crate::types::Signature; use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; @@ -585,45 +574,13 @@ mod test { }) ); - let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); let bad_nodetype: OpType = crate::ops::CFG { signature }.into(); - assert_eq!( - sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Cfg - }) - ); - // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation - simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); + // Performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation + simple_dfg_hugr.replace_op(root, bad_nodetype); assert!(simple_dfg_hugr.validate().is_err()); } - #[rstest] - fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { - let root = simple_dfg_hugr.root(); - let case_nodetype = crate::ops::Case { - signature: simple_dfg_hugr - .root_type() - .dataflow_signature() - .unwrap() - .into_owned(), - }; - let mut sib_mut = SiblingMut::<_, DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap(); - // As expected, we cannot replace the root with a Case - assert_eq!( - sib_mut.replace_op(root, case_nodetype), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Case - }) - ); - - let nested_sib_mut = SiblingMut::<_, DataflowParentID>::try_new(&mut sib_mut, root); - assert!(nested_sib_mut.is_err()); - } - #[rstest] fn extract_hugr() -> Result<(), Box> { let (hugr, _def, inner) = make_module_hgr()?; diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index c681fafc9..9502d9f6b 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -22,13 +22,15 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::core::HugrNode; use crate::extension::ExtensionSet; -use crate::hugr::{HugrMut, HugrView, RootTagged}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{NamedOp, OpTag, OpTrait, OpType}; use crate::types::{Signature, Type}; use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; +use super::root_checked::RootCheckable; + /// A non-empty convex subgraph of a HUGR sibling graph. /// /// A HUGR region in which all nodes share the same parent. Unlike @@ -95,11 +97,18 @@ impl SiblingSubgraph { /// /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. - pub fn try_new_dataflow_subgraph(dfg_graph: &H) -> Result> + pub fn try_new_dataflow_subgraph<'h, H, Root>( + dfg_graph: impl RootCheckable<&'h H, Root>, + ) -> Result> where - H: Clone + RootTagged, - Root: ContainerHandle, + H: 'h + Clone + HugrView, + Root: ContainerHandle, { + let Ok(dfg_graph) = dfg_graph.try_into_checked() else { + return Err(InvalidSubgraph::NonDataflowRegion); + }; + let dfg_graph = dfg_graph.into_hugr(); + let parent = dfg_graph.root(); let nodes = dfg_graph.children(parent).skip(2).collect_vec(); let (inputs, outputs) = get_input_output_ports(dfg_graph); @@ -798,6 +807,9 @@ pub enum InvalidSubgraph { /// An invalid boundary port was found. #[error("Invalid boundary port.")] InvalidBoundary(#[from] InvalidSubgraphBoundary), + /// The hugr region is not a dataflow graph. + #[error("SiblingSubgraphs may only be defined on dataflow regions.")] + NonDataflowRegion, } /// Errors that can occur while constructing a [`SiblingSubgraph`]. @@ -985,7 +997,7 @@ mod tests { fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { let (mut hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; let empty_dfg = { let builder = @@ -1009,7 +1021,7 @@ mod tests { fn test_signature() -> Result<(), InvalidSubgraph> { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; assert_eq!( sub.signature(&func), Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( @@ -1046,7 +1058,7 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); assert_eq!( - SiblingSubgraph::try_new_dataflow_subgraph(&func) + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func) .unwrap() .nodes() .len(), @@ -1162,7 +1174,8 @@ mod tests { let (hugr, func_root) = build_hugr_classical().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let func = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature, func.signature(&func_graph).into()); } @@ -1172,7 +1185,8 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let subgraph = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); extracted.validate().unwrap(); @@ -1197,7 +1211,7 @@ mod tests { let outw = [outw1].into_iter().chain(outw2); let h = builder.finish_hugr_with_outputs(outw).unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); - let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 1b96f1ebd..5e1fecdb6 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -353,8 +353,7 @@ fn to_module_hugr(mut hugr: Hugr) -> Result { name: "main".to_string(), signature: signature.into_owned().into(), }, - ) - .expect("Hugr accepts any root node"); + ); // Wrap it in a module. let new_root = hugr.add_node(Module::new().into()); diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index a55072a99..1b0931bd2 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -69,7 +69,7 @@ fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Resul hugr.insert_hugr(func_node, func_hugr); for lcn in load_constant_ns { - hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?; + hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?); } any_changes = true; } diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index ca0d9880e..7f332209f 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -3,12 +3,10 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::hugr::RootTagged; - +use hugr_core::hugr::views::RootCheckable; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{OpTag, OpTrait}; - -use hugr_core::{Direction, Node}; +use hugr_core::{Direction, HugrView, Node}; /// We provide a view of a cfg where every node has at most one of /// (multiple predecessors, multiple successors). @@ -32,9 +30,12 @@ struct HalfNodeView { exit: H::Node, } -impl> HalfNodeView { +impl HalfNodeView { #[allow(unused)] - pub(crate) fn new(h: H) -> Self { + pub(crate) fn new(h: impl RootCheckable>) -> Self { + let checked = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = checked.into_hugr(); + let (entry, exit) = { let mut children = h.children(h.root()); (children.next().unwrap(), children.next().unwrap()) @@ -64,7 +65,7 @@ impl> HalfNodeView { } } -impl> CfgNodeMap> for HalfNodeView { +impl CfgNodeMap> for HalfNodeView { fn entry_node(&self) -> HalfNode { HalfNode::N(self.entry) } @@ -98,7 +99,6 @@ mod test { use super::super::nest_cfgs::{test::*, EdgeClassifier}; use super::{HalfNode, HalfNodeView}; use hugr_core::builder::BuildError; - use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::NodeHandle; use itertools::Itertools; @@ -118,7 +118,7 @@ mod test { // \---<---<---<---<---<---<---<---<---<---/ // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) - let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap()); + let v = HalfNodeView::new(&h); let edge_classes = EdgeClassifier::get_edge_classes(&v); let HalfNodeView { h: _, entry, exit } = v; diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 8de6c00a2..3a3bd5e91 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -1,5 +1,5 @@ use hugr_core::{ - hugr::{hugrmut::HugrMut, views::SiblingSubgraph, HugrError}, + hugr::{hugrmut::HugrMut, views::SiblingSubgraph}, ops::OpType, Hugr, Node, }; @@ -10,14 +10,10 @@ use thiserror::Error; /// New operations must match the signature of the old operations. /// /// Returns a list of the replaced nodes and their old operations. -/// -/// # Errors -/// -/// Returns a [`HugrError`] if any replacement fails. pub fn replace_many_ops>( hugr: &mut impl HugrMut, mapping: impl Fn(&OpType) -> Option, -) -> Result, HugrError> { +) -> Vec<(Node, OpType)> { let replacements = hugr .nodes() .filter_map(|node| { @@ -28,7 +24,10 @@ pub fn replace_many_ops>( replacements .into_iter() - .map(|(node, new_op)| hugr.replace_op(node, new_op).map(|old_op| (node, old_op))) + .map(|(node, new_op)| { + let old_op = hugr.replace_op(node, new_op); + (node, old_op) + }) .collect() } @@ -117,8 +116,7 @@ mod test { } else { None } - }) - .unwrap(); + }); assert_eq!(replaced.len(), 1); let (n, op) = replaced.remove(0); diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index d1731107d..a5de5eb57 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -4,11 +4,11 @@ use std::collections::HashMap; use hugr_core::extension::prelude::UnpackTuple; use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::views::RootCheckable; use itertools::Itertools; use hugr_core::hugr::rewrite::inline_dfg::InlineDFG; use hugr_core::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use hugr_core::hugr::RootTagged; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; use hugr_core::{Hugr, HugrView, Node}; @@ -16,7 +16,13 @@ use hugr_core::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor /// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { +pub fn merge_basic_blocks<'h, H>(cfg: impl RootCheckable<&'h mut H, CfgID>) +where + H: 'h + HugrMut, +{ + let checked = cfg.try_into_checked().expect("Hugr must be a CFG region"); + let cfg = checked.into_hugr(); + let mut worklist = cfg.children(cfg.root()).collect::>(); while let Some(n) = worklist.pop() { // Consider merging n with its successor @@ -44,7 +50,7 @@ pub fn merge_basic_blocks(cfg: &mut impl HugrMut, + cfg: &impl HugrView, pred: Node, succ: Node, ) -> (Replacement, Node, [Node; 2]) { @@ -158,14 +164,12 @@ mod test { use std::sync::Arc; use hugr_core::extension::prelude::PRELUDE_ID; - use hugr_core::hugr::views::RootChecked; use itertools::Itertools; use rstest::rstest; use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; use hugr_core::ops::constant::Value; - use hugr_core::ops::handle::CfgID; use hugr_core::ops::{LoadConstant, OpTrait, OpType}; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; @@ -252,7 +256,7 @@ mod test { let mut h = h.finish_hugr()?; let r = h.root(); - merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); + merge_basic_blocks(&mut h); h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); @@ -346,7 +350,7 @@ mod test { h.branch(&bb3, 0, &h.exit_block())?; let mut h = h.finish_hugr()?; - merge_basic_blocks(&mut RootChecked::<_, CfgID>::try_new(&mut h).unwrap()); + merge_basic_blocks(&mut h); h.validate()?; // Should only be one BB left diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 3164702d8..3ac85a020 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -157,7 +157,7 @@ fn mono_scan( h.disconnect(ch, fn_inp); // No-op if copying+substituting h.connect(new_tgt, fn_out, ch, fn_inp); - h.replace_op(ch, new_op).unwrap(); + h.replace_op(ch, new_op); } } @@ -178,7 +178,7 @@ fn instantiate( name: mangle_inner_func(&outer_name, &fd.name), signature: fd.signature.clone(), }; - h.replace_op(n, fd).unwrap(); + h.replace_op(n, fd); h.move_after_sibling(n, poly_func); } else { to_scan.extend(h.children(n)) diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 1c4928e12..b98d4fb23 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -46,8 +46,8 @@ use thiserror::Error; use hugr_core::hugr::rewrite::outline_cfg::OutlineCfg; use hugr_core::hugr::views::sibling::SiblingMut; -use hugr_core::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use hugr_core::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; +use hugr_core::hugr::views::{HierarchyView, HugrView, RootCheckable, SiblingGraph}; +use hugr_core::hugr::{hugrmut::HugrMut, Rewrite}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; @@ -219,9 +219,12 @@ pub struct IdentityCfgMap { entry: H::Node, exit: H::Node, } -impl> IdentityCfgMap { +impl IdentityCfgMap { /// Creates an [IdentityCfgMap] for the specified CFG - pub fn new(h: H) -> Self { + pub fn new(h: impl RootCheckable>) -> Self { + let h = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = h.into_hugr(); + // Panic if malformed enough not to have two children let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); @@ -636,7 +639,7 @@ pub(crate) mod test { let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); let (entry, exit) = (entry.node(), exit.node()); let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); - let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow())); + let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.as_ref())); let [&left, &right] = edge_classes .keys() .filter(|(s, _)| *s == split) @@ -734,7 +737,7 @@ pub(crate) mod test { // There's no need to use a view of a region here but we do so just to check // that we *can* (as we'll need to for "real" module Hugr's) - let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap()); + let v = IdentityCfgMap::new(SiblingGraph::::try_new(&h, h.root()).unwrap()); let edge_classes = EdgeClassifier::get_edge_classes(&v); let IdentityCfgMap { h: _, entry, exit } = v; let [&left, &right] = edge_classes diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 88c8c8df0..a66de8315 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -3,6 +3,6 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, RootTagged, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; diff --git a/uv.lock b/uv.lock index 130657231..4f7d6012a 100644 --- a/uv.lock +++ b/uv.lock @@ -277,7 +277,7 @@ wheels = [ [[package]] name = "hugr" -version = "0.11.4" +version = "0.11.5" source = { editable = "hugr-py" } dependencies = [ { name = "graphviz" }, From 4c9cce36d340287f2575b5b68eaa81407ed6a83a Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:15:13 +0200 Subject: [PATCH 12/21] feat!: Split Rewrite trait into VerifyPatch and ApplyPatch (#2070) This PR splits the `Rewrite` trait into two (three) traits: - a `VerifyPatch` trait that has the `fn verify` and `fn invalidation_set` functions - a `ApplyPatch` trait that has the `fn apply` function. This inherits `VerifyPatch` and is the "rewriting" trait that should be used in most scenarios. In addition, there is a third trait `ApplyPatchHugrMut` that can be implemented by any patches that can be applied to _any_ `HugrMut` (as opposed to a specific type `H`). This is strictly stronger than `ApplyPatch` and should be implemented instead of `ApplyPatch` where possible (see the docs of the traits). closes #588 closes #2052 BREAKING CHANGE: Replaced the `Rewrite` trait with `Patch`. `Rewrite::ApplyResult` is now `Patch::Outcome`. `Rewrite::verify` was split into a separate trait, and is now `PatchVerification::verify`. BREAKING CHANGE: Renamed `hugr.rewrite` module to `hugr.patch`. BREAKING CHANGE: Changed the type `OutlineCfg::ApplyResult` (now `OutlineCfg::Outcome`) from `(Node, Node)` to `[Node; 2]`. --------- Co-authored-by: Alan Lawrence Co-authored-by: Alan Lawrence --- hugr-core/src/hugr.rs | 4 +- hugr-core/src/hugr/hugrmut.rs | 9 +- hugr-core/src/hugr/patch.rs | 169 ++++++++ .../src/hugr/{rewrite => patch}/consts.rs | 86 ++-- .../hugr/{rewrite => patch}/inline_call.rs | 49 +-- .../src/hugr/{rewrite => patch}/inline_dfg.rs | 38 +- .../{rewrite => patch}/insert_identity.rs | 45 ++- .../hugr/{rewrite => patch}/outline_cfg.rs | 48 ++- .../src/hugr/{rewrite => patch}/port_types.rs | 0 .../src/hugr/{rewrite => patch}/replace.rs | 375 +++++++++++------- .../hugr/{rewrite => patch}/simple_replace.rs | 141 ++++--- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- hugr-passes/src/lower.rs | 9 +- hugr-passes/src/merge_bbs.rs | 8 +- hugr-passes/src/nest_cfgs.rs | 10 +- hugr-passes/src/untuple.rs | 2 +- hugr/src/hugr.rs | 4 +- 17 files changed, 650 insertions(+), 351 deletions(-) create mode 100644 hugr-core/src/hugr/patch.rs rename hugr-core/src/hugr/{rewrite => patch}/consts.rs (74%) rename hugr-core/src/hugr/{rewrite => patch}/inline_call.rs (91%) rename hugr-core/src/hugr/{rewrite => patch}/inline_dfg.rs (96%) rename hugr-core/src/hugr/{rewrite => patch}/insert_identity.rs (84%) rename hugr-core/src/hugr/{rewrite => patch}/outline_cfg.rs (96%) rename hugr-core/src/hugr/{rewrite => patch}/port_types.rs (100%) rename hugr-core/src/hugr/{rewrite => patch}/replace.rs (73%) rename hugr-core/src/hugr/{rewrite => patch}/simple_replace.rs (92%) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index d708f15cd..7a74b4070 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -4,7 +4,7 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; -pub mod rewrite; +pub mod patch; pub mod serialize; pub mod validate; pub mod views; @@ -17,7 +17,7 @@ pub(crate) use self::hugrmut::HugrMut; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; -pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; +pub use patch::{Patch, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 51e92f342..c58ccbdbc 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -12,7 +12,7 @@ use crate::extension::ExtensionRegistry; use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType}; -use crate::hugr::{NodeMetadata, Rewrite}; +use crate::hugr::{NodeMetadata, Patch}; use crate::ops::OpTrait; use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; @@ -195,11 +195,8 @@ pub trait HugrMut: HugrMutInternals { subgraph: &SiblingSubgraph, ) -> HashMap; - /// Applies a rewrite to the graph. - fn apply_rewrite( - &mut self, - rw: impl Rewrite, - ) -> Result + /// Applies a patch to the graph. + fn apply_patch(&mut self, rw: impl Patch) -> Result where Self: Sized, { diff --git a/hugr-core/src/hugr/patch.rs b/hugr-core/src/hugr/patch.rs new file mode 100644 index 000000000..bc6195eba --- /dev/null +++ b/hugr-core/src/hugr/patch.rs @@ -0,0 +1,169 @@ +//! Rewrite operations on the HUGR - replacement, outlining, etc. + +pub mod consts; +pub mod inline_call; +pub mod inline_dfg; +pub mod insert_identity; +pub mod outline_cfg; +mod port_types; +pub mod replace; +pub mod simple_replace; + +use crate::{Hugr, HugrView}; +pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; +pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; + +use super::HugrMut; + +/// Verify that a patch application would succeed. +pub trait PatchVerification { + /// The type of Error with which this Rewrite may fail + type Error: std::error::Error; + + /// The node type of the HugrView that this patch applies to. + type Node; + + /// Checks whether the rewrite would succeed on the specified Hugr. + /// If this call succeeds, [Patch::apply] should also succeed on the same + /// `h` If this calls fails, [Patch::apply] would fail with the same + /// error. + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + + /// Returns a set of nodes referenced by the rewrite. Modifying any of these + /// nodes will invalidate it. + /// + /// Two `impl Rewrite`s can be composed if their invalidation sets are + /// disjoint. + fn invalidation_set(&self) -> impl Iterator; +} + +/// A patch that can be applied to a mutable Hugr of type `H`. +/// +/// ### When to use +/// +/// Use this trait whenever possible in bounds for the most generality. Note +/// that this will require specifying which type `H` the patch applies to. +/// +/// ### When to implement +/// +/// For patches that work on any `H: HugrMut`, prefer implementing [`PatchHugrMut`] instead. This +/// will automatically implement this trait. +pub trait Patch: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [Patch::apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [Patch::UNCHANGED_ON_FAILURE] is true, then `h` must be unchanged if + /// Err is returned. See also [PatchVerification::verify] + /// + /// # Panics + /// + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply(self, h: &mut H) -> Result; +} + +/// A patch that can be applied to any [`HugrMut`]. +/// +/// This trait is a generalisation of [`Patch`] in that it guarantees that +/// the patch can be applied to any type implementing [`HugrMut`]. +/// +/// ### When to use +/// +/// Prefer using the more general [`Patch`] trait in bounds where the +/// type `H` is known. Resort to this trait if patches must be applicable to +/// any [`HugrMut`] instance. +/// +/// ### When to implement +/// +/// Always implement this trait when possible, to define how a patch is applied +/// to any type implementing [`HugrMut`]. A blanket implementation ensures that +/// any type implementing this trait also implements [`Patch`]. +pub trait PatchHugrMut: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [self.apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if + /// Err is returned. See also [self.verify] + /// # Panics + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result; +} + +impl> Patch for R { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = R::UNCHANGED_ON_FAILURE; + + fn apply(self, h: &mut H) -> Result { + self.apply_hugr_mut(h) + } +} + +/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) +pub struct Transactional { + underlying: R, +} + +impl PatchVerification for Transactional { + type Error = R::Error; + type Node = R::Node; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + self.underlying.verify(h) + } + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + self.underlying.invalidation_set() + } +} + +// Note we might like to constrain R to Rewrite but +// this is not yet supported, https://github.com/rust-lang/rust/issues/92827 +impl PatchHugrMut for Transactional { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { + if R::UNCHANGED_ON_FAILURE { + return self.underlying.apply_hugr_mut(h); + } + // Try to backup just the contents of this HugrMut. + let mut backup = Hugr::new(h.root_type().clone()); + backup.insert_from_view(backup.root(), h); + let r = self.underlying.apply_hugr_mut(h); + if r.is_err() { + // Try to restore backup. + h.replace_op(h.root(), backup.root_type().clone()); + while let Some(child) = h.first_child(h.root()) { + h.remove_node(child); + } + h.insert_hugr(h.root(), backup); + } + r + } +} diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/patch/consts.rs similarity index 74% rename from hugr-core/src/hugr/rewrite/consts.rs rename to hugr-core/src/hugr/patch/consts.rs index ac657bf91..6d0c011fe 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -2,11 +2,11 @@ use std::iter; -use crate::{hugr::HugrMut, HugrView, Node}; +use crate::{core::HugrNode, hugr::HugrMut, HugrView, Node}; use itertools::Itertools; use thiserror::Error; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Remove a [`crate::ops::LoadConstant`] node with no consumers. #[derive(Debug, Clone)] @@ -15,25 +15,20 @@ pub struct RemoveLoadConstant(pub N); /// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] #[non_exhaustive] -pub enum RemoveError { +pub enum RemoveError { /// Invalid node. #[error("Node is invalid (either not in HUGR or not correct operation).")] - InvalidNode(Node), + InvalidNode(N), /// Node in use. #[error("Node: {0} has non-zero outgoing connections.")] - ValueUsed(Node), + ValueUsed(N), } -impl Rewrite for RemoveLoadConstant { - type Node = Node; - type Error = RemoveError; +impl PatchVerification for RemoveLoadConstant { + type Error = RemoveError; + type Node = N; - // The Const node the LoadConstant was connected to. - type ApplyResult = Node; - - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { @@ -51,7 +46,18 @@ impl Rewrite for RemoveLoadConstant { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveLoadConstant { + /// The [`Const`](crate::ops::Const) node the [`LoadConstant`](crate::ops::LoadConstant) was + /// connected to. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let source = h @@ -63,26 +69,17 @@ impl Rewrite for RemoveLoadConstant { Ok(source) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } /// Remove a [`crate::ops::Const`] node with no outputs. #[derive(Debug, Clone)] -pub struct RemoveConst(pub Node); - -impl Rewrite for RemoveConst { - type Node = Node; - type Error = RemoveError; +pub struct RemoveConst(pub N); - // The parent of the Const node. - type ApplyResult = Node; +impl PatchVerification for RemoveConst { + type Node = N; + type Error = RemoveError; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { @@ -96,7 +93,18 @@ impl Rewrite for RemoveConst { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveConst { + // The parent of the Const node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let parent = h @@ -106,10 +114,6 @@ impl Rewrite for RemoveConst { Ok(parent) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } #[cfg(test)] @@ -144,12 +148,12 @@ mod test { let tup_node = tup.node(); // can't remove invalid node assert_eq!( - h.apply_rewrite(RemoveConst(tup_node)), + h.apply_patch(RemoveConst(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); assert_eq!( - h.apply_rewrite(RemoveLoadConstant(tup_node)), + h.apply_patch(RemoveLoadConstant(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); let load_1_node = load_1.node(); @@ -172,7 +176,7 @@ mod test { // can't remove nodes in use assert_eq!( - h.apply_rewrite(remove_1.clone()), + h.apply_patch(remove_1.clone()), Err(RemoveError::ValueUsed(load_1_node)) ); @@ -180,20 +184,20 @@ mod test { h.remove_node(tup_node); // remove first load - let reported_con_node = h.apply_rewrite(remove_1)?; + let reported_con_node = h.apply_patch(remove_1)?; assert_eq!(reported_con_node, con_node); // still can't remove const, in use by second load assert_eq!( - h.apply_rewrite(remove_con.clone()), + h.apply_patch(remove_con.clone()), Err(RemoveError::ValueUsed(con_node)) ); // remove second use - let reported_con_node = h.apply_rewrite(remove_2)?; + let reported_con_node = h.apply_patch(remove_2)?; assert_eq!(reported_con_node, con_node); // remove const - assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + assert_eq!(h.apply_patch(remove_con)?, h.root()); assert_eq!(h.node_count(), 4); assert!(h.validate().is_ok()); diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs similarity index 91% rename from hugr-core/src/hugr/rewrite/inline_call.rs rename to hugr-core/src/hugr/patch/inline_call.rs index e32373507..0619d373e 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -2,41 +2,41 @@ //! into a DFG which replaces the Call node. use derive_more::{Display, Error}; +use crate::core::HugrNode; use crate::ops::{DataflowParent, OpType, DFG}; use crate::types::Substitution; use crate::{Direction, HugrView, Node}; -use super::{HugrMut, Rewrite}; +use super::{HugrMut, PatchHugrMut, PatchVerification}; /// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn) -pub struct InlineCall(Node); +pub struct InlineCall(N); /// Error in performing [InlineCall] rewrite. #[derive(Clone, Debug, Display, Error, PartialEq)] #[non_exhaustive] -pub enum InlineCallError { +pub enum InlineCallError { /// The specified Node was not a [Call](OpType::Call) #[display("Node to inline {_0} expected to be a Call but actually {_1}")] - NotCallNode(Node, OpType), + NotCallNode(N, OpType), /// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn) /// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid. #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")] - CallTargetNotFuncDefn(Node, OpType), + CallTargetNotFuncDefn(N, OpType), } -impl InlineCall { +impl InlineCall { /// Create a new instance that will inline the specified node /// (i.e. that should be a [Call](OpType::Call)) - pub fn new(node: Node) -> Self { + pub fn new(node: N) -> Self { Self(node) } } -impl Rewrite for InlineCall { - type Node = Node; - type ApplyResult = (); - type Error = InlineCallError; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { +impl PatchVerification for InlineCall { + type Error = InlineCallError; + type Node = N; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let call_ty = h.get_optype(self.0); if !call_ty.is_call() { return Err(InlineCallError::NotCallNode(self.0, call_ty.clone())); @@ -52,7 +52,14 @@ impl Rewrite for InlineCall { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn invalidation_set(&self) -> impl Iterator { + Some(self.0).into_iter() + } +} + +impl PatchHugrMut for InlineCall { + type Outcome = (); + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { self.verify(h)?; // Now we know we have a Call to a FuncDefn. let orig_func = h.static_source(self.0).unwrap(); @@ -99,10 +106,6 @@ impl Rewrite for InlineCall { /// Failure only occurs if the node is not a Call, or the target not a FuncDefn. /// (Any later failure means an invalid Hugr and `panic`.) const UNCHANGED_ON_FAILURE: bool = true; - - fn invalidation_set(&self) -> impl Iterator { - Some(self.0).into_iter() - } } #[cfg(test)] @@ -177,7 +180,7 @@ mod test { .count(), 1 ); - hugr.apply_rewrite(InlineCall(call1.node())).unwrap(); + hugr.apply_patch(InlineCall(call1.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); assert_eq!(calls(&hugr), [call2]); @@ -190,7 +193,7 @@ mod test { .count(), 1 ); - hugr.apply_rewrite(InlineCall(call2.node())).unwrap(); + hugr.apply_patch(InlineCall(call2.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).next(), None); assert_eq!(calls(&hugr), []); @@ -225,7 +228,7 @@ mod test { let func = func.node(); let mut call = call.node(); for i in 2..10 { - hugr.apply_rewrite(InlineCall(call))?; + hugr.apply_patch(InlineCall(call))?; hugr.validate().unwrap(); assert_eq!(extension_ops(&hugr).len(), i); let v = calls(&hugr); @@ -264,7 +267,7 @@ mod test { let h = modb.finish_hugr().unwrap(); let mut h2 = h.clone(); assert_eq!( - h2.apply_rewrite(InlineCall(call.node())), + h2.apply_patch(InlineCall(call.node())), Err(InlineCallError::CallTargetNotFuncDefn( decl.node(), h.get_optype(decl.node()).clone() @@ -277,7 +280,7 @@ mod test { .try_into() .unwrap(); assert_eq!( - h2.apply_rewrite(InlineCall(inp)), + h2.apply_patch(InlineCall(inp)), Err(InlineCallError::NotCallNode( inp, Input { @@ -314,7 +317,7 @@ mod test { hugr.output_neighbours(inner.node()).collect::>(), [call1.node(), call2.node()] ); - hugr.apply_rewrite(InlineCall::new(call1.node()))?; + hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( hugr.output_neighbours(inner.node()).collect::>(), diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/patch/inline_dfg.rs similarity index 96% rename from hugr-core/src/hugr/rewrite/inline_dfg.rs rename to hugr-core/src/hugr/patch/inline_dfg.rs index 8988df170..58fd51cbb 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/patch/inline_dfg.rs @@ -2,7 +2,7 @@ //! of the DFG except Input+Output into the DFG's parent, //! and deleting the DFG along with its Input + Output -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use crate::ops::handle::{DfgID, NodeHandle}; use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; @@ -21,13 +21,10 @@ pub enum InlineDFGError { NoParent, } -impl Rewrite for InlineDFG { - /// Returns the removed nodes: the DFG, and its Input and Output children. - type Node = Node; - type ApplyResult = [Node; 3]; +impl PatchVerification for InlineDFG { type Error = InlineDFGError; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = Node; fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> { let n = self.0.node(); @@ -40,10 +37,21 @@ impl Rewrite for InlineDFG { Ok(()) } - fn apply( + fn invalidation_set(&self) -> impl Iterator { + [self.0.node()].into_iter() + } +} + +impl PatchHugrMut for InlineDFG { + /// The removed nodes: the DFG, and its Input and Output children. + type Outcome = [Node; 3]; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( self, h: &mut impl crate::hugr::HugrMut, - ) -> Result { + ) -> Result { self.verify(h)?; let n = self.0.node(); let (oth_in, oth_out) = { @@ -124,10 +132,6 @@ impl Rewrite for InlineDFG { h.remove_node(n); Ok([n, input, output]) } - - fn invalidation_set(&self) -> impl Iterator { - [self.0.node()].into_iter() - } } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use crate::extension::prelude::qb_t; use crate::extension::ExtensionSet; - use crate::hugr::rewrite::inline_dfg::InlineDFGError; + use crate::hugr::patch::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{OpType, Value}; @@ -212,13 +216,13 @@ mod test { // Check we can't inline the outer DFG let mut h = outer.clone(); assert_eq!( - h.apply_rewrite(InlineDFG(DfgID::from(h.root()))), + h.apply_patch(InlineDFG(DfgID::from(h.root()))), Err(InlineDFGError::NoParent) ); assert_eq!(h, outer); // unchanged } - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; assert_eq!(outer.nodes().count(), 7); assert_eq!(find_dfgs(&outer), vec![outer.root()]); @@ -274,7 +278,7 @@ mod test { ] ); - h.apply_rewrite(InlineDFG(*swap.handle()))?; + h.apply_patch(InlineDFG(*swap.handle()))?; assert_eq!(find_dfgs(&h), vec![h.root()]); assert_eq!(h.nodes().count(), 5); // Dfg+I+O let mut ops = extension_ops(&h); @@ -350,7 +354,7 @@ mod test { )?; let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?; - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; let order_neighbours = |n, d| { let p = outer.get_optype(n).other_port(d).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/patch/insert_identity.rs similarity index 84% rename from hugr-core/src/hugr/rewrite/insert_identity.rs rename to hugr-core/src/hugr/patch/insert_identity.rs index bde43413b..98ab0ff02 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/patch/insert_identity.rs @@ -2,6 +2,7 @@ use std::iter; +use crate::core::HugrNode; use crate::extension::prelude::Noop; use crate::hugr::{HugrMut, Node}; use crate::ops::{OpTag, OpTrait}; @@ -9,22 +10,22 @@ use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{HugrView, IncomingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use thiserror::Error; /// Specification of a identity-insertion operation. #[derive(Debug, Clone)] -pub struct IdentityInsertion { +pub struct IdentityInsertion { /// The node following the identity to be inserted. - pub post_node: Node, + pub post_node: N, /// The port following the identity to be inserted. pub post_port: IncomingPort, } -impl IdentityInsertion { +impl IdentityInsertion { /// Create a new [`IdentityInsertion`] specification. - pub fn new(post_node: Node, post_port: IncomingPort) -> Self { + pub fn new(post_node: N, post_port: IncomingPort) -> Self { Self { post_node, post_port, @@ -47,12 +48,10 @@ pub enum IdentityInsertionError { InvalidPortKind(Option), } -impl Rewrite for IdentityInsertion { - type Node = Node; +impl PatchVerification for IdentityInsertion { type Error = IdentityInsertionError; - /// The inserted node. - type ApplyResult = Node; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = N; + fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> { /* Assumptions: @@ -66,10 +65,23 @@ impl Rewrite for IdentityInsertion { unimplemented!() } - fn apply( + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.post_node) + } +} + +impl PatchHugrMut for IdentityInsertion { + /// The inserted node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( self, - h: &mut impl HugrMut, - ) -> Result { + h: &mut impl HugrMut, + ) -> Result { let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); @@ -92,11 +104,6 @@ impl Rewrite for IdentityInsertion { h.connect(new_node, 0, self.post_node, self.post_port); Ok(new_node) } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.post_node) - } } #[cfg(test)] @@ -122,7 +129,7 @@ mod tests { let rw = IdentityInsertion::new(final_node, final_node_port); - let noop_node = h.apply_rewrite(rw).unwrap(); + let noop_node = h.apply_patch(rw).unwrap(); assert_eq!(h.node_count(), 7); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs similarity index 96% rename from hugr-core/src/hugr/rewrite/outline_cfg.rs rename to hugr-core/src/hugr/patch/outline_cfg.rs index a76dbc6ee..0f40615a9 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -1,4 +1,5 @@ -//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG +//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection +//! of an existing CFG use std::collections::HashSet; use itertools::Itertools; @@ -6,7 +7,6 @@ use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::extension::ExtensionSet; -use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; @@ -16,6 +16,8 @@ use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; +use super::{PatchHugrMut, PatchVerification}; + /// Moves part of a Control-flow Sibling Graph into a new CFG-node /// that is the only child of a new Basic Block in the original CSG. pub struct OutlineCfg { @@ -92,20 +94,30 @@ impl OutlineCfg { } } -impl Rewrite for OutlineCfg { - type Node = Node; +impl PatchVerification for OutlineCfg { type Error = OutlineCfgError; + type Node = Node; + fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { + self.compute_entry_exit_outside_extensions(h)?; + Ok(()) + } + + fn invalidation_set(&self) -> impl Iterator { + self.blocks.iter().copied() + } +} + +impl PatchHugrMut for OutlineCfg { /// The newly-created basic block, and the [CFG] node inside it /// /// [CFG]: OpType::CFG - type ApplyResult = (Node, Node); + type Outcome = [Node; 2]; const UNCHANGED_ON_FAILURE: bool = true; - fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside_extensions(h)?; - Ok(()) - } - fn apply(self, h: &mut impl HugrMut) -> Result<(Node, Node), OutlineCfgError> { + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result<[Node; 2], OutlineCfgError> { let (entry, exit, outside, extension_delta) = self.compute_entry_exit_outside_extensions(h)?; // 1. Compute signature @@ -212,11 +224,7 @@ impl Rewrite for OutlineCfg { // 4(b). Reconnect exit edge to the new exit node within the inner CFG h.connect(exit, exit_port, inner_exit, 0); - Ok((new_block, cfg_node)) - } - - fn invalidation_set(&self) -> impl Iterator { - self.blocks.iter().copied() + Ok([new_block, cfg_node]) } } @@ -361,22 +369,22 @@ mod test { } = cond_then_loop_cfg; let backup = h.clone(); - let r = h.apply_rewrite(OutlineCfg::new([tail])); + let r = h.apply_patch(OutlineCfg::new([tail])); assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right])); assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right]))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([left, right, merge])); + let r = h.apply_patch(OutlineCfg::new([left, right, merge])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right]))); assert_eq!(h, backup); // The entry node implicitly has an extra incoming edge - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head]))); assert_eq!(h, backup); @@ -497,7 +505,7 @@ mod test { ) -> (Node, Node, Node) { let mut other_blocks = h.children(cfg).collect::>(); assert!(blocks.iter().all(|b| other_blocks.remove(b))); - let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap(); + let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap(); for n in other_blocks { assert_eq!(h.get_parent(n), Some(cfg)) diff --git a/hugr-core/src/hugr/rewrite/port_types.rs b/hugr-core/src/hugr/patch/port_types.rs similarity index 100% rename from hugr-core/src/hugr/rewrite/port_types.rs rename to hugr-core/src/hugr/patch/port_types.rs diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/patch/replace.rs similarity index 73% rename from hugr-core/src/hugr/rewrite/replace.rs rename to hugr-core/src/hugr/patch/replace.rs index 0316f9d5b..6f0b0ed65 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -5,29 +5,32 @@ use std::collections::{HashMap, HashSet, VecDeque}; use itertools::Itertools; use thiserror::Error; +use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; use crate::hugr::HugrMut; use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Specifies how to create a new edge. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct NewEdgeSpec { - /// The source of the new edge. For [Replacement::mu_inp] and [Replacement::mu_new], this is in the - /// existing Hugr; for edges in [Replacement::mu_out] this is in the [Replacement::replacement] - pub src: Node, - /// The target of the new edge. For [Replacement::mu_inp], this is in the [Replacement::replacement]; - /// for edges in [Replacement::mu_out] and [Replacement::mu_new], this is in the existing Hugr. - pub tgt: Node, +pub struct NewEdgeSpec { + /// The source of the new edge. For [Replacement::mu_inp] and + /// [Replacement::mu_new], this is in the existing Hugr; for edges in + /// [Replacement::mu_out] this is in the [Replacement::replacement] + pub src: SrcNode, + /// The target of the new edge. For [Replacement::mu_inp], this is in the + /// [Replacement::replacement]; for edges in [Replacement::mu_out] and + /// [Replacement::mu_new], this is in the existing Hugr. + pub tgt: TgtNode, /// The kind of edge to create, and any port specifiers required pub kind: NewEdgeKind, } /// Describes an edge that should be created between two nodes already given -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NewEdgeKind { /// An [EdgeKind::StateOrder] edge (between DFG nodes only) Order, @@ -54,40 +57,47 @@ pub enum NewEdgeKind { /// Specification of a `Replace` operation #[derive(Debug, Clone, PartialEq)] -pub struct Replacement { +pub struct Replacement { /// The nodes to remove from the existing Hugr (known as Gamma). - /// These must all have a common parent (i.e. be siblings). Called "S" in the spec. - /// Must be non-empty - otherwise there is no parent under which to place [Self::replacement], - /// and there would be no possible [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. - pub removal: Vec, - /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), whose root - /// is the same type as the root of [Self::replacement]. "G" in the spec. + /// These must all have a common parent (i.e. be siblings). Called "S" in + /// the spec. Must be non-empty - otherwise there is no parent under + /// which to place [Self::replacement], and there would be no possible + /// [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. + pub removal: Vec, + /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), + /// whose root is the same type as the root of [Self::replacement]. "G" + /// in the spec. pub replacement: Hugr, - /// Describes how parts of the Hugr that would otherwise be removed should instead be preserved but - /// with new parents amongst the newly-inserted nodes. This is a Map from container nodes in - /// [Self::replacement] that have no children, to container nodes that are descended from [Self::removal]. - /// The keys are the new parents for the children of the values. Note no value may be ancestor or - /// descendant of another. This is "B" in the spec; "R" is the set of descendants of [Self::removal] - /// that are not descendants of values here. - pub adoptions: HashMap, - /// Edges from nodes in the existing Hugr that are not removed ([NewEdgeSpec::src] in Gamma\R) - /// to inserted nodes ([NewEdgeSpec::tgt] in [Self::replacement]). - pub mu_inp: Vec, - /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to existing nodes not removed - /// ([NewEdgeSpec::tgt] in Gamma \ R). - pub mu_out: Vec, - /// Edges to add between existing nodes (both [NewEdgeSpec::src] and [NewEdgeSpec::tgt] in Gamma \ R). - /// For example, in cases where the source had an edge to a removed node, and the target had an - /// edge from a removed node, this would allow source to be directly connected to target. - pub mu_new: Vec, + /// Describes how parts of the Hugr that would otherwise be removed should + /// instead be preserved but with new parents amongst the newly-inserted + /// nodes. This is a Map from container nodes in [Self::replacement] + /// that have no children, to container nodes that are descended from + /// [Self::removal]. The keys are the new parents for the children of + /// the values. Note no value may be ancestor or descendant of another. + /// This is "B" in the spec; "R" is the set of descendants of + /// [Self::removal] that are not descendants of values here. + pub adoptions: HashMap, + /// Edges from nodes in the existing Hugr that are not removed + /// ([NewEdgeSpec::src] in Gamma\R) to inserted nodes + /// ([NewEdgeSpec::tgt] in [Self::replacement]). + pub mu_inp: Vec>, + /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to + /// existing nodes not removed ([NewEdgeSpec::tgt] in Gamma \ R). + pub mu_out: Vec>, + /// Edges to add between existing nodes (both [NewEdgeSpec::src] and + /// [NewEdgeSpec::tgt] in Gamma \ R). For example, in cases where the + /// source had an edge to a removed node, and the target had an + /// edge from a removed node, this would allow source to be directly + /// connected to target. + pub mu_new: Vec>, } -impl NewEdgeSpec { - fn check_src( +impl NewEdgeSpec { + fn check_src( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.src); let ok = match self.kind { NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder), @@ -103,13 +113,14 @@ impl NewEdgeSpec { } }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec(self.clone()))) } - fn check_tgt( + + fn check_tgt( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.tgt); let ok = match self.kind { NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder), @@ -126,18 +137,20 @@ impl NewEdgeSpec { ), }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec(self.clone()))) } +} +impl NewEdgeSpec { fn check_existing_edge( &self, - h: &impl HugrView, - legal_src_ancestors: &HashSet, - err_edge: impl Fn() -> NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + legal_src_ancestors: &HashSet, + err_edge: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind { - let descends_from_legal = |mut descendant: Node| -> bool { + let descends_from_legal = |mut descendant: HostNode| -> bool { while !legal_src_ancestors.contains(&descendant) { let Some(p) = h.get_parent(descendant) else { return false; @@ -150,15 +163,18 @@ impl NewEdgeSpec { .single_linked_output(self.tgt, tgt_pos) .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { - return Err(ReplaceError::NoRemovedEdge(err_edge())); + return Err(ReplaceError::NoRemovedEdge(err_edge(self.clone()))); }; }; Ok(()) } } -impl Replacement { - fn check_parent(&self, h: &impl HugrView) -> Result { +impl Replacement { + fn check_parent( + &self, + h: &impl HugrView, + ) -> Result> { let parent = self .removal .iter() @@ -168,8 +184,9 @@ impl Replacement { .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))? .ok_or(ReplaceError::CantReplaceRoot)?; // If no parent - // Check replacement parent is of same tag. Note we do not require exact equality - // of OpType/Signature, e.g. to ease changing of Input/Output node signatures too. + // Check replacement parent is of same tag. Note we do not require exact + // equality of OpType/Signature, e.g. to ease changing of Input/Output + // node signatures too. let removed = h.get_optype(parent).tag(); let replacement = self.replacement.root_type().tag(); if removed != replacement { @@ -183,8 +200,8 @@ impl Replacement { fn get_removed_nodes( &self, - h: &impl HugrView, - ) -> Result, ReplaceError> { + h: &impl HugrView, + ) -> Result, ReplaceError> { // Check the keys of the transfer map too, the values we'll use imminently self.adoptions.keys().try_for_each(|&n| { (self.replacement.contains_node(n) @@ -193,7 +210,7 @@ impl Replacement { .then_some(()) .ok_or(ReplaceError::InvalidAdoptingParent(n)) })?; - let mut transferred: HashSet = self.adoptions.values().copied().collect(); + let mut transferred: HashSet = self.adoptions.values().copied().collect(); if transferred.len() != self.adoptions.values().len() { return Err(ReplaceError::AdopteesNotSeparateDescendants( self.adoptions @@ -221,98 +238,149 @@ impl Replacement { Ok(removed) } } -impl Rewrite for Replacement { - type Node = Node; - type Error = ReplaceError; - - /// Map from Node in replacement to corresponding Node in the result Hugr - type ApplyResult = HashMap; - const UNCHANGED_ON_FAILURE: bool = false; +impl PatchVerification for Replacement { + type Error = ReplaceError; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { self.check_parent(h)?; let removed = self.get_removed_nodes(h)?; // Edge sources... - for e in self.mu_inp.iter().chain(self.mu_new.iter()) { + for e in self.mu_inp.iter() { if !h.contains_node(e.src) || removed.contains(&e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )); } - e.check_src(h, e)?; + e.check_src(h, WhichEdgeSpec::HostToRepl)?; + } + for e in self.mu_new.iter() { + if !h.contains_node(e.src) || removed.contains(&e.src) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToHost(e.clone()), + )); + } + e.check_src(h, WhichEdgeSpec::HostToHost)?; } self.mu_out .iter() .try_for_each(|e| match self.replacement.valid_non_root(e.src) { - true => e.check_src(&self.replacement, e), + true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost), false => Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::ReplToHost(e.clone()), )), })?; // Edge targets... self.mu_inp .iter() .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) { - true => e.check_tgt(&self.replacement, e), + true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl), false => Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )), })?; - for e in self.mu_out.iter().chain(self.mu_new.iter()) { + for e in self.mu_out.iter() { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::ReplToHost(e.clone()), )); } - e.check_tgt(h, e)?; + e.check_tgt(h, WhichEdgeSpec::ReplToHost)?; // The descendant check is to allow the case where the old edge is nonlocal // from a part of the Hugr being moved (which may require changing source, // depending on where the transplanted portion ends up). While this subsumes - // the first "removed.contains" check, we'll keep that as a common-case fast-path. - e.check_existing_edge(h, &removed, || e.clone())?; + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::ReplToHost)?; + } + for e in self.mu_new.iter() { + if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Incoming, + WhichEdgeSpec::HostToHost(e.clone()), + )); + } + e.check_tgt(h, WhichEdgeSpec::HostToHost)?; + // The descendant check is to allow the case where the old edge is nonlocal + // from a part of the Hugr being moved (which may require changing source, + // depending on where the transplanted portion ends up). While this subsumes + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::HostToHost)?; } Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + self.removal.iter().copied() + } +} + +impl PatchHugrMut for Replacement { + /// Map from Node in replacement to corresponding Node in the result Hugr + type Outcome = HashMap; + + const UNCHANGED_ON_FAILURE: bool = false; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only - // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) + // nodes we are going to remove, individually, anyway; so no *asymptotic* speed + // penalty) let to_remove = self.get_removed_nodes(h)?; - // 1. Add all the new nodes. Note this includes replacement.root(), which we don't want. + // 1. Add all the new nodes. Note this includes replacement.root(), which we + // don't want. // TODO what would an error here mean? e.g. malformed self.replacement?? let InsertionResult { new_root, node_map } = h.insert_hugr(parent, self.replacement); // 2. Add new edges from existing to copied nodes according to mu_in - let translate_idx = |n| node_map.get(&n).copied().ok_or(WhichHugr::Replacement); - let kept = |n| { - let keep = !to_remove.contains(&n); - keep.then_some(n).ok_or(WhichHugr::Retained) - }; - transfer_edges(h, self.mu_inp.iter(), kept, translate_idx, None)?; + let translate_idx = |n| node_map.get(&n).copied(); + let kept = |n| (!to_remove.contains(&n)).then_some(n); + transfer_edges( + h, + self.mu_inp.iter(), + kept, + translate_idx, + WhichEdgeSpec::HostToRepl, + None, + )?; // 3. Add new edges from copied to existing nodes according to mu_out, // replacing existing value/static edges incoming to targets - transfer_edges(h, self.mu_out.iter(), translate_idx, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_out.iter(), + translate_idx, + kept, + WhichEdgeSpec::ReplToHost, + Some(&to_remove), + )?; // 4. Add new edges between existing nodes according to mu_new, // replacing existing value/static edges incoming to targets. - transfer_edges(h, self.mu_new.iter(), kept, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_new.iter(), + kept, + kept, + WhichEdgeSpec::HostToHost, + Some(&to_remove), + )?; // 5. Put newly-added copies into correct places in hierarchy // (these will be correct places after removing nodes) let mut remove_top_sibs = self.removal.iter(); - for new_node in h.children(new_root).collect::>().into_iter() { + for new_node in h.children(new_root).collect::>().into_iter() { if let Some(top_sib) = remove_top_sibs.next() { h.move_before_sibling(new_node, *top_sib); } else { @@ -337,51 +405,53 @@ impl Rewrite for Replacement { }); Ok(node_map) } - - fn invalidation_set(&self) -> impl Iterator { - self.removal.iter().copied() - } } -fn transfer_edges<'a>( - h: &mut impl HugrMut, - edges: impl Iterator, - trans_src: impl Fn(Node) -> Result, - trans_tgt: impl Fn(Node) -> Result, - legal_src_ancestors: Option<&HashSet>, -) -> Result<(), ReplaceError> { +fn transfer_edges<'a, SrcNode, TgtNode, HostNode>( + h: &mut impl HugrMut, + edges: impl Iterator>, + trans_src: impl Fn(SrcNode) -> Option, + trans_tgt: impl Fn(TgtNode) -> Option, + err_spec: impl Fn(NewEdgeSpec) -> WhichEdgeSpec, + legal_src_ancestors: Option<&HashSet>, +) -> Result<(), ReplaceError> +where + SrcNode: 'a + HugrNode, + TgtNode: 'a + HugrNode, + HostNode: 'a + HugrNode, +{ for oe in edges { + let err_spec = err_spec(oe.clone()); let e = NewEdgeSpec { // Translation can only fail for Nodes that are supposed to be in the replacement src: trans_src(oe.src) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Outgoing, h, oe.clone()))?, + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Outgoing, err_spec.clone()))?, tgt: trans_tgt(oe.tgt) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Incoming, h, oe.clone()))?, - ..oe.clone() + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?, + kind: oe.kind, }; if !h.valid_node(e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); } if !h.valid_node(e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); }; - e.check_src(h, oe)?; - e.check_tgt(h, oe)?; + let err_spec = |_| err_spec.clone(); + e.check_src(h, err_spec)?; + e.check_tgt(h, err_spec)?; match e.kind { NewEdgeKind::Order => { h.add_other_edge(e.src, e.tgt); } NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => { if let Some(legal_src_ancestors) = legal_src_ancestors { - e.check_existing_edge(h, legal_src_ancestors, || oe.clone())?; + e.check_existing_edge(h, legal_src_ancestors, err_spec)?; h.disconnect(e.tgt, tgt_pos); } h.connect(e.src, src_pos, e.tgt, tgt_pos); @@ -395,14 +465,14 @@ fn transfer_edges<'a>( /// Error in a [`Replacement`] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum ReplaceError { +pub enum ReplaceError { /// The node(s) to replace had no parent i.e. were root(s). // (Perhaps if there is only one node to replace we should be able to?) #[error("Cannot replace the root node of the Hugr")] CantReplaceRoot, /// The nodes to replace did not have a unique common parent #[error("Removed nodes had different parents {0:?}")] - MultipleParents(Vec), + MultipleParents(Vec), /// Replacement root node had different tag from parent of removed nodes #[error("Expected replacement root with tag {removed} but found {replacement}")] WrongRootNodeTag { @@ -411,40 +481,47 @@ pub enum ReplaceError { /// The tag of the root in the replacement Hugr replacement: OpTag, }, - /// Keys in [Replacement::adoptions] were not valid container nodes in [Replacement::replacement] + /// Keys in [Replacement::adoptions] were not valid container nodes in + /// [Replacement::replacement] #[error("Node {0} was not an empty container node in the replacement")] InvalidAdoptingParent(Node), - /// Some values in [Replacement::adoptions] were either descendants of other values, or not - /// descendants of the [Replacement::removal]. The nodes are indicated on a best-effort basis. + /// Some values in [Replacement::adoptions] were either descendants of other + /// values, or not descendants of the [Replacement::removal]. The nodes + /// are indicated on a best-effort basis. #[error("Nodes not free to be moved into new locations: {0:?}")] - AdopteesNotSeparateDescendants(Vec), + AdopteesNotSeparateDescendants(Vec), /// A node at one end of a [NewEdgeSpec] was not found - #[error("{0:?} end of edge {2:?} not found in {1}")] - BadEdgeSpec(Direction, WhichHugr, NewEdgeSpec), - /// The target of the edge was found, but there was no existing edge to replace + #[error("{0:?} end of edge {1:?} not found in {which_hugr}", which_hugr = .1.which_hugr(*.0))] + BadEdgeSpec(Direction, WhichEdgeSpec), + /// The target of the edge was found, but there was no existing edge to + /// replace #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")] - NoRemovedEdge(NewEdgeSpec), + NoRemovedEdge(WhichEdgeSpec), /// The [NewEdgeKind] was not applicable for the source/target node(s) #[error("The edge kind was not applicable to the {0:?} node: {1:?}")] - BadEdgeKind(Direction, NewEdgeSpec), + BadEdgeKind(Direction, WhichEdgeSpec), } -/// A Hugr or portion thereof that is part of the [Replacement] +/// The three kinds of [NewEdgeSpec] that may appear in a [ReplaceError] #[derive(Clone, Debug, PartialEq, Eq)] -pub enum WhichHugr { - /// The newly-inserted nodes, i.e. the [Replacement::replacement] - Replacement, - /// Nodes in the existing Hugr that are not [Replacement::removal] - /// (or are on the RHS of an entry in [Replacement::adoptions]) - Retained, +pub enum WhichEdgeSpec { + /// An edge from the host Hugr into the replacement, i.e. + /// [Replacement::mu_inp] + HostToRepl(NewEdgeSpec), + /// An edge from the replacement to the host, i.e. [Replacement::mu_out] + ReplToHost(NewEdgeSpec), + /// An edge between two nodes in the host (bypassing the replacement), + /// i.e. [Replacement::mu_new] + HostToHost(NewEdgeSpec), } -impl std::fmt::Display for WhichHugr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - Self::Replacement => "replacement Hugr", - Self::Retained => "retained portion of Hugr", - }) +impl WhichEdgeSpec { + fn which_hugr(&self, d: Direction) -> &str { + match (self, d) { + (Self::HostToRepl(_), Direction::Incoming) + | (Self::ReplToHost(_), Direction::Outgoing) => "replacement Hugr", + _ => "retained portion of Hugr", + } } } @@ -462,8 +539,8 @@ mod test { use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::internal::HugrMutInternals; - use crate::hugr::rewrite::replace::WhichHugr; - use crate::hugr::{HugrMut, Rewrite}; + use crate::hugr::patch::PatchVerification; + use crate::hugr::{HugrMut, Patch}; use crate::ops::custom::ExtensionOp; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; @@ -473,7 +550,7 @@ mod test { use crate::utils::{depth, test_quantum_extension}; use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; - use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; + use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement, WhichEdgeSpec}; #[test] #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-' @@ -520,7 +597,8 @@ mod test { } // Replacement: one BB with two DFGs inside. - // Use Hugr rather than Builder because DFGs must be empty (not even Input/Output). + // Use Hugr rather than Builder because it must be empty (not even + // Input/Output). let mut replacement = Hugr::new(ops::CFG { signature: Signature::new_endo(just_list.clone()), }); @@ -569,7 +647,7 @@ mod test { replacement.connect(r_df2, 1, out, 1); } - h.apply_rewrite(Replacement { + h.apply_patch(Replacement { removal: vec![entry.node(), bb2.node()], replacement, adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]), @@ -788,7 +866,10 @@ mod test { mu_inp: vec![edge_from_removed.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed) + ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToRepl(edge_from_removed) + ) ); let bad_out_edge = NewEdgeSpec { src: h.nodes().max().unwrap(), // not valid in replacement @@ -800,7 +881,7 @@ mod test { mu_out: vec![bad_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge) + ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichEdgeSpec::ReplToHost(bad_out_edge),) ); let bad_order_edge = NewEdgeSpec { src: cond.node(), @@ -812,7 +893,7 @@ mod test { mu_new: vec![bad_order_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, bad_order_edge) + ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, WhichEdgeSpec::HostToHost(bad_order_edge)) ); let op = OutgoingPort::from(0); let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap(); @@ -829,7 +910,7 @@ mod test { mu_out: vec![new_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge) + ReplaceError::BadEdgeKind(Direction::Outgoing, WhichEdgeSpec::ReplToHost(new_out_edge)) ); } } diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs similarity index 92% rename from hugr-core/src/hugr/rewrite/simple_replace.rs rename to hugr-core/src/hugr/patch/simple_replace.rs index 5d3716dc0..e9283644d 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, Rewrite}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; @@ -14,7 +14,7 @@ use itertools::Itertools; use thiserror::Error; use super::inline_dfg::InlineDFGError; -use super::{BoundaryPort, HostPort, ReplacementPort}; +use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort}; /// Specification of a simple replacement operation. /// @@ -28,7 +28,8 @@ pub struct SimpleReplacement { /// A hugr with DFG root (consisting of replacement nodes). replacement: Hugr, /// A map from (target ports of edges from the Input node of `replacement`) - /// to (target ports of edges from nodes not in `subgraph` to nodes in `subgraph`). + /// to (target ports of edges from nodes not in `subgraph` to nodes in + /// `subgraph`). nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, /// A map from (target ports of edges from nodes in `subgraph` to nodes not /// in `subgraph`) to (input ports of the Output node of `replacement`). @@ -125,7 +126,8 @@ impl SimpleReplacement { }) .map( |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) + // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, + // n_inp_port) let (rem_inp_pred_node, rem_inp_pred_port) = host .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); @@ -158,8 +160,9 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p] such that the predecessor of q is not an Input port, - // there will be an edge from (the new copy of) the predecessor of q to p. + // For each q = self.nu_out[p] such that the predecessor of q is not an Input + // port, there will be an edge from (the new copy of) the predecessor of + // q to p. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| { @@ -196,8 +199,8 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 - // to p1. + // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the + // predecessor of p0 to p1. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| { @@ -245,8 +248,9 @@ impl SimpleReplacement { /// Get all edges that the replacement would add between `host` and /// `self.replacement`. /// - /// This is equivalent to chaining the results of [`Self::incoming_boundary`], - /// [`Self::outgoing_boundary`], and [`Self::host_to_host_boundary`]. + /// This is equivalent to chaining the results of + /// [`Self::incoming_boundary`], [`Self::outgoing_boundary`], and + /// [`Self::host_to_host_boundary`]. /// /// This panics if self.replacement is not a DFG. pub fn all_boundary_edges<'a>( @@ -274,17 +278,35 @@ impl SimpleReplacement { } } -impl Rewrite for SimpleReplacement { - type Node = Node; +impl PatchVerification for SimpleReplacement { type Error = SimpleReplacementError; - type ApplyResult = Vec<(Node, OpType)>; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { + fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { self.is_valid_rewrite(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + #[inline] + fn invalidation_set(&self) -> impl Iterator { + let subcirc = self.subgraph.nodes().iter().copied(); + let out_neighs = self.nu_out.keys().map(|key| key.0); + subcirc.chain(out_neighs) + } +} + +/// Result of applying a [`SimpleReplacement`]. +pub struct Outcome { + /// Map from Node in replacement to corresponding Node in the result Hugr + pub node_map: HashMap, + /// Nodes removed from the result Hugr and their weights + pub removed_nodes: HashMap, +} + +impl PatchHugrMut for SimpleReplacement { + type Outcome = Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.is_valid_rewrite(h)?; let parent = self.subgraph.get_parent(h); @@ -305,13 +327,10 @@ impl Rewrite for SimpleReplacement { } = self; // 2. Insert the replacement as a whole. - let InsertionResult { - new_root, - node_map: index_map, - } = h.insert_hugr(parent, replacement); + let InsertionResult { new_root, node_map } = h.insert_hugr(parent, replacement); // remove the Input and Output nodes from the replacement graph - let replace_children = h.children(new_root).collect::>(); + let replace_children = h.children(new_root).collect::>(); for &io in &replace_children[..2] { h.remove_node(io); } @@ -324,24 +343,22 @@ impl Rewrite for SimpleReplacement { // 3. Insert all boundary edges. for (src, tgt) in boundary_edges { - let (src_node, src_port) = src.map_replacement(&index_map); - let (tgt_node, tgt_port) = tgt.map_replacement(&index_map); + let (src_node, src_port) = src.map_replacement(&node_map); + let (tgt_node, tgt_port) = tgt.map_replacement(&node_map); h.connect(src_node, src_port, tgt_node, tgt_port); } // 4. Remove all nodes in subgraph and edges between them. - Ok(subgraph + let removed_nodes = subgraph .nodes() .iter() .map(|&node| (node, h.remove_node(node))) - .collect()) - } + .collect(); - #[inline] - fn invalidation_set(&self) -> impl Iterator { - let subcirc = self.subgraph.nodes().iter().copied(); - let out_neighs = self.nu_out.keys().map(|key| key.0); - subcirc.chain(out_neighs) + Ok(Outcome { + node_map, + removed_nodes, + }) } } @@ -364,9 +381,10 @@ pub enum SimpleReplacementError { } #[cfg(test)] -pub(in crate::hugr::rewrite) mod test { +pub(in crate::hugr::patch) mod test { use itertools::Itertools; use rstest::{fixture, rstest}; + use std::collections::{HashMap, HashSet}; use crate::builder::test::n_identity; @@ -376,8 +394,9 @@ pub(in crate::hugr::rewrite) mod test { }; use crate::extension::prelude::{bool_t, qb_t}; use crate::extension::ExtensionSet; + use crate::hugr::patch::PatchVerification; use crate::hugr::views::{HugrView, SiblingSubgraph}; - use crate::hugr::{Hugr, HugrMut, Rewrite}; + use crate::hugr::{Hugr, HugrMut, Patch}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::NodeHandle; use crate::ops::OpTag; @@ -433,7 +452,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr { + pub(in crate::hugr::patch) fn simple_hugr() -> Hugr { make_hugr().unwrap() } /// Creates a hugr with a DFG root like the following: @@ -453,7 +472,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr { make_dfg_hugr().unwrap() } @@ -473,7 +492,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr { make_dfg_hugr2().unwrap() } @@ -485,11 +504,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ ┌─────────┐ /// └────┤ (2) NOT ├── /// └─────────┘ - /// This can be replaced with an empty hugr coping the input to both outputs. + /// This can be replaced with an empty hugr coping the input to both + /// outputs. /// /// Returns the hugr and the nodes of the NOT gates, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -516,11 +536,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ /// └───────────────── /// - /// This can be replaced with a single NOT op, coping the input to the first output. + /// This can be replaced with a single NOT op, coping the input to the first + /// output. /// /// Returns the hugr and the nodes of the NOT ops, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -682,7 +703,7 @@ pub(in crate::hugr::rewrite) mod test { nu_inp, nu_out, }; - h.apply_rewrite(r).unwrap(); + h.apply_patch(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ // ┤ H ├┤ H ├ @@ -736,7 +757,7 @@ pub(in crate::hugr::rewrite) mod test { }) .map(|p| ((output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), replacement, inputs, @@ -788,7 +809,7 @@ pub(in crate::hugr::rewrite) mod test { .map(|p| ((repl_output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), repl, inputs, @@ -800,8 +821,8 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), orig.node_count()); } - /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input - /// directly to the outputs. + /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the + /// input directly to the outputs. /// /// https://github.com/CQCL/hugr/issues/1190 #[rstest] @@ -822,8 +843,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr) .unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -836,8 +858,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -857,8 +879,8 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(hugr.node_count(), 3); } - /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input - /// directly to the output. + /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting + /// the input directly to the output. /// /// https://github.com/CQCL/hugr/issues/1323 #[rstest] @@ -880,8 +902,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -894,8 +917,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -959,9 +982,9 @@ pub(in crate::hugr::rewrite) mod test { assert_eq!(h.node_count(), 6); } - use crate::hugr::rewrite::replace::Replacement; + use crate::hugr::patch::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { - use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; + use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec}; let mut replacement = s.replacement; let (in_, out) = replacement @@ -1018,10 +1041,10 @@ pub(in crate::hugr::rewrite) mod test { } fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(rw).unwrap(); + h.apply_patch(rw).unwrap(); } fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(to_replace(h, rw)).unwrap(); + h.apply_patch(to_replace(h, rw)).unwrap(); } } diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 9502d9f6b..680d58a03 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -838,7 +838,7 @@ mod tests { use cool_asserts::assert_matches; use crate::builder::inout_sig; - use crate::hugr::Rewrite; + use crate::hugr::Patch; use crate::ops::Const; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::std_extensions::logic::{self, LogicOp}; @@ -1011,7 +1011,7 @@ mod tests { assert_eq!(rep.subgraph().nodes().len(), 4); assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out - hugr.apply_rewrite(rep).unwrap(); + hugr.apply_patch(rep).unwrap(); assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out Ok(()) diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 3a3bd5e91..7e68e600a 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -4,6 +4,7 @@ use hugr_core::{ Hugr, Node, }; +use itertools::Itertools; use thiserror::Error; /// Replace all operations in a HUGR according to a mapping. @@ -69,9 +70,11 @@ pub fn lower_ops( .map(|(node, replacement)| { let subcirc = SiblingSubgraph::from_node(node, hugr); let rw = subcirc.create_simple_replacement(hugr, replacement)?; - let mut repls = hugr.apply_rewrite(rw)?; - debug_assert_eq!(repls.len(), 1); - Ok(repls.remove(0)) + let removed_nodes = hugr.apply_patch(rw)?.removed_nodes; + Ok(removed_nodes + .into_iter() + .exactly_one() + .expect("removed exactly one node")) }) .collect() } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index a5de5eb57..5c76ba51d 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -7,8 +7,8 @@ use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::views::RootCheckable; use itertools::Itertools; -use hugr_core::hugr::rewrite::inline_dfg::InlineDFG; -use hugr_core::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; +use hugr_core::hugr::patch::inline_dfg::InlineDFG; +use hugr_core::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; use hugr_core::{Hugr, HugrView, Node}; @@ -39,11 +39,11 @@ where continue; }; let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.apply_rewrite(rep).unwrap(); + let node_map = cfg.apply_patch(rep).unwrap(); let merged_bb = *node_map.get(&merge_bb).unwrap(); for dfg_id in dfgs { let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.apply_rewrite(InlineDFG(n_id.into())).unwrap(); + cfg.apply_patch(InlineDFG(n_id.into())).unwrap(); } worklist.push(merged_bb); } diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index b98d4fb23..6e9df7f1a 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -44,10 +44,10 @@ use std::hash::Hash; use itertools::Itertools; use thiserror::Error; -use hugr_core::hugr::rewrite::outline_cfg::OutlineCfg; +use hugr_core::hugr::patch::outline_cfg::OutlineCfg; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::hugr::views::{HierarchyView, HugrView, RootCheckable, SiblingGraph}; -use hugr_core::hugr::{hugrmut::HugrMut, Rewrite}; +use hugr_core::hugr::{hugrmut::HugrMut, Patch}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; @@ -260,7 +260,7 @@ impl> CfgNester for IdentityCfgMap { assert!([entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); - let (new_block, new_cfg) = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); + let [new_block, new_cfg] = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); debug_assert!([entry_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); @@ -579,7 +579,7 @@ pub(crate) mod test { }; use hugr_core::extension::{prelude::usize_t, ExtensionSet}; - use hugr_core::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; + use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::{ConstID, NodeHandle}; use hugr_core::ops::Value; @@ -830,7 +830,7 @@ pub(crate) mod test { let rw = IdentityInsertion::new(final_node, final_node_input); - let apply_result = h.apply_rewrite(rw); + let apply_result = h.apply_patch(rw); assert_eq!( apply_result, Err(IdentityInsertionError::InvalidPortKind(Some( diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index d074bed0f..00af101dc 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -131,7 +131,7 @@ impl ComposablePass for UntuplePass { let rewrites_applied = rewrites.len(); // The rewrites are independent, so we can always apply them all. for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; + hugr.apply_patch(rewrite)?; } Ok(UntupleResult { rewrites_applied }) } diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index a66de8315..0bd8f64ff 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -2,7 +2,7 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ - hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, + hugrmut, patch, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Patch, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; From ca302d62e870a1af21bf0c0e523098efd79a4bba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:10:37 +0100 Subject: [PATCH 13/21] feat!: Bump MSRV to 1.85 (#2136) Bumps the minimum supported rust version to 1.85, ~and migrates to the 2024 edition~. Most of the changes here are automated. The diff is quite noisy due to some changes in formatting, and me using this opportunity to auto--fix some optional clippy lints. It may be easier to check the changes per-commit. EDIT: Edition change has been left for a separate PR, as it's quite noisy. BREAKING CHANGE: Bumped MSRV to 1.85 --- .github/workflows/ci-rs.yml | 10 ++++------ .pre-commit-config.yaml | 2 +- Cargo.toml | 2 +- DEVELOPMENT.md | 2 +- hugr-cli/README.md | 2 +- hugr-core/README.md | 2 +- hugr-llvm/README.md | 2 +- hugr-model/README.md | 2 +- hugr-passes/README.md | 2 +- hugr-passes/src/replace_types/linearize.rs | 3 +-- hugr-passes/src/untuple.rs | 2 +- hugr/README.md | 2 +- justfile | 2 +- 13 files changed, 16 insertions(+), 19 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 56c093e8b..4fe5d244f 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -233,7 +233,7 @@ jobs: id: toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: "1.75" + toolchain: "1.85" - name: Install nightly toolchain uses: dtolnay/rust-toolchain@master with: @@ -252,12 +252,10 @@ jobs: cargo binstall cargo-minimal-versions --force - name: Pin transitive dependencies not compatible with our MSRV # Add new dependencies as needed if the check fails due to - # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.75.0" + # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.85.0" run: | - rm Cargo.lock - cargo add -p hugr half@2.4.1 - cargo add -p hugr litemap@0.7.4 - cargo add -p hugr zerofrom@0.1.5 + # rm Cargo.lock + # cargo add -p hugr half@2.4.1 - name: Build with no features run: cargo minimal-versions --direct test --verbose --no-default-features --no-run - name: Tests with no features diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 569ccfec1..4fe582d93 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: # built into a binary build (without using `maturin`) # # This feature list should be kept in sync with the `hugr-py/pyproject.toml` - entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' language: system files: \.rs$ pass_filenames: false diff --git a/Cargo.toml b/Cargo.toml index 3031df1e7..c72326a33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] [workspace.package] -rust-version = "1.75" +rust-version = "1.85" edition = "2021" homepage = "https://github.com/CQCL/hugr" repository = "https://github.com/CQCL/hugr" diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4659f96c7..6d9465140 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -29,7 +29,7 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/). To setup the environment manually you will need: - Just: https://just.systems/ -- Rust `>=1.75`: https://www.rust-lang.org/tools/install +- Rust `>=1.85`: https://www.rust-lang.org/tools/install - uv `>=0.3`: docs.astral.sh/uv/getting-started/installation - Optional: capnproto `>=1.0`: https://capnproto.org/install.html Required when modifying the `hugr-model` serialization schema. diff --git a/hugr-cli/README.md b/hugr-cli/README.md index 277628d2b..dba9900e2 100644 --- a/hugr-cli/README.md +++ b/hugr-cli/README.md @@ -64,7 +64,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-cli/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-cli [crates]: https://img.shields.io/crates/v/hugr-cli [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-core/README.md b/hugr-core/README.md index 379041a5b..765d4577b 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -36,7 +36,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-core/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main -[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg +[msrv]: https://img.shields.io/crates/msrv/hugr-core [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-llvm/README.md b/hugr-llvm/README.md index 5fd2d3239..6d81cd35d 100644 --- a/hugr-llvm/README.md +++ b/hugr-llvm/README.md @@ -32,7 +32,7 @@ See [DEVELOPMENT](DEVELOPMENT.md) for instructions on setting up the development This project is licensed under Apache License, Version 2.0 ([LICENCE](LICENCE) or ). [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-llvm [hugr]: https://lib.rs/crates/hugr [inkwell]: https://thedan64.github.io/inkwell/inkwell/index.html [llvm-sys]: https://crates.io/crates/llvm-sys diff --git a/hugr-model/README.md b/hugr-model/README.md index 0ea6fdf8f..be93253eb 100644 --- a/hugr-model/README.md +++ b/hugr-model/README.md @@ -30,7 +30,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-model/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-model [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-passes/README.md b/hugr-passes/README.md index b9552fe75..b441ed5e7 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -51,7 +51,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-passes/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-passes [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 321ec194f..2788a2379 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,4 +1,3 @@ -use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ @@ -273,7 +272,7 @@ impl Linearizer for DelegatingLinearizer { let mut elems_for_copy = vec![vec![]; num_outports]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { let inp_copies = if ty.copyable() { - repeat(inp).take(num_outports).collect::>() + std::iter::repeat_n(inp, num_outports).collect::>() } else { self.copy_discard_op(ty, num_outports)? .add(&mut case_b, [inp]) diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index 00af101dc..b2782e8d9 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -247,7 +247,7 @@ fn remove_pack_unpack<'h, T: HugrView>( .add_dataflow_op(op, replacement.input_wires()) .unwrap() .outputs_arr(); - outputs.extend(std::iter::repeat(tuple).take(num_other_outputs)) + outputs.extend(std::iter::repeat_n(tuple, num_other_outputs)) } // These should never fail, as we are defining the replacement ourselves. diff --git a/hugr/README.md b/hugr/README.md index 6ecfc405b..b54d4f62d 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -51,7 +51,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr [crates]: https://img.shields.io/crates/v/hugr [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/justfile b/justfile index 61173375b..7b8075f94 100644 --- a/justfile +++ b/justfile @@ -23,7 +23,7 @@ test-rust: HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ - --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' # Run all python tests. test-python: uv run maturin develop --uv From 320b81e93049ca8481cf8114fe8a15d77376db9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 29 Apr 2025 15:32:27 +0100 Subject: [PATCH 14/21] feat!: Cleanup core trait definitions (#2126) Moves some methods around in `HugrInternals` / `HugrView` that we'll require for #1926 and #2029. Notable changes: - Adds a `hierarchy` method to `HugrInternals` that replaces most calls to `base_hugr`. - `HugrInternals::base_hugr` is now deprecated. It's only used by `Sibling/DescendantGraph`, `validate` and a random use in `DeadCodeElimPass`. - Adds a `HugrInternals::region_portgraph` method that returns a `FlatRegion` portgraph wrapper, and a `HugrView::descendants` call. These lets us replace most uses of `SiblingGraph` and `DescendantGraph`. - Renamed `HugrInternals::{get_pg_index,get_node}` to `to_portgraph_node` and `from_portgraph_node`. This requires some new changes in `portgraph`. I'll make a minor release and update it here before merging. We should be able to remove `base_hugr` after #2029. The deprecation warning here is only temporary. BREAKING CHANGE: Modified multiple core `HugrView` and `HugrInternals` trait methods. See #2126. --- Cargo.lock | 96 ++--- Cargo.toml | 6 +- hugr-core/src/builder/dataflow.rs | 6 +- hugr-core/src/core.rs | 2 +- hugr-core/src/export.rs | 15 +- hugr-core/src/hugr.rs | 52 ++- hugr-core/src/hugr/hugrmut.rs | 218 ++++++------ hugr-core/src/hugr/internal.rs | 114 +++--- hugr-core/src/hugr/patch.rs | 4 +- hugr-core/src/hugr/patch/consts.rs | 4 +- hugr-core/src/hugr/patch/insert_identity.rs | 4 +- hugr-core/src/hugr/patch/outline_cfg.rs | 5 +- hugr-core/src/hugr/patch/replace.rs | 25 +- hugr-core/src/hugr/patch/simple_replace.rs | 12 +- hugr-core/src/hugr/rewrite.rs | 4 +- hugr-core/src/hugr/serialize.rs | 10 +- hugr-core/src/hugr/validate.rs | 44 ++- hugr-core/src/hugr/validate/test.rs | 16 +- hugr-core/src/hugr/views.rs | 329 +++++++++++------- hugr-core/src/hugr/views/descendants.rs | 97 ++++-- hugr-core/src/hugr/views/impls.rs | 59 ++-- hugr-core/src/hugr/views/petgraph.rs | 12 +- hugr-core/src/hugr/views/render.rs | 10 +- hugr-core/src/hugr/views/sibling.rs | 159 ++++++--- hugr-core/src/hugr/views/sibling_subgraph.rs | 8 +- hugr-core/src/ops/constant.rs | 2 +- hugr-llvm/src/emit/ops.rs | 54 ++- ...hugr_call_indirect@pre-mem2reg@llvm14.snap | 8 +- hugr-llvm/src/utils/fat.rs | 6 +- hugr-passes/src/const_fold/test.rs | 8 +- hugr-passes/src/dead_code.rs | 1 + hugr-passes/src/force_order.rs | 60 ++-- hugr-passes/src/lower.rs | 2 +- hugr-passes/src/merge_bbs.rs | 2 +- hugr-passes/src/replace_types.rs | 4 +- 35 files changed, 800 insertions(+), 658 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ea3053aa9..6b18c3101 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "serde", "version_check", @@ -301,9 +301,9 @@ checksum = "38c99613cb3cd7429889a08dfcf651721ca971c86afa30798461f8eee994de47" [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -351,9 +351,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.18" +version = "1.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" +checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a" dependencies = [ "jobserver", "libc", @@ -936,9 +936,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -1483,14 +1483,12 @@ dependencies = [ [[package]] name = "insta" -version = "1.42.2" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "ab2d11b2f17a45095b8c3603928ba29d7d918d7129d0d0641a36ba73cf07daa6" dependencies = [ "console", - "linked-hash-map", "once_cell", - "pin-project", "serde", "similar", ] @@ -1622,21 +1620,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.171" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" - -[[package]] -name = "linked-hash-map" -version = "0.5.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" @@ -1696,9 +1688,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -1937,26 +1929,6 @@ dependencies = [ "serde", ] -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2011,9 +1983,9 @@ checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portgraph" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9ea69cfb011d5f17af28813ec37a0a9668a063090e14ad75dc5fc07ba01b47" +checksum = "5fdce52d51ec359351ff3c209fafb6f133562abf52d951ce5821c0184798d979" dependencies = [ "bitvec", "delegate", @@ -2029,7 +2001,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.8.24", + "zerocopy 0.8.25", ] [[package]] @@ -2094,9 +2066,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -2259,7 +2231,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -2694,9 +2666,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -2830,15 +2802,15 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485" dependencies = [ "indexmap", "toml_datetime", @@ -3418,9 +3390,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" +checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5" dependencies = [ "memchr", ] @@ -3496,11 +3468,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.24", + "zerocopy-derive 0.8.25", ] [[package]] @@ -3516,9 +3488,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index c72326a33..97dad7dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,7 +58,7 @@ regex = "1.10.6" regex-syntax = "0.8.3" rstest = "0.24.0" semver = "1.0.26" -serde = "1.0.195" +serde = "1.0.219" serde_json = "1.0.140" serde_yaml = "0.9.34" smol_str = "0.3.1" @@ -87,8 +87,8 @@ zstd = "0.13.2" # These public dependencies usually require breaking changes downstream, so we # try to be as permissive as possible. pyo3 = ">= 0.23.4, < 0.25" -portgraph = { version = ">= 0.13.3, < 0.15" } -petgraph = { version = ">= 0.7.1, < 0.9", default-features = false } +portgraph = { version = "0.14.1" } +petgraph = { version = ">= 0.8.1, < 0.9", default-features = false } [profile.dev.package] insta.opt-level = 3 diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 64c5f5c84..b84f3a05a 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -506,8 +506,8 @@ pub(crate) mod test { #[rstest] fn dfg_hugr(simple_dfg_hugr: Hugr) { - assert_eq!(simple_dfg_hugr.node_count(), 3); - assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg); + assert_eq!(simple_dfg_hugr.num_nodes(), 3); + assert_matches!(simple_dfg_hugr.root_optype().tag(), OpTag::Dfg); } #[test] @@ -533,7 +533,7 @@ pub(crate) mod test { }; let hugr = module_builder.finish_hugr()?; - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); assert_eq!(hugr.get_metadata(hugr.root(), "x"), None); assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42))); diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 03e009bef..cc9da77ab 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -83,7 +83,7 @@ pub struct Wire(N, OutgoingPort); impl Node { /// Returns the node as a portgraph `NodeIndex`. #[inline] - pub(crate) fn pg_index(self) -> portgraph::NodeIndex { + pub(crate) fn into_portgraph(self) -> portgraph::NodeIndex { self.index } } diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 09ccf944c..078fe3c27 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,4 +1,5 @@ //! Exporting HUGR graphs to their `hugr-model` representation. +use crate::hugr::internal::HugrInternals; use crate::{ extension::{ExtensionId, OpDef, SignatureFunc}, hugr::IdentList, @@ -94,7 +95,7 @@ struct Context<'a> { impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = table::Module::default(); - module.nodes.reserve(hugr.node_count()); + module.nodes.reserve(hugr.num_nodes()); let links = Links::new(hugr); Self { @@ -999,7 +1000,7 @@ impl<'a> Context<'a> { let outer_hugr = std::mem::replace(&mut self.hugr, hugr); let outer_node_to_id = std::mem::take(&mut self.node_to_id); - let region = match hugr.root_type() { + let region = match hugr.root_optype() { OpType::DFG(_) => self.export_dfg(hugr.root(), model::ScopeClosure::Closed), _ => panic!("Value::Function root must be a DFG"), }; @@ -1031,7 +1032,7 @@ impl<'a> Context<'a> { } pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] { - let metadata_map = self.hugr.get_node_metadata(node); + let metadata_map = self.hugr.node_metadata_map(node); let has_order_edges = { fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { @@ -1049,13 +1050,11 @@ impl<'a> Context<'a> { .any(|(other, _)| is_relevant_node(self.hugr, other)) }; - let meta_capacity = metadata_map.map_or(0, |map| map.len()) + has_order_edges as usize; + let meta_capacity = metadata_map.len() + has_order_edges as usize; let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump); - if let Some(metadata_map) = metadata_map { - for (name, value) in metadata_map { - meta.push(self.export_json_meta(name, value)); - } + for (name, value) in metadata_map { + meta.push(self.export_json_meta(name, value)); } if has_order_edges { diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 7a74b4070..93250b8e3 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -89,6 +89,25 @@ impl Hugr { Self::with_capacity(root_node.into(), 0, 0) } + /// Create a new Hugr, with a single root node and preallocated capacity. + pub fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { + let mut graph = MultiPortGraph::with_capacity(nodes, ports); + let hierarchy = Hierarchy::new(); + let mut op_types = UnmanagedDenseMap::with_capacity(nodes); + let root = graph.add_node(root_node.input_count(), root_node.output_count()); + let extensions = root_node.used_extensions(); + op_types[root] = root_node; + + Self { + graph, + hierarchy, + root, + op_types, + metadata: UnmanagedDenseMap::with_capacity(nodes), + extensions: extensions.unwrap_or_default(), + } + } + /// Load a Hugr from a json reader. /// /// Validates the Hugr against the provided extension registry, ensuring all @@ -154,7 +173,7 @@ impl Hugr { .map(|ch| Ok((ch, infer(h, ch, remove)?))) .collect::, _>>()?; - let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { + let Some(es) = delta_mut(h.op_types.get_mut(node.into_portgraph())) else { return Ok(h.get_optype(node).extension_delta()); }; if es.contains(&TO_BE_INFERRED) { @@ -260,31 +279,6 @@ impl Hugr { /// Internal API for HUGRs, not intended for use by users. impl Hugr { - /// Create a new Hugr, with a single root node and preallocated capacity. - pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { - let mut graph = MultiPortGraph::with_capacity(nodes, ports); - let hierarchy = Hierarchy::new(); - let mut op_types = UnmanagedDenseMap::with_capacity(nodes); - let root = graph.add_node(root_node.input_count(), root_node.output_count()); - let extensions = root_node.used_extensions(); - op_types[root] = root_node; - - Self { - graph, - hierarchy, - root, - op_types, - metadata: UnmanagedDenseMap::with_capacity(nodes), - extensions: extensions.unwrap_or_default(), - } - } - - /// Set the root node of the hugr. - pub(crate) fn set_root(&mut self, root: Node) { - self.hierarchy.detach(self.root); - self.root = root.pg_index(); - } - /// Add a node to the graph. pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node { let node = self @@ -322,7 +316,7 @@ impl Hugr { /// preserve the indices. pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) { // Generate the ordered list of nodes - let mut ordered = Vec::with_capacity(self.node_count()); + let mut ordered = Vec::with_capacity(self.num_nodes()); let root = self.root(); ordered.extend(self.as_mut().canonical_order(root)); @@ -339,8 +333,8 @@ impl Hugr { let target: Node = portgraph::NodeIndex::new(position).into(); if target != source { - let pg_target = target.pg_index(); - let pg_source = source.pg_index(); + let pg_target = target.into_portgraph(); + let pg_source = source.into_portgraph(); self.graph.swap_nodes(pg_target, pg_source); self.op_types.swap(pg_target, pg_source); self.hierarchy.swap_nodes(pg_target, pg_source); diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index c58ccbdbc..6353820f4 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -18,6 +18,7 @@ use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; +use super::views::{panic_invalid_node, panic_invalid_non_root, panic_invalid_port}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { @@ -26,12 +27,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { - panic_invalid_node(self, node); - self.node_metadata_map_mut(node) - .entry(key.as_ref()) - .or_insert(serde_json::Value::Null) - } + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata; /// Sets a metadata value associated with a node. /// @@ -43,21 +39,14 @@ pub trait HugrMut: HugrMutInternals { node: Self::Node, key: impl AsRef, metadata: impl Into, - ) { - let entry = self.get_metadata_mut(node, key); - *entry = metadata.into(); - } + ); /// Remove a metadata entry associated with a node. /// /// # Panics /// /// If the node is not in the graph. - fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { - panic_invalid_node(self, node); - let node_meta = self.node_metadata_map_mut(node); - node_meta.remove(key.as_ref()); - } + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); /// Add a node to the graph with a parent in the hierarchy. /// @@ -209,9 +198,7 @@ pub trait HugrMut: HugrMutInternals { /// These can be queried using [`HugrView::extensions`]. /// /// See [`ExtensionRegistry::register_updated`] for more information. - fn use_extension(&mut self, extension: impl Into>) { - self.extensions_mut().register_updated(extension); - } + fn use_extension(&mut self, extension: impl Into>); /// Extend the set of extensions used by the hugr with the extensions in the /// registry. @@ -224,10 +211,7 @@ pub trait HugrMut: HugrMutInternals { /// See [`ExtensionRegistry::register_updated`] for more information. fn use_extensions(&mut self, registry: impl IntoIterator) where - ExtensionRegistry: Extend, - { - self.extensions_mut().extend(registry); - } + ExtensionRegistry: Extend; } /// Records the result of inserting a Hugr or view @@ -262,10 +246,33 @@ fn translate_indices( /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. impl HugrMut for Hugr { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { + panic_invalid_node(self, node); + self.node_metadata_map_mut(node) + .entry(key.as_ref()) + .or_insert(serde_json::Value::Null) + } + + fn set_metadata( + &mut self, + node: Self::Node, + key: impl AsRef, + metadata: impl Into, + ) { + let entry = self.get_metadata_mut(node, key); + *entry = metadata.into(); + } + + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { + panic_invalid_node(self, node); + let node_meta = self.node_metadata_map_mut(node); + node_meta.remove(key.as_ref()); + } + fn add_node_with_parent(&mut self, parent: Node, node: impl Into) -> Node { let node = self.as_mut().add_node(node.into()); self.hierarchy - .push_child(node.pg_index(), parent.pg_index()) + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } @@ -273,7 +280,7 @@ impl HugrMut for Hugr { fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { let node = self.as_mut().add_node(nodetype.into()); self.hierarchy - .insert_before(node.pg_index(), sibling.pg_index()) + .insert_before(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } @@ -281,16 +288,16 @@ impl HugrMut for Hugr { fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { let node = self.as_mut().add_node(op.into()); self.hierarchy - .insert_after(node.pg_index(), sibling.pg_index()) + .insert_after(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } fn remove_node(&mut self, node: Node) -> OpType { panic_invalid_non_root(self, node); - self.hierarchy.remove(node.pg_index()); - self.graph.remove_node(node.pg_index()); - self.op_types.take(node.pg_index()) + self.hierarchy.remove(node.into_portgraph()); + self.graph.remove_node(node.into_portgraph()); + self.op_types.take(node.into_portgraph()) } fn remove_subtree(&mut self, node: Node) { @@ -316,9 +323,9 @@ impl HugrMut for Hugr { panic_invalid_port(self, dst, dst_port); self.graph .link_nodes( - src.pg_index(), + src.into_portgraph(), src_port.index(), - dst.pg_index(), + dst.into_portgraph(), dst_port.index(), ) .expect("The ports should exist at this point."); @@ -330,7 +337,7 @@ impl HugrMut for Hugr { panic_invalid_port(self, node, port); let port = self .graph - .port_index(node.pg_index(), offset) + .port_index(node.into_portgraph(), offset) .expect("The port should exist at this point."); self.graph.unlink_port(port); } @@ -364,13 +371,17 @@ impl HugrMut for Hugr { self.metadata.set(new_node, meta); } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.root().pg_index()) + Some(&new_root.into_portgraph()), + node_map.get(&other.root().into_portgraph()) ); InsertionResult { new_root, - node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) - .collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } @@ -384,19 +395,26 @@ impl HugrMut for Hugr { // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); self.op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.metadata.set(new_node, meta.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.get_pg_index(other.root())) + Some(&new_root.into_portgraph()), + node_map.get(&other.to_portgraph_node(other.root())) ); InsertionResult { new_root, - node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) - .collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } @@ -410,7 +428,7 @@ impl HugrMut for Hugr { let context: HashSet = subgraph .nodes() .iter() - .map(|&n| other.get_pg_index(n)) + .map(|&n| other.to_portgraph_node(n)) .collect(); let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( @@ -421,16 +439,24 @@ impl HugrMut for Hugr { let node_map = insert_subgraph_internal(self, root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); self.op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.metadata.set(new_node, meta.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } // Add the required extensions to the registry. if let Ok(exts) = nodetype.used_extensions() { self.use_extensions(exts); } } - translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() + translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect() } fn copy_descendants( @@ -439,15 +465,19 @@ impl HugrMut for Hugr { new_parent: Self::Node, subst: Option, ) -> BTreeMap { - let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); + let mut descendants = self.hierarchy.descendants(root.into_portgraph()); let root2 = descendants.next(); - debug_assert_eq!(root2, Some(root.pg_index())); + debug_assert_eq!(root2, Some(root.into_portgraph())); let nodes = Vec::from_iter(descendants); let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes) .copy_in_parent() .expect("Is a MultiPortGraph"); - let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) - .collect::>(); + let node_map = translate_indices( + |n| self.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -462,12 +492,25 @@ impl HugrMut for Hugr { (None, op) => op.clone(), (Some(subst), op) => op.substitute(subst), }; - self.op_types.set(new_node.pg_index(), new_optype); - let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); - self.metadata.set(new_node.pg_index(), meta); + self.op_types.set(new_node.into_portgraph(), new_optype); + let meta = self.metadata.get(node.into_portgraph()).clone(); + self.metadata.set(new_node.into_portgraph(), meta); } node_map } + + #[inline] + fn use_extension(&mut self, extension: impl Into>) { + self.extensions_mut().register_updated(extension); + } + + #[inline] + fn use_extensions(&mut self, registry: impl IntoIterator) + where + ExtensionRegistry: Extend, + { + self.extensions_mut().extend(registry); + } } /// Internal implementation of `insert_hugr` and `insert_view` methods for @@ -487,18 +530,20 @@ fn insert_hugr_internal( .graph .insert_graph(&other.portgraph()) .unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}")); - let other_root = node_map[&other.get_pg_index(other.root())]; + let other_root = node_map[&other.to_portgraph_node(other.root())]; // Update hierarchy and optypes hugr.hierarchy - .push_child(other_root, root.pg_index()) + .push_child(other_root, root.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); for (&node, &new_node) in node_map.iter() { - other.children(other.get_node(node)).for_each(|child| { - hugr.hierarchy - .push_child(node_map[&other.get_pg_index(child)], new_node) - .expect("Inserting a newly-created node into the hierarchy should never fail."); - }); + other + .children(other.from_portgraph_node(node)) + .for_each(|child| { + hugr.hierarchy + .push_child(node_map[&other.to_portgraph_node(child)], new_node) + .expect("Inserting a newly-created node into the hierarchy should never fail."); + }); } // Merge the extension sets. @@ -534,9 +579,9 @@ fn insert_subgraph_internal( // update the hierarchy with their new id. for (&node, &new_node) in node_map.iter() { let new_parent = other - .get_parent(other.get_node(node)) - .and_then(|parent| node_map.get(&other.get_pg_index(parent)).copied()) - .unwrap_or(root.pg_index()); + .get_parent(other.from_portgraph_node(node)) + .and_then(|parent| node_map.get(&other.to_portgraph_node(parent)).copied()) + .unwrap_or(root.into_portgraph()); hugr.hierarchy .push_child(new_node, new_parent) .expect("Inserting a newly-created node into the hierarchy should never fail."); @@ -545,45 +590,6 @@ fn insert_subgraph_internal( node_map } -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if !hugr.valid_node(node) { - panic!("Received an invalid node {node} while mutating a HUGR.",); - } -} - -/// Panic if [`HugrView::valid_non_root`] fails. -#[track_caller] -pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if !hugr.valid_non_root(node) { - panic!("Received an invalid non-root node {node} while mutating a HUGR.",); - } -} - -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_port( - hugr: &H, - node: Node, - port: impl Into, -) { - let port = port.into(); - // TODO: When stacking hugr wrappers, this gets called for every layer. - // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. - if hugr - .portgraph() - .port_index(node.pg_index(), port.pg_offset()) - .is_none() - { - panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); - } -} - #[cfg(test)] mod test { use crate::extension::PRELUDE; @@ -667,14 +673,14 @@ mod test { fd }); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); hugr.remove_subtree(foo); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); hugr.remove_subtree(bar); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 1); + assert_eq!(hugr.num_nodes(), 1); } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 58ce066c0..f69d2ad39 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -1,6 +1,5 @@ //! Internal traits, not exposed in the public `hugr` API. -use std::borrow::Cow; use std::ops::Range; use std::sync::OnceLock; @@ -8,11 +7,12 @@ use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; use crate::extension::ExtensionRegistry; -use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node}; -use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrView, NodeMetadataMap, OpType}; +use super::views::{panic_invalid_node, panic_invalid_non_root}; +use super::HugrView; +use super::{NodeMetadataMap, OpType}; +use crate::ops::handle::NodeHandle; /// Trait for accessing the internals of a Hugr(View). /// @@ -20,7 +20,7 @@ use super::{HugrView, NodeMetadataMap, OpType}; /// view. pub trait HugrInternals { /// The underlying portgraph view type. - type Portgraph<'p>: LinkView + Clone + 'p + type Portgraph<'p>: LinkView + Clone + 'p where Self: 'p; @@ -30,24 +30,24 @@ pub trait HugrInternals { /// Returns a reference to the underlying portgraph. fn portgraph(&self) -> Self::Portgraph<'_>; + /// Returns a flat portgraph view of a region in the HUGR. + /// + /// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy. + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_>; + /// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph /// returned by [`HugrInternals::portgraph`]. - #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.base_hugr().hierarchy) - } - - /// Returns the Hugr at the base of a chain of views. - fn base_hugr(&self) -> &Hugr; - - /// Return the root node of this view. - fn root_node(&self) -> Self::Node; + fn hierarchy(&self) -> &portgraph::Hierarchy; /// Convert a node to a portgraph node index. - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex; + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex; /// Convert a portgraph node index to a node. - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + #[allow(clippy::wrong_self_convention)] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; /// Returns a metadata entry associated with a node. /// @@ -55,6 +55,14 @@ pub trait HugrInternals { /// /// If the node is not in the graph. fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap; + + /// Returns the Hugr at the base of a chain of views. + // TODO: This will be removed in a future PR. + #[deprecated( + since = "0.16.0", + note = "This method will be removed in a future PR. Use the individual HugrInternals methods instead." + )] + fn base_hugr(&self) -> &Hugr; } impl HugrInternals for Hugr { @@ -71,34 +79,40 @@ impl HugrInternals for Hugr { } #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.hierarchy) + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_> { + let pg = self.portgraph(); + let root = self.to_portgraph_node(parent); + portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root) } #[inline] - fn base_hugr(&self) -> &Hugr { - self + fn hierarchy(&self) -> &portgraph::Hierarchy { + &self.hierarchy } #[inline] - fn root_node(&self) -> Self::Node { - self.root.into() + fn base_hugr(&self) -> &Hugr { + self } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - node.node().pg_index() + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + node.node().into_portgraph() } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { index.into() } + #[inline] fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { static EMPTY: OnceLock = OnceLock::new(); panic_invalid_node(self, node); - let map = self.metadata.get(node.pg_index()).as_ref(); + let map = self.metadata.get(node.into_portgraph()).as_ref(); map.unwrap_or(EMPTY.get_or_init(Default::default)) } } @@ -108,10 +122,14 @@ impl HugrInternals for Hugr { /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. pub trait HugrMutInternals: HugrView { - /// Set root node of the HUGR. + /// Set the node at the root of the HUGR hierarchy. /// - /// This should be an existing node in the HUGR. Most operations use the - /// root node as a starting point for traversal. + /// Any node not reachable from this root should be deleted from the HUGR + /// after this call. + /// + /// # Panics + /// + /// If the node is not in the graph. fn set_root(&mut self, root: Self::Node); /// Set the number of ports on a node. This may invalidate the node's `PortIndex`. @@ -225,21 +243,21 @@ pub trait HugrMutInternals: HugrView { /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. impl HugrMutInternals for Hugr { fn set_root(&mut self, root: Node) { - panic_invalid_node(self, root); - self.root = self.get_pg_index(root); + self.hierarchy.detach(self.root); + self.root = root.into_portgraph(); } #[inline] fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { panic_invalid_node(self, node); self.graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}) } fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { panic_invalid_node(self, node); - let mut incoming = self.graph.num_inputs(node.pg_index()); - let mut outgoing = self.graph.num_outputs(node.pg_index()); + let mut incoming = self.graph.num_inputs(node.into_portgraph()); + let mut outgoing = self.graph.num_outputs(node.into_portgraph()); let increment = |num: &mut usize| { let new = num.saturating_add_signed(amount); let range = *num..new; @@ -251,7 +269,7 @@ impl HugrMutInternals for Hugr { Direction::Outgoing => increment(&mut outgoing), }; self.graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}); range } @@ -263,20 +281,18 @@ impl HugrMutInternals for Hugr { amount: usize, ) -> Range { panic_invalid_node(self, node); - let old_num_ports = self.base_hugr().graph.num_ports(node.pg_index(), direction); + let old_num_ports = self.graph.num_ports(node.into_portgraph(), direction); self.add_ports(node, direction, amount as isize); for swap_from_port in (index..old_num_ports).rev() { let swap_to_port = swap_from_port + amount; let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| { - self.base_hugr() - .graph - .port_index(node.pg_index(), PortOffset::new(direction, p)) + self.graph + .port_index(node.into_portgraph(), PortOffset::new(direction, p)) .unwrap() }); let linked_ports = self - .base_hugr() .graph .port_links(from_port_index) .map(|(_, to_subport)| to_subport.port()) @@ -295,27 +311,27 @@ impl HugrMutInternals for Hugr { fn set_parent(&mut self, node: Node, parent: Node) { panic_invalid_node(self, parent); panic_invalid_node(self, node); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .push_child(node.pg_index(), parent.pg_index()) + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_after_sibling(&mut self, node: Node, after: Node) { panic_invalid_non_root(self, node); panic_invalid_non_root(self, after); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .insert_after(node.pg_index(), after.pg_index()) + .insert_after(node.into_portgraph(), after.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_before_sibling(&mut self, node: Node, before: Node) { panic_invalid_non_root(self, node); panic_invalid_non_root(self, before); - self.hierarchy.detach(node.pg_index()); + self.hierarchy.detach(node.into_portgraph()); self.hierarchy - .insert_before(node.pg_index(), before.pg_index()) + .insert_before(node.into_portgraph(), before.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } @@ -326,14 +342,14 @@ impl HugrMutInternals for Hugr { fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { panic_invalid_node(self, node); - let node = self.get_pg_index(node); + let node = self.to_portgraph_node(node); self.op_types.get_mut(node) } fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap { panic_invalid_node(self, node); self.metadata - .get_mut(node.pg_index()) + .get_mut(node.into_portgraph()) .get_or_insert_with(Default::default) } diff --git a/hugr-core/src/hugr/patch.rs b/hugr-core/src/hugr/patch.rs index bc6195eba..1744ce760 100644 --- a/hugr-core/src/hugr/patch.rs +++ b/hugr-core/src/hugr/patch.rs @@ -153,12 +153,12 @@ impl PatchHugrMut for Transactional { return self.underlying.apply_hugr_mut(h); } // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_type().clone()); + let mut backup = Hugr::new(h.root_optype().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply_hugr_mut(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()); + h.replace_op(h.root(), backup.root_optype().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/patch/consts.rs b/hugr-core/src/hugr/patch/consts.rs index 6d0c011fe..eb9142f85 100644 --- a/hugr-core/src/hugr/patch/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -144,7 +144,7 @@ mod test { let mut h = build.finish_hugr()?; // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); let tup_node = tup.node(); // can't remove invalid node assert_eq!( @@ -199,7 +199,7 @@ mod test { // remove const assert_eq!(h.apply_patch(remove_con)?, h.root()); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); assert!(h.validate().is_ok()); Ok(()) } diff --git a/hugr-core/src/hugr/patch/insert_identity.rs b/hugr-core/src/hugr/patch/insert_identity.rs index 98ab0ff02..c1f959ccd 100644 --- a/hugr-core/src/hugr/patch/insert_identity.rs +++ b/hugr-core/src/hugr/patch/insert_identity.rs @@ -118,7 +118,7 @@ mod tests { fn correct_insertion(dfg_hugr: Hugr) { let mut h = dfg_hugr; - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); let final_node = h .input_neighbours(h.get_io(h.root()).unwrap()[1]) @@ -131,7 +131,7 @@ mod tests { let noop_node = h.apply_patch(rw).unwrap(); - assert_eq!(h.node_count(), 7); + assert_eq!(h.num_nodes(), 7); let noop: Noop = h.get_optype(noop_node).cast().unwrap(); diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index 0f40615a9..b43b6b4e3 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -202,11 +202,11 @@ impl PatchHugrMut for OutlineCfg { // https://github.com/CQCL/hugr/issues/2029 let hierarchy = h.hierarchy(); let inner_exit = hierarchy - .children(h.get_pg_index(cfg_node)) + .children(h.to_portgraph_node(cfg_node)) .exactly_one() .ok() .unwrap(); - let inner_exit = h.get_node(inner_exit); + let inner_exit = h.from_portgraph_node(inner_exit); //let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); // Entry node must be first @@ -512,6 +512,7 @@ mod test { } assert_eq!(h.get_parent(new_block), Some(cfg)); assert!(h.get_optype(new_block).is_dataflow_block()); + #[allow(deprecated)] let b = h.base_hugr(); // To cope with `h` potentially being a SiblingMut assert_eq!(b.get_parent(new_cfg), Some(new_block)); for n in blocks { diff --git a/hugr-core/src/hugr/patch/replace.rs b/hugr-core/src/hugr/patch/replace.rs index 6f0b0ed65..183200751 100644 --- a/hugr-core/src/hugr/patch/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -7,6 +7,7 @@ use thiserror::Error; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; +use crate::hugr::views::check_valid_non_root; use crate::hugr::HugrMut; use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; @@ -188,7 +189,7 @@ impl Replacement { // equality of OpType/Signature, e.g. to ease changing of Input/Output // node signatures too. let removed = h.get_optype(parent).tag(); - let replacement = self.replacement.root_type().tag(); + let replacement = self.replacement.root_optype().tag(); if removed != replacement { return Err(ReplaceError::WrongRootNodeTag { removed, @@ -265,25 +266,25 @@ impl PatchVerification for Replacement { } e.check_src(h, WhichEdgeSpec::HostToHost)?; } - self.mu_out - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.src) { + self.mu_out.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.src) { true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost), false => Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, WhichEdgeSpec::ReplToHost(e.clone()), )), - })?; + } + })?; // Edge targets... - self.mu_inp - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) { + self.mu_inp.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.tgt) { true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl), false => Err(ReplaceError::BadEdgeSpec( Direction::Incoming, WhichEdgeSpec::HostToRepl(e.clone()), )), - })?; + } + })?; for e in self.mu_out.iter() { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { return Err(ReplaceError::BadEdgeSpec( @@ -430,13 +431,13 @@ where .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?, kind: oe.kind, }; - if !h.valid_node(e.src) { + if !h.contains_node(e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, err_spec.clone(), )); } - if !h.valid_node(e.tgt) { + if !h.contains_node(e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, err_spec.clone(), @@ -810,7 +811,7 @@ mod test { // Root node type needs to be that of common parent of the removed nodes: let mut rep2 = rep.clone(); rep2.replacement - .replace_op(rep2.replacement.root(), h.root_type().clone()); + .replace_op(rep2.replacement.root(), h.root_optype().clone()); assert_eq!( check_same_errors(rep2), ReplaceError::WrongRootNodeTag { diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index e9283644d..3908ba58e 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -766,7 +766,7 @@ pub(in crate::hugr::patch) mod test { .unwrap(); // They should be the same, up to node indices - assert_eq!(h.edge_count(), orig.edge_count()); + assert_eq!(h.num_edges(), orig.num_edges()); } #[test] @@ -818,7 +818,7 @@ pub(in crate::hugr::patch) mod test { .unwrap(); // Nothing changed - assert_eq!(h.node_count(), orig.node_count()); + assert_eq!(h.num_nodes(), orig.num_nodes()); } /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the @@ -876,7 +876,7 @@ pub(in crate::hugr::patch) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 3); + assert_eq!(hugr.num_nodes(), 3); } /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting @@ -935,7 +935,7 @@ pub(in crate::hugr::patch) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); } #[rstest] @@ -974,12 +974,12 @@ pub(in crate::hugr::patch) mod test { let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); h.validate().unwrap_or_else(|e| panic!("{e}")); - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); } use crate::hugr::patch::replace::Replacement; diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs index e220864a7..76dc93ab1 100644 --- a/hugr-core/src/hugr/rewrite.rs +++ b/hugr-core/src/hugr/rewrite.rs @@ -77,12 +77,12 @@ impl Rewrite for Transactional { return self.underlying.apply(h); } // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_type().clone()); + let mut backup = Hugr::new(h.root_optype().clone()); backup.insert_from_view(backup.root(), h); let r = self.underlying.apply(h); if r.is_err() { // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()); + h.replace_op(h.root(), backup.root_optype().clone()); while let Some(child) = h.first_child(h.root()) { h.remove_node(child); } diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 906084d55..5e4922157 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -157,13 +157,13 @@ impl TryFrom<&Hugr> for SerHugrLatest { fn try_from(hugr: &Hugr) -> Result { // We compact the operation nodes during the serialization process, // and ignore the copy nodes. - let mut node_rekey: HashMap = HashMap::with_capacity(hugr.node_count()); + let mut node_rekey: HashMap = HashMap::with_capacity(hugr.num_nodes()); for (order, node) in hugr.canonical_order(hugr.root()).enumerate() { node_rekey.insert(node, portgraph::NodeIndex::new(order).into()); } - let mut nodes = vec![None; hugr.node_count()]; - let mut metadata = vec![None; hugr.node_count()]; + let mut nodes = vec![None; hugr.num_nodes()]; + let mut metadata = vec![None; hugr.num_nodes()]; for n in hugr.nodes() { let parent = node_rekey[&hugr.get_parent(n).unwrap_or(n)]; let opt = hugr.get_optype(n); @@ -172,7 +172,7 @@ impl TryFrom<&Hugr> for SerHugrLatest { parent, op: opt.clone(), }); - metadata[new_node].clone_from(hugr.metadata.get(n.pg_index())); + metadata[new_node].clone_from(hugr.metadata.get(n.into_portgraph())); } let nodes = nodes .into_iter() @@ -251,7 +251,7 @@ impl TryFrom for Hugr { } let unwrap_offset = |node: Node, offset, dir, hugr: &Hugr| -> Result { - if !hugr.graph.contains_node(node.pg_index()) { + if !hugr.graph.contains_node(node.into_portgraph()) { return Err(HUGRSerializationError::UnknownEdgeNode { node }); } let offset = match offset { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 3b04ccd86..3690ec947 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -20,7 +20,7 @@ use crate::types::EdgeKind; use crate::{Direction, Hugr, Node, Port}; use super::internal::HugrInternals; -use super::views::{HierarchyView, HugrView, SiblingGraph}; +use super::views::HugrView; use super::ExtensionError; /// Structure keeping track of pre-computed information used in the validation @@ -31,7 +31,7 @@ use super::ExtensionError; struct ValidationContext<'a> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. - dominators: HashMap>, + dominators: HashMap>, } impl Hugr { @@ -138,10 +138,10 @@ impl<'a> ValidationContext<'a> { /// /// The results of this computation should be cached in `self.dominators`. /// We don't do it here to avoid mutable borrows. - fn compute_dominator(&self, parent: Node) -> Dominators { - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); + fn compute_dominator(&self, parent: Node) -> Dominators { + let region = self.hugr.region_portgraph(parent); let entry_node = self.hugr.children(parent).next().unwrap(); - dominators::simple_fast(®ion.as_petgraph(), entry_node) + dominators::simple_fast(®ion, entry_node.into_portgraph()) } /// Check the constraints on a single node. @@ -163,7 +163,7 @@ impl<'a> ValidationContext<'a> { for dir in Direction::BOTH { // Check that we have the correct amount of ports and edges. - let num_ports = self.hugr.graph.num_ports(node.pg_index(), dir); + let num_ports = self.hugr.graph.num_ports(node.into_portgraph(), dir); if num_ports != op_type.port_count(dir) { return Err(ValidationError::WrongNumberOfPorts { node, @@ -316,7 +316,7 @@ impl<'a> ValidationContext<'a> { fn validate_children(&self, node: Node, op_type: &OpType) -> Result<(), ValidationError> { let flags = op_type.validity_flags(); - if self.hugr.hierarchy().child_count(node.pg_index()) > 0 { + if self.hugr.hierarchy().child_count(node.into_portgraph()) > 0 { if flags.allowed_children.is_empty() { return Err(ValidationError::NonContainerWithChildren { node, @@ -352,7 +352,8 @@ impl<'a> ValidationContext<'a> { } } // Additional validations running over the full list of children optypes - let children_optypes = all_children.map(|c| (c.pg_index(), self.hugr.get_optype(c))); + let children_optypes = + all_children.map(|c| (c.into_portgraph(), self.hugr.get_optype(c))); if let Err(source) = op_type.validate_op_children(children_optypes) { return Err(ValidationError::InvalidChildren { parent: node, @@ -363,9 +364,9 @@ impl<'a> ValidationContext<'a> { // Additional validations running over the edges of the contained graph if let Some(edge_check) = flags.edge_check { - for source in self.hugr.hierarchy().children(node.pg_index()) { + for source in self.hugr.hierarchy().children(node.into_portgraph()) { for target in self.hugr.graph.output_neighbours(source) { - if self.hugr.hierarchy.parent(target) != Some(node.pg_index()) { + if self.hugr.hierarchy.parent(target) != Some(node.into_portgraph()) { continue; } let source_op = self.hugr.get_optype(source.into()); @@ -411,16 +412,16 @@ impl<'a> ValidationContext<'a> { /// Inter-graph edges are ignored. Only internal dataflow, constant, or /// state order edges are considered. fn validate_children_dag(&self, parent: Node, op_type: &OpType) -> Result<(), ValidationError> { - if !self.hugr.hierarchy.has_children(parent.pg_index()) { + if !self.hugr.hierarchy.has_children(parent.into_portgraph()) { // No children, nothing to do return Ok(()); }; - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); - let postorder = Topo::new(®ion.as_petgraph()); + let region = self.hugr.region_portgraph(parent); + let postorder = Topo::new(®ion); let nodes_visited = postorder - .iter(®ion.as_petgraph()) - .filter(|n| *n != parent) + .iter(®ion) + .filter(|n| *n != parent.into_portgraph()) .count(); let node_count = self.hugr.children(parent).count(); if nodes_visited != node_count { @@ -500,7 +501,7 @@ impl<'a> ValidationContext<'a> { // Must have an order edge. self.hugr .graph - .get_connections(from.pg_index(), ancestor.pg_index()) + .get_connections(from.into_portgraph(), ancestor.into_portgraph()) .find(|&(p, _)| { let offset = self.hugr.graph.port_offset(p).unwrap(); from_optype.port_kind(offset) == Some(EdgeKind::StateOrder) @@ -537,8 +538,8 @@ impl<'a> ValidationContext<'a> { } }; if !dominator_tree - .dominators(ancestor) - .is_some_and(|mut ds| ds.any(|n| n == from_parent)) + .dominators(ancestor.into_portgraph()) + .is_some_and(|mut ds| ds.any(|n| n == from_parent.into_portgraph())) { return Err(InterGraphEdgeError::NonDominatedAncestor { from, @@ -616,7 +617,12 @@ impl<'a> ValidationContext<'a> { // Root nodes are ignored, as they cannot have connected edges. if node != self.hugr.root() { for dir in Direction::BOTH { - for (i, port_index) in self.hugr.graph.ports(node.pg_index(), dir).enumerate() { + for (i, port_index) in self + .hugr + .graph + .ports(node.into_portgraph(), dir) + .enumerate() + { let port = Port::new(dir, i); self.validate_port(node, port, port_index, op_type, var_decls)?; } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index a66296c35..236f40e3f 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -103,7 +103,7 @@ fn invalid_root() { ); // Fix the root - b.root = module.pg_index(); + b.root = module.into_portgraph(); b.remove_node(root); assert_eq!(b.validate(), Ok(())); } @@ -142,7 +142,7 @@ fn children_restrictions() { let root = b.root(); let (_input, copy, _output) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -185,7 +185,7 @@ fn df_children_restrictions() { let (mut b, def) = make_simple_hugr(2); let (_input, output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -202,7 +202,7 @@ fn df_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, output.into_portgraph())} ); b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])); @@ -211,7 +211,7 @@ fn df_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, copy.into_portgraph())} ); } @@ -791,7 +791,7 @@ fn cfg_children_restrictions() { let (mut b, def) = make_simple_hugr(1); let (_input, _output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -855,7 +855,7 @@ fn cfg_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) - => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} + => {assert_eq!(parent, cfg); assert_eq!(child, exit2.into_portgraph())} ); b.remove_node(exit2); @@ -875,7 +875,7 @@ fn cfg_children_restrictions() { extension_delta: ExtensionSet::new(), }, ); - let mut block_children = b.hierarchy.children(block.pg_index()); + let mut block_children = b.hierarchy.children(block.into_portgraph()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); b.replace_op(block_input, ops::Input::new(vec![qb_t()])); diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index a154a956f..f9eedd548 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -24,10 +24,8 @@ use itertools::Itertools; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{LinkView, PortView}; -use super::internal::HugrInternals; -use super::{ - Hugr, HugrError, HugrMut, Node, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, -}; +use super::internal::{HugrInternals, HugrMutInternals}; +use super::{Hugr, HugrError, HugrMut, Node, NodeMetadata, ValidationError}; use crate::extension::ExtensionRegistry; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; @@ -40,85 +38,67 @@ use itertools::Either; /// For end users we intend this to be superseded by region-specific APIs. pub trait HugrView: HugrInternals { /// Return the root node of this view. - #[inline] - fn root(&self) -> Self::Node { - self.root_node() - } + fn root(&self) -> Self::Node; - /// Return the type of the HUGR root node. + /// Return the optype of the HUGR root node. #[inline] - fn root_type(&self) -> &OpType { + fn root_optype(&self) -> &OpType { let node_type = self.get_optype(self.root()); - // Sadly no way to do this at present - // debug_assert!(Self::RootHandle::can_hold(node_type.tag())); node_type } - /// Returns whether the node exists. + /// Returns `true` if the node exists in the HUGR. fn contains_node(&self, node: Self::Node) -> bool; - /// Validates that a node is valid in the graph. - #[inline] - fn valid_node(&self, node: Self::Node) -> bool { - self.contains_node(node) - } - - /// Validates that a node is a valid root descendant in the graph. - /// - /// To include the root node use [`HugrView::valid_node`] instead. - #[inline] - fn valid_non_root(&self, node: Self::Node) -> bool { - self.root() != node && self.valid_node(node) - } - /// Returns the parent of a node. - #[inline] - fn get_parent(&self, node: Self::Node) -> Option { - if !self.valid_non_root(node) { - return None; - }; - self.base_hugr() - .hierarchy - .parent(self.get_pg_index(node)) - .map(|index| self.get_node(index)) - } - - /// Returns the operation type of a node. - #[inline] - fn get_optype(&self, node: Self::Node) -> &OpType { - match self.contains_node(node) { - true => self.base_hugr().op_types.get(self.get_pg_index(node)), - false => &DEFAULT_OPTYPE, - } - } + fn get_parent(&self, node: Self::Node) -> Option; /// Returns the metadata associated with a node. #[inline] fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata> { match self.contains_node(node) { - true => self.get_node_metadata(node)?.get(key.as_ref()), + true => self.node_metadata_map(node).get(key.as_ref()), false => None, } } - /// Retrieve the complete metadata map for a node. - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap> { - if !self.valid_node(node) { - return None; - } - self.base_hugr() - .metadata - .get(self.get_pg_index(node)) - .as_ref() - } + /// Returns the operation type of a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn get_optype(&self, node: Self::Node) -> &OpType; + + /// Returns the number of nodes in the HUGR. + fn num_nodes(&self) -> usize; - /// Returns the number of nodes in the hugr. - fn node_count(&self) -> usize; + /// Returns the number of edges in the HUGR. + fn num_edges(&self) -> usize; + + /// Number of ports in node for a given direction. + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; - /// Returns the number of edges in the hugr. - fn edge_count(&self) -> usize; + /// Number of inputs to a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. + #[inline] + fn num_inputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Incoming) + } - /// Iterates over the nodes in the port graph. + /// Number of outputs from a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. + #[inline] + fn num_outputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Outgoing) + } + + /// Iterates over the all the nodes in the HUGR. + /// + /// This iterator returns every node in the HUGR, including those that are + /// not descendants from the root node. + /// + /// See [`HugrView::descendants`] and [`HugrView::children`] for more specific + /// iterators. fn nodes(&self) -> impl Iterator + Clone; /// Iterator over ports of node in a given direction. @@ -260,26 +240,15 @@ pub trait HugrView: HugrInternals { self.linked_ports(node, port).next().is_some() } - /// Number of ports in node for a given direction. - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; - - /// Number of inputs to a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. - #[inline] - fn num_inputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Incoming) - } - - /// Number of outputs from a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. - #[inline] - fn num_outputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Outgoing) - } - - /// Return iterator over the direct children of node. + /// Returns an iterator over the direct children of node. fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + /// Returns an iterator over all the descendants of a node, + /// including the node itself. + /// + /// Yields the node itself first, followed by its children in breath-first order. + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + /// Returns the first child of the specified node (if it is a parent). /// Useful because `x.children().next()` leaves x borrowed. fn first_child(&self, node: Self::Node) -> Option { @@ -334,13 +303,13 @@ pub trait HugrView: HugrInternals { /// In contrast to [`poly_func_type`][HugrView::poly_func_type], this /// method always return a concrete [`Signature`]. fn inner_function_type(&self) -> Option> { - self.root_type().inner_function_type() + self.root_optype().inner_function_type() } /// Returns the function type defined by this HUGR, i.e. `Some` iff the root is /// a [`FuncDecl`][crate::ops::FuncDecl] or [`FuncDefn`][crate::ops::FuncDefn]. fn poly_func_type(&self) -> Option { - match self.root_type() { + match self.root_optype() { OpType::FuncDecl(decl) => Some(decl.signature.clone()), OpType::FuncDefn(defn) => Some(defn.signature.clone()), _ => None, @@ -363,13 +332,7 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string(&self) -> String { - self.mermaid_string_with_config(RenderConfig { - node_indices: true, - port_offsets_in_edges: true, - type_labels_in_edges: true, - }) - } + fn mermaid_string(&self) -> String; /// Return the mermaid representation of the underlying hierarchical graph. /// @@ -378,35 +341,14 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string_with_config(&self, config: RenderConfig) -> String { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - graph - .mermaid_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + fn mermaid_string_with_config(&self, config: RenderConfig) -> String; /// Return the graphviz representation of the underlying graph and hierarchy side by side. /// /// For a simpler representation, use the [`HugrView::mermaid_string`] format instead. fn dot_string(&self) -> String where - Self: Sized, - { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - let config = RenderConfig::default(); - graph - .dot_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_port_style(render::port_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + Self: Sized; /// If a node has a static input, return the source node. fn static_source(&self, node: Self::Node) -> Option { @@ -453,10 +395,9 @@ pub trait HugrView: HugrInternals { /// Returns the set of extensions used by the HUGR. /// - /// This set may contain extensions that are no longer required by the HUGR. - fn extensions(&self) -> &ExtensionRegistry { - &self.base_hugr().extensions - } + /// This set contains all extensions required to define the operations and + /// types in the HUGR. + fn extensions(&self) -> &ExtensionRegistry; /// Check the validity of the underlying HUGR. /// @@ -465,6 +406,7 @@ pub trait HugrView: HugrInternals { /// See [`HugrView::validate_no_extensions`] for a version that doesn't check /// extension requirements. fn validate(&self) -> Result<(), ValidationError> { + #[allow(deprecated)] self.base_hugr().validate() } @@ -474,6 +416,7 @@ pub trait HugrView: HugrInternals { /// /// For a more thorough check, use [`HugrView::validate`]. fn validate_no_extensions(&self) -> Result<(), ValidationError> { + #[allow(deprecated)] self.base_hugr().validate_no_extensions() } } @@ -526,18 +469,48 @@ impl ExtractHugr for &mut Hugr { impl HugrView for Hugr { #[inline] - fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(node.pg_index()) + fn root(&self) -> Self::Node { + self.root.into() + } + + #[inline] + fn contains_node(&self, node: Self::Node) -> bool { + self.graph.contains_node(node.into_portgraph()) + } + + #[inline] + fn get_parent(&self, node: Self::Node) -> Option { + if !check_valid_non_root(self, node) { + return None; + }; + self.hierarchy + .parent(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + #[inline] + fn get_optype(&self, node: Node) -> &OpType { + // TODO: This currently fails because some methods get the optype of + // e.g. a parent outside a region view. We should be able to re-enable + // this once we add hugr entrypoints. + //panic_invalid_node(self, node); + self.op_types.get(self.to_portgraph_node(node)) + } + + #[inline] + fn num_nodes(&self) -> usize { + self.portgraph().node_count() } #[inline] - fn node_count(&self) -> usize { - self.graph.node_count() + fn num_edges(&self) -> usize { + self.portgraph().link_count() } #[inline] - fn edge_count(&self) -> usize { - self.graph.link_count() + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.portgraph() + .num_ports(self.to_portgraph_node(node), dir) } #[inline] @@ -547,12 +520,16 @@ impl HugrView for Hugr { #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.port_offsets(node.pg_index(), dir).map_into() + self.graph + .port_offsets(node.into_portgraph(), dir) + .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_port_offsets(node.pg_index()).map_into() + self.graph + .all_port_offsets(node.into_portgraph()) + .map_into() } #[inline] @@ -565,7 +542,7 @@ impl HugrView for Hugr { let port = self .graph - .port_index(node.pg_index(), port.pg_offset()) + .port_index(node.into_portgraph(), port.pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port = link.port(); @@ -578,30 +555,72 @@ impl HugrView for Hugr { #[inline] fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(node.pg_index(), other.pg_index()) + .get_connections(node.into_portgraph(), other.into_portgraph()) .map(|(p1, p2)| { [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into()) }) } #[inline] - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(node.pg_index(), dir) + fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone { + self.hierarchy + .children(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] - fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - self.hierarchy.children(node.pg_index()).map_into() + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + self.hierarchy + .descendants(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.neighbours(node.pg_index(), dir).map_into() + self.graph.neighbours(node.into_portgraph(), dir).map_into() } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_neighbours(node.pg_index()).map_into() + self.graph.all_neighbours(node.into_portgraph()).map_into() + } + + fn mermaid_string(&self) -> String { + self.mermaid_string_with_config(RenderConfig { + node_indices: true, + port_offsets_in_edges: true, + type_labels_in_edges: true, + }) + } + + fn mermaid_string_with_config(&self, config: RenderConfig) -> String { + let graph = self.portgraph(); + graph + .mermaid_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + fn dot_string(&self) -> String + where + Self: Sized, + { + let graph = self.portgraph(); + let config = RenderConfig::default(); + graph + .dot_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_port_style(render::port_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + #[inline] + fn extensions(&self) -> &ExtensionRegistry { + &self.extensions } } @@ -630,7 +649,7 @@ where hugr: &impl HugrView, ) -> impl Iterator { self.filter(move |(n, p)| { - let kind = hugr.get_optype(*n).port_kind(*p); + let kind = HugrView::get_optype(hugr, *n).port_kind(*p); predicate(kind) }) } @@ -642,3 +661,47 @@ where P: Into + Copy, { } + +/// Returns `true` if the node exists in the graph and is not the module at the hierarchy root. +pub(super) fn check_valid_non_root(hugr: &H, node: H::Node) -> bool { + hugr.contains_node(node) && node != hugr.root() +} + +/// Panic if [`HugrView::contains_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !hugr.contains_node(node) { + panic!("Received an invalid node {node}.",); + } +} + +/// Panic if [`check_valid_non_root`] fails. +#[track_caller] +pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !check_valid_non_root(hugr, node) { + panic!("Received an invalid non-root node {node}.",); + } +} + +/// Panic if [`HugrView::valid_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_port( + hugr: &H, + node: Node, + port: impl Into, +) { + let port = port.into(); + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if hugr + .portgraph() + .port_index(node.into_portgraph(), port.pg_offset()) + .is_none() + { + panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); + } +} diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 906dea3e4..e3ba29e2c 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -41,37 +41,44 @@ pub struct DescendantsGraph<'g, Root = Node> { _phantom: std::marker::PhantomData, } impl HugrView for DescendantsGraph<'_, Root> { + #[inline] + fn root(&self) -> Self::Node { + self.root + } + #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] - fn node_count(&self) -> usize { + fn num_nodes(&self) -> usize { self.graph.node_count() } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { self.graph.link_count() } #[inline] fn nodes(&self) -> impl Iterator + Clone { - self.graph.nodes_iter().map(|index| self.get_node(index)) + self.graph + .nodes_iter() + .map(|index| self.from_portgraph_node(index)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -82,19 +89,19 @@ impl HugrView for DescendantsGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port: PortIndex = link.into(); let node = self.graph.port_node(port).unwrap(); let offset = self.graph.port_offset(port).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| { [p1, p2].map(|link| { let offset = self.graph.port_offset(link).unwrap(); @@ -105,30 +112,46 @@ impl HugrView for DescendantsGraph<'_, Root> { #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - let children = match self.graph.contains_node(self.get_pg_index(node)) { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + let hierarchy = self.hierarchy(); + let children = match self.graph.contains_node(self.to_portgraph_node(node)) { + true => hierarchy.children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|index| self.get_node(index)) + children.map(move |index| { + let _ = hierarchy; + self.from_portgraph_node(index) + }) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|index| self.get_node(index)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|index| self.from_portgraph_node(index)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|index| self.get_node(index)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + delegate::delegate! { + to (&self.hugr) { + fn get_parent(&self, node: Self::Node) -> Option; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + } } } @@ -138,10 +161,11 @@ where { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { check_tag::(hugr, root)?; + #[allow(deprecated)] let hugr = hugr.base_hugr(); Ok(Self { root, - graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.to_portgraph_node(root)), hugr, _phantom: std::marker::PhantomData, }) @@ -166,28 +190,39 @@ where &self.graph } - fn base_hugr(&self) -> &Hugr { - self.hugr + #[inline] + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) } fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { self.hugr.node_metadata_map(node) } + + fn base_hugr(&self) -> &Hugr { + self.hugr + } } #[cfg(test)] @@ -245,7 +280,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 7); + assert_eq!(region.num_nodes(), 7); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) || hugr.get_parent(n) == Some(inner))); @@ -265,8 +300,8 @@ pub(super) mod test { inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -315,8 +350,8 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 440df9480..6cd1d7631 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -12,12 +12,13 @@ macro_rules! hugr_internal_methods { delegate::delegate! { to ({let $arg=self; $e}) { fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &crate::Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView + Clone + '_>; + fn hierarchy(&self) -> &portgraph::Hierarchy; + fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap; + #[allow(deprecated)] + fn base_hugr(&self) -> &crate::Hugr; } } }; @@ -30,34 +31,23 @@ macro_rules! hugr_view_methods { delegate::delegate! { to ({let $arg=self; $e}) { fn root(&self) -> Self::Node; - fn root_type(&self) -> &crate::ops::OpType; + fn root_optype(&self) -> &crate::ops::OpType; fn contains_node(&self, node: Self::Node) -> bool; - fn valid_node(&self, node: Self::Node) -> bool; - fn valid_non_root(&self, node: Self::Node) -> bool; fn get_parent(&self, node: Self::Node) -> Option; - fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&crate::hugr::NodeMetadata>; - fn get_node_metadata(&self, node: Self::Node) -> Option<&crate::hugr::NodeMetadataMap>; - fn node_count(&self) -> usize; - fn edge_count(&self) -> usize; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn num_nodes(&self) -> usize; + fn num_edges(&self) -> usize; + fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; + fn num_inputs(&self, node: Self::Node) -> usize; + fn num_outputs(&self, node: Self::Node) -> usize; fn nodes(&self) -> impl Iterator + Clone; fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; - fn linked_ports( - &self, - node: Self::Node, - port: impl Into, - ) -> impl Iterator + Clone; - fn all_linked_ports( - &self, - node: Self::Node, - dir: crate::Direction, - ) -> itertools::Either< - impl Iterator, - impl Iterator, - >; + fn linked_ports(&self, node: Self::Node, port: impl Into) -> impl Iterator + Clone; + fn all_linked_ports(&self, node: Self::Node, dir: crate::Direction) -> itertools::Either, impl Iterator>; fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::Port)>; @@ -67,31 +57,19 @@ macro_rules! hugr_view_methods { fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; - fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; - fn num_inputs(&self, node: Self::Node) -> usize; - fn num_outputs(&self, node: Self::Node) -> usize; fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; fn first_child(&self, node: Self::Node) -> Option; fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn input_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn output_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; - fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>; - fn inner_function_type(&self) -> Option>; - fn poly_func_type(&self) -> Option; - // TODO: cannot use delegate here. `PetgraphWrapper` is a thin - // wrapper around `Self`, so falling back to the default impl - // should be harmless. - // fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>; fn mermaid_string(&self) -> String; fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; fn dot_string(&self) -> String; fn static_source(&self, node: Self::Node) -> Option; fn static_targets(&self, node: Self::Node) -> Option>; - fn signature(&self, node: Self::Node) -> Option>; fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; - fn in_value_types(&self, node: Self::Node) -> impl Iterator; - fn out_value_types(&self, node: Self::Node) -> impl Iterator; fn extensions(&self) -> &crate::extension::ExtensionRegistry; fn validate(&self) -> Result<(), crate::hugr::ValidationError>; fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; @@ -128,6 +106,9 @@ macro_rules! hugr_mut_methods { ($arg:ident, $e:expr) => { delegate::delegate! { to ({let $arg=self; $e}) { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut crate::hugr::NodeMetadata; + fn set_metadata(&mut self, node: Self::Node, key: impl AsRef, metadata: impl Into); + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; @@ -140,6 +121,8 @@ macro_rules! hugr_mut_methods { fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; + fn use_extension(&mut self, extension: impl Into>); + fn use_extensions(&mut self, registry: impl IntoIterator) where crate::extension::ExtensionRegistry: Extend; } } }; diff --git a/hugr-core/src/hugr/views/petgraph.rs b/hugr-core/src/hugr/views/petgraph.rs index 17c3e0062..22da47f0a 100644 --- a/hugr-core/src/hugr/views/petgraph.rs +++ b/hugr-core/src/hugr/views/petgraph.rs @@ -55,7 +55,7 @@ where T: HugrView, { fn node_count(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } } @@ -64,15 +64,15 @@ where T: HugrView, { fn node_bound(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } fn to_index(&self, ix: Self::NodeId) -> usize { - self.hugr.get_pg_index(ix).into() + self.hugr.to_portgraph_node(ix).into() } fn from_index(&self, ix: usize) -> Self::NodeId { - self.hugr.get_node(portgraph::NodeIndex::new(ix)) + self.hugr.from_portgraph_node(portgraph::NodeIndex::new(ix)) } } @@ -81,7 +81,7 @@ where T: HugrView, { fn edge_count(&self) -> usize { - HugrView::edge_count(self.hugr) + HugrView::num_edges(self.hugr) } } @@ -233,7 +233,7 @@ mod test { assert_eq!(wrapper.node_bound(), 5); assert_eq!(wrapper.edge_count(), 7); - let cx1_index = cx1.node().pg_index().index(); + let cx1_index = cx1.node().into_portgraph().index(); assert_eq!(wrapper.to_index(cx1.node()), cx1_index); assert_eq!(wrapper.from_index(cx1_index), cx1.node()); diff --git a/hugr-core/src/hugr/views/render.rs b/hugr-core/src/hugr/views/render.rs index ecb8549c0..43530e4c1 100644 --- a/hugr-core/src/hugr/views/render.rs +++ b/hugr-core/src/hugr/views/render.rs @@ -36,7 +36,7 @@ pub(super) fn node_style( config: RenderConfig, ) -> Box NodeStyle + '_> { fn node_name(h: &H, n: NodeIndex) -> String { - match h.get_optype(h.get_node(n)) { + match h.get_optype(h.from_portgraph_node(n)) { OpType::FuncDecl(f) => format!("FuncDecl: \"{}\"", f.name), OpType::FuncDefn(f) => format!("FuncDefn: \"{}\"", f.name), op => op.name().to_string(), @@ -45,14 +45,14 @@ pub(super) fn node_style( if config.node_indices { Box::new(move |n| { - NodeStyle::Box(format!( + NodeStyle::boxed(format!( "({ni}) {name}", ni = n.index(), name = node_name(h, n) )) }) } else { - Box::new(move |n| NodeStyle::Box(node_name(h, n))) + Box::new(move |n| NodeStyle::boxed(node_name(h, n))) } } @@ -64,7 +64,7 @@ pub(super) fn port_style( let graph = h.portgraph(); Box::new(move |port| { let node = graph.port_node(port).unwrap(); - let optype = h.get_optype(h.get_node(node)); + let optype = h.get_optype(h.from_portgraph_node(node)); let offset = graph.port_offset(port).unwrap(); match optype.port_kind(offset).unwrap() { EdgeKind::Function(pf) => PortStyle::new(html_escape::encode_text(&format!("{}", pf))), @@ -95,7 +95,7 @@ pub(super) fn edge_style( let graph = h.portgraph(); Box::new(move |src, tgt| { let src_node = graph.port_node(src).unwrap(); - let src_optype = h.get_optype(h.get_node(src_node)); + let src_optype = h.get_optype(h.from_portgraph_node(src_node)); let src_offset = graph.port_offset(src).unwrap(); let tgt_offset = graph.port_offset(tgt).unwrap(); diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index ac31d2695..44e29ab1a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -51,15 +51,19 @@ pub struct SiblingGraph<'g, Root = Node> { macro_rules! impl_base_members { () => { #[inline] - fn node_count(&self) -> usize { - self.base_hugr() - .hierarchy - .child_count(self.get_pg_index(self.root)) + fn root(&self) -> Self::Node { + self.root + } + + #[inline] + fn num_nodes(&self) -> usize { + self.hierarchy() + .child_count(self.to_portgraph_node(self.root)) + 1 } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { // Faster implementation than filtering all the nodes in the internal graph. self.nodes() .map(|n| self.output_neighbours(n).count()) @@ -70,10 +74,9 @@ macro_rules! impl_base_members { fn nodes(&self) -> impl Iterator + Clone { // Faster implementation than filtering all the nodes in the internal graph. let children = self - .base_hugr() - .hierarchy - .children(self.get_pg_index(self.root)) - .map(|n| self.get_node(n)); + .hierarchy() + .children(self.to_portgraph_node(self.root)) + .map(|n| self.from_portgraph_node(n)); iter::once(self.root).chain(children) } @@ -83,10 +86,41 @@ macro_rules! impl_base_members { ) -> impl DoubleEndedIterator + Clone { // Same as SiblingGraph let children = match node == self.root { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + true => self.hierarchy().children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|n| self.get_node(n)) + children.map(|n| self.from_portgraph_node(n)) + } + + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType { + self.hugr.get_optype(node) + } + + fn extensions(&self) -> &crate::extension::ExtensionRegistry { + self.hugr.extensions() + } + + fn get_parent(&self, node: Self::Node) -> Option { + match self.hugr.get_parent(node) { + Some(parent) if parent == self.root => Some(self.root), + _ => None, + } + } + + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + if node == self.root { + Either::Left(self.hugr.descendants(node)) + } else { + Either::Right(iter::empty()) + } + } + + delegate::delegate! { + to (&self.hugr) { + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + } } }; } @@ -96,20 +130,20 @@ impl HugrView for SiblingGraph<'_, Root> { #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -120,47 +154,52 @@ impl HugrView for SiblingGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let node = self.graph.port_node(link).unwrap(); let offset = self.graph.port_offset(link).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into())) } #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|n| self.get_node(n)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|n| self.get_node(n)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } } impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> { fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self { + #[allow(deprecated)] let hugr = hugr.base_hugr(); Self { root, - graph: FlatRegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: FlatRegionGraph::new_with_root( + &hugr.graph, + &hugr.hierarchy, + hugr.to_portgraph_node(root), + ), hugr, _phantom: std::marker::PhantomData, } @@ -173,7 +212,7 @@ where { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { assert!( - hugr.valid_node(root), + hugr.contains_node(root), "Cannot create a sibling graph from an invalid node {}.", root ); @@ -200,23 +239,34 @@ where } #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn base_hugr(&self) -> &Hugr { + self.hugr } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) + } + + #[inline] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) } #[inline] @@ -272,37 +322,50 @@ impl<'g, H: HugrMut, Root: NodeHandle> HugrInternals for SiblingMut<'g, #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { - FlatRegionGraph::new( + FlatRegionGraph::new_with_root( + #[allow(deprecated)] &self.base_hugr().graph, - &self.base_hugr().hierarchy, - self.get_pg_index(self.root), + self.hierarchy(), + self.to_portgraph_node(self.root), ) } #[inline] - fn base_hugr(&self) -> &Hugr { - self.hugr.base_hugr() + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Self::Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: impl NodeHandle) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { - self.hugr.get_node(index) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { + self.hugr.from_portgraph_node(index) } #[inline] fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { self.hugr.node_metadata_map(node) } + + #[inline] + fn base_hugr(&self) -> &Hugr { + #[allow(deprecated)] + self.hugr.base_hugr() + } } impl> HugrView for SiblingMut<'_, H, Root> { @@ -435,7 +498,7 @@ mod test { { let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 5); + assert_eq!(region.num_nodes(), 5); assert_eq!(region.portgraph().node_count(), 5); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) @@ -455,8 +518,8 @@ mod test { inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -591,8 +654,8 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 680d58a03..7fd2b9f54 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -194,7 +194,7 @@ impl SiblingSubgraph { let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs)); let nodes = subpg .nodes_iter() - .map(|index| hugr.get_node(index)) + .map(|index| hugr.from_portgraph_node(index)) .collect_vec(); validate_subgraph(hugr, &nodes, &inputs, &outputs)?; @@ -525,7 +525,7 @@ fn make_boundary<'a, N: HugrNode>( ) -> Boundary { let to_pg_index = |n: N, p: Port| { hugr.portgraph() - .port_index(hugr.get_pg_index(n), p.pg_offset()) + .port_index(hugr.to_portgraph_node(n), p.pg_offset()) .unwrap() }; Boundary::new( @@ -1010,9 +1010,9 @@ mod tests { assert_eq!(rep.subgraph().nodes().len(), 4); - assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out + assert_eq!(hugr.num_nodes(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out hugr.apply_patch(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + assert_eq!(hugr.num_nodes(), 4); // Module + Def + In + Out Ok(()) } diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 794e6eaaa..6aad904cb 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -364,7 +364,7 @@ pub enum ConstTypeError { /// Hugrs (even functions) inside Consts must be monomorphic fn mono_fn_type(h: &Hugr) -> Result, ConstTypeError> { let err = || ConstTypeError::NotMonomorphicFunction { - hugr_root_type: h.root_type().clone(), + hugr_root_type: h.root_optype().clone(), }; if let Some(pf) = h.poly_func_type() { match pf.try_into() { diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 9cb6f9b10..76bb2bb09 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -5,7 +5,6 @@ use hugr_core::ops::{ }; use hugr_core::Node; use hugr_core::{ - hugr::views::SiblingGraph, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, }; @@ -71,34 +70,33 @@ where debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len()); debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len()); - let region: SiblingGraph = node.try_new_hierarchy_view().unwrap(); - Topo::new(®ion.as_petgraph()) - .iter(®ion.as_petgraph()) - .filter(|x| (*x != node.node())) - .map(|x| node.hugr().fat_optype(x)) - .try_for_each(|node| { - let inputs_rmb = context.node_ins_rmb(node)?; - let inputs = inputs_rmb.read(context.builder(), [])?; - let outputs = context.node_outs_rmb(node)?.promise(); - match node.as_ref() { - OpType::Input(_) => { - let i = self.take_input()?; - outputs.finish(context.builder(), i) - } - OpType::Output(_) => { - let o = self.take_output()?; - o.finish(context.builder(), inputs) - } - _ => emit_optype( - context, - EmitOpArgs { - node, - inputs, - outputs, - }, - ), + let region_graph = node.hugr().region_portgraph(node.node()); + let topo = Topo::new(®ion_graph); + for n in topo.iter(®ion_graph) { + let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n)); + let inputs_rmb = context.node_ins_rmb(node)?; + let inputs = inputs_rmb.read(context.builder(), [])?; + let outputs = context.node_outs_rmb(node)?.promise(); + match node.as_ref() { + OpType::Input(_) => { + let i = self.take_input()?; + outputs.finish(context.builder(), i)?; } - }) + OpType::Output(_) => { + let o = self.take_output()?; + o.finish(context.builder(), inputs)?; + } + _ => emit_optype( + context, + EmitOpArgs { + node, + inputs, + outputs, + }, + )?, + } + } + Ok(()) } } diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap index b3283ee1b..124f36b53 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap @@ -20,14 +20,14 @@ entry_block: ; preds = %alloca_block define i1 @_hl.main_unary.6(i1 %0) { alloca_block: %"0" = alloca i1, align 1 - %"7_0" = alloca i1, align 1 %"9_0" = alloca i1 (i1)*, align 8 + %"7_0" = alloca i1, align 1 %"10_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block - store i1 %0, i1* %"7_0", align 1 store i1 (i1)* @_hl.main_unary.6, i1 (i1)** %"9_0", align 8 + store i1 %0, i1* %"7_0", align 1 %"9_01" = load i1 (i1)*, i1 (i1)** %"9_0", align 8 %"7_02" = load i1, i1* %"7_0", align 1 %1 = call i1 %"9_01"(i1 %"7_02") @@ -42,17 +42,17 @@ define { i1, i1 } @_hl.main_binary.11(i1 %0, i1 %1) { alloca_block: %"0" = alloca i1, align 1 %"1" = alloca i1, align 1 + %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"12_0" = alloca i1, align 1 %"12_1" = alloca i1, align 1 - %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"15_0" = alloca i1, align 1 %"15_1" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block + store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 store i1 %0, i1* %"12_0", align 1 store i1 %1, i1* %"12_1", align 1 - store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 %"14_01" = load { i1, i1 } (i1, i1)*, { i1, i1 } (i1, i1)** %"14_0", align 8 %"12_02" = load i1, i1* %"12_0", align 1 %"12_13" = load i1, i1* %"12_1", align 1 diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index dec866b4e..5deeb4bf0 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -47,7 +47,7 @@ where /// Note that while we do check the type of the node's `get_optype`, we /// do not verify that it is actually equal to `ot`. pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok()); // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq` Self { @@ -63,7 +63,7 @@ where /// If the node is invalid, or if its `get_optype` is not `OT`, returns /// `None`. pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option { - (hugr.valid_node(node)).then_some(())?; + (hugr.contains_node(node)).then_some(())?; Some(Self::new( hugr, node, @@ -99,7 +99,7 @@ impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> { /// /// Panics if the node is not valid in the [Hugr]. pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); FatNode::new(hugr, node, hugr.get_optype(node)) } diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index ff5cd93a5..3a296fc0b 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -160,7 +160,7 @@ fn test_big() { .unwrap(); let mut h = build.finish_hugr_with_outputs(to_int.outputs()).unwrap(); - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); constant_fold_pass(&mut h); @@ -333,7 +333,7 @@ fn test_const_fold_to_nonfinite() { assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); - assert_eq!(h0.node_count(), 5); + assert_eq!(h0.num_nodes(), 5); // HUGR computing 1.0 / 0.0 let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); @@ -342,7 +342,7 @@ fn test_const_fold_to_nonfinite() { let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); let mut h1 = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); constant_fold_pass(&mut h1); - assert_eq!(h1.node_count(), 8); + assert_eq!(h1.num_nodes(), 8); } #[test] @@ -1362,7 +1362,7 @@ fn test_tail_loop_unknown() { constant_fold_pass(&mut h); // Must keep the loop, even though we know the output, in case the output doesn't happen - assert_eq!(h.node_count(), 12); + assert_eq!(h.num_nodes(), 12); let tl = h .nodes() .filter(|n| h.get_optype(*n).is_tail_loop()) diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index d92fed134..25f6cf798 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -145,6 +145,7 @@ impl DeadCodeElimPass { if let Some(res) = cache.get(&n) { return *res; } + #[allow(deprecated)] let res = match self.preserve_callback.as_ref()(h.base_hugr(), n) { PreserveNode::MustKeep => true, PreserveNode::CanRemoveIgnoringChildren => false, diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index ad40e2164..ec59ccefd 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -2,11 +2,7 @@ use std::{cmp::Reverse, collections::BinaryHeap, iter}; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, HierarchyView, SiblingGraph}, - HugrError, - }, + hugr::{hugrmut::HugrMut, HugrError}, ops::{NamedOp, OpTag, OpTrait}, types::EdgeKind, HugrView as _, Node, @@ -51,34 +47,42 @@ pub fn force_order_by_key, K: Ord>( root: Node, rank: impl Fn(&H, Node) -> K, ) -> Result<(), HugrError> { - let dataflow_parents = DescendantsGraph::::try_new(hugr, root)? - .nodes() + let dataflow_parents = hugr + .descendants(root) .filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent) .collect_vec(); for dp in dataflow_parents { // we filter out the input and output nodes from the topological sort let [i, o] = hugr.get_io(dp).unwrap(); - let rank = |n| rank(hugr, n); - let sg = SiblingGraph::::try_new(hugr, dp)?; - let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o); - let ordered_nodes = ForceOrder::new(&petgraph, &rank) - .iter(&petgraph) - .filter(|&x| { - let expected_edge = Some(EdgeKind::StateOrder); - let optype = hugr.get_optype(x); - if optype.other_input() == expected_edge || optype.other_output() == expected_edge { - assert_eq!( - optype.other_input(), - optype.other_output(), - "Optype does not have both input and output order edge: {}", - optype.name() - ); - true - } else { - false - } - }) - .collect_vec(); + let ordered_nodes = { + let rank = |n| rank(hugr, hugr.from_portgraph_node(n)); + let sg = hugr.region_portgraph(dp); + let petgraph = NodeFiltered::from_fn(&sg, |x| { + let x = hugr.from_portgraph_node(x); + x != dp && x != i && x != o + }); + ForceOrder::new(&petgraph, &rank) + .iter(&petgraph) + .map(|x| hugr.from_portgraph_node(x)) + .filter(|&x| { + let expected_edge = Some(EdgeKind::StateOrder); + let optype = hugr.get_optype(x); + if optype.other_input() == expected_edge + || optype.other_output() == expected_edge + { + assert_eq!( + optype.other_input(), + optype.other_output(), + "Optype does not have both input and output order edge: {}", + optype.name() + ); + true + } else { + false + } + }) + .collect_vec() + }; // we iterate over the topologically sorted nodes, prepending the input // node and suffixing the output node. diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 7e68e600a..403e3d84b 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -141,6 +141,6 @@ mod test { }); assert_eq!(lowered.unwrap().len(), 1); - assert_eq!(h.node_count(), 3); // DFG, input, output + assert_eq!(h.num_nodes(), 3); // DFG, input, output } } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 5c76ba51d..170ff3789 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -59,7 +59,7 @@ fn mk_rep( let succ_sig = succ_ty.inner_signature(); // Make a Hugr with just a single CFG root node having the same signature. - let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + let mut replacement: Hugr = Hugr::new(cfg.root_optype().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { let mut merged_block = DataflowBlock { diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index d33234126..25249f5ae 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -135,7 +135,7 @@ impl NodeTemplate { ) -> Result<(), Option> { let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, - NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::CompoundOp(hugr) => hugr.root_optype(), NodeTemplate::Call(_, _) => return Ok(()), // no way to tell } .dataflow_signature(); @@ -1012,7 +1012,7 @@ mod test { // list -> read -> usz just becomes list -> read -> qb // list> -> read> -> opt becomes list -> get -> opt assert_eq!( - h.root_type().dataflow_signature().unwrap().io(), + h.root_optype().dataflow_signature().unwrap().io(), ( &vec![list_type(qb_t()); 2].into(), &vec![qb_t(), option_type(qb_t()).into()].into() From 2b8686b59bce58e8bea8d6411613b92b6d97488e Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 29 Apr 2025 17:16:58 +0200 Subject: [PATCH 15/21] chore: Remove stray rewrite.rs file (#2142) Oupsie, during one of the merge conflict resolutions I must have forgotten to remove the old `rewrite.rs` file. As you can see in the `hugr-core/src/hugr.rs` file, this is no longer a module and thus the file should be deleted. It has been renamed to `patch.rs` in #2070 --- hugr-core/src/hugr/rewrite.rs | 98 ----------------------------------- 1 file changed, 98 deletions(-) delete mode 100644 hugr-core/src/hugr/rewrite.rs diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs deleted file mode 100644 index 76dc93ab1..000000000 --- a/hugr-core/src/hugr/rewrite.rs +++ /dev/null @@ -1,98 +0,0 @@ -//! Rewrite operations on the HUGR - replacement, outlining, etc. - -pub mod consts; -pub mod inline_call; -pub mod inline_dfg; -pub mod insert_identity; -pub mod outline_cfg; -mod port_types; -pub mod replace; -pub mod simple_replace; - -use crate::core::HugrNode; -use crate::{Hugr, HugrView}; -pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; -pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; - -use super::HugrMut; - -/// An operation that can be applied to mutate a Hugr -pub trait Rewrite { - /// The node type used by the target Hugr. - type Node: HugrNode; - /// The type of Error with which this Rewrite may fail - type Error: std::error::Error; - /// The type returned on successful application of the rewrite. - type ApplyResult; - - /// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err. - /// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned. - const UNCHANGED_ON_FAILURE: bool; - - /// Checks whether the rewrite would succeed on the specified Hugr. - /// If this call succeeds, [self.apply] should also succeed on the same `h` - /// If this calls fails, [self.apply] would fail with the same error. - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; - - /// Mutate the specified Hugr, or fail with an error. - /// Returns [`Self::ApplyResult`] if successful. - /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned. - /// See also [self.verify] - /// # Panics - /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, - /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` - /// being preferred. - fn apply( - self, - h: &mut impl HugrMut, - ) -> Result; - - /// Returns a set of nodes referenced by the rewrite. Modifying any of these - /// nodes will invalidate it. - /// - /// Two `impl Rewrite`s can be composed if their invalidation sets are - /// disjoint. - fn invalidation_set(&self) -> impl Iterator; -} - -/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) -pub struct Transactional { - underlying: R, -} - -// Note we might like to constrain R to Rewrite but this -// is not yet supported, https://github.com/rust-lang/rust/issues/92827 -impl Rewrite for Transactional { - type Node = R::Node; - type Error = R::Error; - type ApplyResult = R::ApplyResult; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { - self.underlying.verify(h) - } - - fn apply(self, h: &mut impl HugrMut) -> Result { - if R::UNCHANGED_ON_FAILURE { - return self.underlying.apply(h); - } - // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_optype().clone()); - backup.insert_from_view(backup.root(), h); - let r = self.underlying.apply(h); - if r.is_err() { - // Try to restore backup. - h.replace_op(h.root(), backup.root_optype().clone()); - while let Some(child) = h.first_child(h.root()) { - h.remove_node(child); - } - h.insert_from_view(h.root(), &backup); - } - r - } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - self.underlying.invalidation_set() - } -} From 3d82a1bb818491606608b5b0a39e0b34083fdb64 Mon Sep 17 00:00:00 2001 From: Kartik Singhal Date: Wed, 30 Apr 2025 07:51:35 -0500 Subject: [PATCH 16/21] chore(hugr-llvm): upgrade to inkwell 0.6.0 (#2128) Part of https://github.com/quantinuum-dev/hugrverse/issues/158 --- Cargo.lock | 8 ++++---- DEVELOPMENT.md | 8 ++++---- hugr-llvm/Cargo.toml | 2 +- hugr-llvm/README.md | 2 +- hugr-llvm/src/extension/collections/static_array.rs | 8 ++++++++ hugr-llvm/src/sum.rs | 2 ++ hugr-llvm/src/sum/layout.rs | 3 +++ 7 files changed, 23 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b18c3101..f9ba0b100 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1458,9 +1458,9 @@ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "inkwell" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2" +checksum = "e67349bd7578d4afebbe15eaa642a80b884e8623db74b1716611b131feb1deef" dependencies = [ "either", "inkwell_internals", @@ -1472,9 +1472,9 @@ dependencies = [ [[package]] name = "inkwell_internals" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" +checksum = "f365c8de536236cfdebd0ba2130de22acefed18b1fb99c32783b3840aec5fb46" dependencies = [ "proc-macro2", "quote", diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 6d9465140..d9f19ed64 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -28,10 +28,10 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/). To setup the environment manually you will need: -- Just: https://just.systems/ -- Rust `>=1.85`: https://www.rust-lang.org/tools/install -- uv `>=0.3`: docs.astral.sh/uv/getting-started/installation -- Optional: capnproto `>=1.0`: https://capnproto.org/install.html +- Just: +- Rust `>=1.85`: +- uv `>=0.3`: +- Optional: capnproto `>=1.0`: Required when modifying the `hugr-model` serialization schema. - Optional: llvm `== 14.0`. The "llvm" feature (backed by the sub-crate `hugr-llvm`) requires LLVM installed. We use the rust bindings diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 1b288aa82..8f823c52a 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -23,7 +23,7 @@ llvm14-0 = ["inkwell/llvm14-0"] [dependencies] -inkwell = { version = "0.5.0", default-features = false } +inkwell = { version = "0.6.0", default-features = false } hugr-core = { path = "../hugr-core", version = "0.15.4" } anyhow = "1.0.98" itertools.workspace = true diff --git a/hugr-llvm/README.md b/hugr-llvm/README.md index 6d81cd35d..988a650dd 100644 --- a/hugr-llvm/README.md +++ b/hugr-llvm/README.md @@ -25,7 +25,7 @@ version will only change on major releases. ## Developing hugr-llvm -See [DEVELOPMENT](DEVELOPMENT.md) for instructions on setting up the development environment. +See [DEVELOPMENT](../DEVELOPMENT.md) for instructions on setting up the development environment. ## License diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 7d3ac5f5c..7f59bff82 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -58,6 +58,7 @@ fn value_is_const<'c>(value: impl BasicValue<'c>) -> bool { BasicValueEnum::PointerValue(v) => v.is_const(), BasicValueEnum::StructValue(v) => v.is_const(), BasicValueEnum::VectorValue(v) => v.is_const(), + BasicValueEnum::ScalableVectorValue(v) => v.is_const(), } } @@ -109,6 +110,13 @@ fn const_array<'c>( .collect_vec() .as_slice(), ), + BasicTypeEnum::ScalableVectorType(t) => t.const_array( + values + .into_iter() + .map(|x| x.as_basic_value_enum().into_scalable_vector_value()) + .collect_vec() + .as_slice(), + ), } } diff --git a/hugr-llvm/src/sum.rs b/hugr-llvm/src/sum.rs index c2b9a0475..381e09469 100644 --- a/hugr-llvm/src/sum.rs +++ b/hugr-llvm/src/sum.rs @@ -47,6 +47,7 @@ fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_undef().as_basic_value_enum(), } } @@ -60,6 +61,7 @@ fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_poison().as_basic_value_enum(), } } diff --git a/hugr-llvm/src/sum/layout.rs b/hugr-llvm/src/sum/layout.rs index fd67a3240..d016de851 100644 --- a/hugr-llvm/src/sum/layout.rs +++ b/hugr-llvm/src/sum/layout.rs @@ -45,6 +45,9 @@ fn size_of_type<'c>(t: impl BasicType<'c>) -> Option { BasicTypeEnum::PointerType(t) => t.size_of().get_zero_extended_constant(), BasicTypeEnum::StructType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), BasicTypeEnum::VectorType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), + BasicTypeEnum::ScalableVectorType(t) => { + t.size_of().and_then(|x| x.get_zero_extended_constant()) + } } } From cd7ef68120b5b903b12ac2fcbbf5fae812e3e70f Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Thu, 1 May 2025 09:43:21 +0100 Subject: [PATCH 17/21] feat!: Removed runtime extension sets. (#2145) This PR removes (runtime) extension sets from hugr-core and hugr-py (see https://github.com/CQCL/hugr/issues/1906). BREAKING CHANGE: Functions that manipulate runtime extension sets have been removed from the Rust and Python code. Extension set parameters were removed from operations. Closes https://github.com/CQCL/hugr/issues/1906 --- .github/workflows/ci-rs.yml | 2 +- .pre-commit-config.yaml | 7 +- hugr-core/Cargo.toml | 1 - hugr-core/README.md | 4 - hugr-core/src/builder.rs | 17 +- hugr-core/src/builder/build_traits.rs | 82 +---- hugr-core/src/builder/cfg.rs | 115 +------ hugr-core/src/builder/circuit.rs | 56 ++- hugr-core/src/builder/conditional.rs | 38 +-- hugr-core/src/builder/dataflow.rs | 23 +- hugr-core/src/builder/module.rs | 5 +- hugr-core/src/builder/tail_loop.rs | 26 +- hugr-core/src/export.rs | 10 +- hugr-core/src/extension.rs | 73 +--- hugr-core/src/extension/declarative.rs | 8 +- .../src/extension/declarative/signature.rs | 7 +- hugr-core/src/extension/op_def.rs | 56 +-- hugr-core/src/extension/prelude.rs | 30 +- .../src/extension/prelude/unwrap_builder.rs | 6 +- hugr-core/src/extension/resolution.rs | 4 - hugr-core/src/extension/resolution/test.rs | 46 +-- hugr-core/src/extension/resolution/types.rs | 2 - .../src/extension/resolution/types_mut.rs | 2 - hugr-core/src/hugr.rs | 318 +----------------- hugr-core/src/hugr/hugrmut.rs | 4 +- hugr-core/src/hugr/internal.rs | 3 +- hugr-core/src/hugr/patch/consts.rs | 6 +- hugr-core/src/hugr/patch/inline_call.rs | 19 +- hugr-core/src/hugr/patch/inline_dfg.rs | 12 +- hugr-core/src/hugr/patch/outline_cfg.rs | 33 +- hugr-core/src/hugr/patch/replace.rs | 27 +- hugr-core/src/hugr/patch/simple_replace.rs | 13 +- hugr-core/src/hugr/serialize/test.rs | 20 +- hugr-core/src/hugr/serialize/upgrade/test.rs | 1 - hugr-core/src/hugr/validate.rs | 59 +--- hugr-core/src/hugr/validate/test.rs | 315 ++--------------- hugr-core/src/hugr/views.rs | 15 - hugr-core/src/hugr/views/descendants.rs | 14 +- hugr-core/src/hugr/views/impls.rs | 1 - hugr-core/src/hugr/views/sibling.rs | 7 +- hugr-core/src/hugr/views/sibling_subgraph.rs | 64 +--- hugr-core/src/import.rs | 8 +- hugr-core/src/ops.rs | 8 +- hugr-core/src/ops/constant.rs | 29 +- hugr-core/src/ops/constant/custom.rs | 48 +-- hugr-core/src/ops/controlflow.rs | 66 +--- hugr-core/src/ops/custom.rs | 6 +- hugr-core/src/ops/dataflow.rs | 8 +- hugr-core/src/package.rs | 3 - .../std_extensions/arithmetic/conversions.rs | 8 +- .../std_extensions/arithmetic/float_ops.rs | 3 +- .../std_extensions/arithmetic/float_types.rs | 6 +- .../src/std_extensions/arithmetic/int_ops.rs | 9 +- .../std_extensions/arithmetic/int_types.rs | 6 +- .../src/std_extensions/collections/array.rs | 7 +- .../collections/array/array_repeat.rs | 41 +-- .../collections/array/array_scan.rs | 69 +--- .../collections/array/op_builder.rs | 10 +- .../src/std_extensions/collections/list.rs | 7 +- .../collections/static_array.rs | 17 +- hugr-core/src/std_extensions/ptr.rs | 5 +- hugr-core/src/types/poly_func.rs | 1 - hugr-core/src/types/signature.rs | 37 +- hugr-core/src/types/type_param.rs | 47 +-- hugr-llvm/src/emit/ops/cfg.rs | 9 +- hugr-llvm/src/emit/test.rs | 27 +- hugr-llvm/src/extension/collections/array.rs | 41 +-- hugr-passes/Cargo.toml | 3 - hugr-passes/README.md | 12 +- hugr-passes/src/composable.rs | 33 +- hugr-passes/src/const_fold/test.rs | 5 +- hugr-passes/src/dataflow/test.rs | 12 +- hugr-passes/src/dead_code.rs | 6 +- hugr-passes/src/force_order.rs | 2 +- hugr-passes/src/lower.rs | 2 +- hugr-passes/src/merge_bbs.rs | 10 +- hugr-passes/src/monomorphize.rs | 43 +-- hugr-passes/src/nest_cfgs.rs | 8 +- hugr-passes/src/non_local.rs | 8 +- hugr-passes/src/replace_types.rs | 20 +- hugr-passes/src/replace_types/handlers.rs | 21 +- hugr-passes/src/replace_types/linearize.rs | 11 +- hugr-passes/src/untuple.rs | 40 +-- hugr-passes/src/validation.rs | 11 +- hugr-py/src/hugr/_serialization/extension.py | 6 +- hugr-py/src/hugr/_serialization/ops.py | 24 +- hugr-py/src/hugr/_serialization/tys.py | 34 +- hugr-py/src/hugr/ext.py | 16 - hugr-py/src/hugr/ops.py | 11 +- .../_json_defs/arithmetic/conversions.json | 40 +-- .../hugr/std/_json_defs/arithmetic/float.json | 60 ++-- .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 141 +++----- .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 31 +- .../hugr/std/_json_defs/collections/list.json | 19 +- .../_json_defs/collections/static_array.json | 7 +- hugr-py/src/hugr/std/_json_defs/logic.json | 16 +- hugr-py/src/hugr/std/_json_defs/prelude.json | 25 +- hugr-py/src/hugr/std/_json_defs/ptr.json | 10 +- hugr-py/src/hugr/std/int.py | 2 +- hugr-py/src/hugr/tys.py | 58 +--- hugr-py/tests/serialization/test_extension.py | 6 +- hugr-py/tests/test_custom.py | 2 +- hugr-py/tests/test_tys.py | 4 - hugr/Cargo.toml | 1 - hugr/README.md | 7 +- hugr/benches/benchmarks/hugr/examples.rs | 9 +- justfile | 2 +- specification/hugr.md | 65 ---- specification/schema/hugr_schema_live.json | 81 ----- .../schema/hugr_schema_strict_live.json | 81 ----- .../schema/testing_hugr_schema_live.json | 81 ----- .../testing_hugr_schema_strict_live.json | 81 ----- .../arithmetic/conversions.json | 40 +-- .../std_extensions/arithmetic/float.json | 60 ++-- .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 141 +++----- .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 31 +- .../std_extensions/collections/list.json | 19 +- .../collections/static_array.json | 7 +- specification/std_extensions/logic.json | 16 +- specification/std_extensions/prelude.json | 25 +- specification/std_extensions/ptr.json | 10 +- 125 files changed, 597 insertions(+), 3000 deletions(-) diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 4fe5d244f..c6814fc60 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -108,7 +108,7 @@ jobs: - name: Override criterion with the CodSpeed harness run: cargo add --dev codspeed-criterion-compat --rename criterion --package hugr - name: Build benchmarks - run: cargo codspeed build --profile bench --features extension_inference,declarative,llvm,llvm-test + run: cargo codspeed build --profile bench --features declarative,llvm,llvm-test - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4fe582d93..b6e481bd5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: # built into a binary build (without using `maturin`) # # This feature list should be kept in sync with the `hugr-py/pyproject.toml` - entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' + entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' language: system files: \.rs$ pass_filenames: false @@ -100,10 +100,7 @@ repos: - id: py-test name: pytest description: Run python tests - # We need to rebuild `hugr-cli` without the `extension_inference` feature - # to avoid test errors. - # TODO: Remove this once the issue is fixed. - entry: sh -c "cargo build -p hugr-cli && uv run pytest" + entry: sh -c "uv run pytest" language: system files: \.py$ pass_filenames: false diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 49417097a..653e47b73 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -17,7 +17,6 @@ categories = ["compilers"] workspace = true [features] -extension_inference = [] declarative = ["serde_yaml"] zstd = ["dep:zstd"] diff --git a/hugr-core/README.md b/hugr-core/README.md index 765d4577b..0e15305f1 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -14,10 +14,6 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 056690e0a..9f7a219a7 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -42,7 +42,7 @@ //! let _dfg_handle = { //! let mut dfg = module_builder.define_function( //! "main", -//! Signature::new_endo(bool_t()).with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(bool_t()), //! )?; //! //! // Get the wires from the function inputs. @@ -59,8 +59,7 @@ //! let _circuit_handle = { //! let mut dfg = module_builder.define_function( //! "circuit", -//! Signature::new_endo(vec![bool_t(), bool_t()]) -//! .with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(vec![bool_t(), bool_t()]), //! )?; //! let mut circuit = dfg.as_circuit(dfg.input_wires()); //! @@ -89,7 +88,7 @@ use thiserror::Error; use crate::extension::simple_op::OpLoadError; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::hugr::ValidationError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::ops::{NamedOp, OpType}; @@ -123,16 +122,14 @@ pub use conditional::{CaseBuilder, ConditionalBuilder}; mod circuit; pub use circuit::{CircuitBuildError, CircuitBuilder}; -/// Return a FunctionType with the same input and output types (specified) -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the same input and output types (specified). pub fn endo_sig(types: impl Into) -> Signature { - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED) + Signature::new_endo(types) } -/// Return a FunctionType with the specified input and output types -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the specified input and output types. pub fn inout_sig(inputs: impl Into, outputs: impl Into) -> Signature { - Signature::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED) + Signature::new(inputs, outputs) } #[derive(Debug, Clone, PartialEq, Error)] diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 58c15c54a..ba366c117 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -20,7 +20,7 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; +use crate::extension::ExtensionRegistry; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -319,10 +319,7 @@ pub trait Dataflow: Container { inputs: impl IntoIterator, ) -> Result, BuildError> { let (types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); - self.dfg_builder( - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED), - input_wires, - ) + self.dfg_builder(Signature::new_endo(types), input_wires) } /// Return a builder for a [`crate::ops::CFG`] node, @@ -330,7 +327,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The Extension delta will be inferred. /// /// # Errors /// @@ -340,27 +336,6 @@ pub trait Dataflow: Container { &mut self, inputs: impl IntoIterator, output_types: TypeRow, - ) -> Result, BuildError> { - self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::CFG`] node, - /// i.e. a nested controlflow subgraph. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [cfg_builder](Self::cfg_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the CFG node. - fn cfg_builder_exts( - &mut self, - inputs: impl IntoIterator, - output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); @@ -369,8 +344,7 @@ pub trait Dataflow: Container { let (cfg_node, _) = add_node_with_wires( self, ops::CFG { - signature: Signature::new(inputs.clone(), output_types.clone()) - .with_extension_delta(extension_delta), + signature: Signature::new(inputs.clone(), output_types.clone()), }, input_wires, )?; @@ -449,7 +423,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The extension delta will be inferred. /// /// # Errors /// @@ -461,27 +434,6 @@ pub trait Dataflow: Container { just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, - ) -> Result, BuildError> { - self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::TailLoop`] node. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` explicitly specified. Alternatively - /// [tail_loop_builder](Self::tail_loop_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the [`ops::TailLoop`] node. - fn tail_loop_builder_exts( - &mut self, - just_inputs: impl IntoIterator, - inputs_outputs: impl IntoIterator, - just_out_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -493,7 +445,6 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), - extension_delta: extension_delta.into(), }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; @@ -507,41 +458,17 @@ pub trait Dataflow: Container { /// /// The `other_inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. - /// The `output_types` are the types of the outputs. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the Conditional node. - fn conditional_builder( - &mut self, - sum_input: (impl IntoIterator, Wire), - other_inputs: impl IntoIterator, - output_types: TypeRow, - ) -> Result, BuildError> { - self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::Conditional`] node. - /// `sum_rows` and `sum_wire` define the type of the Sum - /// variants and the wire carrying the Sum respectively. - /// - /// The `other_inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [conditional_builder](Self::conditional_builder) may be used to infer it. /// /// # Errors /// /// This function will return an error if there is an error when building /// the Conditional node. - fn conditional_builder_exts( + fn conditional_builder( &mut self, (sum_rows, sum_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let mut input_wires = vec![sum_wire]; let (input_types, rest_input_wires): (Vec, Vec) = @@ -558,7 +485,6 @@ pub trait Dataflow: Container { sum_rows, other_inputs: inputs, outputs: output_types, - extension_delta: extension_delta.into(), }, input_wires, )?; diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index 81c7d7269..0aadc047b 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -5,9 +5,8 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::extension::TO_BE_INFERRED; use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; -use crate::{extension::ExtensionSet, types::Signature}; +use crate::types::Signature; use crate::{hugr::views::HugrView, types::TypeRow}; use crate::Node; @@ -106,7 +105,6 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// let hugr = cfg_builder.finish_hugr()?; /// Ok(hugr) /// }; -/// #[cfg(not(feature = "extension_inference"))] /// assert!(make_cfg().is_ok()); /// ``` #[derive(Debug, PartialEq)] @@ -157,10 +155,7 @@ impl CFGBuilder { } impl HugrBuilder for CFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -192,7 +187,7 @@ impl + AsRef> CFGBuilder { /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. + /// specified by `sum_rows`. /// /// # Errors /// @@ -203,36 +198,12 @@ impl + AsRef> CFGBuilder { sum_rows: impl IntoIterator, other_outputs: TypeRow, ) -> Result, BuildError> { - self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn block_builder_exts( - &mut self, - inputs: TypeRow, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - false, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, false) } fn any_block_builder( &mut self, inputs: TypeRow, - extension_delta: ExtensionSet, sum_rows: impl IntoIterator, other_outputs: TypeRow, entry: bool, @@ -242,7 +213,6 @@ impl + AsRef> CFGBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta, }); let parent = self.container_node(); let block_n = if entry { @@ -257,9 +227,9 @@ impl + AsRef> CFGBuilder { BlockBuilder::create(self.hugr_mut(), block_n) } - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type - /// (a Sum of `n_cases` unit types) to select the successor. + /// Return a builder for a non-entry [`DataflowBlock`] child graph with + /// `inputs` and `outputs` , plus a UnitSum type (a Sum of `n_cases` unit + /// types) to select the successor. /// /// # Errors /// @@ -269,17 +239,15 @@ impl + AsRef> CFGBuilder { signature: Signature, n_cases: usize, ) -> Result, BuildError> { - self.block_builder_exts( + self.block_builder( signature.input, vec![type_row![]; n_cases], signature.output, - signature.runtime_reqs, ) } /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs` /// and the variants of the branching Sum value specified by `sum_rows`. - /// Extension delta will be inferred. /// /// # Errors /// @@ -288,35 +256,12 @@ impl + AsRef> CFGBuilder { &mut self, sum_rows: impl IntoIterator, other_outputs: TypeRow, - ) -> Result, BuildError> { - self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`, - /// the variants of the branching Sum value specified by `sum_rows`, and - /// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder) - /// may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if an entry block has already been built. - pub fn entry_builder_exts( - &mut self, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let inputs = self .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - true, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, true) } /// Return a builder for the entry [`DataflowBlock`] child graph with @@ -333,22 +278,6 @@ impl + AsRef> CFGBuilder { self.entry_builder(vec![type_row![]; n_cases], outputs) } - /// Return a builder for the entry [`DataflowBlock`] child graph with - /// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`. - /// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn simple_entry_builder_exts( - &mut self, - outputs: TypeRow, - n_cases: usize, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta) - } - /// Returns the exit block of this [`CFGBuilder`]. pub fn exit_block(&self) -> BasicBlockID { self.exit_node.into() @@ -412,23 +341,10 @@ impl + AsRef> BlockBuilder { impl BlockBuilder { /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// Extension delta will be inferred. pub fn new( inputs: impl Into, sum_rows: impl IntoIterator, other_outputs: impl Into, - ) -> Result { - Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - inputs: impl Into, - sum_rows: impl IntoIterator, - other_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let inputs = inputs.into(); let sum_rows: Vec<_> = sum_rows.into_iter().collect(); @@ -437,7 +353,6 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); @@ -507,11 +422,7 @@ pub(crate) mod test { ) -> Result<(), BuildError> { let usize_row: TypeRow = vec![usize_t()].into(); let sum2_variants = vec![usize_row.clone(), usize_row]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum2_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?; let entry = { let [inw] = entry_b.input_wires_arr(); @@ -537,11 +448,7 @@ pub(crate) mod test { let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![])?; let [inw] = entry_b.input_wires_arr(); let entry = { let sum = entry_b.load_const(&sum_tuple_const); diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 01f5e3e45..eb48b4fbc 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -245,8 +245,8 @@ mod test { use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::{ExtensionId, ExtensionSet}; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; + use crate::extension::ExtensionId; + use crate::std_extensions::arithmetic::float_types::ConstF64; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, }; @@ -260,10 +260,7 @@ mod test { #[test] fn simple_linear() { let build_res = build_main( - Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires = f_build.input_wires().map(Some).collect(); @@ -314,11 +311,7 @@ mod test { Signature::new( vec![qb_t(), qb_t(), usize_t()], vec![qb_t(), qb_t(), bool_t()], - ) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - my_ext_name, - ])), + ), ) .unwrap(); @@ -351,38 +344,33 @@ mod test { #[test] fn ancillae() { - let build_res = build_main( - Signature::new_endo(qb_t()) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .into(), - |mut f_build| { - let mut circ = f_build.as_circuit(f_build.input_wires()); - assert_eq!(circ.n_wires(), 1); + let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| { + let mut circ = f_build.as_circuit(f_build.input_wires()); + assert_eq!(circ.n_wires(), 1); - let [q0] = circ.tracked_units_arr(); - let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; - let ancilla = circ.track_wire(ancilla); + let [q0] = circ.tracked_units_arr(); + let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; + let ancilla = circ.track_wire(ancilla); - assert_ne!(ancilla, 0); - assert_eq!(circ.n_wires(), 2); - assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); + assert_ne!(ancilla, 0); + assert_eq!(circ.n_wires(), 2); + assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); - circ.append(cx_gate(), [q0, ancilla])?; - let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; + circ.append(cx_gate(), [q0, ancilla])?; + let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; - let q0 = circ.untrack_wire(q0)?; + let q0 = circ.untrack_wire(q0)?; - assert_eq!(circ.tracked_units_arr(), [ancilla]); + assert_eq!(circ.tracked_units_arr(), [ancilla]); - circ.append_and_consume(q_discard(), [q0])?; + circ.append_and_consume(q_discard(), [q0])?; - let outs = circ.finish(); + let outs = circ.finish(); - assert_eq!(outs.len(), 1); + assert_eq!(outs.len(), 1); - f_build.finish_with_outputs(outs) - }, - ); + f_build.finish_with_outputs(outs) + }); assert_matches!(build_res, Ok(_)); } diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index 0404abaf3..73670526c 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -1,6 +1,4 @@ -use crate::extension::TO_BE_INFERRED; use crate::hugr::views::HugrView; -use crate::ops::dataflow::DataflowOpTrait; use crate::types::{Signature, TypeRow}; use crate::ops; @@ -16,7 +14,7 @@ use super::{ }; use crate::Node; -use crate::{extension::ExtensionSet, hugr::HugrMut, Hugr}; +use crate::{hugr::HugrMut, Hugr}; use std::collections::HashSet; @@ -107,7 +105,6 @@ impl + AsRef> ConditionalBuilder { .clone() .try_into() .expect("Parent node does not have Conditional optype."); - let extension_delta = cond.signature().runtime_reqs.clone(); let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; @@ -118,8 +115,7 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { - signature: Signature::new(inputs.clone(), outputs.clone()) - .with_extension_delta(extension_delta.clone()), + signature: Signature::new(inputs.clone(), outputs.clone()), }; let case_node = // add case before any existing subsequent cases @@ -134,7 +130,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - Signature::new(inputs, outputs).with_extension_delta(extension_delta), + Signature::new(inputs, outputs), )?; Ok(CaseBuilder::from_dfg_builder(dfg_builder)) @@ -142,33 +138,18 @@ impl + AsRef> ConditionalBuilder { } impl HugrBuilder for ConditionalBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } } impl ConditionalBuilder { - /// Initialize a Conditional rooted HUGR builder, extension delta will be inferred. + /// Initialize a Conditional rooted HUGR builder. pub fn new( sum_rows: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, - ) -> Result { - Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED) - } - - /// Initialize a Conditional rooted HUGR builder, - /// `extension_delta` explicitly specified. Alternatively, - /// [new](Self::new) may be used to infer it. - pub fn new_exts( - sum_rows: impl IntoIterator, - other_inputs: impl Into, - outputs: impl Into, - extension_delta: impl Into, ) -> Result { let sum_rows: Vec<_> = sum_rows.into_iter().collect(); let other_inputs = other_inputs.into(); @@ -181,7 +162,6 @@ impl ConditionalBuilder { sum_rows, other_inputs, outputs, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); let conditional_node = base.root(); @@ -225,12 +205,8 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { - let mut conditional_b = ConditionalBuilder::new_exts( - [type_row![], type_row![]], - vec![usize_t()], - vec![usize_t()], - ExtensionSet::new(), - )?; + let mut conditional_b = + ConditionalBuilder::new([type_row![], type_row![]], vec![usize_t()], vec![usize_t()])?; n_identity(conditional_b.case_builder(1)?)?; n_identity(conditional_b.case_builder(0)?)?; diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index b84f3a05a..4e66f857f 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -82,10 +82,7 @@ impl DFGBuilder { } impl HugrBuilder for DFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -418,19 +415,15 @@ pub(crate) mod test { #[test] fn simple_inter_graph_edge() { let builder = || -> Result { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = f_build.input_wires_arr(); let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?; let i1 = noop.out_wire(0); - let mut nested = f_build.dfg_builder( - Signature::new(type_row![], vec![bool_t()]).with_prelude(), - [], - )?; + let mut nested = + f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?; let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?; @@ -445,10 +438,8 @@ pub(crate) mod test { #[test] fn add_inputs_outputs() { let builder = || -> Result<(Hugr, Node), BuildError> { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let f_node = f_build.container_node(); let [i0] = f_build.input_wires_arr(); diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 1387c1ec5..a77f01e5f 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -50,10 +50,7 @@ impl Default for ModuleBuilder { } impl HugrBuilder for ModuleBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.0.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.0.validate()?; Ok(self.0) } diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index fd6fb03b8..2baa0bcd5 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,4 +1,3 @@ -use crate::extension::{ExtensionSet, TO_BE_INFERRED}; use crate::ops::{self, DataflowOpTrait}; use crate::hugr::views::HugrView; @@ -72,29 +71,15 @@ impl + AsRef> TailLoopBuilder { impl TailLoopBuilder { /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// Extension delta will be inferred. pub fn new( just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, - ) -> Result { - Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED) - } - - /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - just_inputs: impl Into, - inputs_outputs: impl Into, - just_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), - extension_delta: extension_delta.into(), }; let base = Hugr::new(tail_loop.clone()); let root = base.root(); @@ -109,7 +94,7 @@ mod test { use crate::extension::prelude::bool_t; use crate::{ builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer}, - extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{usize_t, ConstUsize}, hugr::ValidationError, ops::Value, type_row, @@ -120,8 +105,7 @@ mod test { #[test] fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { - let mut loop_b = - TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?; + let mut loop_b = TailLoopBuilder::new(vec![], vec![bool_t()], vec![usize_t()])?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -138,10 +122,8 @@ mod test { fn loop_with_conditional() -> Result<(), BuildError> { let build_result = { let mut module_builder = ModuleBuilder::new(); - let mut fbuild = module_builder.define_function( - "main", - Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(), - )?; + let mut fbuild = module_builder + .define_function("main", Signature::new(vec![bool_t()], vec![usize_t()]))?; let _fdef = { let [b1] = fbuild.input_wires_arr(); let loop_id = { diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 078fe3c27..42e04629b 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -832,7 +832,6 @@ impl<'a> Context<'a> { ); self.make_term(table::Term::List(parts)) } - TypeArg::Extensions { .. } => self.make_term_apply("compat.ext_set", &[]), TypeArg::Variable { v } => self.export_type_arg_var(v), } } @@ -939,7 +938,6 @@ impl<'a> Context<'a> { let types = self.make_term(table::Term::List(parts)); self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } - TypeParam::Extensions => self.make_term_apply("compat.ext_set_type", &[]), } } @@ -1175,19 +1173,15 @@ mod test { use crate::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::qb_t, - std_extensions::arithmetic::float_types, types::Signature, - utils::test_quantum_extension::{self, cx_gate, h_gate}, + utils::test_quantum_extension::{cx_gate, h_gate}, Hugr, }; #[fixture] fn test_simple_circuit() -> Hugr { crate::builder::test::build_main( - Signature::new_endo(vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new_endo(vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires: Vec<_> = f_build.input_wires().collect(); let mut linear = f_build.as_circuit(wires); diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 23238ccfd..4300c74ad 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -23,7 +23,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; -use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; +use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; mod const_fold; @@ -547,8 +547,6 @@ pub struct Extension { pub version: Version, /// Unique identifier for the extension. pub name: ExtensionId, - /// Runtime dependencies this extension has on other extensions. - pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, /// Operation declarations with serializable definitions. @@ -572,7 +570,6 @@ impl Extension { Self { name, version, - runtime_reqs: Default::default(), types: Default::default(), operations: Default::default(), } @@ -629,12 +626,6 @@ impl Extension { } } - /// Extend the runtime requirements of this extension with another set of extensions. - pub fn add_requirements(&mut self, runtime_reqs: impl Into) { - let reqs = mem::take(&mut self.runtime_reqs); - self.runtime_reqs = reqs.union(runtime_reqs.into()); - } - /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -734,14 +725,6 @@ pub enum ExtensionBuildError { #[display("[{}]", _0.iter().join(", "))] pub struct ExtensionSet(BTreeSet); -/// A special ExtensionId which indicates that the delta of a non-Function -/// container node should be computed by extension inference. -/// -/// See [`infer_extensions`] which lists the container nodes to which this can be applied. -/// -/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions -pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); - impl ExtensionSet { /// Creates a new empty extension set. pub const fn new() -> Self { @@ -753,14 +736,6 @@ impl ExtensionSet { self.0.insert(extension.clone()); } - /// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set - pub fn insert_type_var(&mut self, idx: usize) { - // Represent type vars as string representation of variable index. - // This is not a legal IdentList or ExtensionId so should not conflict. - self.0 - .insert(ExtensionId::new_unchecked(idx.to_string().as_str())); - } - /// Returns `true` if the set contains the given extension. pub fn contains(&self, extension: &ExtensionId) -> bool { self.0.contains(extension) @@ -783,14 +758,6 @@ impl ExtensionSet { set } - /// An ExtensionSet containing a single type variable - /// (which must have been declared as a [TypeParam::Extensions]) - pub fn type_var(idx: usize) -> Self { - let mut set = Self::new(); - set.insert_type_var(idx); - set - } - /// Returns the union of two extension sets. pub fn union(mut self, other: Self) -> Self { self.0.extend(other.0); @@ -821,22 +788,6 @@ impl ExtensionSet { pub fn is_empty(&self) -> bool { self.0.is_empty() } - - pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> { - self.iter() - .filter_map(as_typevar) - .try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions)) - } - - pub(crate) fn substitute(&self, t: &Substitution) -> Self { - Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) { - None => vec![e.clone()], - Some(i) => match t.apply_var(i, &TypeParam::Extensions) { - TypeArg::Extensions{es} => es.iter().cloned().collect::>(), - _ => panic!("value for type var was not extension set - type scheme should be validated first"), - }, - })) - } } impl From for ExtensionSet { @@ -863,16 +814,6 @@ impl<'a> IntoIterator for &'a ExtensionSet { } } -fn as_typevar(e: &ExtensionId) -> Option { - // Type variables are represented as radix-10 numbers, which are illegal - // as standard ExtensionIds. Hence if an ExtensionId starts with a digit, - // we assume it must be a type variable, and fail fast if it isn't. - match e.chars().next() { - Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()), - _ => None, - } -} - impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { Self(BTreeSet::from_iter(iter)) @@ -967,16 +908,8 @@ pub mod test { type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - ( - hash_set(0..10usize, 0..3), - hash_set(any::(), 0..3), - ) - .prop_map(|(vars, extensions)| { - ExtensionSet::union_over( - std::iter::once(extensions.into_iter().collect::()) - .chain(vars.into_iter().map(ExtensionSet::type_var)), - ) - }) + hash_set(any::(), 0..3) + .prop_map(|extensions| extensions.into_iter().collect::()) .boxed() } } diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 64092981f..14995db27 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -149,9 +149,14 @@ impl ExtensionDeclaration { /// Create an [`Extension`] from this declaration. pub fn make_extension( &self, - imports: &ExtensionSet, + _imports: &ExtensionSet, ctx: DeclarationContext<'_>, ) -> Result, ExtensionDeclarationError> { + // TODO: The imports were previously used as runtime extension + // requirements for the constructed extension. Now that runtime + // extension requirements are removed, they are no longer recorded + // anywhere in the `Extension`. + Extension::try_new_arc( self.name.clone(), // TODO: Get the version as a parameter. @@ -164,7 +169,6 @@ impl ExtensionDeclaration { for o in &self.operations { o.register(ext, ctx, extension_ref)?; } - ext.add_requirements(imports.clone()); Ok(()) }, diff --git a/hugr-core/src/extension/declarative/signature.rs b/hugr-core/src/extension/declarative/signature.rs index b84d56853..e2300956b 100644 --- a/hugr-core/src/extension/declarative/signature.rs +++ b/hugr-core/src/extension/declarative/signature.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::extension::prelude::PRELUDE_ID; -use crate::extension::{ExtensionSet, SignatureFunc, TypeDef}; +use crate::extension::{SignatureFunc, TypeDef}; use crate::types::type_param::TypeParam; use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV}; use crate::Extension; @@ -26,10 +26,6 @@ pub(super) struct SignatureDeclaration { inputs: Vec, /// The outputs of the operation. outputs: Vec, - /// A set of extensions invoked while running this operation. - #[serde(default)] - #[serde(skip_serializing_if = "crate::utils::is_default")] - extensions: ExtensionSet, } impl SignatureDeclaration { @@ -53,7 +49,6 @@ impl SignatureDeclaration { let body = FuncValueType { input: make_type_row(&self.inputs)?, output: make_type_row(&self.outputs)?, - runtime_reqs: self.extensions.clone(), }; let poly_func = PolyFuncTypeRV::new(op_params, body); diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d5c9a5b5d..48eef663f 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -244,11 +244,7 @@ impl SignatureFunc { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; - let mut res = pf.instantiate(args)?; - - // Automatically add the extensions where the operation is defined to - // the runtime requirements of the op. - res.runtime_reqs.insert(def.extension.clone()); + let res = pf.instantiate(args)?; // If there are any row variables left, this will fail with an error: res.try_into() @@ -722,8 +718,7 @@ pub(super) mod test { Ok(Signature::new( vec![usize_t(); 3], vec![Type::new_tuple(vec![usize_t(); 3])] - ) - .with_extension_delta(EXT_ID)) + )) ); assert_eq!(def.validate_args(&args, &[]), Ok(())); @@ -733,10 +728,10 @@ pub(super) mod test { let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) + Ok(Signature::new( + tyvars.clone(), + vec![Type::new_tuple(tyvars)] + )) ); def.validate_args(&args, &[TypeBound::Copyable.into()]) .unwrap(); @@ -787,14 +782,11 @@ pub(super) mod test { ), extension_ref, )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); + let tv = Type::new_var_use(0, TypeBound::Copyable); let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); + assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); // But not with an external row variable let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); assert_eq!( @@ -811,36 +803,6 @@ pub(super) mod test { Ok(()) } - #[test] - fn instantiate_extension_delta() -> Result<(), Box> { - use crate::extension::prelude::bool_t; - - let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(bool_t()).with_extension_delta(db_set); - - let def = ext.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - extension_ref, - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(EXT_ID); - let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, ¶ms).unwrap(); - assert_eq!(def.compute_signature(&args), Ok(exp_fun_ty)); - - Ok(()) - })?; - - Ok(()) - } - mod proptest { use std::sync::Weak; diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index f88a84a0d..b1e78baf8 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -11,7 +11,7 @@ use crate::extension::simple_op::{ try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; use crate::extension::{ - ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound, + ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDefBound, }; use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::OpName; @@ -245,10 +245,6 @@ impl CustomConst for ConstString { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { string_type() } @@ -438,10 +434,6 @@ impl CustomConst for ConstUsize { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { usize_t() } @@ -495,9 +487,6 @@ impl CustomConst for ConstError { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { error_type() } @@ -555,9 +544,6 @@ impl CustomConst for ConstExternalSymbol { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { self.typ.clone() } @@ -1068,7 +1054,7 @@ mod test { let optype: OpType = op.clone().into(); assert_eq!( optype.dataflow_signature().unwrap().as_ref(), - &Signature::new_endo(type_row![Type::UNIT]).with_prelude() + &Signature::new_endo(type_row![Type::UNIT]) ); let new_op = Barrier::from_extension_op(optype.as_extension_op().unwrap()).unwrap(); @@ -1121,10 +1107,6 @@ mod test { assert!(error_val.validate().is_ok()); - assert_eq!( - error_val.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); @@ -1181,10 +1163,6 @@ mod test { let string_const: ConstString = ConstString::new("Lorem ipsum".into()); assert_eq!(string_const.name(), "ConstString(\"Lorem ipsum\")"); assert!(string_const.validate().is_ok()); - assert_eq!( - string_const.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(string_const.equal_consts(&ConstString::new("Lorem ipsum".into()))); assert!(!string_const.equal_consts(&ConstString::new("Lorem ispum".into()))); } @@ -1206,10 +1184,6 @@ mod test { assert_eq!(subject.get_type(), Type::UNIT); assert_eq!(subject.name(), "@foo"); assert!(subject.validate().is_ok()); - assert_eq!( - subject.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", string_type(), false))); diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 06b4e3939..3817d65c8 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -111,10 +111,8 @@ mod tests { #[test] fn test_build_unwrap() { - let mut builder = DFGBuilder::new( - Signature::new(Type::from(option_type(bool_t())), bool_t()).with_prelude(), - ) - .unwrap(); + let mut builder = + DFGBuilder::new(Signature::new(Type::from(option_type(bool_t())), bool_t())).unwrap(); let [opt] = builder.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 90eae9422..a08cbfb38 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -9,10 +9,6 @@ //! HUGR nodes and wire types. This is computed from the union of all extension //! required across the HUGR. //! -//! This is distinct from _runtime_ extension requirements, which are defined -//! more granularly in each function signature by the `runtime_reqs` -//! field. See the `extension_inference` feature and related modules for that. -//! //! Note: These procedures are only temporary until `hugr-model` is stabilized. //! Once that happens, hugrs will no longer be directly deserialized using serde //! but instead will be created by the methods in `crate::import`. As these diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 19373b04c..f3ae229ec 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -11,7 +11,7 @@ use crate::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; -use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize}; use crate::extension::resolution::WeakExtensionRegistry; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, @@ -28,7 +28,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::types::type_param::TypeParam; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; -use crate::{std_extensions, type_row, Extension, Hugr, HugrView}; +use crate::{type_row, Extension, Hugr, HugrView}; #[rstest] #[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] @@ -158,17 +158,7 @@ fn check_extension_resolution(mut hugr: Hugr) { /// Build a small hugr using the float types extension and check that the extensions are resolved. #[rstest] fn resolve_hugr_extensions_simple() { - let mut build = DFGBuilder::new( - Signature::new(vec![], vec![float64_type()]).with_extension_delta( - [ - PRELUDE_ID.to_owned(), - std_extensions::arithmetic::float_types::EXTENSION_ID.to_owned(), - ] - .into_iter() - .collect::(), - ), - ) - .unwrap(); + let mut build = DFGBuilder::new(Signature::new(vec![], vec![float64_type()])).unwrap(); // A constant op using a non-prelude extension. let f_const = build.add_load_const(Value::extension(ConstF64::new(f64::consts::PI))); @@ -218,7 +208,7 @@ fn resolve_hugr_extensions() { let (ext_b, op_b) = make_extension("dummy.b", "op_b"); let (ext_c, op_c) = make_extension("dummy.c", "op_c"); let (ext_d, op_d) = make_extension("dummy.d", "op_d"); - let (ext_e, op_e) = make_extension("dummy.e", "op_e"); + let (_ext_e, op_e) = make_extension("dummy.e", "op_e"); let mut module = ModuleBuilder::new(); @@ -234,18 +224,7 @@ fn resolve_hugr_extensions() { let mut func = module .define_function( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ - ext_a.name(), - ext_b.name(), - ext_c.name(), - ext_d.name(), - ext_e.name(), - ] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [func_i0, func_i1] = func.input_wires_arr(); @@ -368,11 +347,7 @@ fn resolve_call() { let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap(); let mut func = module - .define_function( - "caller_fn", - Signature::new(vec![], vec![bool_t()]) - .with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())), - ) + .define_function("caller_fn", Signature::new(vec![], vec![bool_t()])) .unwrap(); let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap(); let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap(); @@ -390,15 +365,10 @@ fn resolve_call() { /// Fail when collecting extensions but the weak pointers are not resolved. #[rstest] fn dropped_weak_extensions() { - let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let (_ext_a, op_a) = make_extension("dummy.a", "op_a"); let mut func = FunctionBuilder::new( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ext_a.name()] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [_func_i0, func_i1] = func.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 6094f0aee..28bd6a12b 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -131,8 +131,6 @@ pub(crate) fn collect_signature_exts( used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that we do not be require to be defined. collect_type_row_exts(&signature.input, used_extensions, missing_extensions); collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index d70d6b861..af5803eff 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -124,8 +124,6 @@ pub(super) fn resolve_signature_exts( extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that may not be currently present. resolve_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; resolve_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; Ok(()) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 93250b8e3..16152b298 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -29,8 +29,8 @@ use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; -use crate::ops::{OpTag, OpTrait}; +use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::ops::OpTag; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -112,9 +112,6 @@ impl Hugr { /// /// Validates the Hugr against the provided extension registry, ensuring all /// operations are resolved. - /// - /// If the feature `extension_inference` is enabled, we will ensure every function - /// correctly specifies the extensions required by its contained ops. pub fn load_json( reader: impl Read, extension_registry: &ExtensionRegistry, @@ -122,87 +119,11 @@ impl Hugr { let mut hugr: Hugr = serde_json::from_reader(reader)?; hugr.resolve_extension_defs(extension_registry)?; - hugr.validate_no_extensions()?; - - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - hugr.validate_extensions()?; - } + hugr.validate()?; Ok(hugr) } - /// Infers an extension-delta for any non-function container node - /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta - /// will be the smallest delta compatible with its children and that includes any - /// other [ExtensionId]s in the current delta. - /// - /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], - /// ExtensionIds are removed from the delta if they are *not* used by any child node. - /// - /// The non-function container nodes are: - /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] - /// - /// [Case]: crate::ops::Case - /// [CFG]: crate::ops::CFG - /// [Conditional]: crate::ops::Conditional - /// [DataflowBlock]: crate::ops::DataflowBlock - /// [DFG]: crate::ops::DFG - /// [TailLoop]: crate::ops::TailLoop - /// [extension_delta]: crate::ops::OpType::extension_delta - /// [ExtensionId]: crate::extension::ExtensionId - pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { - fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { - match optype { - OpType::DFG(dfg) => Some(&mut dfg.signature.runtime_reqs), - OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), - OpType::TailLoop(tl) => Some(&mut tl.extension_delta), - OpType::CFG(cfg) => Some(&mut cfg.signature.runtime_reqs), - OpType::Conditional(c) => Some(&mut c.extension_delta), - OpType::Case(c) => Some(&mut c.signature.runtime_reqs), - //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed - //OpType::FuncDefn(_) // Not at present due to the possibility of recursion - _ => None, - } - } - fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { - let mut child_sets = h - .children(node) - .collect::>() // Avoid borrowing h over recursive call - .into_iter() - .map(|ch| Ok((ch, infer(h, ch, remove)?))) - .collect::, _>>()?; - - let Some(es) = delta_mut(h.op_types.get_mut(node.into_portgraph())) else { - return Ok(h.get_optype(node).extension_delta()); - }; - if es.contains(&TO_BE_INFERRED) { - // Do not remove anything from current delta - any other elements are a lower bound - child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst - } else if remove { - child_sets.iter().try_for_each(|(ch, ch_exts)| { - if !es.is_superset(ch_exts) { - return Err(ExtensionError { - parent: node, - parent_extensions: es.clone(), - child: *ch, - child_extensions: ch_exts.clone(), - }); - } - Ok(()) - })?; - } else { - return Ok(es.clone()); // Can't neither add nor remove, so nothing to do - } - let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged); - - Ok(es.clone()) - } - infer(self, self.root(), remove)?; - Ok(()) - } - /// Given a Hugr that has been deserialized, collect all extensions used to /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into /// [`OpType::ExtensionOp`]s and updating the extension pointer in all @@ -214,11 +135,6 @@ impl Hugr { /// to define the HUGR nodes and wire types. This is computed from the union /// of all extension required across the HUGR. /// - /// This is distinct from _runtime_ extension requirements computed in - /// [`Hugr::infer_extensions`], which are computed more granularly in each - /// function signature by the `runtime_reqs` field and define the set - /// of capabilities required by the runtime to execute each function. - /// /// Updates the internal extension registry with the extensions used in the /// definition. /// @@ -393,73 +309,13 @@ pub enum LoadHugrError { #[cfg(test)] mod test { - use std::sync::Arc; use std::{fs::File, io::BufReader}; - use super::internal::HugrMutInternals; - #[cfg(feature = "extension_inference")] - use super::ValidationError; - use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; - use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED}; - use crate::ops::{ExtensionOp, OpName}; - use crate::types::type_param::TypeParam; - use crate::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV, TypeRow, - }; - - use crate::{const_extension_ids, ops, test_file, type_row, Extension}; - use cool_asserts::assert_matches; - use lazy_static::lazy_static; - use rstest::rstest; + use super::{Hugr, HugrView}; + use crate::extension::PRELUDE_REGISTRY; - const_extension_ids! { - pub(crate) const LIFT_EXT_ID: ExtensionId = "LIFT_EXT_ID"; - } - lazy_static! { - /// Tests only extension holding an Op that can add arbitrary extensions to a row. - pub(crate) static ref LIFT_EXT: Arc = { - Extension::new_arc( - LIFT_EXT_ID, - hugr::extension::Version::new(0, 0, 0), - |ext, extension_ref| { - ext.add_op( - OpName::new_inline("Lift"), - "".into(), - PolyFuncTypeRV::new( - vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)], - FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any)) - .with_extension_delta(ExtensionSet::type_var(0)), - ), - extension_ref, - ) - .unwrap(); - }, - ) - }; - } - - pub(crate) fn lift_op( - type_row: impl Into, - extensions: impl Into, - ) -> ExtensionOp { - LIFT_EXT - .instantiate_extension_op( - "Lift", - [ - TypeArg::Extensions { - es: extensions.into(), - }, - TypeArg::Sequence { - elems: type_row - .into() - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }, - ], - ) - .unwrap() - } + use crate::test_file; + use cool_asserts::assert_matches; #[test] fn impls_send_and_sync() { @@ -522,164 +378,4 @@ mod test { ); assert_matches!(&hugr, Ok(_)); } - - const_extension_ids! { - const XA: ExtensionId = "EXT_A"; - const XB: ExtensionId = "EXT_B"; - } - - #[rstest] - #[case([], XA.into())] - #[case([XA], XA.into())] - #[case([XB], ExtensionSet::from_iter([XA, XB]))] - - fn infer_single_delta( - #[case] parent: impl IntoIterator, - #[values(true, false)] remove: bool, // makes no difference when inferring - #[case] result: ExtensionSet, - ) { - let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); - let (mut h, _) = build_ext_dfg(parent); - h.infer_extensions(remove).unwrap(); - assert_eq!(h, build_ext_dfg(result.union(LIFT_EXT_ID.into())).0); - } - - #[test] - fn infer_removes_from_delta() { - let parent = ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]); - let mut h = build_ext_dfg(parent.clone()).0; - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - h.infer_extensions(true).unwrap(); - assert_eq!( - h, - build_ext_dfg(ExtensionSet::from_iter([XA, LIFT_EXT_ID])).0 - ); - } - - #[test] - fn infer_bad_remove() { - let (mut h, mid) = build_ext_dfg(XB.into()); - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - let val_res = h.validate(); - let expected_err = ExtensionError { - parent: h.root(), - parent_extensions: XB.into(), - child: mid, - child_extensions: ExtensionSet::from_iter([XA, LIFT_EXT_ID]), - }; - #[cfg(feature = "extension_inference")] - assert_eq!( - val_res, - Err(ValidationError::ExtensionError(expected_err.clone())) - ); - #[cfg(not(feature = "extension_inference"))] - assert!(val_res.is_ok()); - - let inf_res = h.infer_extensions(true); - assert_eq!(inf_res, Err(expected_err)); - } - - fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let mut h = Hugr::new(ops::DFG { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent.clone()), - }); - let root = h.root(); - let mid = add_inliftout(&mut h, root, ty); - (h, mid) - } - - fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { - let inp = h.add_node_with_parent( - p, - ops::Input { - types: ty.clone().into(), - }, - ); - let out = h.add_node_with_parent( - p, - ops::Output { - types: ty.clone().into(), - }, - ); - let mid = h.add_node_with_parent(p, lift_op(ty, XA)); - h.connect(inp, 0, mid, 0); - h.connect(mid, 0, out, 0); - mid - } - - #[rstest] - // Base case success: delta inferred for parent equals grandparent. - #[case([XA], [TO_BE_INFERRED], true, [XA])] - // Success: delta inferred for parent is subset of grandparent - #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] - // Base case failure: infers [XA] for parent but grandparent has disjoint set - #[case([XB], [TO_BE_INFERRED], false, [XA])] - // Failure: as previous, but extra "lower bound" on parent that has no effect - #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] - // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB - #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] - // Success: grandparent includes extra XB required for parent's "lower bound" - #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] - // Success: grandparent is also inferred so can include 'extra' XB from parent - #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] - // No inference: extraneous XB in parent is removed so all become [XA]. - #[case([XA], [XA, XB], true, [XA])] - fn infer_three_generations( - #[case] grandparent: impl IntoIterator, - #[case] parent: impl IntoIterator, - #[case] success: bool, - #[case] result: impl IntoIterator, - ) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let grandparent = ExtensionSet::from_iter(grandparent).union(LIFT_EXT_ID.into()); - let parent = ExtensionSet::from_iter(parent).union(LIFT_EXT_ID.into()); - let result = ExtensionSet::from_iter(result).union(LIFT_EXT_ID.into()); - let root_ty = ops::Conditional { - sum_rows: vec![type_row![]], - other_inputs: ty.clone().into(), - outputs: ty.clone().into(), - extension_delta: grandparent.clone(), - }; - let mut h = Hugr::new(root_ty.clone()); - let p = h.add_node_with_parent( - h.root(), - ops::Case { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent), - }, - ); - add_inliftout(&mut h, p, ty.clone()); - assert!(h.validate_extensions().is_err()); - let backup = h.clone(); - let inf_res = h.infer_extensions(true); - if success { - assert!(inf_res.is_ok()); - let expected_p = ops::Case { - signature: Signature::new_endo(ty).with_extension_delta(result.clone()), - }; - let mut expected = backup; - expected.replace_op(p, expected_p); - let expected_gp = ops::Conditional { - extension_delta: result, - ..root_ty - }; - expected.replace_op(h.root(), expected_gp); - - assert_eq!(h, expected); - } else { - assert_eq!( - inf_res, - Err(ExtensionError { - parent: h.root(), - parent_extensions: grandparent, - child: p, - child_extensions: result - }) - ); - } - } } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index 6353820f4..7805d3c67 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -614,9 +614,7 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]) - .with_prelude() - .into(), + signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]).into(), }, ); diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index f69d2ad39..09f234de0 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -372,8 +372,7 @@ mod test { #[test] fn insert_ports() { let (nop, mut hugr) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); let [nop_in] = builder.input_wires_arr(); let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) diff --git a/hugr-core/src/hugr/patch/consts.rs b/hugr-core/src/hugr/patch/consts.rs index eb9142f85..4ddd0b476 100644 --- a/hugr-core/src/hugr/patch/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -120,7 +120,6 @@ impl PatchHugrMut for RemoveConst { mod test { use super::*; - use crate::extension::prelude::PRELUDE_ID; use crate::{ builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, extension::prelude::ConstUsize, @@ -133,10 +132,7 @@ mod test { let mut build = ModuleBuilder::new(); let con_node = build.add_constant(Value::extension(ConstUsize::new(2))); - let mut dfg_build = build.define_function( - "main", - Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID.clone()), - )?; + let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?; let load_1 = dfg_build.load_const(&con_node); let load_2 = dfg_build.load_const(&con_node); let tup = dfg_build.make_tuple([load_1, load_2])?; diff --git a/hugr-core/src/hugr/patch/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs index 0619d373e..5f31fbc79 100644 --- a/hugr-core/src/hugr/patch/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -121,10 +121,8 @@ mod test { use crate::extension::prelude::usize_t; use crate::ops::handle::{FuncID, NodeHandle}; use crate::ops::{Input, OpType, Value}; - use crate::std_extensions::arithmetic::{ - int_ops::{self, IntOpDef}, - int_types::{self, ConstInt, INT_TYPES}, - }; + use crate::std_extensions::arithmetic::int_types::INT_TYPES; + use crate::std_extensions::arithmetic::{int_ops::IntOpDef, int_types::ConstInt}; use crate::types::{PolyFuncType, Signature, Type, TypeBound}; use crate::{HugrView, Node}; @@ -145,9 +143,7 @@ mod test { fn test_inline() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?)); - let sig = Signature::new_endo(INT_TYPES[4].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[4].clone()); let func = { let mut fb = mb.define_function("foo", sig.clone())?; let c1 = fb.load_const(&cst3); @@ -205,9 +201,7 @@ mod test { #[test] fn test_recursion() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); - let sig = Signature::new_endo(INT_TYPES[5].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[5].clone()); let (func, rec_call) = { let mut fb = mb.define_function("foo", sig.clone())?; let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?); @@ -294,10 +288,7 @@ mod test { #[test] fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); - let mut fb = FunctionBuilder::new( - "mkpair", - Signature::new(usize_t(), tuple_ty.clone()).with_prelude(), - )?; + let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; let inner = fb.define_function( "id", PolyFuncType::new( diff --git a/hugr-core/src/hugr/patch/inline_dfg.rs b/hugr-core/src/hugr/patch/inline_dfg.rs index 58fd51cbb..c7356f8e0 100644 --- a/hugr-core/src/hugr/patch/inline_dfg.rs +++ b/hugr-core/src/hugr/patch/inline_dfg.rs @@ -145,8 +145,6 @@ mod test { SubContainer, }; use crate::extension::prelude::qb_t; - use crate::extension::ExtensionSet; - use crate::hugr::patch::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{OpType, Value}; @@ -175,6 +173,8 @@ mod test { #[case(true)] #[case(false)] fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { + use crate::hugr::patch::inline_dfg::InlineDFGError; + let int_ty = &int_types::INT_TYPES[6]; let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; @@ -334,12 +334,8 @@ mod test { .add_dataflow_op(test_quantum_extension::measure(), r.outputs())? .outputs_arr(); // Node using the boolean. Here we just select between two empty computations. - let mut if_n = inner.conditional_builder_exts( - ([type_row![], type_row![]], b), - [], - type_row![], - ExtensionSet::new(), - )?; + let mut if_n = + inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?; if_n.case_builder(0)?.finish_with_outputs([])?; if_n.case_builder(1)?.finish_with_outputs([])?; let if_n = if_n.finish_sub_container()?; diff --git a/hugr-core/src/hugr/patch/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs index b43b6b4e3..b9cafed9e 100644 --- a/hugr-core/src/hugr/patch/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -6,11 +6,9 @@ use itertools::Itertools; use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; -use crate::extension::ExtensionSet; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; -use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::NodeHandle; use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; @@ -33,12 +31,11 @@ impl OutlineCfg { } /// Compute the entry and exit nodes of the CFG which contains - /// [`self.blocks`], along with the output neighbour its parent graph and - /// the combined extension_deltas of all of the blocks. - fn compute_entry_exit_outside_extensions( + /// [`self.blocks`], along with the output neighbour its parent graph. + fn compute_entry_exit( &self, h: &impl HugrView, - ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> { + ) -> Result<(Node, Node, Node), OutlineCfgError> { let cfg_n = match self .blocks .iter() @@ -50,13 +47,12 @@ impl OutlineCfg { _ => return Err(OutlineCfgError::NotSiblings), }; let o = h.get_optype(cfg_n); - let OpType::CFG(o) = o else { + let OpType::CFG(_) = o else { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; let mut exit_succ = None; - let mut extension_delta = ExtensionSet::new(); for &n in self.blocks.iter() { if n == cfg_entry || h.input_neighbours(n) @@ -71,7 +67,6 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(o.signature().runtime_reqs.clone()); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors @@ -87,7 +82,7 @@ impl OutlineCfg { }; } match (entry, exit_succ) { - (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)), + (Some(e), Some((x, o))) => Ok((e, x, o)), (None, _) => Err(OutlineCfgError::NoEntryNode), (_, None) => Err(OutlineCfgError::NoExitNode), } @@ -98,7 +93,7 @@ impl PatchVerification for OutlineCfg { type Error = OutlineCfgError; type Node = Node; fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside_extensions(h)?; + self.compute_entry_exit(h)?; Ok(()) } @@ -118,8 +113,7 @@ impl PatchHugrMut for OutlineCfg { self, h: &mut impl HugrMut, ) -> Result<[Node; 2], OutlineCfgError> { - let (entry, exit, outside, extension_delta) = - self.compute_entry_exit_outside_extensions(h)?; + let (entry, exit, outside) = self.compute_entry_exit(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else { @@ -136,17 +130,10 @@ impl PatchHugrMut for OutlineCfg { // 2. new_block contains input node, sub-cfg, exit node all connected let (new_block, cfg_node) = { - let mut new_block_bldr = BlockBuilder::new_exts( - inputs.clone(), - vec![type_row![]], - outputs.clone(), - extension_delta.clone(), - ) - .unwrap(); + let mut new_block_bldr = + BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap(); let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); - let cfg = new_block_bldr - .cfg_builder_exts(wires_in, outputs, extension_delta) - .unwrap(); + let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum()); let pred_wire = new_block_bldr.load_const(&unit_sum); diff --git a/hugr-core/src/hugr/patch/replace.rs b/hugr-core/src/hugr/patch/replace.rs index 183200751..606733543 100644 --- a/hugr-core/src/hugr/patch/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -609,21 +609,18 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: list::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(intermed, simple_unary_plus(just_list.clone())), }, ); [0, 1] @@ -706,7 +703,7 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder_exts(op_sig.output.clone(), 1, op_sig.runtime_reqs.clone())? + h.simple_entry_builder(op_sig.output.clone(), 1)? } else { h.simple_block_builder(op_sig.into_owned(), 1)? }; @@ -733,25 +730,20 @@ mod test { ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) .unwrap(); }); - let ext_name = ext.name().clone(); let foo = ext.instantiate_extension_op("foo", []).unwrap(); let bar = ext.instantiate_extension_op("bar", []).unwrap(); let baz = ext.instantiate_extension_op("baz", []).unwrap(); let mut registry = test_quantum_extension::REG.clone(); registry.register(ext).unwrap(); - let mut h = DFGBuilder::new( - Signature::new(vec![usize_t(), bool_t()], vec![usize_t()]) - .with_extension_delta(ext_name.clone()), - ) - .unwrap(); + let mut h = + DFGBuilder::new(Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])).unwrap(); let [i, b] = h.input_wires_arr(); let mut cond = h - .conditional_builder_exts( + .conditional_builder( (vec![type_row![]; 2], b), [(usize_t(), i)], vec![usize_t()].into(), - ext_name.clone(), ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); @@ -759,12 +751,7 @@ mod test { let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); - let mut baz_dfg = case2 - .dfg_builder( - utou.clone().with_extension_delta(ext_name.clone()), - bar.outputs(), - ) - .unwrap(); + let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs()).unwrap(); let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 3908ba58e..245a3cdc0 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -393,7 +393,6 @@ pub(in crate::hugr::patch) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::extension::ExtensionSet; use crate::hugr::patch::PatchVerification; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; @@ -404,7 +403,7 @@ pub(in crate::hugr::patch) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -421,12 +420,8 @@ pub(in crate::hugr::patch) mod test { fn make_hugr() -> Result { let mut module_builder = ModuleBuilder::new(); let _f_id = { - let just_q: ExtensionSet = EXTENSION_ID.into(); - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(just_q.clone()), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -462,7 +457,7 @@ pub(in crate::hugr::patch) mod test { /// ┤ H ├┤ X ├ /// └───┘└───┘ fn make_dfg_hugr() -> Result { - let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?; + let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 49a7b9321..6848062b7 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -6,10 +6,10 @@ use crate::builder::{ DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::simple_op::MakeRegisteredOp; +use crate::extension::test::SimpleOpDef; use crate::extension::ExtensionRegistry; -use crate::extension::{test::SimpleOpDef, ExtensionSet}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; @@ -300,7 +300,7 @@ fn weighted_hugr_ser() { let t_row = vec![Type::new_sum([vec![usize_t()], vec![qb_t()]])]; let mut f_build = module_builder - .define_function("main", Signature::new(t_row.clone(), t_row).with_prelude()) + .define_function("main", Signature::new(t_row.clone(), t_row)) .unwrap(); let outputs = f_build @@ -324,7 +324,7 @@ fn weighted_hugr_ser() { #[test] fn dfg_roundtrip() -> Result<(), Box> { let tp: Vec = vec![bool_t(); 2]; - let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp).with_prelude())?; + let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp))?; let mut params: [_; 2] = dfg.input_wires_arr(); for p in params.iter_mut() { *p = dfg @@ -390,8 +390,8 @@ fn opaque_ops() -> Result<(), Box> { #[test] fn function_type() -> Result<(), Box> { - let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()]).with_prelude()); - let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]).with_prelude())?; + let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()])); + let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]))?; let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?; let h = bldr.finish_hugr_with_outputs(op.outputs())?; @@ -482,10 +482,8 @@ fn roundtrip_value(#[case] value: Value) { } fn polyfunctype1() -> PolyFuncType { - let mut extension_set = ExtensionSet::new(); - extension_set.insert_type_var(1); - let function_type = Signature::new_endo(type_row![]).with_extension_delta(extension_set); - PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) + let function_type = Signature::new_endo(type_row![]); + PolyFuncType::new([TypeParam::max_nat()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { @@ -541,7 +539,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index 5e1d3ee51..e3aa4740b 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -55,7 +55,6 @@ pub fn hugr_with_named_op() -> Hugr { #[rstest] #[case("empty_hugr", empty_hugr())] #[case("hugr_with_named_op", hugr_with_named_op())] -#[cfg_attr(feature = "extension_inference", ignore = "Fails extension inference")] fn serial_upgrade(#[case] name: String, #[case] hugr: Hugr) { let path = TEST_CASE_DIR.join(format!("{}.json", name)); if !path.exists() { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 3690ec947..31b34b044 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,12 +9,12 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, NamedOp, OpName, OpParent, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::EdgeKind; use crate::{Direction, Hugr, Node, Port}; @@ -35,68 +35,15 @@ struct ValidationContext<'a> { } impl Hugr { - /// Check the validity of the HUGR, assuming that it has no open extension - /// variables. - /// TODO: Add a version of validation which allows for open extension - /// variables (see github issue #457) + /// Check the validity of the HUGR. pub fn validate(&self) -> Result<(), ValidationError> { - self.validate_no_extensions()?; - if cfg!(feature = "extension_inference") { - self.validate_extensions()?; - } - Ok(()) - } - - /// Check the validity of the HUGR, but don't check consistency of extension - /// requirements between connected nodes or between parents and children. - pub fn validate_no_extensions(&self) -> Result<(), ValidationError> { let mut validator = ValidationContext::new(self); validator.validate() } - - /// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children. - pub fn validate_extensions(&self) -> Result<(), ValidationError> { - for parent in self.nodes() { - let parent_op = self.get_optype(parent); - if parent_op.extension_delta().contains(&TO_BE_INFERRED) { - return Err(ValidationError::ExtensionsNotInferred { node: parent }); - } - let parent_extensions = match parent_op.inner_function_type() { - Some(s) => s.runtime_reqs.clone(), - None => match parent_op.tag() { - OpTag::Cfg | OpTag::Conditional => parent_op.extension_delta(), - // ModuleRoot holds but does not execute its children, so allow any extensions - OpTag::ModuleRoot => continue, - _ => { - assert!(self.children(parent).next().is_none(), - "Unknown parent node type {} - not a DataflowParent, Module, Cfg or Conditional", - parent_op); - continue; - } - }, - }; - for child in self.children(parent) { - let child_extensions = self.get_optype(child).extension_delta(); - if !parent_extensions.is_superset(&child_extensions) { - return Err(ExtensionError { - parent, - parent_extensions, - child, - child_extensions, - } - .into()); - } - } - } - Ok(()) - } } impl<'a> ValidationContext<'a> { /// Create a new validation context. - // Allow unused "extension_closure" variable for when - // the "extension_inference" feature is disabled. - #[allow(unused_variables)] pub fn new(hugr: &'a Hugr) -> Self { let dominators = HashMap::new(); Self { hugr, dominators } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 236f40e3f..7fec75bce 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -11,8 +11,8 @@ use crate::builder::{ FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; -use crate::extension::{Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; +use crate::extension::{Extension, ExtensionRegistry, TypeDefBound, PRELUDE}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; @@ -35,9 +35,7 @@ use crate::{ fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]) - .with_prelude() - .into(), + signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]).into(), } .into(); @@ -119,7 +117,7 @@ fn leaf_root() { #[test] fn dfg_root() { let dfg_op: OpType = ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_prelude(), + signature: Signature::new_endo(vec![bool_t()]), } .into(); @@ -217,17 +215,14 @@ fn df_children_restrictions() { #[test] fn test_ext_edge() { - let mut h = closed_dfg_root_hugr( - Signature::new(vec![bool_t(), bool_t()], vec![bool_t()]) - .with_extension_delta(TO_BE_INFERRED), - ); + let mut h = closed_dfg_root_hugr(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])); let [input, output] = h.get_io(h.root()).unwrap(); // Nested DFG bool_t() -> bool_t() let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_extension_delta(TO_BE_INFERRED), + signature: Signature::new_endo(vec![bool_t()]), }, ); // this Xor has its 2nd input unconnected @@ -254,7 +249,6 @@ fn test_ext_edge() { ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -289,8 +283,7 @@ fn no_ext_edge_into_func() -> Result<(), Box> { #[test] fn test_local_const() { - let mut h = - closed_dfg_root_hugr(Signature::new_endo(bool_t()).with_extension_delta(TO_BE_INFERRED)); + let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op()); h.connect(input, 0, and, 0); @@ -312,7 +305,6 @@ fn test_local_const() { h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -549,11 +541,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { reg.validate()?; let mut def = FunctionBuilder::new( "myfunc", - PolyFuncType::new( - [BOUND], - Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(list::EXTENSION_ID), - ), + PolyFuncType::new([BOUND], Signature::new(vec![], vec![list_of_var.clone()])), )?; let empty_list = Value::extension(list::ListValue::new_empty(Type::new_var_use( 0, @@ -646,7 +634,7 @@ fn row_variables() -> Result<(), Box> { "id", PolyFuncType::new( [TypeParam::new_list(TypeBound::Any)], - Signature::new(inner_ft.clone(), ft_usz).with_extension_delta(e.name.clone()), + Signature::new(inner_ft.clone(), ft_usz), ), )?; // All the wires here are carrying higher-order Function values @@ -668,19 +656,15 @@ fn row_variables() -> Result<(), Box> { #[test] fn test_polymorphic_call() -> Result<(), Box> { + // TODO: This tests a function call that is polymorphic in an extension set. + // Should this be rewritten to be polymorphic in something else or removed? + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); + let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; + let evaled_fn = Type::new_function(Signature::new( + Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(1, TypeBound::Any), + )); // Single-input/output version of the higher-order "eval" operation, with extension param. // Note the extension-delta of the eval node includes that of the input function. ext.add_op( @@ -690,9 +674,8 @@ fn test_polymorphic_call() -> Result<(), Box> { params.clone(), Signature::new( vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), + Type::new_var_use(1, TypeBound::Any), + ), ), extension_ref, )?; @@ -700,27 +683,23 @@ fn test_polymorphic_call() -> Result<(), Box> { Ok(()) })?; - fn utou(e: impl Into) -> Type { - Type::new_function(Signature::new_endo(usize_t()).with_extension_delta(e.into())) + fn utou() -> Type { + Type::new_function(Signature::new_endo(usize_t())) } let int_pair = Type::new_tuple(vec![usize_t(); 2]); - // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints + // Root DFG: applies a function int-->int to each element of a pair of two ints let mut d = DFGBuilder::new(inout_sig( - vec![utou(PRELUDE_ID), int_pair.clone()], + vec![utou(), int_pair.clone()], vec![int_pair.clone()], ))?; - // ....by calling a function parametrized (int--e-->int, int_pair) -> int_pair + // ....by calling a function (int-->int, int_pair) -> int_pair let f = { - let es = ExtensionSet::type_var(0); let mut f = d.define_function( "two_ints", PolyFuncType::new( - vec![TypeParam::Extensions], - Signature::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone()) - .with_extension_delta(EXT_ID) - .with_prelude() - .with_extension_delta(es.clone()), + vec![], + Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), ), )?; let [func, tup] = f.input_wires_arr(); @@ -731,14 +710,7 @@ fn test_polymorphic_call() -> Result<(), Box> { )?; let mut cc = c.case_builder(0)?; let [i1, i2] = cc.input_wires_arr(); - let op = e.instantiate_extension_op( - "eval", - vec![ - usize_t().into(), - TypeArg::Extensions { es }, - usize_t().into(), - ], - )?; + let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); cc.finish_with_outputs([f1, f2])?; @@ -748,18 +720,10 @@ fn test_polymorphic_call() -> Result<(), Box> { }; let [func, tup] = d.input_wires_arr(); - let call = d.call( - f.handle(), - &[TypeArg::Extensions { - es: ExtensionSet::singleton(PRELUDE_ID), - }], - [func, tup], - )?; + let call = d.call(f.handle(), &[], [func, tup])?; let h = d.finish_hugr_with_outputs(call.outputs())?; let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); - let exp_fun_ty = Signature::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair) - .with_extension_delta(EXT_ID) - .with_prelude(); + let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); assert_eq!(call_ty.as_ref(), &exp_fun_ty); Ok(()) } @@ -817,7 +781,6 @@ fn cfg_children_restrictions() { inputs: vec![bool_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![bool_t()].into(), - extension_delta: ExtensionSet::new(), }, ); let const_op: ops::Const = ops::Value::unit_sum(0, 1).unwrap().into(); @@ -872,7 +835,6 @@ fn cfg_children_restrictions() { inputs: vec![qb_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![qb_t()].into(), - extension_delta: ExtensionSet::new(), }, ); let mut block_children = b.hierarchy.children(block.into_portgraph()); @@ -899,8 +861,7 @@ fn cfg_connections() -> Result<(), Box> { let mut hugr = CFGBuilder::new(Signature::new_endo(usize_t()))?; let unary_pred = hugr.add_constant(Value::unary_unit_sum()); - let mut entry = - hugr.simple_entry_builder_exts(vec![usize_t()].into(), 1, ExtensionSet::new())?; + let mut entry = hugr.simple_entry_builder(vec![usize_t()].into(), 1)?; let p = entry.load_const(&unary_pred); let ins = entry.input_wires(); let entry = entry.finish_with_outputs(p, ins)?; @@ -944,219 +905,3 @@ fn cfg_entry_io_bug() -> Result<(), Box> { Ok(()) } - -#[cfg(feature = "extension_inference")] -mod extension_tests { - use self::ops::handle::{BasicBlockID, TailLoopID}; - use rstest::rstest; - - use super::*; - use crate::builder::handle::Outputs; - use crate::builder::{BlockBuilder, BuildHandle, CFGBuilder, DFGWrapper, TailLoopBuilder}; - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::hugr::test::{lift_op, LIFT_EXT_ID}; - use crate::macros::const_extension_ids; - use crate::Wire; - const_extension_ids! { - const XA: ExtensionId = "A"; - const XB: ExtensionId = "BOOL_EXT"; - } - - #[rstest] - #[case::d1(|signature| ops::DFG {signature}.into())] - #[case::f1(|sig: Signature| ops::FuncDefn {name: "foo".to_string(), signature: sig.into()}.into())] - #[case::c1(|signature| ops::Case {signature}.into())] - fn parent_extension_mismatch( - #[case] parent_f: impl Fn(Signature) -> OpType, - #[values(ExtensionSet::new(), XA.into())] parent_extensions: ExtensionSet, - ) { - // Child graph adds extension "XB", but the parent (in all cases) - // declares a different delta, causing a mismatch. - - let parent = parent_f( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - ); - let mut hugr = Hugr::new(parent); - - let input = hugr.add_node_with_parent( - hugr.root(), - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - hugr.root(), - ops::Output { - types: vec![usize_t()].into(), - }, - ); - - let lift = hugr.add_node_with_parent(hugr.root(), lift_op(usize_t(), XB)); - - hugr.connect(input, 0, lift, 0); - hugr.connect(lift, 0, output, 0); - - let result = hugr.validate(); - assert_eq!( - result, - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child: lift, - child_extensions: ExtensionSet::from_iter([LIFT_EXT_ID, XB]), - })) - ); - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB]), true)] - fn cfg_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) -> Result<(), BuildError> { - let mut cfg = CFGBuilder::new( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - )?; - let mut bb = cfg.simple_entry_builder_exts(usize_t().into(), 1, XB)?; - let pred = bb.add_load_value(Value::unary_unit_sum()); - let inputs = bb.input_wires(); - let blk = bb.finish_with_outputs(pred, inputs)?; - let exit = cfg.exit_block(); - cfg.branch(&blk, 0, &exit)?; - let root = cfg.hugr().root(); - let res = cfg.finish_hugr(); - if success { - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(ValidationError::ExtensionError(ExtensionError { - parent: root, - parent_extensions, - child: blk.node(), - child_extensions: XB.into() - })) - ); - } - Ok(()) - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]), true)] - fn conditional_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) { - // Child graph adds extension "XB", but the parent - // declares a different delta, in same cases causing a mismatch. - let parent = ops::Conditional { - sum_rows: vec![type_row![], type_row![]], - other_inputs: vec![usize_t()].into(), - outputs: vec![usize_t()].into(), - extension_delta: parent_extensions.clone(), - }; - let mut hugr = Hugr::new(parent); - - // First case with no delta should be ok in all cases. Second one may not be. - let [_, child] = [None, Some(XB)].map(|case_ext| { - let case_exts = if let Some(ex) = &case_ext { - ExtensionSet::from_iter([ex.clone(), LIFT_EXT_ID]) - } else { - ExtensionSet::new() - }; - let case = hugr.add_node_with_parent( - hugr.root(), - ops::Case { - signature: Signature::new_endo(usize_t()).with_extension_delta(case_exts), - }, - ); - - let input = hugr.add_node_with_parent( - case, - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - case, - ops::Output { - types: vec![usize_t()].into(), - }, - ); - let res = match case_ext { - None => input, - Some(new_ext) => { - let lift = hugr.add_node_with_parent(case, lift_op(usize_t(), new_ext)); - hugr.connect(input, 0, lift, 0); - lift - } - }; - hugr.connect(res, 0, output, 0); - case - }); - // case is the last-assigned child, i.e. the one that requires 'XB' - let result = hugr.validate(); - let expected = if success { - Ok(()) - } else { - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child, - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]), - })) - }; - assert_eq!(result, expected); - } - - #[rstest] - #[case(make_bb, |bb: &mut DFGWrapper<_,_>, outs| bb.make_tuple(outs))] - #[case(make_tailloop, |tl: &mut DFGWrapper<_,_>, outs| tl.make_break(tl.loop_signature().unwrap().clone(), outs))] - fn bb_extension_mismatch( - #[case] dfg_fn: impl Fn(Type, ExtensionSet) -> DFGWrapper, - #[case] make_pred: impl Fn(&mut DFGWrapper, Outputs) -> Result, - // last one includes prelude because `MakeTuple` is in prelude - #[values((ExtensionSet::from_iter([XA,LIFT_EXT_ID]), false), (LIFT_EXT_ID.into(), false), (ExtensionSet::from_iter([XA,XB,LIFT_EXT_ID,PRELUDE_ID]), true))] - parent_exts_success: (ExtensionSet, bool), - ) -> Result<(), BuildError> { - let (parent_extensions, success) = parent_exts_success; - let mut dfg = dfg_fn(usize_t(), parent_extensions.clone()); - let lift = dfg.add_dataflow_op(lift_op(usize_t(), XB), dfg.input_wires())?; - let pred = make_pred(&mut dfg, lift.outputs())?; - let root = dfg.hugr().root(); - let res = dfg.finish_hugr_with_outputs([pred]); - if success { - if res.is_err() { - dbg!(&res); - } - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(BuildError::InvalidHUGR(ValidationError::ExtensionError( - ExtensionError { - parent: root, - parent_extensions, - child: lift.node(), - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]) - } - ))) - ); - } - Ok(()) - } - - fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper { - BlockBuilder::new_exts(t.clone(), vec![t.into()], type_row![], es).unwrap() - } - - fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper> { - let row = TypeRow::from(t); - TailLoopBuilder::new_exts(row.clone(), type_row![], row, es).unwrap() - } -} diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index f9eedd548..ea414c376 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -400,25 +400,10 @@ pub trait HugrView: HugrInternals { fn extensions(&self) -> &ExtensionRegistry; /// Check the validity of the underlying HUGR. - /// - /// This includes checking consistency of extension requirements between - /// connected nodes and between parents and children. - /// See [`HugrView::validate_no_extensions`] for a version that doesn't check - /// extension requirements. fn validate(&self) -> Result<(), ValidationError> { #[allow(deprecated)] self.base_hugr().validate() } - - /// Check the validity of the underlying HUGR, but don't check consistency - /// of extension requirements between connected nodes or between parents and - /// children. - /// - /// For a more thorough check, use [`HugrView::validate`]. - fn validate_no_extensions(&self) -> Result<(), ValidationError> { - #[allow(deprecated)] - self.base_hugr().validate_no_extensions() - } } /// A common trait for views of a HUGR hierarchical subgraph. diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index e3ba29e2c..13dfde8f7 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -236,7 +236,7 @@ pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, types::Signature, - utils::test_quantum_extension::{h_gate, EXTENSION_ID}, + utils::test_quantum_extension::h_gate, }; use super::*; @@ -249,10 +249,8 @@ pub(super) mod test { let mut module_builder = ModuleBuilder::new(); let (f_id, inner_id) = { - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![usize_t(), qb_t()]))?; let [int, qb] = func_builder.input_wires_arr(); @@ -288,11 +286,7 @@ pub(super) mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 6cd1d7631..9be352b5e 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -72,7 +72,6 @@ macro_rules! hugr_view_methods { fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; fn extensions(&self) -> &crate::extension::ExtensionRegistry; fn validate(&self) -> Result<(), crate::hugr::ValidationError>; - fn validate_no_extensions(&self) -> Result<(), crate::hugr::ValidationError>; } } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index 44e29ab1a..fa8378c7a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -480,7 +480,6 @@ mod test { use crate::ops::OpType; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; use crate::types::Signature; - use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; @@ -507,11 +506,7 @@ mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); assert_eq!( diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index 7fd2b9f54..b2eba044e 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -21,7 +21,6 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::core::HugrNode; -use crate::extension::ExtensionSet; use crate::hugr::{HugrMut, HugrView}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; @@ -349,11 +348,7 @@ impl SiblingSubgraph { sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); - Signature::new(input, output).with_extension_delta(ExtensionSet::union_over( - self.nodes - .iter() - .map(|n| hugr.get_optype(*n).extension_delta()), - )) + Signature::new(input, output) } /// The parent of the sibling subgraph. @@ -840,10 +835,10 @@ mod tests { use crate::builder::inout_sig; use crate::hugr::Patch; use crate::ops::Const; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; - use crate::std_extensions::logic::{self, LogicOp}; + use crate::std_extensions::arithmetic::float_types::ConstF64; + use crate::std_extensions::logic::LogicOp; use crate::type_row; - use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64}; + use crate::utils::test_quantum_extension::{cx_gate, rz_f64}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -889,12 +884,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ])) - .into(), + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -913,12 +903,7 @@ mod tests { /// A bool to bool hugr with three subsequent NOT gates. fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(vec![bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?; @@ -937,9 +922,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new(bool_t(), vec![bool_t(), bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), + Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -957,12 +940,7 @@ mod tests { /// A HUGR with a copy fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(bool_t()) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let in_wire = dfg.input_wires().exactly_one().unwrap(); @@ -1024,12 +1002,7 @@ mod tests { let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; assert_eq!( sub.signature(&func), - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( - ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ]) - ) + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) ); Ok(()) } @@ -1218,12 +1191,7 @@ mod tests { #[test] fn test_unconnected() { // test a replacement on a subgraph with a discarded output - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - // .with_prelude() - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded @@ -1234,11 +1202,7 @@ mod tests { assert_eq!(subg.nodes().len(), 1); // TODO create a valid replacement let replacement = { - let mut rep_b = DFGBuilder::new( - Signature::new_endo(bool_t()) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let inw = rep_b.input_wires().exactly_one().unwrap(); let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); @@ -1253,11 +1217,7 @@ mod tests { #[test] fn single_node_subgraph() { // A hugr with a single NOT operation, with disconnected output. - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 899deb17d..ce5971364 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionRegistry, SignatureError}, hugr::{HugrMut, NodeMetadata}, ops::{ constant::{CustomConst, CustomSerialized, OpaqueValue}, @@ -791,7 +791,6 @@ impl<'a> Context<'a> { just_inputs, just_outputs, rest, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -819,7 +818,6 @@ impl<'a> Context<'a> { sum_rows, other_inputs, outputs, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -887,7 +885,6 @@ impl<'a> Context<'a> { inputs: types.clone(), other_outputs: TypeRow::default(), sum_rows: vec![types.clone()], - extension_delta: ExtensionSet::default(), }), ); @@ -988,7 +985,6 @@ impl<'a> Context<'a> { inputs, other_outputs, sum_rows, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -1491,7 +1487,7 @@ impl<'a> Context<'a> { let runtime_type = self.import_type(runtime_type)?; let value: serde_json::Value = serde_json::from_str(json) .map_err(|_| table::ModelError::TypeError(term_id))?; - let custom_const = CustomSerialized::new(runtime_type, value, ExtensionSet::new()); + let custom_const = CustomSerialized::new(runtime_type, value); let opaque_value = OpaqueValue::new(custom_const); return Ok(Value::Extension { e: opaque_value }); } diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index ce0d44de0..5b5dbc420 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -16,7 +16,7 @@ use crate::extension::resolution::{ use std::borrow::Cow; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; +use crate::extension::{ExtensionId, ExtensionRegistry}; use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; @@ -398,12 +398,6 @@ pub trait OpTrait: Sized + Clone { None } - /// The delta between the input extensions specified for a node, - /// and the output extensions calculated for that node - fn extension_delta(&self) -> ExtensionSet { - ExtensionSet::new() - } - /// The edge kind for the non-dataflow inputs of the operation, /// not described by the signature. /// diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 6aad904cb..18f3974d4 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -8,7 +8,6 @@ use std::hash::{Hash, Hasher}; use super::{NamedOp, OpName, OpTrait, StaticTag}; use super::{OpTag, OpType}; -use crate::extension::ExtensionSet; use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow}; use crate::{Hugr, HugrView}; @@ -81,10 +80,6 @@ impl OpTrait for Const { "Constant value" } - fn extension_delta(&self) -> ExtensionSet { - self.value().extension_reqs() - } - fn tag(&self) -> OpTag { ::TAG } @@ -251,7 +246,6 @@ pub enum Value { /// use serde_json::json; /// /// let expected_json = json!({ -/// "extensions": ["prelude"], /// "typ": usize_t(), /// "value": {'c': "ConstUsize", 'v': 1} /// }); @@ -259,9 +253,8 @@ pub enum Value { /// assert_eq!(&serde_json::to_value(&ev).unwrap(), &expected_json); /// assert_eq!(ev, serde_json::from_value(expected_json).unwrap()); /// -/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null, ExtensionSet::default())); +/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null)); /// let expected_json = json!({ -/// "extensions": [], /// "typ": usize_t(), /// "value": null /// }); @@ -297,8 +290,6 @@ impl OpaqueValue { pub fn get_type(&self) -> Type; /// An identifier of the internal [`CustomConst`]. pub fn name(&self) -> ValueName; - /// The extension(s) defining the internal [`CustomConst`]. - pub fn extension_reqs(&self) -> ExtensionSet; } } } @@ -523,17 +514,6 @@ impl Value { .into() } - /// The extensions required by a [`Value`] - pub fn extension_reqs(&self) -> ExtensionSet { - match self { - Self::Extension { e } => e.extension_reqs().clone(), - Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) - Self::Sum(Sum { values, .. }) => { - ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs())) - } - } - } - /// Check the value. pub fn validate(&self) -> Result<(), ConstTypeError> { match self { @@ -631,10 +611,6 @@ pub(crate) mod test { format!("CustomTestValue({:?})", self.0).into() } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(self.0.extension().clone()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -849,8 +825,7 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), ); - let json_const: Value = - CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into(); + let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into()).into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable); assert_eq!(json_const.get_type(), classic_t); diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 985e15594..6ff1b67aa 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -13,7 +13,6 @@ use thiserror::Error; use crate::extension::resolution::{ resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; use crate::IncomingPort; @@ -44,7 +43,6 @@ use super::{Value, ValueName}; /// #[typetag::serde] /// impl CustomConst for CC { /// fn name(&self) -> ValueName { "CC".into() } -/// fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(int_types::EXTENSION_ID) } /// fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() } /// } /// @@ -61,13 +59,6 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> ValueName; - /// The extension(s) defining the custom constant - /// (a set to allow, say, a [List] of [USize]) - /// - /// [List]: crate::std_extensions::collections::list::LIST_TYPENAME - /// [USize]: crate::extension::prelude::usize_t - fn extension_reqs(&self) -> ExtensionSet; - /// Check the value. fn validate(&self) -> Result<(), CustomCheckFailure> { Ok(()) @@ -185,7 +176,6 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); pub struct CustomSerialized { typ: Type, value: serde_json::Value, - extensions: ExtensionSet, } #[derive(Debug, Error)] @@ -206,15 +196,10 @@ pub struct DeserializeError { impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new( - typ: impl Into, - value: serde_json::Value, - exts: impl Into, - ) -> Self { + pub fn new(typ: impl Into, value: serde_json::Value) -> Self { Self { typ: typ.into(), value, - extensions: exts.into(), } } @@ -240,7 +225,6 @@ impl CustomSerialized { err, payload: cc.clone_box(), })?, - cc.extension_reqs(), ), }) } @@ -259,10 +243,10 @@ impl CustomSerialized { match cc.downcast::() { Ok(x) => Ok(*x), Err(cc) => { - let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs()); + let typ = cc.get_type(); let value = serialize_custom_const(cc.as_ref()) .map_err(|err| SerializeError { err, payload: cc })?; - Ok(Self::new(typ, value, extension_reqs)) + Ok(Self::new(typ, value)) } } } @@ -313,9 +297,6 @@ impl CustomConst for CustomSerialized { Some(self) == other.downcast_ref() } - fn extension_reqs(&self) -> ExtensionSet { - self.extensions.clone() - } fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -437,11 +418,8 @@ mod test { // check serialize_custom_const assert_eq!(expected_json, serialize_custom_const(&example.cc).unwrap()); - let expected_custom_serialized = CustomSerialized::new( - example.cc.get_type(), - expected_json, - example.cc.extension_reqs(), - ); + let expected_custom_serialized = + CustomSerialized::new(example.cc.get_type(), expected_json); // check all the try_from/try_into/into variations assert_eq!( @@ -494,11 +472,7 @@ mod test { let inner = example_custom_serialized().1; ( inner.clone(), - CustomSerialized::new( - inner.get_type(), - serialize_custom_const(&inner).unwrap(), - inner.extension_reqs(), - ), + CustomSerialized::new(inner.get_type(), serialize_custom_const(&inner).unwrap()), ) } @@ -545,7 +519,6 @@ mod proptest { use ::proptest::prelude::*; use crate::{ - extension::ExtensionSet, ops::constant::CustomSerialized, proptest::{any_serde_json_value, any_string}, types::Type, @@ -556,7 +529,6 @@ mod proptest { type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { let typ = any::(); - let extensions = any::(); // here we manually construct a serialized `dyn CustomConst`. // The "c" and "v" come from the `typetag::serde` annotation on // `trait CustomConst`. @@ -570,12 +542,8 @@ mod proptest { .collect::>() .into() }); - (typ, value, extensions) - .prop_map(|(typ, value, extensions)| CustomSerialized { - typ, - value, - extensions, - }) + (typ, value) + .prop_map(|(typ, value)| CustomSerialized { typ, value }) .boxed() } } diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 49728980f..07c04f5c4 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; -use crate::extension::ExtensionSet; use crate::types::{EdgeKind, Signature, Type, TypeRow}; use crate::Direction; @@ -20,8 +19,6 @@ pub struct TailLoop { pub just_outputs: TypeRow, /// Types that are appended to both input and output pub rest: TypeRow, - /// Extension requirements to execute the body - pub extension_delta: ExtensionSet, } impl_op_name!(TailLoop); @@ -37,9 +34,7 @@ impl DataflowOpTrait for TailLoop { // TODO: Store a cached signature let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter())); - Cow::Owned( - Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, outputs)) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -47,7 +42,6 @@ impl DataflowOpTrait for TailLoop { just_inputs: self.just_inputs.substitute(subst), just_outputs: self.just_outputs.substitute(subst), rest: self.rest.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -80,10 +74,10 @@ impl TailLoop { impl DataflowParent for TailLoop { fn inner_signature(&self) -> Cow<'_, Signature> { // TODO: Store a cached signature - Cow::Owned( - Signature::new(self.body_input_row(), self.body_output_row()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.body_input_row(), + self.body_output_row(), + )) } } @@ -97,8 +91,6 @@ pub struct Conditional { pub other_inputs: TypeRow, /// Output types pub outputs: TypeRow, - /// Extensions used to produce the outputs - pub extension_delta: ExtensionSet, } impl_op_name!(Conditional); @@ -115,10 +107,7 @@ impl DataflowOpTrait for Conditional { inputs .to_mut() .insert(0, Type::new_sum(self.sum_rows.clone())); - Cow::Owned( - Signature::new(inputs, self.outputs.clone()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, self.outputs.clone())) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -126,7 +115,6 @@ impl DataflowOpTrait for Conditional { sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), other_inputs: self.other_inputs.substitute(subst), outputs: self.outputs.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -174,7 +162,6 @@ pub struct DataflowBlock { pub inputs: TypeRow, pub other_outputs: TypeRow, pub sum_rows: Vec, - pub extension_delta: ExtensionSet, } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -213,10 +200,10 @@ impl DataflowParent for DataflowBlock { let sum_type = Type::new_sum(self.sum_rows.clone()); let mut node_outputs = vec![sum_type]; node_outputs.extend_from_slice(&self.other_outputs); - Cow::Owned( - Signature::new(self.inputs.clone(), TypeRow::from(node_outputs)) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.inputs.clone(), + TypeRow::from(node_outputs), + )) } } @@ -237,10 +224,6 @@ impl OpTrait for DataflowBlock { Some(EdgeKind::ControlFlow) } - fn extension_delta(&self) -> ExtensionSet { - self.extension_delta.clone() - } - fn non_df_port_count(&self, dir: Direction) -> usize { match dir { Direction::Incoming => 1, @@ -253,7 +236,6 @@ impl OpTrait for DataflowBlock { inputs: self.inputs.substitute(subst), other_outputs: self.other_outputs.substitute(subst), sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -343,10 +325,6 @@ impl OpTrait for Case { "A case node inside a conditional" } - fn extension_delta(&self) -> ExtensionSet { - self.signature.runtime_reqs.clone() - } - fn tag(&self) -> OpTag { ::TAG } @@ -373,10 +351,7 @@ impl Case { #[cfg(test)] mod test { use crate::{ - extension::{ - prelude::{qb_t, usize_t, PRELUDE_ID}, - ExtensionSet, - }, + extension::prelude::{qb_t, usize_t}, ops::{Conditional, DataflowOpTrait, DataflowParent}, types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, }; @@ -391,19 +366,12 @@ mod test { inputs: vec![usize_t(), tv0.clone()].into(), other_outputs: vec![tv0.clone()].into(), sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()], - extension_delta: ExtensionSet::type_var(1), }; - let dfb2 = dfb.substitute(&Substitution::new(&[ - qb_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( dfb2.inner_signature(), Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) - .with_extension_delta(PRELUDE_ID) ); } @@ -414,7 +382,6 @@ mod test { sum_rows: vec![usize_t().into(), tv1.clone().into()], other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), outputs: vec![usize_t(), tv1].into(), - extension_delta: ExtensionSet::new(), }; let cond2 = cond.substitute(&Substitution::new(&[ TypeArg::Sequence { @@ -439,21 +406,14 @@ mod test { just_inputs: vec![qb_t(), tv0.clone()].into(), just_outputs: vec![tv0.clone(), qb_t()].into(), rest: vec![tv0.clone()].into(), - extension_delta: ExtensionSet::type_var(1), }; - let tail2 = tail_loop.substitute(&Substitution::new(&[ - usize_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); assert_eq!( tail2.signature(), Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) - .with_extension_delta(PRELUDE_ID) ); } } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 6b907c947..5f5a13427 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -233,7 +233,6 @@ impl OpaqueOp { args: impl Into>, signature: Signature, ) -> Self { - let signature = signature.with_extension_delta(extension.clone()); Self { extension, name: name.into(), @@ -382,10 +381,7 @@ mod test { assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); - assert_eq!( - op.signature().as_ref(), - &sig.with_extension_delta(op.extension().clone()) - ); + assert_eq!(op.signature().as_ref(), &sig); } #[test] diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index c63c44b87..ba8f81c0c 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::{impl_op_name, OpTag, OpTrait}; -use crate::extension::{ExtensionSet, SignatureError}; +use crate::extension::SignatureError; use crate::ops::StaticTag; use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; use crate::{type_row, IncomingPort}; @@ -151,15 +151,15 @@ impl OpTrait for T { fn description(&self) -> &str { DataflowOpTrait::description(self) } + fn tag(&self) -> OpTag { T::TAG } + fn dataflow_signature(&self) -> Option> { Some(DataflowOpTrait::signature(self)) } - fn extension_delta(&self) -> ExtensionSet { - DataflowOpTrait::signature(self).runtime_reqs.clone() - } + fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) } diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 5e1fecdb6..a7c48b3a2 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -224,9 +224,6 @@ impl Package { // As a fallback, try to load a hugr json. if let Ok(mut hugr) = serde_json::from_value::(val) { hugr.resolve_extension_defs(extension_registry)?; - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - } return Ok(Package::from_hugr(hugr)?); } diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index abeb61ab0..ea1004d92 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -8,7 +8,7 @@ use crate::extension::prelude::sum_with_error; use crate::extension::prelude::{bool_t, string_type, usize_t}; use crate::extension::simple_op::{HasConcrete, HasDef}; use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc}; use crate::ops::OpName; use crate::ops::{custom::ExtensionOp, NamedOp}; use crate::std_extensions::arithmetic::int_ops::int_polytype; @@ -167,12 +167,6 @@ lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ])); - ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 08b478535..f61353528 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -9,7 +9,7 @@ use crate::{ extension::{ prelude::{bool_t, string_type}, simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionSet, OpDef, SignatureFunc, + ExtensionId, OpDef, SignatureFunc, }, types::Signature, Extension, @@ -111,7 +111,6 @@ lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); FloatOps::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 200e9dcbf..b5a741953 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{CustomType, Type, TypeBound}, Extension, @@ -97,10 +97,6 @@ impl CustomConst for ConstF64 { fn equal_consts(&self, _: &dyn CustomConst) -> bool { false } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } } lazy_static! { diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index d0ae7baa7..69939d4e1 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -14,7 +14,7 @@ use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; use crate::utils::collect_array; use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, SignatureError}, types::{type_param::TypeArg, Type}, Extension, }; @@ -252,7 +252,6 @@ lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); IntOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; @@ -377,7 +376,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new(int_type(3), int_type(4)).with_extension_delta(EXTENSION_ID) + &Signature::new(int_type(3), int_type(4)) ); assert_eq!( IntOpDef::iwiden_s @@ -386,7 +385,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new_endo(int_type(3)).with_extension_delta(EXTENSION_ID) + &Signature::new_endo(int_type(3)) ); assert_eq!( IntOpDef::inarrow_s @@ -396,7 +395,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(3), sum_ty_with_err(int_type(3))) - .with_extension_delta(EXTENSION_ID) ); assert!( IntOpDef::iwiden_u @@ -414,7 +412,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(2), sum_ty_with_err(int_type(1))) - .with_extension_delta(EXTENSION_ID) ); assert!(IntOpDef::inarrow_u diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 1342dd932..022f4d61e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, @@ -184,10 +184,6 @@ impl CustomConst for ConstInt { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } - fn get_type(&self) -> Type { int_type(type_arg(self.log_width)) } diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index fac12b1bf..2e7ee5b75 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -17,7 +17,7 @@ use crate::extension::resolution::{ WeakExtensionRegistry, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; use crate::ops::{ExtensionOp, OpName, Value}; use crate::types::type_param::{TypeArg, TypeParam}; @@ -143,11 +143,6 @@ impl CustomConst for ArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 544866970..a31505cb2 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; @@ -42,16 +42,10 @@ impl FromStr for ArrayRepeatDef { impl ArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![ - TypeParam::max_nat(), - TypeBound::Any.into(), - TypeParam::Extensions, - ]; + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t = Type::new_var_use(1, TypeBound::Any); - let es = ExtensionSet::type_var(2); - let func = - Type::new_function(Signature::new(vec![], vec![t.clone()]).with_extension_delta(es)); + let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = instantiate_array(array_def, n, t).expect("Array type instantiation failed"); PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() } @@ -109,18 +103,12 @@ pub struct ArrayRepeat { pub elem_ty: Type, /// Size of the array. pub size: u64, - /// The extensions required by the function that generates the array elements. - pub extension_reqs: ExtensionSet, } impl ArrayRepeat { /// Creates a new array repeat op. - pub fn new(elem_ty: Type, size: u64, extension_reqs: ExtensionSet) -> Self { - ArrayRepeat { - elem_ty, - size, - extension_reqs, - } + pub fn new(elem_ty: Type, size: u64) -> Self { + ArrayRepeat { elem_ty, size } } } @@ -143,9 +131,6 @@ impl MakeExtensionOp for ArrayRepeat { vec![ TypeArg::BoundedNat { n: self.size }, self.elem_ty.clone().into(), - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -169,8 +154,8 @@ impl HasConcrete for ArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }, TypeArg::Extensions { es }] => { - Ok(ArrayRepeat::new(ty.clone(), *n, es.clone())) + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + Ok(ArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -179,7 +164,7 @@ impl HasConcrete for ArrayRepeatDef { #[cfg(test)] mod tests { - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -190,7 +175,7 @@ mod tests { #[test] fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(EXTENSION_ID)); + let op = ArrayRepeat::new(qb_t(), 2); let optype: OpType = op.clone().into(); let new_op: ArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -200,8 +185,7 @@ mod tests { fn test_repeat() { let size = 2; let element_ty = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); + let op = ArrayRepeat::new(element_ty.clone(), size); let optype: OpType = op.into(); @@ -210,10 +194,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![Type::new_function( - Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) - )] - .into(), + &vec![Type::new_function(Signature::new(vec![], vec![qb_t()]))].into(), &vec![array_type(size, element_ty.clone())].into(), ) ); diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 86a0fe94e..8064a73d0 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; @@ -51,13 +51,11 @@ impl ArrayScanDef { TypeBound::Any.into(), TypeBound::Any.into(), TypeParam::new_list(TypeBound::Any), - TypeParam::Extensions, ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t1 = Type::new_var_use(1, TypeBound::Any); let t2 = Type::new_var_use(2, TypeBound::Any); let s = TypeRV::new_row_var_use(3, TypeBound::Any); - let es = ExtensionSet::type_var(4); PolyFuncTypeRV::new( params, FuncTypeBase::::new( @@ -65,13 +63,10 @@ impl ArrayScanDef { instantiate_array(array_def, n.clone(), t1.clone()) .expect("Array type instantiation failed") .into(), - Type::new_function( - FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - ) - .with_extension_delta(es), - ) + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) .into(), s.clone(), ], @@ -145,25 +140,16 @@ pub struct ArrayScan { pub acc_tys: Vec, /// Size of the array. pub size: u64, - /// The extensions required by the scan function. - pub extension_reqs: ExtensionSet, } impl ArrayScan { /// Creates a new array scan op. - pub fn new( - src_ty: Type, - tgt_ty: Type, - acc_tys: Vec, - size: u64, - extension_reqs: ExtensionSet, - ) -> Self { + pub fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { ArrayScan { src_ty, tgt_ty, acc_tys, size, - extension_reqs, } } } @@ -191,9 +177,6 @@ impl MakeExtensionOp for ArrayScan { TypeArg::Sequence { elems: self.acc_tys.clone().into_iter().map_into().collect(), }, - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -217,7 +200,7 @@ impl HasConcrete for ArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }, TypeArg::Extensions { es }] => + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() @@ -226,13 +209,7 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, - *n, - es.clone(), - )) + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -243,7 +220,7 @@ impl HasConcrete for ArrayScanDef { mod tests { use crate::extension::prelude::usize_t; - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, @@ -254,13 +231,7 @@ mod tests { #[test] fn test_scan_def() { - let op = ArrayScan::new( - bool_t(), - qb_t(), - vec![usize_t()], - 2, - ExtensionSet::singleton(EXTENSION_ID), - ); + let op = ArrayScan::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); let new_op: ArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -271,9 +242,8 @@ mod tests { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -282,9 +252,7 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new(vec![src_ty], vec![tgt_ty.clone()]).with_extension_delta(es) - ) + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) ] .into(), &vec![array_type(size, tgt_ty)].into(), @@ -299,14 +267,12 @@ mod tests { let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); let op = ArrayScan::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], size, - es.clone(), ); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -316,13 +282,10 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new( - vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], - vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] - ) - .with_extension_delta(es) - ), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), acc_ty1.clone(), acc_ty2.clone() ] diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 46338dd43..623443347 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -213,9 +213,7 @@ impl ArrayOpBuilder for D {} #[cfg(test)] mod test { - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::std_extensions::collections::array::{self, array_type}; + use crate::std_extensions::collections::array::array_type; use crate::{ builder::{DFGBuilder, HugrBuilder}, extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, @@ -229,11 +227,7 @@ mod test { #[rstest::fixture] #[default(DFGBuilder)] fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID - ]))).unwrap())] + #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)).unwrap())] mut builder: B, ) -> B { let us0 = builder.add_load_value(ConstUsize::new(0)); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 98804bab0..3ffb4d9a0 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -25,7 +25,7 @@ use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound, + ExtensionId, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, ops::{custom::ExtensionOp, NamedOp}, @@ -126,11 +126,6 @@ impl CustomConst for ListValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 9d2259e0b..05e5651a1 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -28,7 +28,7 @@ use crate::{ try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }, - ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef, + ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef, }, ops::{ constant::{maybe_hash_values, CustomConst, TryHash, ValueName}, @@ -128,11 +128,6 @@ impl CustomConst for StaticArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.get_contents().iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -404,7 +399,7 @@ impl StaticArrayOpBuilder for T {} mod test { use crate::{ builder::{DFGBuilder, DataflowHugr as _}, - extension::prelude::{qb_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{qb_t, ConstUsize}, type_row, }; @@ -419,10 +414,10 @@ mod test { #[test] fn all_ops() { let _ = { - let mut builder = DFGBuilder::new( - Signature::new(type_row![], Type::from(option_type(usize_t()))) - .with_extension_delta(ExtensionSet::from_iter([PRELUDE_ID, EXTENSION_ID])), - ) + let mut builder = DFGBuilder::new(Signature::new( + type_row![], + Type::from(option_type(usize_t())), + )) .unwrap(); let array = builder.add_load_value( StaticArrayValue::try_new( diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index fc0b1bbb4..6d77ae52d 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -268,10 +268,7 @@ pub(crate) mod test { let in_row = vec![bool_t(), float64_type()]; let hugr = { - let mut builder = DFGBuilder::new( - Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), - ) - .unwrap(); + let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap(); let in_wires: [Wire; 2] = builder.input_wires_arr(); for (ty, w) in in_row.into_iter().zip(in_wires.iter()) { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 67bc7fbf5..885b6bae8 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -277,7 +277,6 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - TypeParam::Extensions, TypeParam::List { param: Box::new(TypeParam::max_nat()), }, diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 28c39fa08..78965f1b6 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -37,8 +37,6 @@ pub struct FuncTypeBase { /// Value outputs of the function. #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] pub output: TypeRowBase, - /// The extensions the function specifies as required at runtime. - pub runtime_reqs: ExtensionSet, } /// The concept of "signature" in the spec - the edges required to/from a node @@ -55,22 +53,10 @@ pub type Signature = FuncTypeBase; pub type FuncValueType = FuncTypeBase; impl FuncTypeBase { - /// Builder method, add runtime_reqs to a FunctionType - pub fn with_extension_delta(mut self, rs: impl Into) -> Self { - self.runtime_reqs = self.runtime_reqs.union(rs.into()); - self - } - - /// Shorthand for adding the prelude extension to a FunctionType. - pub fn with_prelude(self) -> Self { - self.with_extension_delta(crate::extension::prelude::PRELUDE_ID) - } - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { Self { input: self.input.substitute(tr), output: self.output.substitute(tr), - runtime_reqs: self.runtime_reqs.substitute(tr), } } @@ -79,7 +65,6 @@ impl FuncTypeBase { Self { input: input.into(), output: output.into(), - runtime_reqs: ExtensionSet::new(), } } @@ -117,19 +102,10 @@ impl FuncTypeBase { pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.input.validate(var_decls)?; - self.output.validate(var_decls)?; - self.runtime_reqs.validate(var_decls) + self.output.validate(var_decls) } /// Returns a registry with the concrete extensions used by this signature. - /// - /// Note that extension type parameters are not included, as they have not - /// been instantiated yet. - /// - /// This method only returns extensions actually used by the types in the - /// signature. The extension deltas added via [`Self::with_extension_delta`] - /// refer to _runtime_ extensions, which may not be in all places that - /// manipulate a HUGR. pub fn used_extensions(&self) -> Result { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); @@ -167,7 +143,6 @@ impl Default for FuncTypeBase { Self { input: Default::default(), output: Default::default(), - runtime_reqs: Default::default(), } } } @@ -290,9 +265,6 @@ impl Display for FuncTypeBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.input.fmt(f)?; f.write_str(" -> ")?; - if !self.runtime_reqs.is_empty() { - self.runtime_reqs.fmt(f)?; - } self.output.fmt(f) } } @@ -303,7 +275,7 @@ impl TryFrom for Signature { fn try_from(value: FuncValueType) -> Result { let input: TypeRow = value.input.try_into()?; let output: TypeRow = value.output.try_into()?; - Ok(Self::new(input, output).with_extension_delta(value.runtime_reqs)) + Ok(Self::new(input, output)) } } @@ -312,16 +284,13 @@ impl From for FuncValueType { Self { input: value.input.into(), output: value.output.into(), - runtime_reqs: value.runtime_reqs, } } } impl PartialEq> for FuncTypeBase { fn eq(&self, other: &FuncTypeBase) -> bool { - self.input == other.input - && self.output == other.output - && self.runtime_reqs == other.runtime_reqs + self.input == other.input && self.output == other.output } } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index db2efecc6..e8fa28346 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -15,7 +15,6 @@ use super::{ check_typevar_decl, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, }; -use crate::extension::ExtensionSet; use crate::extension::SignatureError; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] @@ -92,10 +91,6 @@ pub enum TypeParam { /// The [TypeParam]s contained in the tuple. params: Vec, }, - /// Argument is a [TypeArg::Extensions]. A set of [ExtensionId]s. - /// - /// [ExtensionId]: crate::extension::ExtensionId - Extensions, } impl TypeParam { @@ -131,7 +126,6 @@ impl TypeParam { (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) } - (TypeParam::Extensions, TypeParam::Extensions) => true, _ => false, } } @@ -184,18 +178,9 @@ pub enum TypeArg { /// List of element types elems: Vec, }, - /// Instance of [TypeParam::Extensions], providing the extension ids. - #[display("Exts({})", { - use itertools::Itertools as _; - es.iter().map(|t|t.to_string()).join(",") - })] - Extensions { - #[allow(missing_docs)] - es: ExtensionSet, - }, /// Variable (used in type schemes or inside polymorphic functions), /// but not a [TypeArg::Type] (not even a row variable i.e. [TypeParam::List] of type) - /// nor [TypeArg::Extensions] - see [TypeArg::new_var_use] + /// - see [TypeArg::new_var_use] #[display("{v}")] Variable { #[allow(missing_docs)] @@ -239,14 +224,7 @@ impl From> for TypeArg { } } -impl From for TypeArg { - fn from(es: ExtensionSet) -> Self { - Self::Extensions { es } - } -} - -/// Variable in a TypeArg, that is neither a [TypeArg::Extensions] -/// nor a single [TypeArg::Type] (i.e. not a [Type::new_var_use] +/// Variable in a TypeArg, that is not a single [TypeArg::Type] (i.e. not a [Type::new_var_use] /// - it might be a [Type::new_row_var_use]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, @@ -270,10 +248,6 @@ impl TypeArg { // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. TypeParam::Type { b } => Type::new_var_use(idx, b).into(), - // Prevent TypeArg::Variable(idx, TypeParam::Extensions) - TypeParam::Extensions => TypeArg::Extensions { - es: ExtensionSet::type_var(idx), - }, _ => TypeArg::Variable { v: TypeArgVariable { idx, @@ -314,7 +288,6 @@ impl TypeArg { TypeArg::Type { ty } => ty.validate(var_decls), TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), - TypeArg::Extensions { es: _ } => Ok(()), TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => { @@ -362,9 +335,6 @@ impl TypeArg { }; TypeArg::Sequence { elems } } - TypeArg::Extensions { es } => TypeArg::Extensions { - es: es.substitute(t), - }, TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => t.apply_var(*idx, cached_decl), @@ -377,10 +347,9 @@ impl Transformable for TypeArg { match self { TypeArg::Type { ty } => ty.transform(tr), TypeArg::Sequence { elems } => elems.transform(tr), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Extensions { .. } - | TypeArg::Variable { .. } => Ok(false), + TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => { + Ok(false) + } } } } @@ -449,7 +418,6 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr } (TypeArg::String { .. }, TypeParam::String) => Ok(()), - (TypeArg::Extensions { .. }, TypeParam::Extensions) => Ok(()), _ => Err(TypeArgError::TypeMismatch { arg: arg.clone(), param: param.clone(), @@ -659,7 +627,6 @@ mod test { use proptest::prelude::*; use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; - use crate::extension::ExtensionSet; use crate::proptest::RecursionDepth; use crate::types::{Type, TypeBound}; @@ -680,7 +647,6 @@ mod test { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::Extensions).boxed(), Just(Self::String).boxed(), any::().prop_map(|b| Self::Type { b }).boxed(), any::() @@ -711,9 +677,6 @@ mod test { let mut strat = Union::new([ any::().prop_map(|n| Self::BoundedNat { n }).boxed(), any::().prop_map(|arg| Self::String { arg }).boxed(), - any::() - .prop_map(|es| Self::Extensions { es }) - .boxed(), any_with::(depth) .prop_map(|ty| Self::Type { ty }) .boxed(), diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 4d62350be..12f22d2f7 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -219,7 +219,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { mod test { use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; use hugr_core::extension::prelude::{self, bool_t}; - use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; + use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; use hugr_core::type_row; @@ -239,7 +239,6 @@ mod test { llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_int_extensions); let t1 = INT_TYPES[0].clone(); let t2 = INT_TYPES[1].clone(); - let es = ExtensionSet::from_iter([int_types::EXTENSION_ID, prelude::PRELUDE_ID]); let hugr = SimpleHugrConfig::new() .with_ins(vec![t1.clone(), t2.clone()]) .with_outs(t2.clone()) @@ -250,11 +249,7 @@ mod test { .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); let mut cfg_builder = builder - .cfg_builder_exts( - [(t1.clone(), in1), (t2.clone(), in2)], - t2.clone().into(), - es.clone(), - ) + .cfg_builder([(t1.clone(), in1), (t2.clone(), in2)], t2.clone().into()) .unwrap(); // entry block takes (t1,t2) and unconditionally branches to b1 with no other outputs diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 3f6977a8c..d53d2ef0c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -4,13 +4,8 @@ use anyhow::{anyhow, Result}; use hugr_core::builder::{ BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, }; -use hugr_core::extension::prelude::PRELUDE_ID; -use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; +use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; -use hugr_core::std_extensions::arithmetic::{ - conversions, float_ops, float_types, int_ops, int_types, -}; -use hugr_core::std_extensions::{collections, logic}; use hugr_core::types::TypeRow; use hugr_core::{Hugr, HugrView, Node}; use inkwell::module::Module; @@ -150,23 +145,7 @@ impl SimpleHugrConfig { ) -> Hugr { let mut mod_b = ModuleBuilder::new(); let func_b = mod_b - .define_function( - "main", - HugrFuncType::new(self.ins, self.outs).with_extension_delta( - ExtensionSet::from_iter([ - PRELUDE_ID, - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - float_types::EXTENSION_ID, - float_ops::EXTENSION_ID, - conversions::EXTENSION_ID, - logic::EXTENSION_ID, - collections::array::EXTENSION_ID, - collections::list::EXTENSION_ID, - collections::static_array::EXTENSION_ID, - ]), - ), - ) + .define_function("main", HugrFuncType::new(self.ins, self.outs)) .unwrap(); make(func_b, &self.extensions); @@ -265,7 +244,7 @@ mod test_fns { use hugr_core::ops::{CallIndirect, Tag, Value}; use hugr_core::std_extensions::arithmetic::int_ops::{self}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::arithmetic::int_types::{self, ConstInt}; use hugr_core::std_extensions::STD_REG; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{type_row, Hugr}; diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 55dcecefc..0216e9014 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -708,7 +708,6 @@ pub fn emit_scan_op<'c, H: HugrView>( mod test { use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; - use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; use hugr_core::std_extensions::STD_REG; @@ -854,16 +853,6 @@ mod test { ]) } - fn exec_extension_set() -> ExtensionSet { - ExtensionSet::from_iter([ - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - logic::EXTENSION_ID, - prelude::PRELUDE_ID, - array::EXTENSION_ID, - ]) - } - #[rstest] #[case(0, 1)] #[case(1, 2)] @@ -1223,16 +1212,12 @@ mod test { .with_extensions(exec_registry()) .finish(|mut builder| { let mut func = builder - .define_function( - "foo", - Signature::new(vec![], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), - ) + .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); let func_id = func.finish_with_outputs(vec![v]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); + let repeat = ArrayRepeat::new(int_ty.clone(), size); let arr = builder .add_dataflow_op(repeat, vec![func_v]) .unwrap() @@ -1280,8 +1265,7 @@ mod test { let mut func = builder .define_function( "foo", - Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), + Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), ) .unwrap(); let [elem] = func.input_wires_arr(); @@ -1289,13 +1273,7 @@ mod test { let out = func.add_iadd(6, elem, delta).unwrap(); let func_id = func.finish_with_outputs(vec![out]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - int_ty.clone(), - vec![], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), int_ty.clone(), vec![], size); let mut arr = builder .add_dataflow_op(scan, [arr, func_v]) .unwrap() @@ -1357,8 +1335,7 @@ mod test { Signature::new( vec![int_ty.clone(), int_ty.clone()], vec![Type::UNIT, int_ty.clone()], - ) - .with_extension_delta(exec_extension_set()), + ), ) .unwrap(); let [elem, acc] = func.input_wires_arr(); @@ -1369,13 +1346,7 @@ mod test { .out_wire(0); let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - Type::UNIT, - vec![int_ty.clone()], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), Type::UNIT, vec![int_ty.clone()], size); let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); let sum = builder .add_dataflow_op(scan, [arr, func_v, zero]) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 5855b8fb1..241f53ec0 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -26,9 +26,6 @@ paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } -[features] -extension_inference = ["hugr-core/extension_inference"] - [dev-dependencies] rstest = { workspace = true } proptest = { workspace = true } diff --git a/hugr-passes/README.md b/hugr-passes/README.md index b441ed5e7..c2bca2124 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-passes -=============== +# hugr-passes [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-passes) @@ -29,13 +28,6 @@ cargo add hugr-passes Please read the [API documentation here][]. -## Experimental Features - -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - ## Recent Changes See [CHANGELOG][] for a list of changes. The minimum supported rust @@ -55,4 +47,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md \ No newline at end of file diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs index ad8ff1ec0..faf92b8a7 100644 --- a/hugr-passes/src/composable.rs +++ b/hugr-passes/src/composable.rs @@ -132,21 +132,11 @@ pub enum ValidatePassError { /// Runs an underlying pass, but with validation of the Hugr /// both before and afterwards. -pub struct ValidatingPass

(P, bool); +pub struct ValidatingPass

(P); impl ValidatingPass

{ - pub fn new_default(underlying: P) -> Self { - // Self(underlying, cfg!(feature = "extension_inference")) - // Sadly, many tests fail with extension inference, hence: - Self(underlying, false) - } - - pub fn new_validating_extensions(underlying: P) -> Self { - Self(underlying, true) - } - - pub fn new(underlying: P, validate_extensions: bool) -> Self { - Self(underlying, validate_extensions) + pub fn new(underlying: P) -> Self { + Self(underlying) } fn validation_impl( @@ -154,11 +144,8 @@ impl ValidatingPass

{ hugr: &impl HugrView, mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, ) -> Result<(), ValidatePassError> { - match self.1 { - false => hugr.validate_no_extensions(), - true => hugr.validate(), - } - .map_err(|err| mk_err(err, hugr.mermaid_string())) + hugr.validate() + .map_err(|err| mk_err(err, hugr.mermaid_string())) } } @@ -222,7 +209,7 @@ pub(crate) fn validate_if_test( hugr: &mut impl HugrMut, ) -> Result> { if cfg!(test) { - ValidatingPass::new_default(pass).run(hugr) + ValidatingPass::new(pass).run(hugr) } else { pass.run(hugr).map_err(ValidatePassError::Underlying) } @@ -237,9 +224,7 @@ mod test { Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use hugr_core::extension::prelude::{ - bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, - }; + use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; @@ -315,7 +300,7 @@ mod test { cfold.run(&mut h).unwrap(); assert_eq!(h, backup); // Did nothing - let r = ValidatingPass(cfold, false).run(&mut h); + let r = ValidatingPass(cfold).run(&mut h); assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); } @@ -324,7 +309,7 @@ mod test { let tr = TypeRow::from(vec![usize_t(); 2]); let h = { - let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let sig = Signature::new_endo(tr.clone()); let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); let [a, b] = fb.input_wires_arr(); let tup = fb diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 3a296fc0b..dcdc4df0a 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -3,7 +3,6 @@ use std::collections::HashSet; use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::Const; -use hugr_core::std_extensions::arithmetic::{int_ops, int_types}; use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; @@ -1595,9 +1594,7 @@ fn test_module() -> Result<(), Box> { let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", - Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]) - .with_extension_delta(int_types::EXTENSION_ID) - .with_extension_delta(int_ops::EXTENSION_ID), + Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]), )?; let lc7 = main.load_const(&c7); let lc17 = main.load_const(&c17); diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 1c4b4e439..a67556ce1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -9,10 +9,7 @@ use hugr_core::ops::{CallIndirect, TailLoop}; use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::{ - prelude::{bool_t, UnpackTuple}, - ExtensionSet, - }, + extension::prelude::{bool_t, UnpackTuple}, ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, types::{Signature, SumType, Type}, @@ -176,12 +173,7 @@ fn test_tail_loop_two_iters() { let false_w = builder.add_load_value(Value::false_val()); let tlb = builder - .tail_loop_builder_exts( - [], - [(bool_t(), false_w), (bool_t(), true_w)], - type_row![], - ExtensionSet::new(), - ) + .tail_loop_builder([], [(bool_t(), false_w), (bool_t(), true_w)], type_row![]) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().signature().as_ref(), diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index 25f6cf798..69bcfabf6 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -180,7 +180,7 @@ mod test { use std::sync::Arc; use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder}; - use hugr_core::extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}; + use hugr_core::extension::prelude::{usize_t, ConstUsize}; use hugr_core::ops::{handle::NodeHandle, OpTag, OpTrait}; use hugr_core::types::Signature; use hugr_core::{ops::Value, type_row, HugrView}; @@ -192,9 +192,7 @@ mod test { #[test] fn test_cfg_callback() { - let mut cb = - CFGBuilder::new(Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) - .unwrap(); + let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap(); let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3))); let cst_used_in_dfg = cb.add_constant(Value::from(ConstUsize::new(5))); let cst_used = cb.add_constant(Value::unary_unit_sum()); diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index ec59ccefd..cbb637b2a 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -279,7 +279,7 @@ mod test { .iter(&hugr.as_petgraph()) .filter(|n| rank_map.contains_key(n)) .collect_vec(); - hugr.validate_no_extensions().unwrap(); + hugr.validate().unwrap(); topo_sorted } diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 403e3d84b..334127bab 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -94,7 +94,7 @@ mod test { #[fixture] fn noop_hugr() -> Hugr { - let mut b = DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let out = b .add_dataflow_op(Noop::new(bool_t()), [b.input_wires().next().unwrap()]) .unwrap() diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 170ff3789..2c739c3d7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -62,14 +62,10 @@ fn mk_rep( let mut replacement: Hugr = Hugr::new(cfg.root_optype().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { - let mut merged_block = DataflowBlock { + DataflowBlock { inputs: pred_ty.inputs.clone(), ..succ_ty.clone() - }; - merged_block.extension_delta = merged_block - .extension_delta - .union(pred_ty.extension_delta.clone()); - merged_block + } }); let input = replacement.add_node_with_parent( merged, @@ -225,7 +221,7 @@ mod test { let e = extension(); let tst_op = e.instantiate_extension_op("Test", [])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; - let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; + let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; let br = unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 3ac85a020..cfe2c9514 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -2,7 +2,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, convert::Infallible, fmt::Write, - ops::Deref, }; use hugr_core::{ @@ -300,10 +299,6 @@ fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fm TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), - TypeArg::Extensions { es } => f.write_fmt(format_args!( - "es({})", - es.iter().map(|x| x.deref()).join(",") - )), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), @@ -338,6 +333,7 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; use hugr_core::types::type_param::TypeParam; @@ -347,16 +343,10 @@ mod test { Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use hugr_core::extension::prelude::{ - usize_t, ConstUsize, UnpackTuple, UnwrapBuilder, PRELUDE_ID, - }; - use hugr_core::extension::ExtensionSet; + use hugr_core::extension::prelude::{usize_t, ConstUsize, UnpackTuple, UnwrapBuilder}; use hugr_core::ops::handle::{FuncID, NodeHandle}; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; - use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; - use hugr_core::types::{ - PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, - }; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum}; use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; @@ -372,10 +362,6 @@ mod test { Type::new_tuple(vec![ty.clone(), ty.clone(), ty]) } - fn prelusig(ins: impl Into, outs: impl Into) -> Signature { - Signature::new(ins, outs).with_extension_delta(PRELUDE_ID) - } - #[test] fn test_null() { let dfg_builder = @@ -411,7 +397,7 @@ mod test { }; let tr = { - let sig = prelusig(tv0(), Type::new_tuple(vec![tv0(); 3])); + let sig = Signature::new(tv0(), Type::new_tuple(vec![tv0(); 3])); let mut fb = mb.define_function( "triple", PolyFuncType::new([TypeBound::Copyable.into()], sig), @@ -428,7 +414,7 @@ mod test { }; let mn = { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; - let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; + let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?; let [elem] = fb.input_wires_arr(); let [res1] = fb .call(tr.handle(), &[usize_t().into()], [elem])? @@ -493,37 +479,30 @@ mod test { let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", - prelusig( + Signature::new( array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) .unwrap(), vec![usize_t(); 2], - ) - .with_extension_delta(collections::array::EXTENSION_ID), + ), ) .unwrap(); let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); let pf1t = PolyFuncType::new( [TypeParam::max_nat()], - prelusig(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()), ); let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); let pf2t = PolyFuncType::new( [TypeParam::max_nat(), TypeBound::Copyable.into()], - prelusig(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)), ); let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); let mono_func = { let mut fb = pf2 - .define_function( - "get_usz", - prelusig(vec![], usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), - ) + .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() @@ -706,8 +685,6 @@ mod test { #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] - #[case::extensionset(vec![ExtensionSet::from_iter([PRELUDE_ID,int_types::EXTENSION_ID]).into()], - "$foo$$es(arithmetic.int.types,prelude)")] // alphabetic ordering of extension names #[should_panic] #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], "$foo$$v(1)")] diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 6e9df7f1a..3c15ca6f2 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -577,7 +577,7 @@ pub(crate) mod test { use hugr_core::builder::{ endo_sig, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; - use hugr_core::extension::{prelude::usize_t, ExtensionSet}; + use hugr_core::extension::prelude::usize_t; use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; @@ -612,11 +612,7 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder_exts( - vec![usize_t()].into(), - 1, - ExtensionSet::new(), - )?, + cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index 180e9d6fc..a2219d14f 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -54,8 +54,7 @@ mod test { #[test] fn ensures_no_nonlocal_edges() { let hugr = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let [out_w] = builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) @@ -69,12 +68,11 @@ mod test { #[test] fn find_nonlocal_edges() { let (hugr, edge) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let ([out_w], edge) = { let mut dfg_builder = builder - .dfg_builder(Signature::new(type_row![], bool_t()).with_prelude(), []) + .dfg_builder(Signature::new(type_row![], bool_t()), []) .unwrap(); let noop = dfg_builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 25249f5ae..05d0168c8 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -595,17 +595,17 @@ mod test { FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::{ - array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + array::{array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; @@ -682,12 +682,7 @@ mod test { let mut dfb = new(Signature::new( vec![array_type(64, elem_ty.clone()), i64_t()], elem_ty.clone(), - ) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID, - conversions::EXTENSION_ID, - ]))) + )) .unwrap(); let [val, idx] = dfb.input_wires_arr(); let [idx] = dfb @@ -745,8 +740,7 @@ mod test { let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); - let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) - .with_extension_delta(ext.name.clone()); + let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()); let mut fb = mb.define_function("main", sig).unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let [indices] = fb @@ -1062,7 +1056,7 @@ mod test { repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; repl.run(&mut h).unwrap(); - h.validate_no_extensions().unwrap(); + h.validate().unwrap(); } #[test] diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index b6e6e6780..573188340 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -3,7 +3,6 @@ use hugr_core::builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{option_type, UnwrapBuilder}; -use hugr_core::extension::ExtensionSet; use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; @@ -13,8 +12,8 @@ use hugr_core::std_extensions::collections::array::{ array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayValue, }; use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::type_row; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; -use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; use super::{ @@ -67,10 +66,6 @@ pub fn array_const( Ok(Some(ArrayValue::new(elem_t, vals).into())) } -fn runtime_reqs(h: &Hugr) -> ExtensionSet { - h.signature(h.root()).unwrap().runtime_reqs.clone() -} - /// Handler for copying/discarding arrays if their elements have become linear. /// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. /// @@ -97,7 +92,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); + let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n); let in_type = array_type(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); @@ -131,8 +126,7 @@ pub fn linearize_array( .unwrap(); dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; - let repeats = - vec![ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; + let repeats = vec![ArrayRepeat::new(option_ty.clone(), *n); num_new]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -212,7 +206,6 @@ pub fn linearize_array( .chain(vec![option_array; num_new]) .collect(), *n, - runtime_reqs(©_elem), ); let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap()); @@ -240,13 +233,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([val]).unwrap() }; - let unwrap_scan = ArrayScan::new( - option_ty.clone(), - ty.clone(), - vec![], - *n, - runtime_reqs(&unwrap_elem), - ); + let unwrap_scan = ArrayScan::new(option_ty.clone(), ty.clone(), vec![], *n); let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap()); let out_arrays = std::iter::once(out_array1) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 2788a2379..81324dbee 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -622,10 +622,7 @@ mod test { NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), ); - let sig3 = Some( - Signature::new(lin_t.clone(), vec![lin_t.clone(); 3]) - .with_extension_delta(ext.name().clone()), - ); + let sig3 = Some(Signature::new(lin_t.clone(), vec![lin_t.clone(); 3])); assert_eq!( bad_copy, Err(LinearizeError::WrongSignature { @@ -782,11 +779,7 @@ mod test { let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); let discard_fn = { let mut fb = dfb - .define_function( - "drop", - Signature::new(lin_t.clone(), type_row![]) - .with_extension_delta(e.name().clone()), - ) + .define_function("drop", Signature::new(lin_t.clone(), type_row![])) .unwrap(); let ins = fb.input_wires(); fb.add_dataflow_op( diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index b2782e8d9..1c9be1c75 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -278,9 +278,7 @@ mod test { /// These can be removed entirely. #[fixture] fn unused_pack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![]).with_prelude()) - .unwrap(); + let mut h = DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![])).unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -295,8 +293,7 @@ mod test { /// These can be removed entirely. #[fixture] fn simple_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -315,8 +312,7 @@ mod test { /// we just remove everything. #[fixture] fn ordered_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -338,13 +334,10 @@ mod test { /// These can be removed entirely. #[fixture] fn multi_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![bool_t(), bool_t(), bool_t(), bool_t()], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![bool_t(), bool_t(), bool_t(), bool_t()], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); @@ -369,17 +362,14 @@ mod test { /// The unpack operation can be removed, but the pack operation cannot. #[fixture] fn partial_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![ - bool_t(), - bool_t(), - Type::new_tuple(vec![bool_t(), bool_t()]), - ], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![ + bool_t(), + bool_t(), + Type::new_tuple(vec![bool_t(), bool_t()]), + ], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 6c3e61fb4..90d338faf 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -16,11 +16,8 @@ use hugr_core::HugrView; pub enum ValidationLevel { /// Do no verification. None, - /// Validate using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - WithoutExtensions, /// Validate using [HugrView::validate]. - WithExtensions, + Validate, } #[derive(Error, Debug, PartialEq)] @@ -44,8 +41,7 @@ pub enum ValidatePassError { impl Default for ValidationLevel { fn default() -> Self { if cfg!(test) { - // Many tests fail when run with Self::WithExtensions - Self::WithoutExtensions + Self::Validate } else { Self::None } @@ -86,8 +82,7 @@ impl ValidationLevel { { match self { ValidationLevel::None => Ok(()), - ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(), - ValidationLevel::WithExtensions => hugr.validate(), + ValidationLevel::Validate => hugr.validate(), } .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) } diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 95e59754e..3bb377ed5 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -86,9 +86,7 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): def deserialize(self, extension: ext.Extension) -> ext.OpDef: signature = ext.OpDefSig( - self.signature.deserialize().with_runtime_reqs([extension.name]) - if self.signature - else None, + self.signature.deserialize() if self.signature else None, self.binary, ) @@ -106,7 +104,6 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: class Extension(ConfiguredBaseModel): version: SemanticVersion name: ExtensionId - runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] operations: dict[str, OpDef] @@ -118,7 +115,6 @@ def deserialize(self) -> ext.Extension: e = ext.Extension( version=self.version, # type: ignore[arg-type] name=self.name, - runtime_reqs=self.runtime_reqs, ) for k, t in self.types.items(): diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index 48b4e6b87..28a1daf5e 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -206,7 +206,6 @@ class DataflowBlock(BaseOp): inputs: TypeRow = Field(default_factory=list) other_outputs: TypeRow = Field(default_factory=list) sum_rows: list[TypeRow] - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: num_cases = len(out_types) @@ -384,13 +383,11 @@ class DFG(DataflowOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.DFG: sig = self.signature.deserialize() - return ops.DFG(sig.input, sig.output, sig.runtime_reqs) + return ops.DFG(sig.input, sig.output) # ------------------------------------------------ @@ -407,8 +404,6 @@ class Conditional(DataflowOp): sum_rows: list[TypeRow] = Field( description="The possible rows of the Sum input", default_factory=list ) - # Extensions used to produce the outputs - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # First port is a predicate, i.e. a sum of tuple types. We need to unpack @@ -442,9 +437,7 @@ class Case(BaseOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = stys.FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = stys.FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.Case: sig = self.signature.deserialize() @@ -455,11 +448,12 @@ class TailLoop(DataflowOp): """Tail-controlled loop.""" op: Literal["TailLoop"] = "TailLoop" - just_inputs: TypeRow = Field(default_factory=list) # Types that are only input - just_outputs: TypeRow = Field(default_factory=list) # Types that are only output + # Types that are only input + just_inputs: TypeRow = Field(default_factory=list) + # Types that are only output + just_outputs: TypeRow = Field(default_factory=list) # Types that are appended to both input and output: rest: TypeRow = Field(default_factory=list) - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types == out_types @@ -472,7 +466,6 @@ def deserialize(self) -> ops.TailLoop: just_inputs=deser_it(self.just_inputs), _just_outputs=deser_it(self.just_outputs), rest=deser_it(self.rest), - extension_delta=self.extension_delta, ) @@ -484,7 +477,8 @@ class CFG(DataflowOp): def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) + input=list(inputs), + output=list(outputs), ) def deserialize(self) -> ops.CFG: diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index 4a0a0e75b..c00a73375 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -110,23 +110,11 @@ def deserialize(self) -> tys.TupleParam: return tys.TupleParam(params=deser_it(self.params)) -class ExtensionsParam(BaseTypeParam): - tp: Literal["Extensions"] = "Extensions" - - def deserialize(self) -> tys.ExtensionsParam: - return tys.ExtensionsParam() - - class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam - | BoundedNatParam - | StringParam - | ListParam - | TupleParam - | ExtensionsParam, + TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") @@ -178,14 +166,6 @@ def deserialize(self) -> tys.SequenceArg: return tys.SequenceArg(elems=deser_it(self.elems)) -class ExtensionsArg(BaseTypeArg): - tya: Literal["Extensions"] = "Extensions" - es: ExtensionSet - - def deserialize(self) -> tys.ExtensionsArg: - return tys.ExtensionsArg(extensions=self.es) - - class VariableArg(BaseTypeArg): tya: Literal["Variable"] = "Variable" idx: int @@ -199,12 +179,7 @@ class TypeArg(RootModel): """A type argument.""" root: Annotated[ - TypeTypeArg - | BoundedNatArg - | StringArg - | SequenceArg - | ExtensionsArg - | VariableArg, + TypeTypeArg | BoundedNatArg | StringArg | SequenceArg | VariableArg, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tya") @@ -307,18 +282,15 @@ class FunctionType(BaseType): input: TypeRow # Value inputs of the function. output: TypeRow # Value outputs of the function. - # The extension requirements which are added by the operation - runtime_reqs: ExtensionSet = Field(default_factory=ExtensionSet) @classmethod def empty(cls) -> FunctionType: - return FunctionType(input=[], output=[], runtime_reqs=[]) + return FunctionType(input=[], output=[]) def deserialize(self) -> tys.FunctionType: return tys.FunctionType( input=deser_it(self.input), output=deser_it(self.output), - runtime_reqs=self.runtime_reqs, ) model_config = ConfigDict( diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 7bd02f982..fd59da0fc 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -235,13 +235,6 @@ def instantiate( concrete_signature: Concrete function type of the operation, only required if the operation is polymorphic. """ - # Add the extension where the operation is defined as a runtime requirement. - # We don't store this in the json definition as it is redundant information. - if concrete_signature is not None: - concrete_signature = concrete_signature.with_runtime_reqs( - [self.get_extension().name] - ) - return ops.ExtOp(self, concrete_signature, list(args or [])) @@ -256,8 +249,6 @@ class Extension: name: ExtensionId #: The version of the extension. version: Version - #: Extensions required by this extension at runtime, identified by name. - runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) #: Operation definitions in the extension. @@ -273,7 +264,6 @@ def _to_serial(self) -> ext_s.Extension: return ext_s.Extension( name=self.name, version=self.version, # type: ignore[arg-type] - runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -303,12 +293,6 @@ def add_op_def(self, op_def: OpDef) -> OpDef: Returns: The added operation definition, now associated with the extension. """ - if op_def.signature.poly_func is not None: - # Ensure the op def signature has the extension as a requirement - op_def.signature.poly_func = op_def.signature.poly_func.with_runtime_reqs( - [self.name] - ) - op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 1555dab4d..b6030b6a0 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -456,7 +456,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=self.types, output=[tys.Tuple(*self.types)], - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -499,7 +498,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=[tys.Tuple(*self.types)], output=self.types, - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -632,7 +630,6 @@ class DFG(DfParentOp, DataflowOp): #: Inputs types of the operation. inputs: tys.TypeRow _outputs: tys.TypeRow | None = field(default=None, repr=False) - _extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def outputs(self) -> tys.TypeRow: @@ -650,7 +647,7 @@ def signature(self) -> tys.FunctionType: Raises: IncompleteOp: If the outputs have not been set. """ - return tys.FunctionType(self.inputs, self.outputs, self._extension_delta) + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int: @@ -729,7 +726,6 @@ class DataflowBlock(DfParentOp): inputs: tys.TypeRow _sum: tys.Sum | None = None _other_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list) @property def sum_ty(self) -> tys.Sum: @@ -762,7 +758,6 @@ def _to_serial(self, parent: Node) -> sops.DataflowBlock: inputs=ser_it(self.inputs), sum_rows=list(map(ser_it, self.sum_ty.variant_rows)), other_outputs=ser_it(self.other_outputs), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -993,7 +988,6 @@ class TailLoop(DfParentOp, DataflowOp): #: Types that are appended to both inputs and outputs of the graph. rest: tys.TypeRow _just_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def just_outputs(self) -> tys.TypeRow: @@ -1014,7 +1008,6 @@ def _to_serial(self, parent: Node) -> sops.TailLoop: just_inputs=ser_it(self.just_inputs), just_outputs=ser_it(self.just_outputs), rest=ser_it(self.rest), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -1334,13 +1327,11 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def outer_signature(self) -> tys.FunctionType: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def _set_in_types(self, types: tys.TypeRow) -> None: diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 1d310df25..3c2fb983c 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -1,10 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, "operations": { "bytecast_float64_to_int64": { @@ -36,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -71,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -159,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -223,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -256,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -300,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -344,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -375,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -436,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -497,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 8da056772..60180ec84 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "fabs": { @@ -30,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -67,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -97,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -134,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -169,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -199,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -234,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -269,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -304,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -339,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -450,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -485,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -515,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -552,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -582,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -619,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -649,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 0c563c474..33db43f5b 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 5b1a81250..e8e6fdca8 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "iabs": { @@ -53,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -122,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -191,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -277,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -363,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -432,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -501,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -611,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -721,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -806,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -891,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -949,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1007,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1065,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1123,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1181,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1239,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1297,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1355,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1413,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1482,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1551,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1620,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1689,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1775,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1861,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1930,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1999,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2068,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2142,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2216,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2274,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2327,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2380,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2449,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2518,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2587,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2656,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2709,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2778,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2847,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2916,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2969,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3026,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3083,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3152,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 36df125a6..0b77d2e55 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 375e13c72..fba222793 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -60,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -126,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -166,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -182,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -213,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -243,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -299,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -340,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -465,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -578,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 8a60d3544..de9736e4e 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -70,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -151,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -207,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -274,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -332,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index 53b8e61c7..cde35e063 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -68,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -108,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ff29d2c21..45cd7f606 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, "operations": { "And": { @@ -29,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -60,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -86,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -117,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -148,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index ec392b155..7cf1d02c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -73,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -146,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -188,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -236,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -259,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -307,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -329,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 614b6aecf..d701fff53 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -56,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -139,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 4c1d0cdeb..27432a3d5 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -93,7 +93,7 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.width)] * 2 - return tys.FunctionType.endo(row, runtime_reqs=[INT_OPS_EXTENSION.name]) + return tys.FunctionType.endo(row) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index fbaadf7d3..8411f19bf 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -188,21 +188,6 @@ def to_model(self) -> model.Term: return model.Apply("core.tuple", [item_types]) -@dataclass(frozen=True) -class ExtensionsParam(TypeParam): - """An extension set parameter.""" - - def _to_serial(self) -> stys.ExtensionsParam: - return stys.ExtensionsParam() - - def __str__(self) -> str: - return "Extensions" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set_type") - - # ------------------------------------------ # --------------- TypeArg ------------------ # ------------------------------------------ @@ -280,23 +265,6 @@ def to_model(self) -> model.Term: return model.List([elem.to_model() for elem in self.elems]) -@dataclass(frozen=True) -class ExtensionsArg(TypeArg): - """Type argument for an :class:`ExtensionsParam`.""" - - extensions: ExtensionSet - - def _to_serial(self) -> stys.ExtensionsArg: - return stys.ExtensionsArg(es=self.extensions) - - def __str__(self) -> str: - return f"Extensions({comma_sep_str(self.extensions)})" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set") - - @dataclass(frozen=True) class VariableArg(TypeArg): """A type argument variable.""" @@ -518,7 +486,6 @@ class FunctionType(Type): input: TypeRow output: TypeRow - runtime_reqs: ExtensionSet = field(default_factory=ExtensionSet) def type_bound(self) -> TypeBound: return TypeBound.Copyable @@ -527,7 +494,6 @@ def _to_serial(self) -> stys.FunctionType: return stys.FunctionType( input=ser_it(self.input), output=ser_it(self.output), - runtime_reqs=self.runtime_reqs, ) @classmethod @@ -541,16 +507,14 @@ def empty(cls) -> FunctionType: return cls(input=[], output=[]) @classmethod - def endo( - cls, tys: TypeRow, runtime_reqs: ExtensionSet | None = None - ) -> FunctionType: + def endo(cls, tys: TypeRow) -> FunctionType: """Function type with the same input and output types. Example: >>> FunctionType.endo([Qubit]) FunctionType([Qubit], [Qubit]) """ - return cls(input=tys, output=tys, runtime_reqs=runtime_reqs or ExtensionSet()) + return cls(input=tys, output=tys) def flip(self) -> FunctionType: """Return a new function type with input and output types swapped. @@ -569,17 +533,8 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType: return FunctionType( input=[ty.resolve(registry) for ty in self.input], output=[ty.resolve(registry) for ty in self.output], - runtime_reqs=self.runtime_reqs, ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> FunctionType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - exts = set(self.runtime_reqs) - exts = exts.union(runtime_reqs) - return FunctionType(self.input, self.output, [*exts]) - def __str__(self) -> str: return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}" @@ -614,15 +569,6 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType: body=self.body.resolve(registry), ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> PolyFuncType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - return PolyFuncType( - params=self.params, - body=self.body.with_runtime_reqs(runtime_reqs), - ) - def __str__(self) -> str: return f"∀ {comma_sep_str(self.params)}. {self.body!s}" diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index cf595319a..7f1ea28bf 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -25,7 +25,6 @@ { "version": "0.1.0", "name": "ext", - "runtime_reqs": [], "types": { "foo": { "extension": "ext", @@ -64,8 +63,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "lower_funcs": [] @@ -99,7 +97,6 @@ def test_extension(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={"foo": type_def}, values={}, operations={"New": op_def}, @@ -121,7 +118,6 @@ def test_package(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={}, values={}, operations={}, diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 3018bf863..48f57de7a 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -37,7 +37,7 @@ def type_args(self) -> list[tys.TypeArg]: return [tys.StringArg(self.tag)] def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType.endo([], runtime_reqs=[STRINGLY_EXT.name]) + return tys.FunctionType.endo([]) @classmethod def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp": diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index e2c6d7d51..33bf55561 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -14,8 +14,6 @@ BoundedNatArg, BoundedNatParam, Either, - ExtensionsArg, - ExtensionsParam, ExtType, FunctionType, ListParam, @@ -95,7 +93,6 @@ def test_tys_sum_str(ty: Type, string: str, repr_str: str): "(Any, Nat(3))", ), (ListParam(StringParam()), "[String]"), - (ExtensionsParam(), "Extensions"), ], ) def test_params_str(param: TypeParam, string: str): @@ -113,7 +110,6 @@ def test_params_str(param: TypeParam, string: str): "(Type(Qubit), 3)", ), (VariableArg(2, StringParam()), "$2"), - (ExtensionsArg(["A", "B"]), "Extensions(A, B)"), ], ) def test_args_str(arg: TypeArg, string: str): diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index d96ac8fe5..c2f6e0a22 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -24,7 +24,6 @@ path = "src/lib.rs" [features] default = ["zstd"] -extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] diff --git a/hugr/README.md b/hugr/README.md index b54d4f62d..83a2cc501 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr -=============== +# hugr [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr) @@ -29,10 +28,6 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 0ece1eefb..2b7676439 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -7,9 +7,8 @@ use hugr::builder::{ HugrBuilder, ModuleBuilder, }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; -use hugr::extension::ExtensionSet; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}; +use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; use hugr::types::Signature; use hugr::{type_row, CircuitUnit, Extension, Hugr, Node}; use lazy_static::lazy_static; @@ -97,11 +96,7 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { let h_gate = QUANTUM_EXT.instantiate_extension_op("H", []).unwrap(); let cx_gate = QUANTUM_EXT.instantiate_extension_op("CX", []).unwrap(); let rz = QUANTUM_EXT.instantiate_extension_op("Rz", []).unwrap(); - let signature = - Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(ExtensionSet::from_iter([ - QUANTUM_EXT.name().clone(), - float_types::EXTENSION_ID, - ])); + let signature = Signature::new_endo(vec![qb_t(), qb_t()]); let mut module_builder = ModuleBuilder::new(); let mut f_build = module_builder.define_function("main", signature).unwrap(); diff --git a/justfile b/justfile index 7b8075f94..d7e3f81f2 100644 --- a/justfile +++ b/justfile @@ -23,7 +23,7 @@ test-rust: HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ - --features 'hugr/extension_inference hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' + --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' # Run all python tests. test-python: uv run maturin develop --uv diff --git a/specification/hugr.md b/specification/hugr.md index 6204e0e4f..3bd22b8ef 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -891,71 +891,6 @@ See [Declarative Format](#declarative-format) for more examples. Note that since a row variable does not have kind Type, it cannot be used as the type of an edge. -### Extension Tracking - -The type of `Function` includes a set of [extensions](#extension-system) which are required to execute the graph. -Similarly, every dataflow node in the HUGR has a set of extensions required to execute the node (computed from its operation), -also known as the "delta". The delta of any node must be a subset of its parent's delta, -except for FuncDefn's: -* the delta of any child of a FuncDefn must be a subset of the extensions in the FuncDefn's *type* -* the FuncDefn itself has no delta (trivially a subset of any parent): this reflects that the extensions -are not needed to *know* the FuncDefn, only to *execute* it -(by a Call node, whose delta is taken from the called FuncDefn's *type*). - -Keeping track of the extension requirements like this allows extension designers -and third-party tooling to control how/where a module is run. - -Concretely, if a plugin writer adds an extension -*X*, then some function from -a plugin needs to provide a mechanism to convert the -*X* to some other extension -requirement before it can interface with other plugins which don't know -about *X*. - -A runtime could have access to means of -running different extensions. By the same mechanism, the runtime can reason -about where to run different parts of the graph by inspecting their -extension requirements. - -Special operations **lift** and **liftGraph** can add extension requirements: -* `lift>` is a node with input and output rows `R` and extension-delta `{E}` -* `liftGraph, E: ExtensionSet, O: List>` has one input -$ \vec{I}^{\underrightarrow{\;E\;}}\vec{O} $ and one output $ \vec{I}^{\underrightarrow{\;E \cup N\;}}\vec{O}$. -That is, given a graph, it adds extensions $N$ to the requirements of the graph. - -The latter is useful for higher-order operations such as conditionally selecting -one function or another, where the output must have a consistent type (including -the extension-requirements of the function). - -### Rewriting Extension Requirements - -Extension requirements help denote different runtime capabilities. -For example, a quantum computer may not be able to handle arithmetic -while running a circuit, so its use is tracked in the function type so that -rewrites can be performed which remove the arithmetic. - -Simple circuits may look something like: - -```haskell -Function[Quantum](Array(5, Q), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -A circuit built using a higher-order extension to manage control flow -could then look like: - -```haskell -Function[Quantum, HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -So the compiler would need to perform some graph transformation pass to turn the -graph-based control flow into a CFG node that a quantum computer could -run, which removes the `HigherOrder` extension requirement. - -```haskell -precompute :: Function[](Function[Quantum,HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))), - Function[Quantum](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit)))) -``` - ## Extension System ### Goals and constraints diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index ea08dff5b..02889a3f4 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1541,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1654,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1675,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1753,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1776,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 8b65bae94..558f64c57 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1541,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1654,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1675,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1753,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1776,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index 91b121da6..f534a3cbd 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1540,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1732,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1753,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1831,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1854,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index eae6a13a7..eb3fcff0f 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -528,7 +506,6 @@ "required": [ "version", "name", - "runtime_reqs", "types", "operations" ], @@ -581,42 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -742,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1540,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1732,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1753,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1831,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1854,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 1d310df25..3c2fb983c 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -1,10 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, "operations": { "bytecast_float64_to_int64": { @@ -36,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -71,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -159,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -223,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -256,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -300,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -344,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -375,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -436,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -497,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 8da056772..60180ec84 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "fabs": { @@ -30,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -67,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -97,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -134,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -169,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -199,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -234,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -269,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -304,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -339,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -450,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -485,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -515,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -552,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -582,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -619,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -649,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 0c563c474..33db43f5b 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 5b1a81250..e8e6fdca8 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -1,9 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, "operations": { "iabs": { @@ -53,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -122,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -191,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -277,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -363,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -432,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -501,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -611,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -721,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -806,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -891,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -949,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1007,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1065,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1123,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1181,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1239,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1297,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1355,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1413,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1482,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1551,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1620,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1689,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1775,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1861,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1930,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1999,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2068,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2142,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2216,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2274,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2327,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2380,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2449,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2518,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2587,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2656,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2709,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2778,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2847,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2916,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2969,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3026,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3083,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3152,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 36df125a6..0b77d2e55 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 375e13c72..fba222793 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -60,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -126,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -166,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -182,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -213,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -243,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -299,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -340,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -465,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -578,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 8a60d3544..de9736e4e 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -70,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -151,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -207,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -274,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -332,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -413,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index 53b8e61c7..cde35e063 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -68,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -108,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ff29d2c21..45cd7f606 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, "operations": { "And": { @@ -29,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -60,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -86,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -117,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -148,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index ec392b155..7cf1d02c7 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -73,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -115,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -146,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -188,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -236,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -259,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -307,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -329,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 614b6aecf..d701fff53 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -56,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -139,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false From 614e43a1d4318a971b7735582a8563ada4e73b05 Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Tue, 6 May 2025 12:55:11 +0200 Subject: [PATCH 18/21] feat!: Accept outgoing ports in SimpleReplacement nu_out (#2151) Currently, SimpleReplacement stores its output boundary map nu_out by referring to nodes outside the deleted subgraph in the host graph. This forces the invalidation set of the replacements to include nodes past the output, forbidding simultaneous adjacent replacements. This PR fixes this by allowing the keys of `nu_out` (i.e. the ports on the output boundary of the subgraph) to be either incoming ports (as before), or outgoing ports (in which case this is equivalent to specifying the map on all incoming ports linked to the given outgoing ports). The latter is less general but covers most use cases and reduces the size of the invalidation set. Closes https://github.com/CQCL/hugr/issues/2098 BREAKING CHANGE: Generalised arguments to [`SimpleReplacement::new`] and [`SimpleReplacement::map_host_output`] to allow for outgoing ports. Previous type signatures remain valid, however, type inference may fail. --------- Co-authored-by: Mark Koch <48097969+mark-koch@users.noreply.github.com> --- hugr-core/src/hugr/patch/port_types.rs | 2 +- hugr-core/src/hugr/patch/simple_replace.rs | 325 ++++++++++++++++--- hugr-core/src/hugr/views/sibling_subgraph.rs | 9 +- 3 files changed, 291 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/hugr/patch/port_types.rs b/hugr-core/src/hugr/patch/port_types.rs index 3aeafa4ae..fd9d07cea 100644 --- a/hugr-core/src/hugr/patch/port_types.rs +++ b/hugr-core/src/hugr/patch/port_types.rs @@ -23,7 +23,7 @@ pub enum BoundaryPort { pub struct HostPort(pub N, pub P); /// A port in the replacement graph. -#[derive(Debug, Clone, Copy, From)] +#[derive(Debug, Clone, Copy, From, PartialEq, Eq, PartialOrd, Ord)] pub struct ReplacementPort

(pub Node, pub P); impl BoundaryPort { diff --git a/hugr-core/src/hugr/patch/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs index 245a3cdc0..8fc7c7617 100644 --- a/hugr-core/src/hugr/patch/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -7,9 +7,10 @@ use crate::hugr::hugrmut::InsertionResult; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView}; use crate::ops::{OpTag, OpTrait, OpType}; -use crate::{Hugr, IncomingPort, Node, OutgoingPort}; +use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port}; -use itertools::Itertools; +use derive_more::derive::From; +use itertools::{Either, Itertools}; use thiserror::Error; @@ -31,9 +32,121 @@ pub struct SimpleReplacement { /// to (target ports of edges from nodes not in `subgraph` to nodes in /// `subgraph`). nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, - /// A map from (target ports of edges from nodes in `subgraph` to nodes not - /// in `subgraph`) to (input ports of the Output node of `replacement`). - nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>, + /// The output boundary, mapping the edges of the output boundary of + /// `subgraph` to the incoming ports of the output node of `replacement`. + /// + /// ### Output boundary map + /// + /// The keys of the map, specifying the output boundary edges of `subgraph` + /// can be either: + /// - the outgoing ports as returned by [`SiblingSubgraph::outgoing_ports`], + /// or + /// - the incoming ports linked to the [`SiblingSubgraph::outgoing_ports`] + /// in the host HUGR. + /// + /// Specifying the output boundary map in terms of incoming ports is more + /// general, but will refer to nodes **outside of `subgraph`**. In most + /// cases, it is sufficient to specify the output boundary map using the + /// outgoing ports on the output boundary. + /// + /// ## Invalidation set + /// + /// If using outgoing ports for the output boundary, + /// [`SimpleReplacement::invalidation_set`] will be the set of nodes in the + /// subgraph, as returned by [`SiblingSubgraph::nodes`]. If using incoming + /// ports, the invalidation set will include the nodes of the HUGR past the + /// output boundary of `subgraph`. + nu_out: OutputBoundaryMap, +} + +/// A map from edges in a host HUGR to incoming ports. +/// +/// The edges in the map keys can be specified either as incoming ports, or +/// as outgoing ports (in which case all incoming ports linked to the same +/// outgoing port `o` map to the image of `o` under the map). +#[derive(Debug, Clone, From)] +pub enum OutputBoundaryMap { + /// Express map in terms of incoming ports past the output boundary of the + /// subgraph + ByIncoming(HashMap<(HostNode, IncomingPort), IncomingPort>), + /// Express map in terms of outgoing ports on the output boundary of the + /// subgraph + ByOutgoing(HashMap<(HostNode, OutgoingPort), IncomingPort>), +} + +impl OutputBoundaryMap { + /// Iterate over the boundary map. + /// + /// The keys' ports are either incoming or outgoing, depending on the + /// variant of `self`. + pub fn iter(&self) -> impl Iterator + '_ { + match self { + OutputBoundaryMap::ByIncoming(map) => Either::Left( + map.iter() + .map(|(&(node, in_port), &v)| ((node, in_port.into()), v)), + ), + OutputBoundaryMap::ByOutgoing(map) => Either::Right( + map.iter() + .map(|(&(node, out_port), &v)| ((node, out_port.into()), v)), + ), + } + .into_iter() + } + + /// Iterate over the boundary map with keys resolved as incoming ports. + /// + /// By providing the host HUGR `host`, all ports in the keys are resolved + /// to incoming ports. + pub fn iter_as_incoming<'a>( + &'a self, + host: &'a impl HugrView, + ) -> impl Iterator + 'a { + self.iter() + .flat_map(move |((rem_out_node, rem_out_port), rep_out_port)| { + as_incoming_ports(rem_out_node, rem_out_port, host).map( + move |(rem_out_node, rem_out_port)| { + ((rem_out_node, rem_out_port), rep_out_port) + }, + ) + }) + } + + /// Get the image of a port under the boundary map. + /// + /// The port `port` should be either incoming or outgoing, depending on the + /// variant of `self`, else `None` is returned. + pub fn get>(&self, node: N, port: P) -> Option { + match (self, port.into().as_directed()) { + (OutputBoundaryMap::ByIncoming(map), Either::Left(incoming)) => { + map.get(&(node, incoming)).copied() + } + (OutputBoundaryMap::ByOutgoing(map), Either::Right(outgoing)) => { + map.get(&(node, outgoing)).copied() + } + _ => None, + } + } + + /// Get the image of an incoming port under the boundary map. + /// + /// By providing the host HUGR `host`, all ports in the keys are resolved + /// to incoming ports. + pub fn get_as_incoming( + &self, + node: N, + incoming: IncomingPort, + host: &impl HugrView, + ) -> Option { + match self { + OutputBoundaryMap::ByIncoming(map) => map.get(&(node, incoming)).copied(), + OutputBoundaryMap::ByOutgoing(map) => { + let outgoing = host + .single_linked_output(node, incoming) + .expect("invalid data flow wire"); + map.get(&outgoing).copied() + } + } + } } impl SimpleReplacement { @@ -43,13 +156,13 @@ impl SimpleReplacement { subgraph: SiblingSubgraph, replacement: Hugr, nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, - nu_out: HashMap<(HostNode, IncomingPort), IncomingPort>, + nu_out: impl Into>, ) -> Self { Self { subgraph, replacement, nu_inp, - nu_out, + nu_out: nu_out.into(), } } @@ -151,7 +264,7 @@ impl SimpleReplacement { /// This panics if self.replacement is not a DFG. pub fn outgoing_boundary<'a>( &'a self, - _host: &'a impl HugrView, + host: &'a impl HugrView, ) -> impl Iterator< Item = ( ReplacementPort, @@ -163,12 +276,11 @@ impl SimpleReplacement { // For each q = self.nu_out[p] such that the predecessor of q is not an Input // port, there will be an edge from (the new copy of) the predecessor of // q to p. - self.nu_out - .iter() - .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| { + self.nu_out.iter_as_incoming(host).filter_map( + move |((rem_out_node, rem_out_port), rep_out_port)| { let (rep_out_pred_node, rep_out_pred_port) = self .replacement - .single_linked_output(replacement_output_node, *rep_out_port) + .single_linked_output(replacement_output_node, rep_out_port) .unwrap(); (self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input).then_some({ ( @@ -177,7 +289,8 @@ impl SimpleReplacement { HostPort(rem_out_node, rem_out_port), ) }) - }) + }, + ) } /// Get all edges that the replacement would add between ports in `host`. @@ -201,9 +314,8 @@ impl SimpleReplacement { // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the // predecessor of p0 to p1. - self.nu_out - .iter() - .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| { + self.nu_out.iter_as_incoming(host).filter_map( + move |((rem_out_node, rem_out_port), rep_out_port)| { self.nu_inp .get(&(replacement_output_node, rep_out_port)) .map(|&(rem_inp_node, rem_inp_port)| { @@ -215,22 +327,26 @@ impl SimpleReplacement { HostPort(rem_out_node, rem_out_port), ) }) - }) + }, + ) } /// Get the incoming port at the output node of `self.replacement` that /// corresponds to the given host output port. /// + /// If the output boundary map is given as outgoing (incoming) ports, the + /// `port` must be outgoing (incoming). Otherwise, `None` is returned. + /// /// This panics if self.replacement is not a DFG. - pub fn map_host_output( + pub fn map_host_output>( &self, - port: impl Into>, + port: impl Into>, ) -> Option> { let HostPort(node, port) = port.into(); let [_, rep_output] = self.get_replacement_io().expect("replacement is a DFG"); self.nu_out - .get(&(node, port)) - .map(|&rep_out_port| ReplacementPort(rep_output, rep_out_port)) + .get(node, port.into()) + .map(|rep_out_port| ReplacementPort(rep_output, rep_out_port)) } /// Get the incoming port in `subgraph` that corresponds to the given @@ -289,8 +405,13 @@ impl PatchVerification for SimpleReplacement { #[inline] fn invalidation_set(&self) -> impl Iterator { let subcirc = self.subgraph.nodes().iter().copied(); - let out_neighs = self.nu_out.keys().map(|key| key.0); - subcirc.chain(out_neighs) + let nu_out_nodes = match &self.nu_out { + OutputBoundaryMap::ByIncoming(map) => Some(map.keys().map(|key| key.0)), + OutputBoundaryMap::ByOutgoing(_) => None, + } + .into_iter() + .flatten(); + subcirc.chain(nu_out_nodes) } } @@ -380,6 +501,18 @@ pub enum SimpleReplacementError { InliningFailed(#[from] InlineDFGError), } +fn as_incoming_ports<'a, N: HugrNode + 'a>( + node: N, + port: Port, + hugr: &'a impl HugrView, +) -> impl Iterator + 'a { + match port.as_directed() { + Either::Left(incoming) => Either::Left(std::iter::once((node, incoming))), + Either::Right(outgoing) => Either::Right(hugr.linked_inputs(node, outgoing)), + } + .into_iter() +} + #[cfg(test)] pub(in crate::hugr::patch) mod test { use itertools::Itertools; @@ -393,7 +526,8 @@ pub(in crate::hugr::patch) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::hugr::patch::PatchVerification; + use crate::hugr::patch::simple_replace::OutputBoundaryMap; + use crate::hugr::patch::{PatchVerification, ReplacementPort}; use crate::hugr::views::{HugrView, SiblingSubgraph}; use crate::hugr::{Hugr, HugrMut, Patch}; use crate::ops::dataflow::DataflowOpTrait; @@ -404,7 +538,7 @@ pub(in crate::hugr::patch) mod test { use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; use crate::utils::test_quantum_extension::{cx_gate, h_gate}; - use crate::{IncomingPort, Node}; + use crate::{Direction, IncomingPort, Node, OutgoingPort, Port}; use super::SimpleReplacement; @@ -619,8 +753,19 @@ pub(in crate::hugr::patch) mod test { subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(), replacement: n, nu_inp, - nu_out, + nu_out: nu_out.into(), }; + + // Check output boundary + assert_eq!( + r.map_host_output((h_outp_node, h_port_2)).unwrap(), + ReplacementPort::from((r.get_replacement_io().unwrap()[1], n_port_2)) + ); + assert!(r + .map_host_output((h_outp_node, OutgoingPort::from(0))) + .is_none()); + + // Check invalidation set assert_eq!( HashSet::<_>::from_iter(r.invalidation_set()), HashSet::<_>::from_iter([h_node_cx, h_node_h0, h_node_h1, h_outp_node]), @@ -696,7 +841,7 @@ pub(in crate::hugr::patch) mod test { subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(), replacement: n, nu_inp, - nu_out, + nu_out: nu_out.into(), }; h.apply_patch(r).unwrap(); // Expect [DFG] to be replaced with: @@ -740,7 +885,7 @@ pub(in crate::hugr::patch) mod test { (link, link) }) .collect(); - let outputs = h + let outputs: HashMap<_, _> = h .node_inputs(output) .filter(|&p| { h.get_optype(output) @@ -798,7 +943,7 @@ pub(in crate::hugr::patch) mod test { .map(|p| repl.linked_inputs(repl_input, p).next().unwrap()); let inputs = embedded_inputs.zip(repl_inputs).collect(); - let outputs = repl + let outputs: HashMap<_, _> = repl .node_inputs(repl_output) .filter(|&p| repl.signature(repl_output).unwrap().port_type(p).is_some()) .map(|p| ((repl_output, p), p)) @@ -855,7 +1000,7 @@ pub(in crate::hugr::patch) mod test { .collect(); // A map from (target ports of edges from nodes in `removal` to nodes not in // `removal`) to (input ports of the Output node of `replacement`). - let nu_out = [ + let nu_out: HashMap<_, _> = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), ] @@ -866,7 +1011,7 @@ pub(in crate::hugr::patch) mod test { subgraph, replacement, nu_inp, - nu_out, + nu_out: nu_out.into(), }; rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); @@ -914,7 +1059,7 @@ pub(in crate::hugr::patch) mod test { .collect(); // A map from (target ports of edges from nodes in `removal` to nodes not in // `removal`) to (input ports of the Output node of `replacement`). - let nu_out = [ + let nu_out: HashMap<_, _> = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), ] @@ -925,7 +1070,7 @@ pub(in crate::hugr::patch) mod test { subgraph, replacement, nu_inp, - nu_out, + nu_out: nu_out.into(), }; rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); @@ -960,7 +1105,7 @@ pub(in crate::hugr::patch) mod test { .into_iter() .collect(); - let nu_out = vec![( + let nu_out: HashMap<_, _> = vec![( (h.get_io(h.root()).unwrap()[1], IncomingPort::from(1)), IncomingPort::from(0), )] @@ -977,6 +1122,110 @@ pub(in crate::hugr::patch) mod test { assert_eq!(h.num_nodes(), 6); } + #[rstest] + fn test_simple_replacement_with_empty_wires_using_outgoing_ports( + simple_hugr: Hugr, + dfg_hugr2: Hugr, + ) { + let mut h: Hugr = simple_hugr; + + // 1. Locate the CX in h + let h_node_cx: Node = h + .nodes() + .find(|node: &Node| *h.get_optype(*node) == cx_gate().into()) + .unwrap(); + let s = vec![h_node_cx]; + // 2. Construct a new DFG-rooted hugr for the replacement + let n: Hugr = dfg_hugr2; + // 3. Construct the input and output matchings + // 3.1. Locate the Output and its predecessor H in n + let [_n_node_input, n_node_output] = n.get_io(n.root()).unwrap(); + let n_node_h = n.input_neighbours(n_node_output).nth(1).unwrap(); + // 3.2. Locate the ports we need to specify as "glue" in n + let (n_port_0, n_port_1) = n + .node_inputs(n_node_output) + .take(2) + .collect_tuple() + .unwrap(); + let n_port_2 = n.node_inputs(n_node_h).next().unwrap(); + // 3.3. Locate the ports we need to specify as "glue" in h + let (h_port_0, h_port_1) = h.node_inputs(h_node_cx).take(2).collect_tuple().unwrap(); + // 3.4. Construct the maps + let mut nu_inp = HashMap::new(); + let mut nu_out = HashMap::new(); + nu_inp.insert((n_node_output, n_port_0), (h_node_cx, h_port_0)); + nu_inp.insert((n_node_h, n_port_2), (h_node_cx, h_port_1)); + nu_out.insert((h_node_cx, OutgoingPort::from(0)), n_port_0); + nu_out.insert((h_node_cx, OutgoingPort::from(1)), n_port_1); + // 4. Define the replacement + let r = SimpleReplacement { + subgraph: SiblingSubgraph::try_from_nodes(s, &h).unwrap(), + replacement: n, + nu_inp, + nu_out: nu_out.into(), + }; + h.apply_patch(r).unwrap(); + // Expect [DFG] to be replaced with: + // ┌───┐┌───┐ + // ┤ H ├┤ H ├ + // ├───┤├───┤┌───┐ + // ┤ H ├┤ H ├┤ H ├ + // └───┘└───┘└───┘ + assert_eq!(h.validate(), Ok(())); + } + + #[rstest] + fn test_output_boundary_map(dfg_hugr2: Hugr) { + let [inp, out] = dfg_hugr2.get_io(dfg_hugr2.root()).unwrap(); + let map = [ + ((inp, OutgoingPort::from(0)), IncomingPort::from(0)), + ((inp, OutgoingPort::from(1)), IncomingPort::from(1)), + ] + .into_iter() + .collect(); + let map = OutputBoundaryMap::ByOutgoing(map); + + // Basic check: map as just defined + assert_eq!( + map.get(inp, OutgoingPort::from(0)), + Some(IncomingPort::from(0)) + ); + assert_eq!( + map.get(inp, OutgoingPort::from(1)), + Some(IncomingPort::from(1)) + ); + + // Now check the map in terms of incoming ports + assert!(map.get(out, IncomingPort::from(0)).is_none()); + assert_eq!( + map.get_as_incoming(out, IncomingPort::from(0), &dfg_hugr2), + Some(IncomingPort::from(0)) + ); + + // Finally, check iterators + assert_eq!( + map.iter().collect::>(), + HashSet::from_iter([ + ( + (inp, Port::new(Direction::Outgoing, 0)), + IncomingPort::from(0) + ), + ( + (inp, Port::new(Direction::Outgoing, 1)), + IncomingPort::from(1) + ), + ]) + ); + let h_gate = dfg_hugr2.output_neighbours(inp).nth(1).unwrap(); + assert_eq!( + map.iter_as_incoming(&dfg_hugr2).collect::>(), + HashSet::from_iter([ + ((out, IncomingPort::from(0)), IncomingPort::from(0)), + ((h_gate, IncomingPort::from(0)), IncomingPort::from(1)), + ]) + ); + } + use crate::hugr::patch::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec}; @@ -1007,18 +1256,18 @@ pub(in crate::hugr::patch) mod test { .collect(); let mu_out = s .nu_out - .iter() + .iter_as_incoming(&h) .map(|((tgt, tgt_port), out_port)| { - let (src, src_port) = replacement.single_linked_output(out, *out_port).unwrap(); + let (src, src_port) = replacement.single_linked_output(out, out_port).unwrap(); if src == in_ { unimplemented!() }; NewEdgeSpec { src, - tgt: *tgt, + tgt, kind: NewEdgeKind::Value { src_pos: src_port, - tgt_pos: *tgt_port, + tgt_pos: tgt_port, }, } }) diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index b2eba044e..5402de65d 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -10,7 +10,7 @@ //! hierarchy. use std::cell::OnceCell; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::mem; use itertools::Itertools; @@ -433,14 +433,11 @@ impl SiblingSubgraph { }) }) .collect(); - let nu_out = self + let nu_out: HashMap<_, _> = self .outputs .iter() .zip_eq(rep_outputs) - .flat_map(|(&(self_source_n, self_source_p), (_, rep_target_p))| { - hugr.linked_inputs(self_source_n, self_source_p) - .map(move |self_target| (self_target, rep_target_p)) - }) + .map(|(&self_target, (_, rep_target_p))| (self_target, rep_target_p)) .collect(); Ok(SimpleReplacement::new( From 1bc91c197519f4a81f5fff1bf9df5905a1d1559e Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Tue, 6 May 2025 12:52:38 +0100 Subject: [PATCH 19/21] feat!: Improved array lowering (#2109) Feature branch for improved array lowering. * The old `array` type is now called `value_array` and lives in a separate extension * The default `array` is now a linear type with additional `clone` and `discard` operations * To avoid code duplication, array operations and values are now defined generically over a new `ArrayKind` trait that is instantiated with `Array` (the linear one) and `VArray` (the copyable one) to generate the `array` and `value_array` extensions * An `array` is now lowered to a fat pointer `{ptr, usize}` where `ptr` is a heap allocated pointer of size at least `n * sizeof(T)` and the `usize` is an offset pointing to the first element (i.e. the first element is at `ptr + offset * sizeof(T)`). The rational behind the additional offset is the `pop_left` operation which bumps the offset instead of mutating the pointer. This way, we can still free the original pointer when the array is discarded after a pop. Tracked PRs: * #2097 (closes #2066) * #2100 * #2101 * #2110 * #2112 (closes #2067) * #2119 * #2125 (closes #2124) BREAKING CHANGE: `std.collections.array` is now a linear type, even if the contained elements are copyable. Use the new `std.collections.value_array` for an array with the previous copyable semantics. BREAKING CHANGE: `std.collections.array.get` now also returns the passed array as an extra output BREAKING CHANGE: `ArrayOpBuilder` was moved from `hugr_core::std_extensions::collections::array::op_builder` to `hugr_core::std_extensions::collections::array`. --- Cargo.lock | 1 + hugr-core/src/ops/constant.rs | 26 + hugr-core/src/std_extensions.rs | 1 + hugr-core/src/std_extensions/collections.rs | 1 + .../src/std_extensions/collections/array.rs | 506 ++++++++---- .../collections/array/array_clone.rs | 234 ++++++ .../collections/array/array_conversion.rs | 318 ++++++++ .../collections/array/array_discard.rs | 220 +++++ .../collections/array/array_kind.rs | 119 +++ .../collections/array/array_op.rs | 217 +++-- .../collections/array/array_repeat.rs | 103 ++- .../collections/array/array_scan.rs | 119 ++- .../collections/array/array_value.rs | 159 ++++ .../collections/array/op_builder.rs | 256 ++++-- .../std_extensions/collections/value_array.rs | 349 ++++++++ hugr-core/src/utils.rs | 12 + hugr-llvm/src/emit/libc.rs | 41 +- hugr-llvm/src/extension/collections/array.rs | 770 +++++++++++------- ...ons__array__test__emit_all_ops@llvm14.snap | 358 ++++---- ...test__emit_all_ops@pre-mem2reg@llvm14.snap | 713 ++++++++-------- ..._array__test__emit_array_value@llvm14.snap | 14 +- ...__emit_array_value@pre-mem2reg@llvm14.snap | 26 +- ...tions__array__test__emit_clone@llvm14.snap | 45 + ...__test__emit_clone@pre-mem2reg@llvm14.snap | 60 ++ ...ections__array__test__emit_get@llvm14.snap | 45 +- ...ay__test__emit_get@pre-mem2reg@llvm14.snap | 58 +- hugr-passes/Cargo.toml | 1 + hugr-passes/README.md | 2 +- hugr-passes/src/lib.rs | 2 + hugr-passes/src/linearize_array.rs | 397 +++++++++ hugr-passes/src/monomorphize.rs | 28 +- hugr-passes/src/replace_types.rs | 75 +- hugr-passes/src/replace_types/handlers.rs | 128 ++- hugr-passes/src/replace_types/linearize.rs | 45 +- .../std/_json_defs/collections/array.json | 175 +++- .../_json_defs/collections/value_array.json | 737 +++++++++++++++++ hugr-py/src/hugr/std/collections/array.py | 2 +- .../src/hugr/std/collections/value_array.py | 83 ++ hugr-py/tests/test_tys.py | 23 + .../std_extensions/collections/array.json | 175 +++- .../collections/value_array.json | 737 +++++++++++++++++ 41 files changed, 6017 insertions(+), 1364 deletions(-) create mode 100644 hugr-core/src/std_extensions/collections/array/array_clone.rs create mode 100644 hugr-core/src/std_extensions/collections/array/array_conversion.rs create mode 100644 hugr-core/src/std_extensions/collections/array/array_discard.rs create mode 100644 hugr-core/src/std_extensions/collections/array/array_kind.rs create mode 100644 hugr-core/src/std_extensions/collections/array/array_value.rs create mode 100644 hugr-core/src/std_extensions/collections/value_array.rs create mode 100644 hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@llvm14.snap create mode 100644 hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@pre-mem2reg@llvm14.snap create mode 100644 hugr-passes/src/linearize_array.rs create mode 100644 hugr-py/src/hugr/std/_json_defs/collections/value_array.json create mode 100644 hugr-py/src/hugr/std/collections/value_array.py create mode 100644 specification/std_extensions/collections/value_array.json diff --git a/Cargo.lock b/Cargo.lock index f9ba0b100..5fe1d43ca 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1212,6 +1212,7 @@ dependencies = [ "proptest", "proptest-recurse", "rstest", + "strum", "thiserror 2.0.12", ] diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 18f3974d4..56e77186d 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -585,6 +585,7 @@ pub(crate) mod test { use crate::extension::PRELUDE; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::std_extensions::collections::array::{array_type, ArrayValue}; + use crate::std_extensions::collections::value_array::{value_array_type, VArrayValue}; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ @@ -754,6 +755,11 @@ pub(crate) mod test { ArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into() } + #[fixture] + fn const_value_array_bool() -> Value { + VArrayValue::new(bool_t(), [Value::true_val(), Value::false_val()]).into() + } + #[fixture] fn const_array_options() -> Value { let some_true = Value::some([Value::true_val()]); @@ -762,17 +768,35 @@ pub(crate) mod test { ArrayValue::new(elem_ty.into(), [some_true, none]).into() } + #[fixture] + fn const_value_array_options() -> Value { + let some_true = Value::some([Value::true_val()]); + let none = Value::none(vec![bool_t()]); + let elem_ty = SumType::new_option(vec![bool_t()]); + VArrayValue::new(elem_ty.into(), [some_true, none]).into() + } + #[rstest] #[case(Value::unit(), Type::UNIT, "const:seq:{}")] #[case(const_usize(), usize_t(), "const:custom:ConstUsize(")] #[case(serialized_float(17.4), float64_type(), "const:custom:json:Object")] #[case(const_tuple(), Type::new_tuple(vec![usize_t(), bool_t()]), "const:seq:{")] #[case(const_array_bool(), array_type(2, bool_t()), "const:custom:array")] + #[case( + const_value_array_bool(), + value_array_type(2, bool_t()), + "const:custom:value_array" + )] #[case( const_array_options(), array_type(2, SumType::new_option(vec![bool_t()]).into()), "const:custom:array" )] + #[case( + const_value_array_options(), + value_array_type(2, SumType::new_option(vec![bool_t()]).into()), + "const:custom:value_array" + )] fn const_type( #[case] const_value: Value, #[case] expected_type: Type, @@ -792,7 +816,9 @@ pub(crate) mod test { #[case(const_serialized_usize(), const_usize())] #[case(const_tuple_serialized(), const_tuple())] #[case(const_array_bool(), const_array_bool())] + #[case(const_value_array_bool(), const_value_array_bool())] #[case(const_array_options(), const_array_options())] + #[case(const_value_array_options(), const_value_array_options())] // Opaque constants don't get resolved into concrete types when running miri, // as the `typetag` machinery is not available. #[cfg_attr(miri, ignore)] diff --git a/hugr-core/src/std_extensions.rs b/hugr-core/src/std_extensions.rs index 7892e8fec..cf582f8a1 100644 --- a/hugr-core/src/std_extensions.rs +++ b/hugr-core/src/std_extensions.rs @@ -21,6 +21,7 @@ pub fn std_reg() -> ExtensionRegistry { collections::array::EXTENSION.to_owned(), collections::list::EXTENSION.to_owned(), collections::static_array::EXTENSION.to_owned(), + collections::value_array::EXTENSION.to_owned(), logic::EXTENSION.to_owned(), ptr::EXTENSION.to_owned(), ]); diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 13f5c007e..efd53c805 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -3,3 +3,4 @@ pub mod array; pub mod list; pub mod static_array; +pub mod value_array; diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 2e7ee5b75..177d84a1e 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -1,158 +1,94 @@ //! Fixed-length array type and operations extension. +mod array_clone; +mod array_conversion; +mod array_discard; +mod array_kind; mod array_op; mod array_repeat; mod array_scan; +mod array_value; pub mod op_builder; use std::sync::Arc; -use itertools::Itertools as _; +use delegate::delegate; use lazy_static::lazy_static; -use serde::{Deserialize, Serialize}; -use std::hash::{Hash, Hasher}; - -use crate::extension::resolution::{ - resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, - WeakExtensionRegistry, -}; -use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; + +use crate::builder::{BuildError, Dataflow}; +use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; +use crate::extension::simple_op::{HasConcrete, MakeOpDef, MakeRegisteredOp}; use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; -use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; -use crate::ops::{ExtensionOp, OpName, Value}; +use crate::ops::constant::{CustomConst, ValueName}; +use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; -use crate::types::{CustomCheckFailure, CustomType, Type, TypeBound, TypeName}; -use crate::Extension; +use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; +use crate::{Extension, Wire}; + +pub use array_clone::{GenericArrayClone, GenericArrayCloneDef, ARRAY_CLONE_OP_ID}; +pub use array_conversion::{Direction, GenericArrayConvert, GenericArrayConvertDef, FROM, INTO}; +pub use array_discard::{GenericArrayDiscard, GenericArrayDiscardDef, ARRAY_DISCARD_OP_ID}; +pub use array_kind::ArrayKind; +pub use array_op::{GenericArrayOp, GenericArrayOpDef}; +pub use array_repeat::{GenericArrayRepeat, GenericArrayRepeatDef, ARRAY_REPEAT_OP_ID}; +pub use array_scan::{GenericArrayScan, GenericArrayScanDef, ARRAY_SCAN_OP_ID}; +pub use array_value::GenericArrayValue; -pub use array_op::{ArrayOp, ArrayOpDef, ArrayOpDefIter}; -pub use array_repeat::{ArrayRepeat, ArrayRepeatDef, ARRAY_REPEAT_OP_ID}; -pub use array_scan::{ArrayScan, ArrayScanDef, ARRAY_SCAN_OP_ID}; -pub use op_builder::ArrayOpBuilder; +use op_builder::GenericArrayOpBuilder; /// Reported unique name of the array type. pub const ARRAY_TYPENAME: TypeName = TypeName::new_inline("array"); +/// Reported unique name of the array value. +pub const ARRAY_VALUENAME: TypeName = TypeName::new_inline("array"); /// Reported unique name of the extension pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.array"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -/// Statically sized array of values, all of the same type. -pub struct ArrayValue { - values: Vec, - typ: Type, -} - -impl ArrayValue { - /// Name of the constructor for creating constant arrays. - pub(crate) const CTR_NAME: &'static str = "collections.array.const"; - - /// Create a new [CustomConst] for an array of values of type `typ`. - /// That all values are of type `typ` is not checked here. - pub fn new(typ: Type, contents: impl IntoIterator) -> Self { - Self { - values: contents.into_iter().collect_vec(), - typ, - } - } - - /// Create a new [CustomConst] for an empty array of values of type `typ`. - pub fn new_empty(typ: Type) -> Self { - Self { - values: vec![], - typ, - } - } +/// A linear, fixed-length collection of values. +/// +/// Arrays are linear, even if their elements are copyable. +#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] +pub struct Array; - /// Returns the type of the `[ArrayValue]` as a `[CustomType]`.` - pub fn custom_type(&self) -> CustomType { - array_custom_type(self.values.len() as u64, self.typ.clone()) - } +impl ArrayKind for Array { + const EXTENSION_ID: ExtensionId = EXTENSION_ID; + const TYPE_NAME: TypeName = ARRAY_TYPENAME; + const VALUE_NAME: ValueName = ARRAY_VALUENAME; - /// Returns the type of values inside the `[ArrayValue]`. - pub fn get_element_type(&self) -> &Type { - &self.typ + fn extension() -> &'static Arc { + &EXTENSION } - /// Returns the values contained inside the `[ArrayValue]`. - pub fn get_contents(&self) -> &[Value] { - &self.values + fn type_def() -> &'static TypeDef { + EXTENSION.get_type(&ARRAY_TYPENAME).unwrap() } } -impl TryHash for ArrayValue { - fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { - maybe_hash_values(&self.values, &mut st) && { - self.typ.hash(&mut st); - true - } - } -} - -#[typetag::serde] -impl CustomConst for ArrayValue { - fn name(&self) -> ValueName { - ValueName::new_inline("array") - } - - fn get_type(&self) -> Type { - self.custom_type().into() - } - - fn validate(&self) -> Result<(), CustomCheckFailure> { - let typ = self.custom_type(); - - EXTENSION - .get_type(&ARRAY_TYPENAME) - .unwrap() - .check_custom(&typ) - .map_err(|_| { - CustomCheckFailure::Message(format!( - "Custom typ {typ} is not a valid instantiation of array." - )) - })?; - - // constant can only hold classic type. - let ty = match typ.args() { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] - if *n as usize == self.values.len() => - { - ty - } - _ => { - return Err(CustomCheckFailure::Message(format!( - "Invalid array type arguments: {:?}", - typ.args() - ))) - } - }; - - // check all values are instances of the element type - for v in &self.values { - if v.get_type() != *ty { - return Err(CustomCheckFailure::Message(format!( - "Array element {v:?} is not of expected type {ty}" - ))); - } - } - - Ok(()) - } - - fn equal_consts(&self, other: &dyn CustomConst) -> bool { - crate::ops::constant::downcast_equal_consts(self, other) - } - - fn update_extensions( - &mut self, - extensions: &WeakExtensionRegistry, - ) -> Result<(), ExtensionResolutionError> { - for val in &mut self.values { - resolve_value_extensions(val, extensions)?; - } - resolve_type_extensions(&mut self.typ, extensions) - } -} +/// Array operation definitions. +pub type ArrayOpDef = GenericArrayOpDef; +/// Array clone operation definition. +pub type ArrayCloneDef = GenericArrayCloneDef; +/// Array discard operation definition. +pub type ArrayDiscardDef = GenericArrayDiscardDef; +/// Array repeat operation definition. +pub type ArrayRepeatDef = GenericArrayRepeatDef; +/// Array scan operation definition. +pub type ArrayScanDef = GenericArrayScanDef; + +/// Array operations. +pub type ArrayOp = GenericArrayOp; +/// The array clone operation. +pub type ArrayClone = GenericArrayClone; +/// The array discard operation. +pub type ArrayDiscard = GenericArrayDiscard; +/// The array repeat operation. +pub type ArrayRepeat = GenericArrayRepeat; +/// The array scan operation. +pub type ArrayScan = GenericArrayScan; + +/// An array extension value. +pub type ArrayValue = GenericArrayValue; lazy_static! { /// Extension for array operations. @@ -162,22 +98,49 @@ lazy_static! { ARRAY_TYPENAME, vec![ TypeParam::max_nat(), TypeBound::Any.into()], "Fixed-length array".into(), - TypeDefBound::from_params(vec![1] ), + // Default array is linear, even if the elements are copyable + TypeDefBound::any(), extension_ref, ) .unwrap(); - array_op::ArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); - array_repeat::ArrayRepeatDef.add_to_extension(extension, extension_ref).unwrap(); - array_scan::ArrayScanDef.add_to_extension(extension, extension_ref).unwrap(); + ArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); + ArrayCloneDef::new().add_to_extension(extension, extension_ref).unwrap(); + ArrayDiscardDef::new().add_to_extension(extension, extension_ref).unwrap(); + ArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); + ArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); }) }; } +impl ArrayValue { + /// Name of the constructor for creating constant arrays. + pub(crate) const CTR_NAME: &'static str = "collections.array.const"; +} + +#[typetag::serde(name = "ArrayValue")] +impl CustomConst for ArrayValue { + delegate! { + to self { + fn name(&self) -> ValueName; + fn validate(&self) -> Result<(), CustomCheckFailure>; + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError>; + fn get_type(&self) -> Type; + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } +} + /// Gets the [TypeDef] for arrays. Note that instantiations are more easily /// created via [array_type] and [array_type_parametric] pub fn array_type_def() -> &'static TypeDef { - EXTENSION.get_type(&ARRAY_TYPENAME).unwrap() + Array::type_def() } /// Instantiate a new array type given a size argument and element type. @@ -185,7 +148,7 @@ pub fn array_type_def() -> &'static TypeDef { /// This method is equivalent to [`array_type_parametric`], but uses concrete /// arguments types to ensure no errors are possible. pub fn array_type(size: u64, element_ty: Type) -> Type { - array_custom_type(size, element_ty).into() + Array::ty(size, element_ty) } /// Instantiate a new array type given the size and element type parameters. @@ -195,28 +158,7 @@ pub fn array_type_parametric( size: impl Into, element_ty: impl Into, ) -> Result { - instantiate_array(array_type_def(), size, element_ty) -} - -fn array_custom_type(size: impl Into, element_ty: impl Into) -> CustomType { - instantiate_array_custom(array_type_def(), size, element_ty) - .expect("array parameters are valid") -} - -fn instantiate_array_custom( - array_def: &TypeDef, - size: impl Into, - element_ty: impl Into, -) -> Result { - array_def.instantiate(vec![size.into(), element_ty.into()]) -} - -fn instantiate_array( - array_def: &TypeDef, - size: impl Into, - element_ty: impl Into, -) -> Result { - instantiate_array_custom(array_def, size, element_ty).map(Into::into) + Array::ty_parametric(size, element_ty) } /// Name of the operation in the prelude for creating new arrays. @@ -224,18 +166,246 @@ pub const NEW_ARRAY_OP_ID: OpName = OpName::new_inline("new_array"); /// Initialize a new array op of element type `element_ty` of length `size` pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { - let op = array_op::ArrayOpDef::new_array.to_concrete(element_ty, size); + let op = ArrayOpDef::new_array.to_concrete(element_ty, size); op.to_extension_op().unwrap() } +/// Trait for building array operations in a dataflow graph. +pub trait ArrayOpBuilder: GenericArrayOpBuilder { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + self.add_new_generic_array::(elem_ty, values) + } + + /// Adds an array clone operation to the dataflow graph and return the wires + /// representing the originala and cloned array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wires representing the original and cloned array. + fn add_array_clone( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_clone::(elem_ty, size, input) + } + + /// Adds an array discard operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + fn add_array_discard( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(), BuildError> { + self.add_generic_array_discard::(elem_ty, size, input) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_get::(elem_ty, size, input, index) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + self.add_generic_array_set::(elem_ty, size, input, index, value) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_left::(elem_ty, size, input) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_right::(elem_ty, size, input) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + self.add_generic_array_discard_empty::(elem_ty, input) + } +} + +impl ArrayOpBuilder for D {} + #[cfg(test)] mod test { use crate::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}; - use crate::extension::prelude::{qb_t, usize_t, ConstUsize}; - use crate::ops::constant::CustomConst; - use crate::std_extensions::arithmetic::float_types::ConstF64; + use crate::extension::prelude::qb_t; - use super::{array_type, new_array_op, ArrayValue}; + use super::{array_type, new_array_op}; #[test] /// Test building a HUGR involving a new_array operation. @@ -251,20 +421,4 @@ mod test { b.finish_hugr_with_outputs(out.outputs()).unwrap(); } - - #[test] - fn test_array_value() { - let array_value = ArrayValue { - values: vec![ConstUsize::new(3).into()], - typ: usize_t(), - }; - - array_value.validate().unwrap(); - - let wrong_array_value = ArrayValue { - values: vec![ConstF64::new(1.2).into()], - typ: usize_t(), - }; - assert!(wrong_array_value.validate().is_err()); - } } diff --git a/hugr-core/src/std_extensions/collections/array/array_clone.rs b/hugr-core/src/std_extensions/collections/array/array_clone.rs new file mode 100644 index 000000000..c532da199 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_clone.rs @@ -0,0 +1,234 @@ +//! Definition of the array clone operation. + +use std::marker::PhantomData; +use std::str::FromStr; +use std::sync::{Arc, Weak}; + +use crate::extension::simple_op::{ + HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, +}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::ops::{ExtensionOp, NamedOp, OpName}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::Extension; + +use super::array_kind::ArrayKind; + +/// Name of the operation to clone an array +pub const ARRAY_CLONE_OP_ID: OpName = OpName::new_inline("clone"); + +/// Definition of the array clone operation. Generic over the concrete array implementation. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct GenericArrayCloneDef(PhantomData); + +impl GenericArrayCloneDef { + /// Creates a new clone operation definition. + pub fn new() -> Self { + GenericArrayCloneDef(PhantomData) + } +} + +impl Default for GenericArrayCloneDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayCloneDef { + fn name(&self) -> OpName { + ARRAY_CLONE_OP_ID + } +} + +impl FromStr for GenericArrayCloneDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ARRAY_CLONE_OP_ID { + Ok(GenericArrayCloneDef::new()) + } else { + Err(()) + } + } +} + +impl GenericArrayCloneDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let element_ty = Type::new_var_use(1, TypeBound::Copyable); + let array_ty = AK::instantiate_ty(array_def, size, element_ty) + .expect("Array type instantiation failed"); + PolyFuncTypeRV::new( + params, + FuncValueType::new(array_ty.clone(), vec![array_ty; 2]), + ) + .into() + } +} + +impl MakeOpDef for GenericArrayCloneDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + self.signature_from_def(AK::type_def()) + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } + + fn extension(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn description(&self) -> String { + "Clones an array with copyable elements".into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; + self.post_opdef(def); + Ok(()) + } +} + +/// Definition of the array clone op. Generic over the concrete array implementation. +#[derive(Clone, Debug, PartialEq)] +pub struct GenericArrayClone { + /// The element type of the array. + pub elem_ty: Type, + /// Size of the array. + pub size: u64, + _kind: PhantomData, +} + +impl GenericArrayClone { + /// Creates a new array clone op. + /// + /// # Errors + /// + /// If the provided element type is not copyable. + pub fn new(elem_ty: Type, size: u64) -> Result { + elem_ty + .copyable() + .then_some(GenericArrayClone { + elem_ty, + size, + _kind: PhantomData, + }) + .ok_or(SignatureError::InvalidTypeArgs.into()) + } +} + +impl NamedOp for GenericArrayClone { + fn name(&self) -> OpName { + ARRAY_CLONE_OP_ID + } +} + +impl MakeExtensionOp for GenericArrayClone { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = GenericArrayCloneDef::::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] + } +} + +impl MakeRegisteredOp for GenericArrayClone { + fn extension_id(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } +} + +impl HasDef for GenericArrayClone { + type Def = GenericArrayCloneDef; +} + +impl HasConcrete for GenericArrayCloneDef { + type Concrete = GenericArrayClone; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + Ok(GenericArrayClone::new(ty.clone(), *n).unwrap()) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::extension::prelude::bool_t; + use crate::std_extensions::collections::array::Array; + use crate::{ + extension::prelude::qb_t, + ops::{OpTrait, OpType}, + }; + + use super::*; + + #[rstest] + #[case(Array)] + fn test_clone_def(#[case] _kind: AK) { + let op = GenericArrayClone::::new(bool_t(), 2).unwrap(); + let optype: OpType = op.clone().into(); + let new_op: GenericArrayClone = optype.cast().unwrap(); + assert_eq!(new_op, op); + + assert_eq!( + GenericArrayClone::::new(qb_t(), 2), + Err(OpLoadError::InvalidArgs(SignatureError::InvalidTypeArgs)) + ); + } + + #[rstest] + #[case(Array)] + fn test_clone(#[case] _kind: AK) { + let size = 2; + let element_ty = bool_t(); + let op = GenericArrayClone::::new(element_ty.clone(), size).unwrap(); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + assert_eq!( + sig.io(), + ( + &vec![AK::ty(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone()); 2].into(), + ) + ); + } +} diff --git a/hugr-core/src/std_extensions/collections/array/array_conversion.rs b/hugr-core/src/std_extensions/collections/array/array_conversion.rs new file mode 100644 index 000000000..3cec4c3fe --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_conversion.rs @@ -0,0 +1,318 @@ +//! Operations for converting between the different array extensions + +use std::marker::PhantomData; +use std::str::FromStr; +use std::sync::{Arc, Weak}; + +use crate::extension::simple_op::{ + HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, +}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::ops::{ExtensionOp, NamedOp, OpName}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::Extension; + +use super::array_kind::ArrayKind; + +/// Array conversion direction. +/// +/// Either the current array type [INTO] the other one, or the current array type [FROM] the +/// other one. +pub type Direction = bool; + +/// Array conversion direction to turn the current array type [INTO] the other one. +pub const INTO: Direction = true; + +/// Array conversion direction to obtain the current array type [FROM] the other one. +pub const FROM: Direction = false; + +/// Definition of array conversion operations. +/// +/// Generic over the concrete array implementation of the extension containing the operation, as +/// well as over another array implementation that should be converted between. Also generic over +/// the conversion [Direction]. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct GenericArrayConvertDef( + PhantomData, + PhantomData, +); + +impl + GenericArrayConvertDef +{ + /// Creates a new array conversion definition. + pub fn new() -> Self { + GenericArrayConvertDef(PhantomData, PhantomData) + } +} + +impl Default + for GenericArrayConvertDef +{ + fn default() -> Self { + Self::new() + } +} + +impl NamedOp + for GenericArrayConvertDef +{ + fn name(&self) -> OpName { + match DIR { + INTO => format!("to_{}", OtherAK::TYPE_NAME).into(), + FROM => format!("from_{}", OtherAK::TYPE_NAME).into(), + } + } +} + +impl FromStr + for GenericArrayConvertDef +{ + type Err = (); + + fn from_str(s: &str) -> Result { + let def = GenericArrayConvertDef::new(); + if s == def.name() { + Ok(def) + } else { + Err(()) + } + } +} + +impl + GenericArrayConvertDef +{ + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let element_ty = Type::new_var_use(1, TypeBound::Any); + + let this_ty = AK::instantiate_ty(array_def, size.clone(), element_ty.clone()) + .expect("Array type instantiation failed"); + let other_ty = + OtherAK::ty_parametric(size, element_ty).expect("Array type instantiation failed"); + + let sig = match DIR { + INTO => FuncValueType::new(this_ty, other_ty), + FROM => FuncValueType::new(other_ty, this_ty), + }; + PolyFuncTypeRV::new(params, sig).into() + } +} + +impl MakeOpDef + for GenericArrayConvertDef +{ + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + self.signature_from_def(AK::type_def()) + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } + + fn extension(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn description(&self) -> String { + match DIR { + INTO => format!("Turns `{}` into `{}`", AK::TYPE_NAME, OtherAK::TYPE_NAME), + FROM => format!("Turns `{}` into `{}`", OtherAK::TYPE_NAME, AK::TYPE_NAME), + } + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; + self.post_opdef(def); + Ok(()) + } +} + +/// Definition of the array conversion op. +/// +/// Generic over the concrete array implementation of the extension containing the operation, as +/// well as over another array implementation that should be converted between. Also generic over +/// the conversion [Direction]. +#[derive(Clone, Debug, PartialEq)] +pub struct GenericArrayConvert { + /// The element type of the array. + pub elem_ty: Type, + /// Size of the array. + pub size: u64, + _kind: PhantomData, + _other_kind: PhantomData, +} + +impl + GenericArrayConvert +{ + /// Creates a new array conversion op. + pub fn new(elem_ty: Type, size: u64) -> Self { + GenericArrayConvert { + elem_ty, + size, + _kind: PhantomData, + _other_kind: PhantomData, + } + } +} + +impl NamedOp + for GenericArrayConvert +{ + fn name(&self) -> OpName { + match DIR { + INTO => format!("to_{}", OtherAK::TYPE_NAME).into(), + FROM => format!("from_{}", OtherAK::TYPE_NAME).into(), + } + } +} + +impl MakeExtensionOp + for GenericArrayConvert +{ + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = GenericArrayConvertDef::::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] + } +} + +impl MakeRegisteredOp + for GenericArrayConvert +{ + fn extension_id(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } +} + +impl HasDef + for GenericArrayConvert +{ + type Def = GenericArrayConvertDef; +} + +impl HasConcrete + for GenericArrayConvertDef +{ + type Concrete = GenericArrayConvert; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + Ok(GenericArrayConvert::new(ty.clone(), *n)) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::extension::prelude::bool_t; + use crate::ops::{OpTrait, OpType}; + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; + + use super::*; + + #[rstest] + #[case(ValueArray, Array)] + fn test_convert_from_def( + #[case] _kind: AK, + #[case] _other_kind: OtherAK, + ) { + let op = GenericArrayConvert::::new(bool_t(), 2); + let optype: OpType = op.clone().into(); + let new_op: GenericArrayConvert = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + + #[rstest] + #[case(ValueArray, Array)] + fn test_convert_into_def( + #[case] _kind: AK, + #[case] _other_kind: OtherAK, + ) { + let op = GenericArrayConvert::::new(bool_t(), 2); + let optype: OpType = op.clone().into(); + let new_op: GenericArrayConvert = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + + #[rstest] + #[case(ValueArray, Array)] + fn test_convert_from( + #[case] _kind: AK, + #[case] _other_kind: OtherAK, + ) { + let size = 2; + let element_ty = bool_t(); + let op = GenericArrayConvert::::new(element_ty.clone(), size); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + assert_eq!( + sig.io(), + ( + &vec![OtherAK::ty(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone())].into(), + ) + ); + } + + #[rstest] + #[case(ValueArray, Array)] + fn test_convert_into( + #[case] _kind: AK, + #[case] _other_kind: OtherAK, + ) { + let size = 2; + let element_ty = bool_t(); + let op = GenericArrayConvert::::new(element_ty.clone(), size); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + assert_eq!( + sig.io(), + ( + &vec![AK::ty(size, element_ty.clone())].into(), + &vec![OtherAK::ty(size, element_ty.clone())].into(), + ) + ); + } +} diff --git a/hugr-core/src/std_extensions/collections/array/array_discard.rs b/hugr-core/src/std_extensions/collections/array/array_discard.rs new file mode 100644 index 000000000..7cdfc9b4f --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_discard.rs @@ -0,0 +1,220 @@ +//! Definition of the array discard operation. + +use std::marker::PhantomData; +use std::str::FromStr; +use std::sync::{Arc, Weak}; + +use crate::extension::simple_op::{ + HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, +}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::ops::{ExtensionOp, NamedOp, OpName}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::{type_row, Extension}; + +use super::array_kind::ArrayKind; + +/// Name of the operation to discard an array +pub const ARRAY_DISCARD_OP_ID: OpName = OpName::new_inline("discard"); + +/// Definition of the array discard op. Generic over the concrete array implementation. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct GenericArrayDiscardDef(PhantomData); + +impl GenericArrayDiscardDef { + /// Creates a new array discard operation definition. + pub fn new() -> Self { + GenericArrayDiscardDef(PhantomData) + } +} + +impl Default for GenericArrayDiscardDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayDiscardDef { + fn name(&self) -> OpName { + ARRAY_DISCARD_OP_ID + } +} + +impl FromStr for GenericArrayDiscardDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ARRAY_DISCARD_OP_ID { + Ok(GenericArrayDiscardDef::new()) + } else { + Err(()) + } + } +} + +impl GenericArrayDiscardDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; + let size = TypeArg::new_var_use(0, TypeParam::max_nat()); + let element_ty = Type::new_var_use(1, TypeBound::Copyable); + let array_ty = AK::instantiate_ty(array_def, size, element_ty) + .expect("Array type instantiation failed"); + PolyFuncTypeRV::new(params, FuncValueType::new(array_ty, type_row![])).into() + } +} + +impl MakeOpDef for GenericArrayDiscardDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) + } + + fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { + self.signature_from_def(AK::type_def()) + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } + + fn extension(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn description(&self) -> String { + "Discards an array with copyable elements".into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; + self.post_opdef(def); + Ok(()) + } +} + +/// Definition of the array discard op. Generic over the concrete array implementation. +#[derive(Clone, Debug, PartialEq)] +pub struct GenericArrayDiscard { + /// The element type of the array. + pub elem_ty: Type, + /// Size of the array. + pub size: u64, + _kind: PhantomData, +} + +impl GenericArrayDiscard { + /// Creates a new array discard op. + pub fn new(elem_ty: Type, size: u64) -> Option { + elem_ty.copyable().then_some(GenericArrayDiscard { + elem_ty, + size, + _kind: PhantomData, + }) + } +} + +impl NamedOp for GenericArrayDiscard { + fn name(&self) -> OpName { + ARRAY_DISCARD_OP_ID + } +} + +impl MakeExtensionOp for GenericArrayDiscard { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = GenericArrayDiscardDef::::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.elem_ty.clone().into(), + ] + } +} + +impl MakeRegisteredOp for GenericArrayDiscard { + fn extension_id(&self) -> ExtensionId { + AK::EXTENSION_ID + } + + fn extension_ref(&self) -> Weak { + Arc::downgrade(AK::extension()) + } +} + +impl HasDef for GenericArrayDiscard { + type Def = GenericArrayDiscardDef; +} + +impl HasConcrete for GenericArrayDiscardDef { + type Concrete = GenericArrayDiscard; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] if ty.copyable() => { + Ok(GenericArrayDiscard::new(ty.clone(), *n).unwrap()) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + + use crate::extension::prelude::bool_t; + use crate::std_extensions::collections::array::Array; + use crate::{ + extension::prelude::qb_t, + ops::{OpTrait, OpType}, + }; + + use super::*; + + #[rstest] + #[case(Array)] + fn test_discard_def(#[case] _kind: AK) { + let op = GenericArrayDiscard::::new(bool_t(), 2).unwrap(); + let optype: OpType = op.clone().into(); + let new_op: GenericArrayDiscard = optype.cast().unwrap(); + assert_eq!(new_op, op); + + assert_eq!(GenericArrayDiscard::::new(qb_t(), 2), None); + } + + #[rstest] + #[case(Array)] + fn test_discard(#[case] _kind: AK) { + let size = 2; + let element_ty = bool_t(); + let op = GenericArrayDiscard::::new(element_ty.clone(), size).unwrap(); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + assert_eq!( + sig.io(), + ( + &vec![AK::ty(size, element_ty.clone())].into(), + &vec![].into(), + ) + ); + } +} diff --git a/hugr-core/src/std_extensions/collections/array/array_kind.rs b/hugr-core/src/std_extensions/collections/array/array_kind.rs new file mode 100644 index 000000000..88c729d3c --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_kind.rs @@ -0,0 +1,119 @@ +use std::sync::Arc; + +use crate::std_extensions::collections::array::op_builder::GenericArrayOpBuilder; +use crate::{ + builder::{BuildError, Dataflow}, + extension::{ExtensionId, SignatureError, TypeDef}, + ops::constant::ValueName, + types::{CustomType, Type, TypeArg, TypeName}, + Extension, Wire, +}; + +/// Trait capturing a concrete array implementation in an extension. +/// +/// Array operations are generically defined over this trait so the different +/// array extensions can share parts of their implementation. See for example +/// [`GenericArrayOpDef`] or [`GenericArrayValue`] +/// +/// Currently the available kinds of array are [`Array`] (the default one) and +/// [`ValueArray`]. +/// +/// [`GenericArrayOpDef`]: super::GenericArrayOpDef +/// [`GenericArrayValue`]: super::GenericArrayValue +/// [`Array`]: super::Array +/// [`ValueArray`]: crate::std_extensions::collections::value_array::ValueArray +pub trait ArrayKind: + Clone + + Copy + + std::fmt::Debug + + std::fmt::Display + + Eq + + PartialEq + + Default + + Send + + Sync + + 'static +{ + /// Identifier of the extension containing the array. + const EXTENSION_ID: ExtensionId; + + /// Name of the array type. + const TYPE_NAME: TypeName; + + /// Name of the array value. + const VALUE_NAME: ValueName; + + /// Returns the extension containing the array. + fn extension() -> &'static Arc; + + /// Returns the definition for the array type. + fn type_def() -> &'static TypeDef; + + /// Instantiates an array [CustomType] from its definition given a size and + /// element type argument. + fn instantiate_custom_ty( + array_def: &TypeDef, + size: impl Into, + element_ty: impl Into, + ) -> Result { + array_def.instantiate(vec![size.into(), element_ty.into()]) + } + + /// Instantiates an array type from its definition given a size and element + /// type argument. + fn instantiate_ty( + array_def: &TypeDef, + size: impl Into, + element_ty: impl Into, + ) -> Result { + Self::instantiate_custom_ty(array_def, size, element_ty).map(Into::into) + } + + /// Instantiates an array [CustomType] given a size and element type argument. + fn custom_ty(size: impl Into, element_ty: impl Into) -> CustomType { + Self::instantiate_custom_ty(Self::type_def(), size, element_ty) + .expect("array parameters are valid") + } + + /// Instantiate a new array type given a size argument and element type. + /// + /// This method is equivalent to [`ArrayKind::ty_parametric`], but uses concrete + /// arguments types to ensure no errors are possible. + fn ty(size: u64, element_ty: Type) -> Type { + Self::custom_ty(size, element_ty).into() + } + + /// Instantiate a new array type given the size and element type parameters. + /// + /// This is a generic version of [`ArrayKind::ty`]. + fn ty_parametric( + size: impl Into, + element_ty: impl Into, + ) -> Result { + Self::instantiate_ty(Self::type_def(), size, element_ty) + } + + /// Adds a operation to a dataflow graph that clones an array of copyable values. + /// + /// The default implementation uses the array clone operation. + fn build_clone( + builder: &mut D, + elem_ty: Type, + size: u64, + arr: Wire, + ) -> Result<(Wire, Wire), BuildError> { + builder.add_generic_array_clone::(elem_ty, size, arr) + } + + /// Adds a operation to a dataflow graph that clones an array of copyable values. + /// + /// The default implementation uses the array clone operation. + fn build_discard( + builder: &mut D, + elem_ty: Type, + size: u64, + arr: Wire, + ) -> Result<(), BuildError> { + builder.add_generic_array_discard::(elem_ty, size, arr) + } +} diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 197536032..e5b63f855 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -1,5 +1,6 @@ //! Definitions of `ArrayOp` and `ArrayOpDef`. +use std::marker::PhantomData; use std::sync::{Arc, Weak}; use strum::{EnumIter, EnumString, IntoStaticStr}; @@ -12,25 +13,25 @@ use crate::extension::{ ExtensionId, OpDef, SignatureError, SignatureFromArgs, SignatureFunc, TypeDef, }; use crate::ops::{ExtensionOp, NamedOp, OpName}; -use crate::std_extensions::collections::array::instantiate_array; use crate::type_row; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Type, TypeBound}; +use crate::utils::Never; use crate::Extension; -use super::{array_type, array_type_def, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; -/// Array operation definitions. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +/// Array operation definitions. Generic over the conrete array implementation. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, IntoStaticStr, EnumIter, EnumString)] #[allow(non_camel_case_types)] #[non_exhaustive] -pub enum ArrayOpDef { +pub enum GenericArrayOpDef { /// Makes a new array, given distinct inputs equal to its length: /// `new_array: (elemty)^SIZE -> array` /// where `SIZE` must be statically known (not a variable) new_array, /// Copies an element out of the array ([TypeBound::Copyable] elements only): - /// `get: array, index -> option` + /// `get: array, index -> option, array` get, /// Exchanges an element of the array with an external value: /// `set: array, index, elemty -> either(elemty, array | elemty, array)` @@ -53,26 +54,30 @@ pub enum ArrayOpDef { /// Allows discarding a 0-element array of linear type. /// `discard_empty: array<0, elemty> -> ` (no outputs) discard_empty, + /// Not an actual operation definition, but an unhabitable variant that + /// references `AK` to ensure that the type parameter is used. + #[strum(disabled)] + _phantom(PhantomData, Never), } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. const STATIC_SIZE_PARAM: &[TypeParam; 1] = &[TypeParam::max_nat()]; -impl SignatureFromArgs for ArrayOpDef { +impl SignatureFromArgs for GenericArrayOpDef { fn compute_signature(&self, arg_values: &[TypeArg]) -> Result { let [TypeArg::BoundedNat { n }] = *arg_values else { return Err(SignatureError::InvalidTypeArgs); }; let elem_ty_var = Type::new_var_use(0, TypeBound::Any); - let array_ty = array_type(n, elem_ty_var.clone()); + let array_ty = AK::ty(n, elem_ty_var.clone()); let params = vec![TypeBound::Any.into()]; let poly_func_ty = match self { - ArrayOpDef::new_array => PolyFuncTypeRV::new( + GenericArrayOpDef::new_array => PolyFuncTypeRV::new( params, FuncValueType::new(vec![elem_ty_var.clone(); n as usize], array_ty), ), - ArrayOpDef::pop_left | ArrayOpDef::pop_right => { - let popped_array_ty = array_type(n - 1, elem_ty_var.clone()); + GenericArrayOpDef::pop_left | GenericArrayOpDef::pop_right => { + let popped_array_ty = AK::ty(n - 1, elem_ty_var.clone()); PolyFuncTypeRV::new( params, FuncValueType::new( @@ -81,6 +86,7 @@ impl SignatureFromArgs for ArrayOpDef { ), ) } + GenericArrayOpDef::_phantom(_, never) => match *never {}, _ => unreachable!( "Operation {} should not need custom computation.", self.name() @@ -94,16 +100,16 @@ impl SignatureFromArgs for ArrayOpDef { } } -impl ArrayOpDef { +impl GenericArrayOpDef { /// Instantiate a new array operation with the given element type and array size. - pub fn to_concrete(self, elem_ty: Type, size: u64) -> ArrayOp { - if self == ArrayOpDef::discard_empty { + pub fn to_concrete(self, elem_ty: Type, size: u64) -> GenericArrayOp { + if self == GenericArrayOpDef::discard_empty { debug_assert_eq!( size, 0, "discard_empty should only be called on empty arrays" ); } - ArrayOp { + GenericArrayOp { def: self, elem_ty, size, @@ -116,7 +122,7 @@ impl ArrayOpDef { array_def: &TypeDef, _extension_ref: &Weak, ) -> SignatureFunc { - use ArrayOpDef::*; + use GenericArrayOpDef::*; if let new_array | pop_left | pop_right = self { // implements SignatureFromArgs // signature computed dynamically, so can rely on type definition in extension. @@ -124,7 +130,7 @@ impl ArrayOpDef { } else { let size_var = TypeArg::new_var_use(0, TypeParam::max_nat()); let elem_ty_var = Type::new_var_use(1, TypeBound::Any); - let array_ty = instantiate_array(array_def, size_var.clone(), elem_ty_var.clone()) + let array_ty = AK::instantiate_ty(array_def, size_var.clone(), elem_ty_var.clone()) .expect("Array type instantiation failed"); let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; @@ -137,12 +143,15 @@ impl ArrayOpDef { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); let copy_array_ty = - instantiate_array(array_def, size_var, copy_elem_ty.clone()) + AK::instantiate_ty(array_def, size_var, copy_elem_ty.clone()) .expect("Array type instantiation failed"); let option_type: Type = option_type(copy_elem_ty).into(); PolyFuncTypeRV::new( params, - FuncValueType::new(vec![copy_array_ty, usize_t], option_type), + FuncValueType::new( + vec![copy_array_ty.clone(), usize_t], + vec![option_type, copy_array_ty], + ), ) } set => { @@ -166,11 +175,12 @@ impl ArrayOpDef { discard_empty => PolyFuncTypeRV::new( vec![TypeBound::Any.into()], FuncValueType::new( - instantiate_array(array_def, 0, Type::new_var_use(0, TypeBound::Any)) + AK::instantiate_ty(array_def, 0, Type::new_var_use(0, TypeBound::Any)) .expect("Array type instantiation failed"), type_row![], ), ), + _phantom(_, never) => match *never {}, new_array | pop_left | pop_right => unreachable!(), } .into() @@ -178,7 +188,7 @@ impl ArrayOpDef { } } -impl MakeOpDef for ArrayOpDef { +impl MakeOpDef for GenericArrayOpDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -187,26 +197,27 @@ impl MakeOpDef for ArrayOpDef { } fn init_signature(&self, extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def(), extension_ref) + self.signature_from_def(AK::type_def(), extension_ref) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { match self { - ArrayOpDef::new_array => "Create a new array from elements", - ArrayOpDef::get => "Get an element from an array", - ArrayOpDef::set => "Set an element in an array", - ArrayOpDef::swap => "Swap two elements in an array", - ArrayOpDef::pop_left => "Pop an element from the left of an array", - ArrayOpDef::pop_right => "Pop an element from the right of an array", - ArrayOpDef::discard_empty => "Discard an empty array", + GenericArrayOpDef::new_array => "Create a new array from elements", + GenericArrayOpDef::get => "Get an element from an array", + GenericArrayOpDef::set => "Set an element in an array", + GenericArrayOpDef::swap => "Swap two elements in an array", + GenericArrayOpDef::pop_left => "Pop an element from the left of an array", + GenericArrayOpDef::pop_right => "Pop an element from the right of an array", + GenericArrayOpDef::discard_empty => "Discard an empty array", + GenericArrayOpDef::_phantom(_, never) => match *never {}, } .into() } @@ -222,7 +233,7 @@ impl MakeOpDef for ArrayOpDef { extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = - self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap(), extension_ref); + self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap(), extension_ref); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -232,33 +243,33 @@ impl MakeOpDef for ArrayOpDef { } #[derive(Clone, Debug, PartialEq)] -/// Concrete array operation. -pub struct ArrayOp { +/// Concrete array operation. Generic over the actual array implemenation. +pub struct GenericArrayOp { /// The operation definition. - pub def: ArrayOpDef, + pub def: GenericArrayOpDef, /// The element type of the array. pub elem_ty: Type, /// The size of the array. pub size: u64, } -impl NamedOp for ArrayOp { +impl NamedOp for GenericArrayOp { fn name(&self) -> OpName { self.def.name() } } -impl MakeExtensionOp for ArrayOp { +impl MakeExtensionOp for GenericArrayOp { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayOpDef::from_def(ext_op.def())?; + let def = GenericArrayOpDef::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } fn type_args(&self) -> Vec { - use ArrayOpDef::*; + use GenericArrayOpDef::*; let ty_arg = TypeArg::Type { ty: self.elem_ty.clone(), }; @@ -273,30 +284,31 @@ impl MakeExtensionOp for ArrayOp { new_array | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } + _phantom(_, never) => match never {}, } } } -impl MakeRegisteredOp for ArrayOp { +impl MakeRegisteredOp for GenericArrayOp { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayOp { - type Def = ArrayOpDef; +impl HasDef for GenericArrayOp { + type Def = GenericArrayOpDef; } -impl HasConcrete for ArrayOpDef { - type Concrete = ArrayOp; +impl HasConcrete for GenericArrayOpDef { + type Concrete = GenericArrayOp; fn instantiate(&self, type_args: &[TypeArg]) -> Result { let (ty, size) = match (self, type_args) { - (ArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), + (GenericArrayOpDef::discard_empty, [TypeArg::Type { ty }]) => (ty.clone(), 0), (_, [TypeArg::BoundedNat { n }, TypeArg::Type { ty }]) => (ty.clone(), *n), _ => return Err(SignatureError::InvalidTypeArgs.into()), }; @@ -307,11 +319,13 @@ impl HasConcrete for ArrayOpDef { #[cfg(test)] mod tests { + use rstest::rstest; use strum::IntoEnumIterator; use crate::extension::prelude::usize_t; use crate::std_extensions::arithmetic::float_types::float64_type; - use crate::std_extensions::collections::array::new_array_op; + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::{bool_t, qb_t}, @@ -320,46 +334,51 @@ mod tests { use super::*; - #[test] - fn test_array_ops() { - for def in ArrayOpDef::iter() { - let ty = if def == ArrayOpDef::get { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_array_ops(#[case] _kind: AK) { + for def in GenericArrayOpDef::::iter() { + let ty = if def == GenericArrayOpDef::get { bool_t() } else { qb_t() }; - let size = if def == ArrayOpDef::discard_empty { + let size = if def == GenericArrayOpDef::discard_empty { 0 } else { 2 }; let op = def.to_concrete(ty, size); let optype: OpType = op.clone().into(); - let new_op: ArrayOp = optype.cast().unwrap(); + let new_op: GenericArrayOp = optype.cast().unwrap(); assert_eq!(new_op, op); } } - #[test] + #[rstest] + #[case(Array)] + #[case(ValueArray)] /// Test building a HUGR involving a new_array operation. - fn test_new_array() { - let mut b = - DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], array_type(2, qb_t()))).unwrap(); + fn test_new_array(#[case] _kind: AK) { + let mut b = DFGBuilder::new(inout_sig(vec![qb_t(), qb_t()], AK::ty(2, qb_t()))).unwrap(); let [q1, q2] = b.input_wires_arr(); - let op = new_array_op(qb_t(), 2); + let op = GenericArrayOpDef::::new_array.to_concrete(qb_t(), 2); let out = b.add_dataflow_op(op, [q1, q2]).unwrap(); b.finish_hugr_with_outputs(out.outputs()).unwrap(); } - #[test] - fn test_get() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_get(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::get.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -368,22 +387,28 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone()), usize_t()].into(), - &vec![option_type(element_ty.clone()).into()].into() + &vec![AK::ty(size, element_ty.clone()), usize_t()].into(), + &vec![ + option_type(element_ty.clone()).into(), + AK::ty(size, element_ty.clone()) + ] + .into() ) ); } - #[test] - fn test_set() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_set(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::set.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::set.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); - let array_ty = array_type(size, element_ty.clone()); + let array_ty = AK::ty(size, element_ty.clone()); let result_row = vec![element_ty.clone(), array_ty.clone()]; assert_eq!( sig.io(), @@ -394,16 +419,18 @@ mod tests { ); } - #[test] - fn test_swap() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_swap(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - let op = ArrayOpDef::swap.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::swap.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); - let array_ty = array_type(size, element_ty.clone()); + let array_ty = AK::ty(size, element_ty.clone()); assert_eq!( sig.io(), ( @@ -413,11 +440,18 @@ mod tests { ); } - #[test] - fn test_pops() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_pops(#[case] _kind: AK) { let size = 2; let element_ty = bool_t(); - for op in [ArrayOpDef::pop_left, ArrayOpDef::pop_right].iter() { + for op in [ + GenericArrayOpDef::::pop_left, + GenericArrayOpDef::::pop_right, + ] + .iter() + { let op = op.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -426,10 +460,10 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone())].into(), &vec![option_type(vec![ element_ty.clone(), - array_type(size - 1, element_ty.clone()) + AK::ty(size - 1, element_ty.clone()) ]) .into()] .into() @@ -438,11 +472,13 @@ mod tests { } } - #[test] - fn test_discard_empty() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_discard_empty(#[case] _kind: AK) { let size = 0; let element_ty = bool_t(); - let op = ArrayOpDef::discard_empty.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::discard_empty.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -450,19 +486,18 @@ mod tests { assert_eq!( sig.io(), - ( - &vec![array_type(size, element_ty.clone())].into(), - &type_row![] - ) + (&vec![AK::ty(size, element_ty.clone())].into(), &type_row![]) ); } - #[test] + #[rstest] + #[case(Array)] + #[case(ValueArray)] /// Initialize an array operation where the element type is not from the prelude. - fn test_non_prelude_op() { + fn test_non_prelude_op(#[case] _kind: AK) { let size = 2; let element_ty = float64_type(); - let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); + let op = GenericArrayOpDef::::get.to_concrete(element_ty.clone(), size); let optype: OpType = op.into(); @@ -471,8 +506,12 @@ mod tests { assert_eq!( sig.io(), ( - &vec![array_type(size, element_ty.clone()), usize_t()].into(), - &vec![option_type(element_ty.clone()).into()].into() + &vec![AK::ty(size, element_ty.clone()), usize_t()].into(), + &vec![ + option_type(element_ty.clone()).into(), + AK::ty(size, element_ty.clone()) + ] + .into() ) ); } diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index a31505cb2..10a52308a 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -1,5 +1,6 @@ //! Definition of the array repeat operation. +use std::marker::PhantomData; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -12,46 +13,60 @@ use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; use crate::Extension; -use super::{array_type_def, instantiate_array, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; /// Name of the operation to repeat a value multiple times pub const ARRAY_REPEAT_OP_ID: OpName = OpName::new_inline("repeat"); -/// Definition of the array repeat op. +/// Definition of the array repeat op. Generic over the concrete array implementation. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct ArrayRepeatDef; +pub struct GenericArrayRepeatDef(PhantomData); -impl NamedOp for ArrayRepeatDef { +impl GenericArrayRepeatDef { + /// Creates a new array repeat operation definition. + pub fn new() -> Self { + GenericArrayRepeatDef(PhantomData) + } +} + +impl Default for GenericArrayRepeatDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayRepeatDef { fn name(&self) -> OpName { ARRAY_REPEAT_OP_ID } } -impl FromStr for ArrayRepeatDef { +impl FromStr for GenericArrayRepeatDef { type Err = (); fn from_str(s: &str) -> Result { - if s == ArrayRepeatDef.name() { - Ok(Self) + if s == ARRAY_REPEAT_OP_ID { + Ok(GenericArrayRepeatDef::new()) } else { Err(()) } } } -impl ArrayRepeatDef { +impl GenericArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t = Type::new_var_use(1, TypeBound::Any); let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); - let array_ty = instantiate_array(array_def, n, t).expect("Array type instantiation failed"); + let array_ty = + AK::instantiate_ty(array_def, n, t).expect("Array type instantiation failed"); PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() } } -impl MakeOpDef for ArrayRepeatDef { +impl MakeOpDef for GenericArrayRepeatDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -60,15 +75,15 @@ impl MakeOpDef for ArrayRepeatDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def()) + self.signature_from_def(AK::type_def()) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { @@ -87,7 +102,7 @@ impl MakeOpDef for ArrayRepeatDef { extension: &mut Extension, extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap()); + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -96,34 +111,39 @@ impl MakeOpDef for ArrayRepeatDef { } } -/// Definition of the array repeat op. +/// Definition of the array repeat op. Generic over the concrete array implementation. #[derive(Clone, Debug, PartialEq)] -pub struct ArrayRepeat { +pub struct GenericArrayRepeat { /// The element type of the resulting array. pub elem_ty: Type, /// Size of the array. pub size: u64, + _kind: PhantomData, } -impl ArrayRepeat { +impl GenericArrayRepeat { /// Creates a new array repeat op. pub fn new(elem_ty: Type, size: u64) -> Self { - ArrayRepeat { elem_ty, size } + GenericArrayRepeat { + elem_ty, + size, + _kind: PhantomData, + } } } -impl NamedOp for ArrayRepeat { +impl NamedOp for GenericArrayRepeat { fn name(&self) -> OpName { ARRAY_REPEAT_OP_ID } } -impl MakeExtensionOp for ArrayRepeat { +impl MakeExtensionOp for GenericArrayRepeat { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayRepeatDef::from_def(ext_op.def())?; + let def = GenericArrayRepeatDef::::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } @@ -135,27 +155,27 @@ impl MakeExtensionOp for ArrayRepeat { } } -impl MakeRegisteredOp for ArrayRepeat { +impl MakeRegisteredOp for GenericArrayRepeat { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayRepeat { - type Def = ArrayRepeatDef; +impl HasDef for GenericArrayRepeat { + type Def = GenericArrayRepeatDef; } -impl HasConcrete for ArrayRepeatDef { - type Concrete = ArrayRepeat; +impl HasConcrete for GenericArrayRepeatDef { + type Concrete = GenericArrayRepeat; fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { - Ok(ArrayRepeat::new(ty.clone(), *n)) + Ok(GenericArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -164,7 +184,10 @@ impl HasConcrete for ArrayRepeatDef { #[cfg(test)] mod tests { - use crate::std_extensions::collections::array::array_type; + use rstest::rstest; + + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -173,19 +196,23 @@ mod tests { use super::*; - #[test] - fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2); + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_repeat_def(#[case] _kind: AK) { + let op = GenericArrayRepeat::::new(qb_t(), 2); let optype: OpType = op.clone().into(); - let new_op: ArrayRepeat = optype.cast().unwrap(); + let new_op: GenericArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); } - #[test] - fn test_repeat() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_repeat(#[case] _kind: AK) { let size = 2; let element_ty = qb_t(); - let op = ArrayRepeat::new(element_ty.clone(), size); + let op = GenericArrayRepeat::::new(element_ty.clone(), size); let optype: OpType = op.into(); @@ -195,7 +222,7 @@ mod tests { sig.io(), ( &vec![Type::new_function(Signature::new(vec![], vec![qb_t()]))].into(), - &vec![array_type(size, element_ty.clone())].into(), + &vec![AK::ty(size, element_ty.clone())].into(), ) ); } diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 8064a73d0..fcd06b628 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -1,5 +1,6 @@ //! Array scanning operation +use std::marker::PhantomData; use std::str::FromStr; use std::sync::{Arc, Weak}; @@ -14,34 +15,47 @@ use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; use crate::Extension; -use super::{array_type_def, instantiate_array, ARRAY_TYPENAME}; +use super::array_kind::ArrayKind; /// Name of the operation for the combined map/fold operation pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); -/// Definition of the array scan op. +/// Definition of the array scan op. Generic over the concrete array implementation. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] -pub struct ArrayScanDef; +pub struct GenericArrayScanDef(PhantomData); -impl NamedOp for ArrayScanDef { +impl GenericArrayScanDef { + /// Creates a new array scan operation definition. + pub fn new() -> Self { + GenericArrayScanDef(PhantomData) + } +} + +impl Default for GenericArrayScanDef { + fn default() -> Self { + Self::new() + } +} + +impl NamedOp for GenericArrayScanDef { fn name(&self) -> OpName { ARRAY_SCAN_OP_ID } } -impl FromStr for ArrayScanDef { +impl FromStr for GenericArrayScanDef { type Err = (); fn from_str(s: &str) -> Result { - if s == ArrayScanDef.name() { - Ok(Self) + if s == ARRAY_SCAN_OP_ID { + Ok(Self::new()) } else { Err(()) } } } -impl ArrayScanDef { +impl GenericArrayScanDef { /// To avoid recursion when defining the extension, take the type definition /// and a reference to the extension as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { @@ -60,7 +74,7 @@ impl ArrayScanDef { params, FuncTypeBase::::new( vec![ - instantiate_array(array_def, n.clone(), t1.clone()) + AK::instantiate_ty(array_def, n.clone(), t1.clone()) .expect("Array type instantiation failed") .into(), Type::new_function(FuncTypeBase::::new( @@ -71,7 +85,7 @@ impl ArrayScanDef { s.clone(), ], vec![ - instantiate_array(array_def, n, t2) + AK::instantiate_ty(array_def, n, t2) .expect("Array type instantiation failed") .into(), s, @@ -82,7 +96,7 @@ impl ArrayScanDef { } } -impl MakeOpDef for ArrayScanDef { +impl MakeOpDef for GenericArrayScanDef { fn from_def(op_def: &OpDef) -> Result where Self: Sized, @@ -91,15 +105,15 @@ impl MakeOpDef for ArrayScanDef { } fn init_signature(&self, _extension_ref: &Weak) -> SignatureFunc { - self.signature_from_def(array_type_def()) + self.signature_from_def(AK::type_def()) } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } fn extension(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn description(&self) -> String { @@ -120,7 +134,7 @@ impl MakeOpDef for ArrayScanDef { extension: &mut Extension, extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { - let sig = self.signature_from_def(extension.get_type(&ARRAY_TYPENAME).unwrap()); + let sig = self.signature_from_def(extension.get_type(&AK::TYPE_NAME).unwrap()); let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -129,9 +143,9 @@ impl MakeOpDef for ArrayScanDef { } } -/// Definition of the array scan op. +/// Definition of the array scan op. Generic over the concrete array implementation. #[derive(Clone, Debug, PartialEq)] -pub struct ArrayScan { +pub struct GenericArrayScan { /// The element type of the input array. pub src_ty: Type, /// The target element type of the output array. @@ -140,32 +154,34 @@ pub struct ArrayScan { pub acc_tys: Vec, /// Size of the array. pub size: u64, + _kind: PhantomData, } -impl ArrayScan { +impl GenericArrayScan { /// Creates a new array scan op. pub fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { - ArrayScan { + GenericArrayScan { src_ty, tgt_ty, acc_tys, size, + _kind: PhantomData, } } } -impl NamedOp for ArrayScan { +impl NamedOp for GenericArrayScan { fn name(&self) -> OpName { ARRAY_SCAN_OP_ID } } -impl MakeExtensionOp for ArrayScan { +impl MakeExtensionOp for GenericArrayScan { fn from_extension_op(ext_op: &ExtensionOp) -> Result where Self: Sized, { - let def = ArrayScanDef::from_def(ext_op.def())?; + let def = GenericArrayScanDef::::from_def(ext_op.def())?; def.instantiate(ext_op.args()) } @@ -181,22 +197,22 @@ impl MakeExtensionOp for ArrayScan { } } -impl MakeRegisteredOp for ArrayScan { +impl MakeRegisteredOp for GenericArrayScan { fn extension_id(&self) -> ExtensionId { - super::EXTENSION_ID + AK::EXTENSION_ID } fn extension_ref(&self) -> Weak { - Arc::downgrade(&super::EXTENSION) + Arc::downgrade(AK::extension()) } } -impl HasDef for ArrayScan { - type Def = ArrayScanDef; +impl HasDef for GenericArrayScan { + type Def = GenericArrayScanDef; } -impl HasConcrete for ArrayScanDef { - type Concrete = ArrayScan; +impl HasConcrete for GenericArrayScanDef { + type Concrete = GenericArrayScan; fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { @@ -209,7 +225,12 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) + Ok(GenericArrayScan::new( + src_ty.clone(), + tgt_ty.clone(), + acc_tys?, + *n, + )) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -218,9 +239,11 @@ impl HasConcrete for ArrayScanDef { #[cfg(test)] mod tests { + use rstest::rstest; use crate::extension::prelude::usize_t; - use crate::std_extensions::collections::array::array_type; + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; use crate::{ extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, @@ -229,21 +252,25 @@ mod tests { use super::*; - #[test] - fn test_scan_def() { - let op = ArrayScan::new(bool_t(), qb_t(), vec![usize_t()], 2); + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_def(#[case] _kind: AK) { + let op = GenericArrayScan::::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); - let new_op: ArrayScan = optype.cast().unwrap(); + let new_op: GenericArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); } - #[test] - fn test_scan_map() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_map(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); + let op = GenericArrayScan::::new(src_ty.clone(), tgt_ty.clone(), vec![], size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -251,24 +278,26 @@ mod tests { sig.io(), ( &vec![ - array_type(size, src_ty.clone()), + AK::ty(size, src_ty.clone()), Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) ] .into(), - &vec![array_type(size, tgt_ty)].into(), + &vec![AK::ty(size, tgt_ty)].into(), ) ); } - #[test] - fn test_scan_accs() { + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_scan_accs(#[case] _kind: AK) { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); - let op = ArrayScan::new( + let op = GenericArrayScan::::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], @@ -281,7 +310,7 @@ mod tests { sig.io(), ( &vec![ - array_type(size, src_ty.clone()), + AK::ty(size, src_ty.clone()), Type::new_function(Signature::new( vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] @@ -290,7 +319,7 @@ mod tests { acc_ty2.clone() ] .into(), - &vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(), + &vec![AK::ty(size, tgt_ty), acc_ty1, acc_ty2].into(), ) ); } diff --git a/hugr-core/src/std_extensions/collections/array/array_value.rs b/hugr-core/src/std_extensions/collections/array/array_value.rs new file mode 100644 index 000000000..2909acf29 --- /dev/null +++ b/hugr-core/src/std_extensions/collections/array/array_value.rs @@ -0,0 +1,159 @@ +use itertools::Itertools as _; +use serde::{Deserialize, Serialize}; +use std::hash::{Hash, Hasher}; +use std::marker::PhantomData; + +use crate::extension::resolution::{ + resolve_type_extensions, resolve_value_extensions, ExtensionResolutionError, + WeakExtensionRegistry, +}; +use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; +use crate::ops::Value; +use crate::types::type_param::TypeArg; +use crate::types::{CustomCheckFailure, CustomType, Type}; + +use super::array_kind::ArrayKind; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// Statically sized array of values, all of the same type. +pub struct GenericArrayValue { + values: Vec, + typ: Type, + _kind: PhantomData, +} + +impl GenericArrayValue { + /// Create a new [CustomConst] for an array of values of type `typ`. + /// That all values are of type `typ` is not checked here. + /// + /// [CustomConst]: crate::ops::constant::CustomConst + pub fn new(typ: Type, contents: impl IntoIterator) -> Self { + Self { + values: contents.into_iter().collect_vec(), + typ, + _kind: PhantomData, + } + } + + /// Create a new [CustomConst] for an empty array of values of type `typ`. + /// + /// [CustomConst]: crate::ops::constant::CustomConst + pub fn new_empty(typ: Type) -> Self { + Self { + values: vec![], + typ, + _kind: PhantomData, + } + } + + /// Returns the type of the `[GenericArrayValue]` as a `[CustomType]`.` + pub fn custom_type(&self) -> CustomType { + AK::custom_ty(self.values.len() as u64, self.typ.clone()) + } + + /// Returns the type of the `[GenericArrayValue]`. + pub fn get_type(&self) -> Type { + self.custom_type().into() + } + + /// Returns the type of values inside the `[ArrayValue]`. + pub fn get_element_type(&self) -> &Type { + &self.typ + } + + /// Returns the values contained inside the `[ArrayValue]`. + pub fn get_contents(&self) -> &[Value] { + &self.values + } + + /// Returns the name of the value. + pub fn name(&self) -> ValueName { + AK::VALUE_NAME + } + + /// Validates the array value. + pub fn validate(&self) -> Result<(), CustomCheckFailure> { + let typ = self.custom_type(); + + AK::extension() + .get_type(&AK::TYPE_NAME) + .unwrap() + .check_custom(&typ) + .map_err(|_| { + CustomCheckFailure::Message(format!( + "Custom typ {typ} is not a valid instantiation of array." + )) + })?; + + // constant can only hold classic type. + let ty = match typ.args() { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] + if *n as usize == self.values.len() => + { + ty + } + _ => { + return Err(CustomCheckFailure::Message(format!( + "Invalid array type arguments: {:?}", + typ.args() + ))) + } + }; + + // check all values are instances of the element type + for v in &self.values { + if v.get_type() != *ty { + return Err(CustomCheckFailure::Message(format!( + "Array element {v:?} is not of expected type {ty}" + ))); + } + } + + Ok(()) + } + + /// Update the extensions associated with the internal values. + pub fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError> { + for val in &mut self.values { + resolve_value_extensions(val, extensions)?; + } + resolve_type_extensions(&mut self.typ, extensions) + } +} + +impl TryHash for GenericArrayValue { + fn try_hash(&self, mut st: &mut dyn Hasher) -> bool { + maybe_hash_values(&self.values, &mut st) && { + self.typ.hash(&mut st); + true + } + } +} + +#[cfg(test)] +mod test { + use rstest::rstest; + + use crate::extension::prelude::{usize_t, ConstUsize}; + use crate::std_extensions::arithmetic::float_types::ConstF64; + + use crate::std_extensions::collections::array::Array; + use crate::std_extensions::collections::value_array::ValueArray; + + use super::*; + + #[rstest] + #[case(Array)] + #[case(ValueArray)] + fn test_array_value(#[case] _kind: AK) { + let array_value = GenericArrayValue::::new(usize_t(), vec![ConstUsize::new(3).into()]); + array_value.validate().unwrap(); + + let wrong_array_value = + GenericArrayValue::::new(usize_t(), vec![ConstF64::new(1.2).into()]); + assert!(wrong_array_value.validate().is_err()); + } +} diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 623443347..4d8f7ce4e 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -1,6 +1,7 @@ //! Builder trait for array operations in the dataflow graph. -use crate::std_extensions::collections::array::{new_array_op, ArrayOpDef}; +use crate::std_extensions::collections::array::GenericArrayOpDef; +use crate::std_extensions::collections::value_array::ValueArray; use crate::{ builder::{BuildError, Dataflow}, extension::simple_op::HasConcrete as _, @@ -9,8 +10,15 @@ use crate::{ }; use itertools::Itertools as _; -/// Trait for building array operations in a dataflow graph. -pub trait ArrayOpBuilder: Dataflow { +use super::{Array, ArrayKind, GenericArrayClone, GenericArrayDiscard}; + +use crate::extension::prelude::{ + either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _, +}; + +/// Trait for building array operations in a dataflow graph that are generic +/// over the concrete array implementation. +pub trait GenericArrayOpBuilder: Dataflow { /// Adds a new array operation to the dataflow graph and return the wire /// representing the new array. /// @@ -26,18 +34,70 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the new array. - fn add_new_array( + fn add_new_generic_array( &mut self, elem_ty: Type, values: impl IntoIterator, ) -> Result { let inputs = values.into_iter().collect_vec(); let [out] = self - .add_dataflow_op(new_array_op(elem_ty, inputs.len() as u64), inputs)? + .add_dataflow_op( + GenericArrayOpDef::::new_array.to_concrete(elem_ty, inputs.len() as u64), + inputs, + )? .outputs_arr(); Ok(out) } + /// Adds an array clone operation to the dataflow graph and return the wires + /// representing the originala and cloned array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wires representing the original and cloned array. + fn add_generic_array_clone( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(Wire, Wire), BuildError> { + let op = GenericArrayClone::::new(elem_ty, size).unwrap(); + let [arr1, arr2] = self.add_dataflow_op(op, vec![input])?.outputs_arr(); + Ok((arr1, arr2)) + } + + /// Adds an array discard operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// If building the operation fails. + fn add_generic_array_discard( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result<(), BuildError> { + let op = GenericArrayDiscard::::new(elem_ty, size).unwrap(); + let [] = self.add_dataflow_op(op, vec![input])?.outputs_arr(); + Ok(()) + } + /// Adds an array get operation to the dataflow graph. /// /// # Arguments @@ -53,17 +113,18 @@ pub trait ArrayOpBuilder: Dataflow { /// /// # Returns /// - /// The wire representing the value at the specified index in the array. - fn add_array_get( + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_generic_array_get( &mut self, elem_ty: Type, size: u64, input: Wire, index: Wire, - ) -> Result { - let op = ArrayOpDef::get.instantiate(&[size.into(), elem_ty.into()])?; - let [out] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); - Ok(out) + ) -> Result<(Wire, Wire), BuildError> { + let op = GenericArrayOpDef::::get.instantiate(&[size.into(), elem_ty.into()])?; + let [out, arr] = self.add_dataflow_op(op, vec![input, index])?.outputs_arr(); + Ok((out, arr)) } /// Adds an array set operation to the dataflow graph. @@ -85,7 +146,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the updated array after the set operation. - fn add_array_set( + fn add_generic_array_set( &mut self, elem_ty: Type, size: u64, @@ -93,7 +154,7 @@ pub trait ArrayOpBuilder: Dataflow { index: Wire, value: Wire, ) -> Result { - let op = ArrayOpDef::set.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::set.instantiate(&[size.into(), elem_ty.into()])?; let [out] = self .add_dataflow_op(op, vec![input, index, value])? .outputs_arr(); @@ -119,7 +180,7 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the updated array after the swap operation. - fn add_array_swap( + fn add_generic_array_swap( &mut self, elem_ty: Type, size: u64, @@ -127,7 +188,7 @@ pub trait ArrayOpBuilder: Dataflow { index1: Wire, index2: Wire, ) -> Result { - let op = ArrayOpDef::swap.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; let [out] = self .add_dataflow_op(op, vec![input, index1, index2])? .outputs_arr(); @@ -151,13 +212,13 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the Option> - fn add_array_pop_left( + fn add_generic_array_pop_left( &mut self, elem_ty: Type, size: u64, input: Wire, ) -> Result { - let op = ArrayOpDef::pop_left.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::pop_left.instantiate(&[size.into(), elem_ty.into()])?; Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) } @@ -178,13 +239,13 @@ pub trait ArrayOpBuilder: Dataflow { /// # Returns /// /// The wire representing the Option> - fn add_array_pop_right( + fn add_generic_array_pop_right( &mut self, elem_ty: Type, size: u64, input: Wire, ) -> Result { - let op = ArrayOpDef::pop_right.instantiate(&[size.into(), elem_ty.into()])?; + let op = GenericArrayOpDef::::pop_right.instantiate(&[size.into(), elem_ty.into()])?; Ok(self.add_dataflow_op(op, vec![input])?.out_wire(0)) } @@ -198,9 +259,13 @@ pub trait ArrayOpBuilder: Dataflow { /// # Errors /// /// Returns an error if building the operation fails. - fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + fn add_generic_array_discard_empty( + &mut self, + elem_ty: Type, + input: Wire, + ) -> Result<(), BuildError> { self.add_dataflow_op( - ArrayOpDef::discard_empty + GenericArrayOpDef::::discard_empty .instantiate(&[elem_ty.into()]) .unwrap(), [input], @@ -209,77 +274,104 @@ pub trait ArrayOpBuilder: Dataflow { } } -impl ArrayOpBuilder for D {} - -#[cfg(test)] -mod test { - use crate::std_extensions::collections::array::array_type; - use crate::{ - builder::{DFGBuilder, HugrBuilder}, - extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, - types::Signature, - Hugr, - }; - use rstest::rstest; +impl GenericArrayOpBuilder for D {} - use super::*; - - #[rstest::fixture] - #[default(DFGBuilder)] - fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)).unwrap())] - mut builder: B, - ) -> B { - let us0 = builder.add_load_value(ConstUsize::new(0)); - let us1 = builder.add_load_value(ConstUsize::new(1)); - let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - let [arr] = { - let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); - let res_sum_ty = { - let array_type = array_type(2, usize_t()); - either_type(array_type.clone(), array_type) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() +/// Helper function to build a Hugr that contains all basic array operations. +/// +/// Generic over the concrete array implementation. +pub fn build_all_array_ops_generic(mut builder: B) -> B { + let us0 = builder.add_load_value(ConstUsize::new(0)); + let us1 = builder.add_load_value(ConstUsize::new(1)); + let us2 = builder.add_load_value(ConstUsize::new(2)); + let arr = builder + .add_new_generic_array::(usize_t(), [us1, us2]) + .unwrap(); + let [arr] = { + let r = builder + .add_generic_array_swap::(usize_t(), 2, arr, us0, us1) + .unwrap(); + let res_sum_ty = { + let array_type = AK::ty(2, usize_t()); + either_type(array_type.clone(), array_type) }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; - let [elem_0] = { - let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); + let ([elem_0], arr) = { + let (r, arr) = builder + .add_generic_array_get::(usize_t(), 2, arr, us0) + .unwrap(); + ( builder .build_unwrap_sum(1, option_type(usize_t()), r) - .unwrap() - }; - - let [_elem_1, arr] = { - let r = builder - .add_array_set(usize_t(), 2, arr, us1, elem_0) - .unwrap(); - let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; - either_type(row.clone(), row) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; + .unwrap(), + arr, + ) + }; - let [_elem_left, arr] = { - let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) - .unwrap() - }; - let [_elem_right, arr] = { - let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) - .unwrap() + let [_elem_1, arr] = { + let r = builder + .add_generic_array_set::(usize_t(), 2, arr, us1, elem_0) + .unwrap(); + let res_sum_ty = { + let row = vec![usize_t(), AK::ty(2, usize_t())]; + either_type(row.clone(), row) }; + builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() + }; - builder.add_array_discard_empty(usize_t(), arr).unwrap(); + let [_elem_left, arr] = { + let r = builder + .add_generic_array_pop_left::(usize_t(), 2, arr) + .unwrap(); + builder + .build_unwrap_sum(1, option_type(vec![usize_t(), AK::ty(1, usize_t())]), r) + .unwrap() + }; + let [_elem_right, arr] = { + let r = builder + .add_generic_array_pop_right::(usize_t(), 1, arr) + .unwrap(); builder + .build_unwrap_sum(1, option_type(vec![usize_t(), AK::ty(0, usize_t())]), r) + .unwrap() + }; + + builder + .add_generic_array_discard_empty::(usize_t(), arr) + .unwrap(); + builder +} + +/// Helper function to build a Hugr that contains all basic array operations. +pub fn build_all_array_ops(builder: B) -> B { + build_all_array_ops_generic::(builder) +} + +/// Helper function to build a Hugr that contains all basic array operations. +pub fn build_all_value_array_ops(builder: B) -> B { + build_all_array_ops_generic::(builder) +} + +/// Testing utilities to generate Hugrs that contain array operations. +#[cfg(test)] +mod test { + use crate::builder::{DFGBuilder, HugrBuilder}; + use crate::types::Signature; + + use super::*; + + #[test] + fn all_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW); + let builder = DFGBuilder::new(sig).unwrap(); + build_all_array_ops(builder).finish_hugr().unwrap(); } - #[rstest] - fn build_all_ops(all_array_ops: DFGBuilder) { - all_array_ops.finish_hugr().unwrap(); + #[test] + fn all_value_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW); + let builder = DFGBuilder::new(sig).unwrap(); + build_all_value_array_ops(builder).finish_hugr().unwrap(); } } diff --git a/hugr-core/src/std_extensions/collections/value_array.rs b/hugr-core/src/std_extensions/collections/value_array.rs new file mode 100644 index 000000000..a1731cd7c --- /dev/null +++ b/hugr-core/src/std_extensions/collections/value_array.rs @@ -0,0 +1,349 @@ +//! A version of the standard fixed-length array extension where arrays of copyable types +//! are copyable themselves. +//! +//! Supports all regular array operations apart from `clone` and `discard`. + +use std::sync::Arc; + +use delegate::delegate; +use lazy_static::lazy_static; + +use crate::builder::{BuildError, Dataflow}; +use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry}; +use crate::extension::simple_op::{HasConcrete, MakeOpDef}; +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; +use crate::ops::constant::{CustomConst, ValueName}; +use crate::types::type_param::{TypeArg, TypeParam}; +use crate::types::{CustomCheckFailure, Type, TypeBound, TypeName}; +use crate::{Extension, Wire}; + +use super::array::op_builder::GenericArrayOpBuilder; +use super::array::{ + Array, ArrayKind, GenericArrayConvert, GenericArrayConvertDef, GenericArrayOp, + GenericArrayOpDef, GenericArrayRepeat, GenericArrayRepeatDef, GenericArrayScan, + GenericArrayScanDef, GenericArrayValue, FROM, INTO, +}; + +/// Reported unique name of the value array type. +pub const VALUE_ARRAY_TYPENAME: TypeName = TypeName::new_inline("value_array"); +/// Reported unique name of the value array value. +pub const VALUE_ARRAY_VALUENAME: TypeName = TypeName::new_inline("value_array"); +/// Reported unique name of the extension +pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.value_array"); +/// Extension version. +pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); + +/// A fixed-length collection of values. +/// +/// A value array inherits its linearity from its elements. +#[derive(Clone, Copy, Debug, derive_more::Display, Eq, PartialEq, Default)] +pub struct ValueArray; + +impl ArrayKind for ValueArray { + const EXTENSION_ID: ExtensionId = EXTENSION_ID; + const TYPE_NAME: TypeName = VALUE_ARRAY_TYPENAME; + const VALUE_NAME: ValueName = VALUE_ARRAY_VALUENAME; + + fn extension() -> &'static Arc { + &EXTENSION + } + + fn type_def() -> &'static TypeDef { + EXTENSION.get_type(&VALUE_ARRAY_TYPENAME).unwrap() + } + + fn build_clone( + _builder: &mut D, + _elem_ty: Type, + _size: u64, + arr: Wire, + ) -> Result<(Wire, Wire), BuildError> { + Ok((arr, arr)) + } + + fn build_discard( + _builder: &mut D, + _elem_ty: Type, + _size: u64, + _arr: Wire, + ) -> Result<(), BuildError> { + Ok(()) + } +} + +/// Value array operation definitions. +pub type VArrayOpDef = GenericArrayOpDef; +/// Value array repeat operation definition. +pub type VArrayRepeatDef = GenericArrayRepeatDef; +/// Value array scan operation definition. +pub type VArrayScanDef = GenericArrayScanDef; +/// Value array to default array conversion operation definition. +pub type VArrayToArrayDef = GenericArrayConvertDef; +/// Value array from default array conversion operation definition. +pub type VArrayFromArrayDef = GenericArrayConvertDef; + +/// Value array operations. +pub type VArrayOp = GenericArrayOp; +/// The value array repeat operation. +pub type VArrayRepeat = GenericArrayRepeat; +/// The value array scan operation. +pub type VArrayScan = GenericArrayScan; +/// The value array to default array conversion operation. +pub type VArrayToArray = GenericArrayConvert; +/// The value array from default array conversion operation. +pub type VArrayFromArray = GenericArrayConvert; + +/// A value array extension value. +pub type VArrayValue = GenericArrayValue; + +lazy_static! { + /// Extension for value array operations. + pub static ref EXTENSION: Arc = { + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + VALUE_ARRAY_TYPENAME, + vec![ TypeParam::max_nat(), TypeBound::Any.into()], + "Fixed-length value array".into(), + // Value arrays are copyable iff their elements are + TypeDefBound::from_params(vec![1]), + extension_ref, + ) + .unwrap(); + + VArrayOpDef::load_all_ops(extension, extension_ref).unwrap(); + VArrayRepeatDef::new().add_to_extension(extension, extension_ref).unwrap(); + VArrayScanDef::new().add_to_extension(extension, extension_ref).unwrap(); + VArrayToArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); + VArrayFromArrayDef::new().add_to_extension(extension, extension_ref).unwrap(); + }) + }; +} + +#[typetag::serde(name = "VArrayValue")] +impl CustomConst for VArrayValue { + delegate! { + to self { + fn name(&self) -> ValueName; + fn validate(&self) -> Result<(), CustomCheckFailure>; + fn update_extensions( + &mut self, + extensions: &WeakExtensionRegistry, + ) -> Result<(), ExtensionResolutionError>; + fn get_type(&self) -> Type; + } + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } +} + +/// Gets the [TypeDef] for value arrays. Note that instantiations are more easily +/// created via [value_array_type] and [value_array_type_parametric] +pub fn value_array_type_def() -> &'static TypeDef { + ValueArray::type_def() +} + +/// Instantiate a new value array type given a size argument and element type. +/// +/// This method is equivalent to [`value_array_type_parametric`], but uses concrete +/// arguments types to ensure no errors are possible. +pub fn value_array_type(size: u64, element_ty: Type) -> Type { + ValueArray::ty(size, element_ty) +} + +/// Instantiate a new value array type given the size and element type parameters. +/// +/// This is a generic version of [`value_array_type`]. +pub fn value_array_type_parametric( + size: impl Into, + element_ty: impl Into, +) -> Result { + ValueArray::ty_parametric(size, element_ty) +} + +/// Trait for building value array operations in a dataflow graph. +pub trait VArrayOpBuilder: GenericArrayOpBuilder { + /// Adds a new array operation to the dataflow graph and return the wire + /// representing the new array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `values` - An iterator over the values to initialize the array with. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// The wire representing the new array. + fn add_new_value_array( + &mut self, + elem_ty: Type, + values: impl IntoIterator, + ) -> Result { + self.add_new_generic_array::(elem_ty, values) + } + + /// Adds an array get operation to the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to get. + /// + /// # Errors + /// + /// If building the operation fails. + /// + /// # Returns + /// + /// * The wire representing the value at the specified index in the array + /// * The wire representing the array + fn add_value_array_get( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + ) -> Result<(Wire, Wire), BuildError> { + self.add_generic_array_get::(elem_ty, size, input, index) + } + + /// Adds an array set operation to the dataflow graph. + /// + /// This operation sets the value at a specified index in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index` - The wire representing the index to set. + /// * `value` - The wire representing the value to set at the specified index. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the set operation. + fn add_value_array_set( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index: Wire, + value: Wire, + ) -> Result { + self.add_generic_array_set::(elem_ty, size, input, index, value) + } + + /// Adds an array swap operation to the dataflow graph. + /// + /// This operation swaps the values at two specified indices in the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// * `index1` - The wire representing the first index to swap. + /// * `index2` - The wire representing the second index to swap. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the updated array after the swap operation. + fn add_value_array_swap( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + index1: Wire, + index2: Wire, + ) -> Result { + let op = + GenericArrayOpDef::::swap.instantiate(&[size.into(), elem_ty.into()])?; + let [out] = self + .add_dataflow_op(op, vec![input, index1, index2])? + .outputs_arr(); + Ok(out) + } + + /// Adds an array pop-left operation to the dataflow graph. + /// + /// This operation removes the leftmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_left( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_left::(elem_ty, size, input) + } + + /// Adds an array pop-right operation to the dataflow graph. + /// + /// This operation removes the rightmost element from the array. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `size` - The size of the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + /// + /// # Returns + /// + /// The wire representing the Option> + fn add_array_pop_right( + &mut self, + elem_ty: Type, + size: u64, + input: Wire, + ) -> Result { + self.add_generic_array_pop_right::(elem_ty, size, input) + } + + /// Adds an operation to discard an empty array from the dataflow graph. + /// + /// # Arguments + /// + /// * `elem_ty` - The type of the elements in the array. + /// * `input` - The wire representing the array. + /// + /// # Errors + /// + /// Returns an error if building the operation fails. + fn add_array_discard_empty(&mut self, elem_ty: Type, input: Wire) -> Result<(), BuildError> { + self.add_generic_array_discard_empty::(elem_ty, input) + } +} + +impl VArrayOpBuilder for D {} diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index f44b075f1..efa0eef84 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -101,6 +101,18 @@ pub(crate) fn is_default(t: &T) -> bool { *t == Default::default() } +/// An empty type. +/// +/// # Example +/// +/// ```ignore +/// fn foo(never: Never) -> ! { +/// match never {} +/// } +/// ``` +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub enum Never {} + #[cfg(test)] pub(crate) mod test_quantum_extension { use std::sync::Arc; diff --git a/hugr-llvm/src/emit/libc.rs b/hugr-llvm/src/emit/libc.rs index 4eb481b63..69a8754a0 100644 --- a/hugr-llvm/src/emit/libc.rs +++ b/hugr-llvm/src/emit/libc.rs @@ -1,6 +1,9 @@ use anyhow::Result; use hugr_core::{HugrView, Node}; -use inkwell::{values::BasicMetadataValueEnum, AddressSpace}; +use inkwell::{ + values::{BasicMetadataValueEnum, BasicValueEnum}, + AddressSpace, +}; use crate::emit::func::EmitFuncContext; @@ -26,3 +29,39 @@ pub fn emit_libc_printf>( context.builder().build_call(printf, args, "")?; Ok(()) } + +/// Emits a call to the libc `void* malloc(size_t size)` function. +pub fn emit_libc_malloc<'c, H: HugrView>( + context: &mut EmitFuncContext<'c, '_, H>, + size: BasicMetadataValueEnum<'c>, +) -> Result> { + let iw_ctx = context.typing_session().iw_context(); + let malloc_sig = iw_ctx + .i8_type() + .ptr_type(AddressSpace::default()) + .fn_type(&[iw_ctx.i64_type().into()], false); + let malloc = context.get_extern_func("malloc", malloc_sig)?; + let res = context + .builder() + .build_call(malloc, &[size], "")? + .try_as_basic_value() + .unwrap_left(); + Ok(res) +} + +/// Emits a call to the libc `void free(void* ptr)` function. +pub fn emit_libc_free>( + context: &mut EmitFuncContext, + ptr: BasicMetadataValueEnum, +) -> Result<()> { + let iw_ctx = context.typing_session().iw_context(); + let ptr_ty = iw_ctx.i8_type().ptr_type(AddressSpace::default()); + let ptr = context + .builder() + .build_bit_cast(ptr.into_pointer_value(), ptr_ty, "")?; + + let free_sig = iw_ctx.void_type().fn_type(&[ptr_ty.into()], false); + let free = context.get_extern_func("free", free_sig)?; + context.builder().build_call(free, &[ptr.into()], "")?; + Ok(()) +} diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 0216e9014..d1f7b6e45 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -1,24 +1,42 @@ //! Codegen for prelude array operations. +//! +//! An `array` is now lowered to a fat pointer `{ptr, usize}` that is allocated +//! to at least `n * sizeof(T)` bytes. The extra `usize` is an offset pointing to the +//! first element, i.e. the first element is at address `ptr + offset * sizeof(T)`. +//! +//! The rational behind the additional offset is the `pop_left` operation which bumps +//! the offset instead of mutating the pointer. This way, we can still free the original +//! pointer when the array is discarded after a pop. +//! +//! We provide utility functions [array_fat_pointer_ty], [build_array_fat_pointer], and +//! [decompose_array_fat_pointer] to work with array fat pointers. +//! +//! The [DefaultArrayCodegen] extension allocates all arrays on the heap using the +//! standard libc `malloc` and `free` functions. This behaviour can be customised +//! by providing a different implementation for [ArrayCodegen::emit_allocate_array] +//! and [ArrayCodegen::emit_free_array]. use std::iter; use anyhow::{anyhow, Ok, Result}; -use hugr_core::extension::prelude::option_type; +use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::{MakeExtensionOp, MakeRegisteredOp}; use hugr_core::ops::DataflowOpTrait; use hugr_core::std_extensions::collections::array::{ - self, array_type, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, + self, array_type, ArrayClone, ArrayDiscard, ArrayOp, ArrayOpDef, ArrayRepeat, ArrayScan, }; use hugr_core::types::{TypeArg, TypeEnum}; use hugr_core::{HugrView, Node}; -use inkwell::builder::{Builder, BuilderError}; -use inkwell::types::{BasicType, BasicTypeEnum}; +use inkwell::builder::Builder; +use inkwell::intrinsics::Intrinsic; +use inkwell::types::{BasicType, BasicTypeEnum, IntType, StructType}; use inkwell::values::{ - ArrayValue, BasicValue as _, BasicValueEnum, CallableValue, IntValue, PointerValue, + BasicValue as _, BasicValueEnum, CallableValue, IntValue, PointerValue, StructValue, }; -use inkwell::IntPredicate; +use inkwell::{AddressSpace, IntPredicate}; use itertools::Itertools; use crate::emit::emit_value; +use crate::emit::libc::{emit_libc_free, emit_libc_malloc}; use crate::{ emit::{deaggregate_call_result, EmitFuncContext, RowPromise}, types::{HugrType, TypingSession}, @@ -41,15 +59,52 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { /// A helper trait for customising the lowering of [hugr_core::std_extensions::collections::array] /// types, [hugr_core::ops::constant::CustomConst]s, and ops. +/// +/// An `array` is now lowered to a fat pointer `{ptr, usize}` that is allocated +/// to at least `n * sizeof(T)` bytes. The extra `usize` is an offset pointing to the +/// first element, i.e. the first element is at address `ptr + offset * sizeof(T)`. +/// +/// The rational behind the additional offset is the `pop_left` operation which bumps +/// the offset instead of mutating the pointer. This way, we can still free the original +/// pointer when the array is discarded after a pop. +/// +/// By default, all arrays are allocated on the heap using the standard libc `malloc` +/// and `free` functions. This behaviour can be customised by providing a different +/// implementation for [ArrayCodegen::emit_allocate_array] and +/// [ArrayCodegen::emit_free_array]. pub trait ArrayCodegen: Clone { + /// Emit an allocation of `size` bytes and return the corresponding pointer. + /// + /// The default implementation allocates on the heap by emitting a call to the + /// standard libc `malloc` function. + fn emit_allocate_array<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + size: IntValue<'c>, + ) -> Result> { + let ptr = emit_libc_malloc(ctx, size.into())?; + Ok(ptr.into_pointer_value()) + } + + /// Emit an deallocation of a pointer. + /// + /// The default implementation emits a call to the standard libc `free` function. + fn emit_free_array<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + ptr: PointerValue<'c>, + ) -> Result<()> { + emit_libc_free(ctx, ptr.into()) + } + /// Return the llvm type of [hugr_core::std_extensions::collections::array::ARRAY_TYPENAME]. fn array_type<'c>( &self, - _session: &TypingSession<'c, '_>, + session: &TypingSession<'c, '_>, elem_ty: BasicTypeEnum<'c>, - size: u64, + _size: u64, ) -> impl BasicType<'c> { - elem_ty.array_type(size as u32) + array_fat_pointer_ty(session, elem_ty) } /// Emit a [hugr_core::std_extensions::collections::array::ArrayValue]. @@ -72,6 +127,26 @@ pub trait ArrayCodegen: Clone { emit_array_op(self, ctx, op, inputs, outputs) } + /// Emit a [hugr_core::std_extensions::collections::array::ArrayClone] operation. + fn emit_array_clone<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayClone, + array_v: BasicValueEnum<'c>, + ) -> Result<(BasicValueEnum<'c>, BasicValueEnum<'c>)> { + emit_clone_op(self, ctx, op, array_v) + } + + /// Emit a [hugr_core::std_extensions::collections::array::ArrayDiscard] operation. + fn emit_array_discard<'c, H: HugrView>( + &self, + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayDiscard, + array_v: BasicValueEnum<'c>, + ) -> Result<()> { + emit_array_discard(self, ctx, op, array_v) + } + /// Emit a [hugr_core::std_extensions::collections::array::ArrayRepeat] op. fn emit_array_repeat<'c, H: HugrView>( &self, @@ -79,7 +154,7 @@ pub trait ArrayCodegen: Clone { op: ArrayRepeat, func: BasicValueEnum<'c>, ) -> Result> { - emit_repeat_op(ctx, op, func) + emit_repeat_op(self, ctx, op, func) } /// Emit a [hugr_core::std_extensions::collections::array::ArrayScan] op. @@ -93,7 +168,14 @@ pub trait ArrayCodegen: Clone { func: BasicValueEnum<'c>, initial_accs: &[BasicValueEnum<'c>], ) -> Result<(BasicValueEnum<'c>, Vec>)> { - emit_scan_op(ctx, op, src_array, func, initial_accs) + emit_scan_op( + self, + ctx, + op, + src_array.into_struct_value(), + func, + initial_accs, + ) } } @@ -153,6 +235,24 @@ impl CodegenExtension for ArrayCodegenExtension { ) } }) + .extension_op(array::EXTENSION_ID, array::ARRAY_CLONE_OP_ID, { + let ccg = self.0.clone(); + move |context, args| { + let arr = args.inputs[0]; + let op = ArrayClone::from_extension_op(args.node().as_ref())?; + let (arr1, arr2) = ccg.emit_array_clone(context, op, arr)?; + args.outputs.finish(context.builder(), [arr1, arr2]) + } + }) + .extension_op(array::EXTENSION_ID, array::ARRAY_DISCARD_OP_ID, { + let ccg = self.0.clone(); + move |context, args| { + let arr = args.inputs[0]; + let op = ArrayDiscard::from_extension_op(args.node().as_ref())?; + ccg.emit_array_discard(context, op, arr)?; + args.outputs.finish(context.builder(), []) + } + }) .extension_op(array::EXTENSION_ID, array::ARRAY_REPEAT_OP_ID, { let ccg = self.0.clone(); move |context, args| { @@ -178,41 +278,85 @@ impl CodegenExtension for ArrayCodegenExtension { } } -/// Helper function to allocate an array on the stack. -/// -/// Returns two pointers: The first one is a pointer to the first element of the -/// array (i.e. it is of type `array.get_element_type().ptr_type()`) whereas the -/// second one points to the whole array value, i.e. it is of type `array.ptr_type()`. -fn build_array_alloca<'c>( +fn usize_ty<'c>(ts: &TypingSession<'c, '_>) -> IntType<'c> { + ts.llvm_type(&usize_t()) + .expect("Prelude codegen is registered") + .into_int_type() +} + +/// Returns the LLVM representation of an array value as a fat pointer. +pub fn array_fat_pointer_ty<'c>( + session: &TypingSession<'c, '_>, + elem_ty: BasicTypeEnum<'c>, +) -> StructType<'c> { + let iw_ctx = session.iw_context(); + iw_ctx.struct_type( + &[ + elem_ty.ptr_type(AddressSpace::default()).into(), + usize_ty(session).into(), + ], + false, + ) +} + +/// Constructs an array fat pointer value. +pub fn build_array_fat_pointer<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ptr: PointerValue<'c>, + offset: IntValue<'c>, +) -> Result> { + let array_ty = array_fat_pointer_ty( + &ctx.typing_session(), + ptr.get_type().get_element_type().try_into().unwrap(), + ); + let array_v = array_ty.get_poison(); + let array_v = ctx + .builder() + .build_insert_value(array_v, ptr.as_basic_value_enum(), 0, "")?; + let array_v = ctx + .builder() + .build_insert_value(array_v, offset.as_basic_value_enum(), 1, "")?; + Ok(array_v.into_struct_value()) +} + +/// Returns the underlying pointer and offset stored in a fat array pointer. +pub fn decompose_array_fat_pointer<'c>( builder: &Builder<'c>, - array: ArrayValue<'c>, -) -> Result<(PointerValue<'c>, PointerValue<'c>), BuilderError> { - let array_ty = array.get_type(); - let array_len: IntValue<'c> = { - let ctx = builder.get_insert_block().unwrap().get_context(); - ctx.i32_type().const_int(array_ty.len() as u64, false) - }; - let ptr = builder.build_array_alloca(array_ty.get_element_type(), array_len, "")?; - let array_ptr = builder - .build_bit_cast(ptr, array_ty.ptr_type(Default::default()), "")? - .into_pointer_value(); - builder.build_store(array_ptr, array)?; - Result::Ok((ptr, array_ptr)) + array_v: BasicValueEnum<'c>, +) -> Result<(PointerValue<'c>, IntValue<'c>)> { + let array_v = array_v.into_struct_value(); + let array_ptr = builder.build_extract_value(array_v, 0, "array_ptr")?; + let array_offset = builder.build_extract_value(array_v, 1, "array_offset")?; + Ok(( + array_ptr.into_pointer_value(), + array_offset.into_int_value(), + )) } -/// Helper function to allocate an array on the stack and pass a pointer to it -/// to a closure. +/// Helper function to allocate a fat array pointer. /// -/// The pointer forwarded to the closure is a pointer to the first element of -/// the array. I.e. it is of type `array.get_element_type().ptr_type()` not -/// `array.ptr_type()` -fn with_array_alloca<'c, T, E: From>( - builder: &Builder<'c>, - array: ArrayValue<'c>, - go: impl FnOnce(PointerValue<'c>) -> Result, -) -> Result { - let (ptr, _) = build_array_alloca(builder, array)?; - go(ptr) +/// Returns a pointer and a struct: The pointer points to the first element of the array (i.e. it +/// is of type `elem_ty.ptr_type()`). The struct is the fat pointer of the that stores an additional +/// offset (initialised to be 0). +fn build_array_alloc<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, + ccg: &impl ArrayCodegen, + elem_ty: BasicTypeEnum<'c>, + size: u64, +) -> Result<(PointerValue<'c>, StructValue<'c>)> { + let usize_t = usize_ty(&ctx.typing_session()); + let length = usize_t.const_int(size, false); + let size_value = ctx + .builder() + .build_int_mul(length, elem_ty.size_of().unwrap(), "")?; + let ptr = ccg.emit_allocate_array(ctx, size_value)?; + let elem_ptr = ctx + .builder() + .build_bit_cast(ptr, elem_ty.ptr_type(AddressSpace::default()), "")? + .into_pointer_value(); + let offset = usize_t.const_zero(); + let array_v = build_array_fat_pointer(ctx, elem_ptr, offset)?; + Ok((elem_ptr, array_v)) } /// Helper function to build a loop that repeats for a given number of iterations. @@ -225,7 +369,7 @@ fn build_loop<'c, T, H: HugrView>( go: impl FnOnce(&mut EmitFuncContext<'c, '_, H>, IntValue<'c>) -> Result, ) -> Result { let builder = ctx.builder(); - let idx_ty = ctx.iw_context().i32_type(); + let idx_ty = usize_ty(&ctx.typing_session()); let idx_ptr = builder.build_alloca(idx_ty, "")?; builder.build_store(idx_ptr, idx_ty.const_zero())?; @@ -257,31 +401,26 @@ fn build_loop<'c, T, H: HugrView>( Ok(val) } +/// Emits an [array::ArrayValue]. pub fn emit_array_value<'c, H: HugrView>( ccg: &impl ArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, value: &array::ArrayValue, ) -> Result> { let ts = ctx.typing_session(); - let llvm_array_ty = ccg - .array_type( - &ts, - ts.llvm_type(value.get_element_type())?, - value.get_contents().len() as u64, - ) - .as_basic_type_enum() - .into_array_type(); - let mut array_v = llvm_array_ty.get_undef(); + let elem_ty = ts.llvm_type(value.get_element_type())?; + let (elem_ptr, array_v) = + build_array_alloc(ctx, ccg, elem_ty, value.get_contents().len() as u64)?; for (i, v) in value.get_contents().iter().enumerate() { let llvm_v = emit_value(ctx, v)?; - array_v = ctx - .builder() - .build_insert_value(array_v, llvm_v, i as u32, "")? - .into_array_value(); + let idx = ts.iw_context().i32_type().const_int(i as u64, true); + let elem_addr = unsafe { ctx.builder().build_in_bounds_gep(elem_ptr, &[idx], "")? }; + ctx.builder().build_store(elem_addr, llvm_v)?; } Ok(array_v.into()) } +/// Emits an [ArrayOp]. pub fn emit_array_op<'c, H: HugrView>( ccg: &impl ArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, @@ -299,28 +438,26 @@ pub fn emit_array_op<'c, H: HugrView>( .into_owned(); let ArrayOp { def, - ref elem_ty, + elem_ty: ref hugr_elem_ty, size, } = op; - let llvm_array_ty = ccg - .array_type(&ts, ts.llvm_type(elem_ty)?, size) - .as_basic_type_enum() - .into_array_type(); + let elem_ty = ts.llvm_type(hugr_elem_ty)?; match def { ArrayOpDef::new_array => { - let mut array_v = llvm_array_ty.get_undef(); + let (elem_ptr, array_v) = build_array_alloc(ctx, ccg, elem_ty, size)?; + let usize_t = usize_ty(&ctx.typing_session()); for (i, v) in inputs.into_iter().enumerate() { - array_v = builder - .build_insert_value(array_v, v, i as u32, "")? - .into_array_value(); + let idx = usize_t.const_int(i as u64, true); + let elem_addr = unsafe { ctx.builder().build_in_bounds_gep(elem_ptr, &[idx], "")? }; + ctx.builder().build_store(elem_addr, v)?; } - outputs.finish(builder, [array_v.as_basic_value_enum()]) + outputs.finish(ctx.builder(), [array_v.into()]) } ArrayOpDef::get => { let [array_v, index_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::get expects two arguments"))?; - let array_v = array_v.into_array_value(); + let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?; let index_v = index_v.into_int_value(); let res_hugr_ty = sig .output() @@ -334,7 +471,7 @@ pub fn emit_array_op<'c, H: HugrView>( ts.llvm_sum_type(st.clone())? }; - let exit_rmb = ctx.new_row_mail_box([res_hugr_ty], "")?; + let exit_rmb = ctx.new_row_mail_box(sig.output.iter(), "")?; let exit_block = ctx.build_positioned_new_block("", None, |ctx, bb| { outputs.finish(ctx.builder(), exit_rmb.read_vec(ctx.builder(), [])?)?; @@ -344,15 +481,13 @@ pub fn emit_array_op<'c, H: HugrView>( let success_block = ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| { let builder = ctx.builder(); - let elem_v = with_array_alloca(builder, array_v, |ptr| { - // inside `success_block` we know `index_v` to be in - // bounds. - let elem_addr = - unsafe { builder.build_in_bounds_gep(ptr, &[index_v], "")? }; - builder.build_load(elem_addr, "") - })?; + // inside `success_block` we know `index_v` to be in bounds + let index_v = builder.build_int_add(index_v, array_offset, "")?; + let elem_addr = + unsafe { builder.build_in_bounds_gep(array_ptr, &[index_v], "")? }; + let elem_v = builder.build_load(elem_addr, "")?; let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v])?; - exit_rmb.write(ctx.builder(), [success_v.into()])?; + exit_rmb.write(ctx.builder(), [success_v.into(), array_v])?; builder.build_unconditional_branch(exit_block)?; Ok(bb) })?; @@ -361,7 +496,7 @@ pub fn emit_array_op<'c, H: HugrView>( ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| { let builder = ctx.builder(); let failure_v = res_sum_ty.build_tag(builder, 0, vec![])?; - exit_rmb.write(ctx.builder(), [failure_v.into()])?; + exit_rmb.write(ctx.builder(), [failure_v.into(), array_v])?; builder.build_unconditional_branch(exit_block)?; Ok(bb) })?; @@ -379,10 +514,10 @@ pub fn emit_array_op<'c, H: HugrView>( Ok(()) } ArrayOpDef::set => { - let [array_v0, index_v, value_v] = inputs + let [array_v, index_v, value_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::set expects three arguments"))?; - let array_v = array_v0.into_array_value(); + let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?; let index_v = index_v.into_int_value(); let res_hugr_ty = sig @@ -407,23 +542,12 @@ pub fn emit_array_op<'c, H: HugrView>( let success_block = ctx.build_positioned_new_block("", Some(exit_block), |ctx, bb| { let builder = ctx.builder(); - let (elem_v, array_v) = with_array_alloca(builder, array_v, |ptr| { - // inside `success_block` we know `index_v` to be in - // bounds. - let elem_addr = - unsafe { builder.build_in_bounds_gep(ptr, &[index_v], "")? }; - let elem_v = builder.build_load(elem_addr, "")?; - builder.build_store(elem_addr, value_v)?; - let ptr = builder - .build_bit_cast( - ptr, - array_v.get_type().ptr_type(Default::default()), - "", - )? - .into_pointer_value(); - let array_v = builder.build_load(ptr, "")?; - Ok((elem_v, array_v)) - })?; + // inside `success_block` we know `index_v` to be in bounds. + let index_v = builder.build_int_add(index_v, array_offset, "")?; + let elem_addr = + unsafe { builder.build_in_bounds_gep(array_ptr, &[index_v], "")? }; + let elem_v = builder.build_load(elem_addr, "")?; + builder.build_store(elem_addr, value_v)?; let success_v = res_sum_ty.build_tag(builder, 1, vec![elem_v, array_v])?; exit_rmb.write(ctx.builder(), [success_v.into()])?; builder.build_unconditional_branch(exit_block)?; @@ -433,8 +557,7 @@ pub fn emit_array_op<'c, H: HugrView>( let failure_block = ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| { let builder = ctx.builder(); - let failure_v = - res_sum_ty.build_tag(builder, 0, vec![value_v, array_v.into()])?; + let failure_v = res_sum_ty.build_tag(builder, 0, vec![value_v, array_v])?; exit_rmb.write(ctx.builder(), [failure_v.into()])?; builder.build_unconditional_branch(exit_block)?; Ok(bb) @@ -452,10 +575,10 @@ pub fn emit_array_op<'c, H: HugrView>( Ok(()) } ArrayOpDef::swap => { - let [array_v0, index1_v, index2_v] = inputs + let [array_v, index1_v, index2_v] = inputs .try_into() .map_err(|_| anyhow!("ArrayOpDef::swap expects three arguments"))?; - let array_v = array_v0.into_array_value(); + let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?; let index1_v = index1_v.into_int_value(); let index2_v = index2_v.into_int_value(); @@ -488,26 +611,18 @@ pub fn emit_array_op<'c, H: HugrView>( // the cost of worse code in cases where it cannot. // For now we choose the simpler option of omitting the check. let builder = ctx.builder(); - let array_v = with_array_alloca(builder, array_v, |ptr| { - // inside `success_block` we know `index1_v` and `index2_v` - // to be in bounds. - let elem1_addr = - unsafe { builder.build_in_bounds_gep(ptr, &[index1_v], "")? }; - let elem1_v = builder.build_load(elem1_addr, "")?; - let elem2_addr = - unsafe { builder.build_in_bounds_gep(ptr, &[index2_v], "")? }; - let elem2_v = builder.build_load(elem2_addr, "")?; - builder.build_store(elem1_addr, elem2_v)?; - builder.build_store(elem2_addr, elem1_v)?; - let ptr = builder - .build_bit_cast( - ptr, - array_v.get_type().ptr_type(Default::default()), - "", - )? - .into_pointer_value(); - builder.build_load(ptr, "") - })?; + // inside `success_block` we know `index1_v` and `index2_v` + // to be in bounds. + let index1_v = builder.build_int_add(index1_v, array_offset, "")?; + let index2_v = builder.build_int_add(index2_v, array_offset, "")?; + let elem1_addr = + unsafe { builder.build_in_bounds_gep(array_ptr, &[index1_v], "")? }; + let elem1_v = builder.build_load(elem1_addr, "")?; + let elem2_addr = + unsafe { builder.build_in_bounds_gep(array_ptr, &[index2_v], "")? }; + let elem2_v = builder.build_load(elem2_addr, "")?; + builder.build_store(elem1_addr, elem2_v)?; + builder.build_store(elem2_addr, elem1_v)?; let success_v = res_sum_ty.build_tag(builder, 1, vec![array_v])?; exit_rmb.write(ctx.builder(), [success_v.into()])?; builder.build_unconditional_branch(exit_block)?; @@ -517,7 +632,7 @@ pub fn emit_array_op<'c, H: HugrView>( let failure_block = ctx.build_positioned_new_block("", Some(success_block), |ctx, bb| { let builder = ctx.builder(); - let failure_v = res_sum_ty.build_tag(builder, 0, vec![array_v.into()])?; + let failure_v = res_sum_ty.build_tag(builder, 0, vec![array_v])?; exit_rmb.write(ctx.builder(), [failure_v.into()])?; builder.build_unconditional_branch(exit_block)?; Ok(bb) @@ -548,11 +663,10 @@ pub fn emit_array_op<'c, H: HugrView>( .try_into() .map_err(|_| anyhow!("ArrayOpDef::pop_left expects one argument"))?; let r = emit_pop_op( - builder, - &ts, - elem_ty.clone(), + ctx, + hugr_elem_ty.clone(), size, - array_v.into_array_value(), + array_v.into_struct_value(), true, )?; outputs.finish(ctx.builder(), [r]) @@ -562,29 +676,95 @@ pub fn emit_array_op<'c, H: HugrView>( .try_into() .map_err(|_| anyhow!("ArrayOpDef::pop_right expects one argument"))?; let r = emit_pop_op( - builder, - &ts, - elem_ty.clone(), + ctx, + hugr_elem_ty.clone(), size, - array_v.into_array_value(), + array_v.into_struct_value(), false, )?; outputs.finish(ctx.builder(), [r]) } - ArrayOpDef::discard_empty => Ok(()), + ArrayOpDef::discard_empty => { + let [array_v] = inputs + .try_into() + .map_err(|_| anyhow!("ArrayOpDef::discard_empty expects one argument"))?; + let (ptr, _) = decompose_array_fat_pointer(builder, array_v)?; + ccg.emit_free_array(ctx, ptr)?; + outputs.finish(ctx.builder(), []) + } _ => todo!(), } } -/// Helper function to emit the pop operations. -fn emit_pop_op<'c>( - builder: &Builder<'c>, - ts: &TypingSession<'c, '_>, +/// Emits an [ArrayClone] op. +pub fn emit_clone_op<'c, H: HugrView>( + ccg: &impl ArrayCodegen, + ctx: &mut EmitFuncContext<'c, '_, H>, + op: ArrayClone, + array_v: BasicValueEnum<'c>, +) -> Result<(BasicValueEnum<'c>, BasicValueEnum<'c>)> { + let elem_ty = ctx.llvm_type(&op.elem_ty)?; + let (array_ptr, array_offset) = decompose_array_fat_pointer(ctx.builder(), array_v)?; + let (other_ptr, other_array_v) = build_array_alloc(ctx, ccg, elem_ty, op.size)?; + let src_ptr = unsafe { + ctx.builder() + .build_in_bounds_gep(array_ptr, &[array_offset], "")? + }; + let length = usize_ty(&ctx.typing_session()).const_int(op.size, false); + let size_value = ctx + .builder() + .build_int_mul(length, elem_ty.size_of().unwrap(), "")?; + let is_volatile = ctx.iw_context().bool_type().const_zero(); + + let memcpy_intrinsic = Intrinsic::find("llvm.memcpy").unwrap(); + let memcpy = memcpy_intrinsic + .get_declaration( + ctx.get_current_module(), + &[ + other_ptr.get_type().into(), + src_ptr.get_type().into(), + size_value.get_type().into(), + ], + ) + .unwrap(); + ctx.builder().build_call( + memcpy, + &[ + other_ptr.into(), + src_ptr.into(), + size_value.into(), + is_volatile.into(), + ], + "", + )?; + Ok((array_v, other_array_v.into())) +} + +/// Emits an [ArrayDiscard] op. +pub fn emit_array_discard<'c, H: HugrView>( + ccg: &impl ArrayCodegen, + ctx: &mut EmitFuncContext<'c, '_, H>, + _op: ArrayDiscard, + array_v: BasicValueEnum<'c>, +) -> Result<()> { + let array_ptr = + ctx.builder() + .build_extract_value(array_v.into_struct_value(), 0, "array_ptr")?; + ccg.emit_free_array(ctx, array_ptr.into_pointer_value())?; + Ok(()) +} + +/// Emits the [ArrayOpDef::pop_left] and [ArrayOpDef::pop_right] operations. +fn emit_pop_op<'c, H: HugrView>( + ctx: &mut EmitFuncContext<'c, '_, H>, elem_ty: HugrType, size: u64, - array_v: ArrayValue<'c>, + array_v: StructValue<'c>, pop_left: bool, ) -> Result> { + let ts = ctx.typing_session(); + let builder = ctx.builder(); + let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v.into())?; let ret_ty = ts.llvm_sum_type(option_type(vec![ elem_ty.clone(), array_type(size.saturating_add_signed(-1), elem_ty), @@ -592,44 +772,43 @@ fn emit_pop_op<'c>( if size == 0 { return Ok(ret_ty.build_tag(builder, 0, vec![])?.into()); } - let ctx = builder.get_insert_block().unwrap().get_context(); - let (elem_v, array_v) = with_array_alloca(builder, array_v, |ptr| { - let (elem_ptr, ptr) = { - if pop_left { - let rest_ptr = - unsafe { builder.build_gep(ptr, &[ctx.i32_type().const_int(1, false)], "") }?; - (ptr, rest_ptr) - } else { - let elem_ptr = unsafe { - builder.build_gep(ptr, &[ctx.i32_type().const_int(size - 1, false)], "") - }?; - (elem_ptr, ptr) - } - }; - let elem_v = builder.build_load(elem_ptr, "")?; - let new_array_ty = array_v - .get_type() - .get_element_type() - .array_type(size as u32 - 1); - let ptr = builder - .build_bit_cast(ptr, new_array_ty.ptr_type(Default::default()), "")? - .into_pointer_value(); - let array_v = builder.build_load(ptr, "")?; - Ok((elem_v, array_v)) - })?; - Ok(ret_ty.build_tag(builder, 1, vec![elem_v, array_v])?.into()) + let (elem_ptr, new_array_offset) = { + if pop_left { + let new_array_offset = builder.build_int_add( + array_offset, + usize_ty(&ts).const_int(1, false), + "new_offset", + )?; + let elem_ptr = unsafe { builder.build_in_bounds_gep(array_ptr, &[array_offset], "") }?; + (elem_ptr, new_array_offset) + } else { + let idx = builder.build_int_add( + array_offset, + usize_ty(&ts).const_int(size - 1, false), + "", + )?; + let elem_ptr = unsafe { builder.build_in_bounds_gep(array_ptr, &[idx], "") }?; + (elem_ptr, array_offset) + } + }; + let elem_v = builder.build_load(elem_ptr, "")?; + let new_array_v = build_array_fat_pointer(ctx, array_ptr, new_array_offset)?; + + Ok(ret_ty + .build_tag(ctx.builder(), 1, vec![elem_v, new_array_v.into()])? + .into()) } /// Emits an [ArrayRepeat] op. pub fn emit_repeat_op<'c, H: HugrView>( + ccg: &impl ArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, op: ArrayRepeat, func: BasicValueEnum<'c>, ) -> Result> { - let builder = ctx.builder(); - let array_len = ctx.iw_context().i32_type().const_int(op.size, false); - let array_ty = ctx.llvm_type(&op.elem_ty)?.array_type(op.size as u32); - let (ptr, array_ptr) = build_array_alloca(builder, array_ty.get_undef())?; + let elem_ty = ctx.llvm_type(&op.elem_ty)?; + let (ptr, array_v) = build_array_alloc(ctx, ccg, elem_ty, op.size)?; + let array_len = usize_ty(&ctx.typing_session()).const_int(op.size, false); build_loop(ctx, array_len, |ctx, idx| { let builder = ctx.builder(); let func_ptr = CallableValue::try_from(func.into_pointer_value()) @@ -643,30 +822,32 @@ pub fn emit_repeat_op<'c, H: HugrView>( builder.build_store(elem_addr, v)?; Ok(()) })?; - - let builder = ctx.builder(); - let array_v = builder.build_load(array_ptr, "")?; - Ok(array_v) + Ok(array_v.into()) } /// Emits an [ArrayScan] op. /// /// Returns the resulting array and the final values of the accumulators. pub fn emit_scan_op<'c, H: HugrView>( + ccg: &impl ArrayCodegen, ctx: &mut EmitFuncContext<'c, '_, H>, op: ArrayScan, - src_array: BasicValueEnum<'c>, + src_array_v: StructValue<'c>, func: BasicValueEnum<'c>, initial_accs: &[BasicValueEnum<'c>], ) -> Result<(BasicValueEnum<'c>, Vec>)> { + let (src_ptr, src_offset) = decompose_array_fat_pointer(ctx.builder(), src_array_v.into())?; + let tgt_elem_ty = ctx.llvm_type(&op.tgt_ty)?; + // TODO: If `sizeof(op.src_ty) >= sizeof(op.tgt_ty)`, we could reuse the memory + // from `src` instead of allocating a fresh array + let (tgt_ptr, tgt_array_v) = build_array_alloc(ctx, ccg, tgt_elem_ty, op.size)?; + let array_len = usize_ty(&ctx.typing_session()).const_int(op.size, false); + let acc_tys: Vec<_> = op + .acc_tys + .iter() + .map(|ty| ctx.llvm_type(ty)) + .try_collect()?; let builder = ctx.builder(); - let ts = ctx.typing_session(); - let array_len = ctx.iw_context().i32_type().const_int(op.size, false); - let tgt_array_ty = ts.llvm_type(&op.tgt_ty)?.array_type(op.size as u32); - let (src_ptr, _) = build_array_alloca(builder, src_array.into_array_value())?; - let (tgt_ptr, tgt_array_ptr) = build_array_alloca(builder, tgt_array_ty.get_undef())?; - - let acc_tys: Vec<_> = op.acc_tys.iter().map(|ty| ts.llvm_type(ty)).try_collect()?; let acc_ptrs: Vec<_> = acc_tys .iter() .map(|ty| builder.build_alloca(*ty, "")) @@ -679,7 +860,8 @@ pub fn emit_scan_op<'c, H: HugrView>( let builder = ctx.builder(); let func_ptr = CallableValue::try_from(func.into_pointer_value()) .map_err(|_| anyhow!("ArrayOpDef::scan expects a function pointer"))?; - let src_elem_addr = unsafe { builder.build_in_bounds_gep(src_ptr, &[idx], "")? }; + let src_idx = builder.build_int_add(idx, src_offset, "")?; + let src_elem_addr = unsafe { builder.build_in_bounds_gep(src_ptr, &[src_idx], "")? }; let src_elem = builder.build_load(src_elem_addr, "")?; let mut args = vec![src_elem.into()]; for ptr in acc_ptrs.iter() { @@ -695,13 +877,13 @@ pub fn emit_scan_op<'c, H: HugrView>( Ok(()) })?; + ccg.emit_free_array(ctx, src_ptr)?; let builder = ctx.builder(); - let tgt_array_v = builder.build_load(tgt_array_ptr, "")?; let final_accs = acc_ptrs .into_iter() .map(|ptr| builder.build_load(ptr, "")) .try_collect()?; - Ok((tgt_array_v, final_accs)) + Ok((tgt_array_v.into(), final_accs)) } #[cfg(test)] @@ -709,6 +891,7 @@ mod test { use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; use hugr_core::ops::Tag; + use hugr_core::std_extensions::collections::array::op_builder::build_all_array_ops; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; use hugr_core::std_extensions::STD_REG; use hugr_core::types::Type; @@ -724,7 +907,6 @@ mod test { int_ops::{self}, int_types::{self, int_type, ConstInt}, }, - collections::array::ArrayOpBuilder, logic, }, type_row, @@ -737,66 +919,15 @@ mod test { check_emission, emit::test::SimpleHugrConfig, test::{exec_ctx, llvm_ctx, TestContext}, - utils::{IntOpBuilder, LogicOpBuilder}, + utils::{ArrayOpBuilder, IntOpBuilder, LogicOpBuilder}, }; - /// Build all array ops - /// Copied from `hugr_core::std_extensions::collections::array::builder::test` - fn all_array_ops(mut builder: B) -> B { - let us0 = builder.add_load_value(ConstUsize::new(0)); - let us1 = builder.add_load_value(ConstUsize::new(1)); - let us2 = builder.add_load_value(ConstUsize::new(2)); - let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - let [arr] = { - let r = builder.add_array_swap(usize_t(), 2, arr, us0, us1).unwrap(); - let res_sum_ty = { - let array_type = array_type(2, usize_t()); - either_type(array_type.clone(), array_type) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let [elem_0] = { - let r = builder.add_array_get(usize_t(), 2, arr, us0).unwrap(); - builder - .build_unwrap_sum(1, option_type(usize_t()), r) - .unwrap() - }; - - let [_elem_1, arr] = { - let r = builder - .add_array_set(usize_t(), 2, arr, us1, elem_0) - .unwrap(); - let res_sum_ty = { - let row = vec![usize_t(), array_type(2, usize_t())]; - either_type(row.clone(), row) - }; - builder.build_unwrap_sum(1, res_sum_ty, r).unwrap() - }; - - let [_elem_left, arr] = { - let r = builder.add_array_pop_left(usize_t(), 2, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(1, usize_t())]), r) - .unwrap() - }; - let [_elem_right, arr] = { - let r = builder.add_array_pop_right(usize_t(), 1, arr).unwrap(); - builder - .build_unwrap_sum(1, option_type(vec![usize_t(), array_type(0, usize_t())]), r) - .unwrap() - }; - - builder.add_array_discard_empty(usize_t(), arr).unwrap(); - builder - } - #[rstest] fn emit_all_ops(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { - all_array_ops(builder.dfg_builder_endo([]).unwrap()) + build_all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() .unwrap(); builder.finish_sub_container().unwrap() @@ -816,7 +947,28 @@ mod test { let us1 = builder.add_load_value(ConstUsize::new(1)); let us2 = builder.add_load_value(ConstUsize::new(2)); let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); - builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); + let (_, arr) = builder.add_array_get(usize_t(), 2, arr, us1).unwrap(); + builder.add_array_discard(usize_t(), 2, arr).unwrap(); + builder.finish_with_outputs([]).unwrap() + }); + llvm_ctx.add_extensions(|cge| { + cge.add_default_prelude_extensions() + .add_default_array_extensions() + }); + check_emission!(hugr, llvm_ctx); + } + + #[rstest] + fn emit_clone(mut llvm_ctx: TestContext) { + let hugr = SimpleHugrConfig::new() + .with_extensions(STD_REG.to_owned()) + .finish(|mut builder| { + let us1 = builder.add_load_value(ConstUsize::new(1)); + let us2 = builder.add_load_value(ConstUsize::new(2)); + let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); + let (arr1, arr2) = builder.add_array_clone(usize_t(), 2, arr).unwrap(); + builder.add_array_discard(usize_t(), 2, arr1).unwrap(); + builder.add_array_discard(usize_t(), 2, arr2).unwrap(); builder.finish_with_outputs([]).unwrap() }); llvm_ctx.add_extensions(|cge| { @@ -872,7 +1024,8 @@ mod test { let us2 = builder.add_load_value(ConstUsize::new(2)); let arr = builder.add_new_array(usize_t(), [us1, us2]).unwrap(); let i = builder.add_load_value(ConstUsize::new(index)); - let get_r = builder.add_array_get(usize_t(), 2, arr, i).unwrap(); + let (get_r, arr) = builder.add_array_get(usize_t(), 2, arr, i).unwrap(); + builder.add_array_discard(usize_t(), 2, arr).unwrap(); let r = { let ot = option_type(usize_t()); let variants = (0..ot.num_variants()) @@ -963,23 +1116,20 @@ mod test { builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap()); let expected_arr_1 = builder.add_load_value(ConstInt::new_u(3, expected_arr[1]).unwrap()); - let [arr_0] = { - let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); - builder - .build_unwrap_sum(1, option_type(int_ty.clone()), r) - .unwrap() - }; - let [arr_1] = { - let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); - builder - .build_unwrap_sum(1, option_type(int_ty.clone()), r) - .unwrap() - }; + let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); + let [arr_0] = builder + .build_unwrap_sum(1, option_type(int_ty.clone()), r) + .unwrap(); + let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); + let [arr_1] = builder + .build_unwrap_sum(1, option_type(int_ty.clone()), r) + .unwrap(); let elem_eq = builder.add_ieq(3, elem, expected_elem).unwrap(); let arr_0_eq = builder.add_ieq(3, arr_0, expected_arr_0).unwrap(); let arr_1_eq = builder.add_ieq(3, arr_1, expected_arr_1).unwrap(); let r = builder.add_and(elem_eq, arr_0_eq).unwrap(); let r = builder.add_and(r, arr_1_eq).unwrap(); + builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); builder.finish_with_outputs([r]).unwrap(); } builder.finish_sub_container().unwrap().out_wire(0) @@ -1074,18 +1224,14 @@ mod test { } conditional.finish_sub_container().unwrap().outputs_arr() }; - let elem_0 = { - let r = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); - builder - .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r) - .unwrap()[0] - }; - let elem_1 = { - let r = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); - builder - .build_unwrap_sum::<1>(1, option_type(int_ty), r) - .unwrap()[0] - }; + let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us0).unwrap(); + let elem_0 = builder + .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r) + .unwrap()[0]; + let (r, arr) = builder.add_array_get(int_ty.clone(), 2, arr, us1).unwrap(); + let elem_1 = builder + .build_unwrap_sum::<1>(1, option_type(int_ty.clone()), r) + .unwrap()[0]; let expected_elem_0 = builder.add_load_value(ConstInt::new_u(3, expected_arr[0]).unwrap()); let elem_0_ok = builder.add_ieq(3, elem_0, expected_elem_0).unwrap(); @@ -1110,6 +1256,7 @@ mod test { .unwrap(); conditional.finish_sub_container().unwrap().out_wire(0) }; + builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { @@ -1122,20 +1269,71 @@ mod test { } #[rstest] - #[case(true, 0, 0)] - #[case(true, 1, 1)] - #[case(true, 2, 3)] - #[case(true, 3, 7)] - #[case(false, 0, 0)] - #[case(false, 1, 4)] - #[case(false, 2, 6)] - #[case(false, 3, 7)] - fn exec_pop( - mut exec_ctx: TestContext, - #[case] from_left: bool, - #[case] num: usize, - #[case] expected: u64, - ) { + #[case(0, 5)] + #[case(1, 5)] + fn exec_clone(mut exec_ctx: TestContext, #[case] index: u64, #[case] new_v: u64) { + // We build a HUGR that: + // - Creates an array: [1, 2] + // - Clones the array + // - Mutates the original at the given index + // - Returns the unchanged element of the cloned array + + let int_ty = int_type(3); + let arr_ty = array_type(2, int_ty.clone()); + let hugr = SimpleHugrConfig::new() + .with_outs(int_ty.clone()) + .with_extensions(exec_registry()) + .finish(|mut builder| { + let idx = builder.add_load_value(ConstUsize::new(index)); + let i1 = builder.add_load_value(ConstInt::new_u(3, 1).unwrap()); + let i2 = builder.add_load_value(ConstInt::new_u(3, 2).unwrap()); + let inew = builder.add_load_value(ConstInt::new_u(3, new_v).unwrap()); + let arr = builder.add_new_array(int_ty.clone(), [i1, i2]).unwrap(); + + let (arr, arr_clone) = builder.add_array_clone(int_ty.clone(), 2, arr).unwrap(); + let r = builder + .add_array_set(int_ty.clone(), 2, arr, idx, inew) + .unwrap(); + let [_, arr] = builder + .build_unwrap_sum( + 1, + either_type( + vec![int_ty.clone(), arr_ty.clone()], + vec![int_ty.clone(), arr_ty.clone()], + ), + r, + ) + .unwrap(); + let (r, arr_clone) = builder + .add_array_get(int_ty.clone(), 2, arr_clone, idx) + .unwrap(); + let [elem] = builder + .build_unwrap_sum(1, option_type(int_ty.clone()), r) + .unwrap(); + builder.add_array_discard(int_ty.clone(), 2, arr).unwrap(); + builder + .add_array_discard(int_ty.clone(), 2, arr_clone) + .unwrap(); + builder.finish_with_outputs([elem]).unwrap() + }); + exec_ctx.add_extensions(|cge| { + cge.add_default_prelude_extensions() + .add_default_array_extensions() + .add_default_int_extensions() + .add_logic_extensions() + }); + assert_eq!([1, 2][index as usize], exec_ctx.exec_hugr_u64(hugr, "main")); + } + + #[rstest] + #[case(&[], 0)] + #[case(&[true], 1)] + #[case(&[false], 4)] + #[case(&[true, true], 3)] + #[case(&[false, false], 6)] + #[case(&[true, false, true], 7)] + #[case(&[false, true, false], 7)] + fn exec_pop(mut exec_ctx: TestContext, #[case] from_left: &[bool], #[case] expected: u64) { // We build a HUGR that: // - Creates an array: [1,2,4] // - Pops `num` elements from the left or right @@ -1155,9 +1353,9 @@ mod test { let mut arr = builder .add_new_array(int_ty.clone(), new_array_args) .unwrap(); - for i in 0..num { + for (i, left) in from_left.iter().enumerate() { let array_size = (array_contents.len() - i) as u64; - let pop_res = if from_left { + let pop_res = if *left { builder .add_array_pop_left(int_ty.clone(), array_size, arr) .unwrap() @@ -1179,6 +1377,13 @@ mod test { arr = new_arr; r = builder.add_iadd(6, r, elem).unwrap(); } + builder + .add_array_discard( + int_ty.clone(), + (array_contents.len() - from_left.len()) as u64, + arr, + ) + .unwrap(); builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { @@ -1223,12 +1428,15 @@ mod test { .unwrap() .out_wire(0); let idx_v = builder.add_load_value(ConstUsize::new(idx)); - let get_res = builder + let (get_res, arr) = builder .add_array_get(int_ty.clone(), size, arr, idx_v) .unwrap(); let [elem] = builder .build_unwrap_sum(1, option_type(vec![int_ty.clone()]), get_res) .unwrap(); + builder + .add_array_discard(int_ty.clone(), size, arr) + .unwrap(); builder.finish_with_outputs([elem]).unwrap() }); exec_ctx.add_extensions(|cge| { @@ -1297,6 +1505,9 @@ mod test { arr = new_arr; r = builder.add_iadd(6, r, elem).unwrap(); } + builder + .add_array_discard_empty(int_ty.clone(), arr) + .unwrap(); builder.finish_with_outputs([r]).unwrap() }); exec_ctx.add_extensions(|cge| { @@ -1348,10 +1559,11 @@ mod test { let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); let scan = ArrayScan::new(int_ty.clone(), Type::UNIT, vec![int_ty.clone()], size); let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); - let sum = builder + let [arr, sum] = builder .add_dataflow_op(scan, [arr, func_v, zero]) .unwrap() - .out_wire(1); + .outputs_arr(); + builder.add_array_discard(Type::UNIT, size, arr).unwrap(); builder.finish_with_outputs([sum]).unwrap() }); exec_ctx.add_extensions(|cge| { diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap index bc9aa19c6..3f4d8cecb 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@llvm14.snap @@ -21,221 +21,235 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = insertvalue [2 x i64] undef, i64 1, 0 - %1 = insertvalue [2 x i64] %0, i64 2, 1 - %2 = icmp ult i64 0, 2 - %3 = icmp ult i64 1, 2 - %4 = and i1 %2, %3 - br i1 %4, label %7, label %5 - -5: ; preds = %entry_block - %6 = insertvalue { i1, [2 x i64] } { i1 false, [2 x i64] poison }, [2 x i64] %1, 1 - br label %17 - -7: ; preds = %entry_block - %8 = alloca i64, i32 2, align 8 - %9 = bitcast i64* %8 to [2 x i64]* - store [2 x i64] %1, [2 x i64]* %9, align 4 - %10 = getelementptr inbounds i64, i64* %8, i64 0 - %11 = load i64, i64* %10, align 4 - %12 = getelementptr inbounds i64, i64* %8, i64 1 - %13 = load i64, i64* %12, align 4 - store i64 %13, i64* %10, align 4 - store i64 %11, i64* %12, align 4 - %14 = bitcast i64* %8 to [2 x i64]* - %15 = load [2 x i64], [2 x i64]* %14, align 4 - %16 = insertvalue { i1, [2 x i64] } { i1 true, [2 x i64] poison }, [2 x i64] %15, 1 - br label %17 - -17: ; preds = %5, %7 - %"0.0" = phi { i1, [2 x i64] } [ %16, %7 ], [ %6, %5 ] - %18 = extractvalue { i1, [2 x i64] } %"0.0", 0 - switch i1 %18, label %19 [ - i1 true, label %21 + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 1, i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 2, i64* %5, align 4 + %array_ptr = extractvalue { i64*, i64 } %3, 0 + %array_offset = extractvalue { i64*, i64 } %3, 1 + %6 = icmp ult i64 0, 2 + %7 = icmp ult i64 1, 2 + %8 = and i1 %6, %7 + br i1 %8, label %11, label %9 + +9: ; preds = %entry_block + %10 = insertvalue { i1, { i64*, i64 } } { i1 false, { i64*, i64 } poison }, { i64*, i64 } %3, 1 + br label %19 + +11: ; preds = %entry_block + %12 = add i64 0, %array_offset + %13 = add i64 1, %array_offset + %14 = getelementptr inbounds i64, i64* %array_ptr, i64 %12 + %15 = load i64, i64* %14, align 4 + %16 = getelementptr inbounds i64, i64* %array_ptr, i64 %13 + %17 = load i64, i64* %16, align 4 + store i64 %17, i64* %14, align 4 + store i64 %15, i64* %16, align 4 + %18 = insertvalue { i1, { i64*, i64 } } { i1 true, { i64*, i64 } poison }, { i64*, i64 } %3, 1 + br label %19 + +19: ; preds = %9, %11 + %"0.0" = phi { i1, { i64*, i64 } } [ %18, %11 ], [ %10, %9 ] + %20 = extractvalue { i1, { i64*, i64 } } %"0.0", 0 + switch i1 %20, label %21 [ + i1 true, label %23 ] -19: ; preds = %17 - %20 = extractvalue { i1, [2 x i64] } %"0.0", 1 +21: ; preds = %19 + %22 = extractvalue { i1, { i64*, i64 } } %"0.0", 1 br label %cond_16_case_0 -21: ; preds = %17 - %22 = extractvalue { i1, [2 x i64] } %"0.0", 1 +23: ; preds = %19 + %24 = extractvalue { i1, { i64*, i64 } } %"0.0", 1 br label %cond_16_case_1 -cond_16_case_0: ; preds = %19 - %23 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @0, i32 0, i32 0) }, 0 - %24 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @0, i32 0, i32 0) }, 1 - %25 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %23, i8* %24) +cond_16_case_0: ; preds = %21 + %25 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @0, i32 0, i32 0) }, 0 + %26 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @0, i32 0, i32 0) }, 1 + %27 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %25, i8* %26) call void @abort() br label %cond_exit_16 -cond_16_case_1: ; preds = %21 +cond_16_case_1: ; preds = %23 br label %cond_exit_16 cond_exit_16: ; preds = %cond_16_case_1, %cond_16_case_0 - %"08.0" = phi [2 x i64] [ zeroinitializer, %cond_16_case_0 ], [ %22, %cond_16_case_1 ] - %26 = icmp ult i64 0, 2 - br i1 %26, label %28, label %27 - -27: ; preds = %cond_exit_16 - br label %34 - -28: ; preds = %cond_exit_16 - %29 = alloca i64, i32 2, align 8 - %30 = bitcast i64* %29 to [2 x i64]* - store [2 x i64] %"08.0", [2 x i64]* %30, align 4 - %31 = getelementptr inbounds i64, i64* %29, i64 0 - %32 = load i64, i64* %31, align 4 - %33 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %32, 1 - br label %34 - -34: ; preds = %27, %28 - %"020.0" = phi { i1, i64 } [ %33, %28 ], [ { i1 false, i64 poison }, %27 ] - %35 = extractvalue { i1, i64 } %"020.0", 0 - switch i1 %35, label %36 [ - i1 true, label %37 + %"08.0" = phi { i64*, i64 } [ zeroinitializer, %cond_16_case_0 ], [ %24, %cond_16_case_1 ] + %array_ptr20 = extractvalue { i64*, i64 } %"08.0", 0 + %array_offset21 = extractvalue { i64*, i64 } %"08.0", 1 + %28 = icmp ult i64 0, 2 + br i1 %28, label %30, label %29 + +29: ; preds = %cond_exit_16 + br label %35 + +30: ; preds = %cond_exit_16 + %31 = add i64 0, %array_offset21 + %32 = getelementptr inbounds i64, i64* %array_ptr20, i64 %31 + %33 = load i64, i64* %32, align 4 + %34 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %33, 1 + br label %35 + +35: ; preds = %29, %30 + %"022.0" = phi { i1, i64 } [ %34, %30 ], [ { i1 false, i64 poison }, %29 ] + %36 = extractvalue { i1, i64 } %"022.0", 0 + switch i1 %36, label %37 [ + i1 true, label %38 ] -36: ; preds = %34 +37: ; preds = %35 br label %cond_28_case_0 -37: ; preds = %34 - %38 = extractvalue { i1, i64 } %"020.0", 1 +38: ; preds = %35 + %39 = extractvalue { i1, i64 } %"022.0", 1 br label %cond_28_case_1 -cond_28_case_0: ; preds = %36 - %39 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0 - %40 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1 - %41 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %39, i8* %40) +cond_28_case_0: ; preds = %37 + %40 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 0 + %41 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, 1 + %42 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %40, i8* %41) call void @abort() br label %cond_exit_28 -cond_28_case_1: ; preds = %37 +cond_28_case_1: ; preds = %38 br label %cond_exit_28 cond_exit_28: ; preds = %cond_28_case_1, %cond_28_case_0 - %"023.0" = phi i64 [ 0, %cond_28_case_0 ], [ %38, %cond_28_case_1 ] - %42 = icmp ult i64 1, 2 - br i1 %42, label %46, label %43 - -43: ; preds = %cond_exit_28 - %44 = insertvalue { i1, i64, [2 x i64] } { i1 false, i64 poison, [2 x i64] poison }, i64 %"023.0", 1 - %45 = insertvalue { i1, i64, [2 x i64] } %44, [2 x i64] %"08.0", 2 - br label %55 - -46: ; preds = %cond_exit_28 - %47 = alloca i64, i32 2, align 8 - %48 = bitcast i64* %47 to [2 x i64]* - store [2 x i64] %"08.0", [2 x i64]* %48, align 4 - %49 = getelementptr inbounds i64, i64* %47, i64 1 + %"026.0" = phi i64 [ 0, %cond_28_case_0 ], [ %39, %cond_28_case_1 ] + %array_ptr36 = extractvalue { i64*, i64 } %"08.0", 0 + %array_offset37 = extractvalue { i64*, i64 } %"08.0", 1 + %43 = icmp ult i64 1, 2 + br i1 %43, label %47, label %44 + +44: ; preds = %cond_exit_28 + %45 = insertvalue { i1, { i64*, i64 }, i64 } { i1 false, { i64*, i64 } poison, i64 poison }, i64 %"026.0", 2 + %46 = insertvalue { i1, { i64*, i64 }, i64 } %45, { i64*, i64 } %"08.0", 1 + br label %53 + +47: ; preds = %cond_exit_28 + %48 = add i64 1, %array_offset37 + %49 = getelementptr inbounds i64, i64* %array_ptr36, i64 %48 %50 = load i64, i64* %49, align 4 - store i64 %"023.0", i64* %49, align 4 - %51 = bitcast i64* %47 to [2 x i64]* - %52 = load [2 x i64], [2 x i64]* %51, align 4 - %53 = insertvalue { i1, i64, [2 x i64] } { i1 true, i64 poison, [2 x i64] poison }, i64 %50, 1 - %54 = insertvalue { i1, i64, [2 x i64] } %53, [2 x i64] %52, 2 - br label %55 - -55: ; preds = %43, %46 - %"033.0" = phi { i1, i64, [2 x i64] } [ %54, %46 ], [ %45, %43 ] - %56 = extractvalue { i1, i64, [2 x i64] } %"033.0", 0 - switch i1 %56, label %57 [ - i1 true, label %60 + store i64 %"026.0", i64* %49, align 4 + %51 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %50, 2 + %52 = insertvalue { i1, { i64*, i64 }, i64 } %51, { i64*, i64 } %"08.0", 1 + br label %53 + +53: ; preds = %44, %47 + %"038.0" = phi { i1, { i64*, i64 }, i64 } [ %52, %47 ], [ %46, %44 ] + %54 = extractvalue { i1, { i64*, i64 }, i64 } %"038.0", 0 + switch i1 %54, label %55 [ + i1 true, label %58 ] -57: ; preds = %55 - %58 = extractvalue { i1, i64, [2 x i64] } %"033.0", 1 - %59 = extractvalue { i1, i64, [2 x i64] } %"033.0", 2 - br label %cond_40_case_0 +55: ; preds = %53 + %56 = extractvalue { i1, { i64*, i64 }, i64 } %"038.0", 2 + %57 = extractvalue { i1, { i64*, i64 }, i64 } %"038.0", 1 + br label %cond_39_case_0 -60: ; preds = %55 - %61 = extractvalue { i1, i64, [2 x i64] } %"033.0", 1 - %62 = extractvalue { i1, i64, [2 x i64] } %"033.0", 2 - br label %cond_40_case_1 +58: ; preds = %53 + %59 = extractvalue { i1, { i64*, i64 }, i64 } %"038.0", 2 + %60 = extractvalue { i1, { i64*, i64 }, i64 } %"038.0", 1 + br label %cond_39_case_1 -cond_40_case_0: ; preds = %57 - %63 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, 0 - %64 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, 1 - %65 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.2, i32 0, i32 0), i32 %63, i8* %64) +cond_39_case_0: ; preds = %55 + %61 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, 0 + %62 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, 1 + %63 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.2, i32 0, i32 0), i32 %61, i8* %62) call void @abort() - br label %cond_exit_40 - -cond_40_case_1: ; preds = %60 - br label %cond_exit_40 - -cond_exit_40: ; preds = %cond_40_case_1, %cond_40_case_0 - %"036.0" = phi i64 [ 0, %cond_40_case_0 ], [ %61, %cond_40_case_1 ] - %"1.0" = phi [2 x i64] [ zeroinitializer, %cond_40_case_0 ], [ %62, %cond_40_case_1 ] - %66 = alloca i64, i32 2, align 8 - %67 = bitcast i64* %66 to [2 x i64]* - store [2 x i64] %"1.0", [2 x i64]* %67, align 4 - %68 = getelementptr i64, i64* %66, i32 1 - %69 = load i64, i64* %66, align 4 - %70 = bitcast i64* %68 to [1 x i64]* - %71 = load [1 x i64], [1 x i64]* %70, align 4 - %72 = insertvalue { i1, i64, [1 x i64] } { i1 true, i64 poison, [1 x i64] poison }, i64 %69, 1 - %73 = insertvalue { i1, i64, [1 x i64] } %72, [1 x i64] %71, 2 - %74 = extractvalue { i1, i64, [1 x i64] } %73, 0 - switch i1 %74, label %75 [ - i1 true, label %76 + br label %cond_exit_39 + +cond_39_case_1: ; preds = %58 + br label %cond_exit_39 + +cond_exit_39: ; preds = %cond_39_case_1, %cond_39_case_0 + %"041.0" = phi i64 [ 0, %cond_39_case_0 ], [ %59, %cond_39_case_1 ] + %"142.0" = phi { i64*, i64 } [ zeroinitializer, %cond_39_case_0 ], [ %60, %cond_39_case_1 ] + %array_ptr61 = extractvalue { i64*, i64 } %"142.0", 0 + %array_offset62 = extractvalue { i64*, i64 } %"142.0", 1 + %new_offset = add i64 %array_offset62, 1 + %64 = getelementptr inbounds i64, i64* %array_ptr61, i64 %array_offset62 + %65 = load i64, i64* %64, align 4 + %66 = insertvalue { i64*, i64 } poison, i64* %array_ptr61, 0 + %67 = insertvalue { i64*, i64 } %66, i64 %new_offset, 1 + %68 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %65, 2 + %69 = insertvalue { i1, { i64*, i64 }, i64 } %68, { i64*, i64 } %67, 1 + %70 = extractvalue { i1, { i64*, i64 }, i64 } %69, 0 + switch i1 %70, label %71 [ + i1 true, label %72 ] -75: ; preds = %cond_exit_40 - br label %cond_51_case_0 +71: ; preds = %cond_exit_39 + br label %cond_50_case_0 -76: ; preds = %cond_exit_40 - %77 = extractvalue { i1, i64, [1 x i64] } %73, 1 - %78 = extractvalue { i1, i64, [1 x i64] } %73, 2 - br label %cond_51_case_1 +72: ; preds = %cond_exit_39 + %73 = extractvalue { i1, { i64*, i64 }, i64 } %69, 2 + %74 = extractvalue { i1, { i64*, i64 }, i64 } %69, 1 + br label %cond_50_case_1 -cond_51_case_0: ; preds = %75 - %79 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, 0 - %80 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, 1 - %81 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.3, i32 0, i32 0), i32 %79, i8* %80) +cond_50_case_0: ; preds = %71 + %75 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, 0 + %76 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, 1 + %77 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.3, i32 0, i32 0), i32 %75, i8* %76) call void @abort() - br label %cond_exit_51 - -cond_51_case_1: ; preds = %76 - br label %cond_exit_51 - -cond_exit_51: ; preds = %cond_51_case_1, %cond_51_case_0 - %"056.0" = phi i64 [ 0, %cond_51_case_0 ], [ %77, %cond_51_case_1 ] - %"157.0" = phi [1 x i64] [ zeroinitializer, %cond_51_case_0 ], [ %78, %cond_51_case_1 ] - %82 = alloca i64, align 8 - %83 = bitcast i64* %82 to [1 x i64]* - store [1 x i64] %"157.0", [1 x i64]* %83, align 4 - %84 = getelementptr i64, i64* %82, i32 0 - %85 = load i64, i64* %84, align 4 - %86 = bitcast i64* %82 to [0 x i64]* - %87 = load [0 x i64], [0 x i64]* %86, align 4 - %88 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %85, 1 - %89 = extractvalue { i1, i64 } %88, 0 - switch i1 %89, label %90 [ - i1 true, label %91 + br label %cond_exit_50 + +cond_50_case_1: ; preds = %72 + br label %cond_exit_50 + +cond_exit_50: ; preds = %cond_50_case_1, %cond_50_case_0 + %"064.0" = phi i64 [ 0, %cond_50_case_0 ], [ %73, %cond_50_case_1 ] + %"165.0" = phi { i64*, i64 } [ zeroinitializer, %cond_50_case_0 ], [ %74, %cond_50_case_1 ] + %array_ptr78 = extractvalue { i64*, i64 } %"165.0", 0 + %array_offset79 = extractvalue { i64*, i64 } %"165.0", 1 + %78 = add i64 %array_offset79, 0 + %79 = getelementptr inbounds i64, i64* %array_ptr78, i64 %78 + %80 = load i64, i64* %79, align 4 + %81 = insertvalue { i64*, i64 } poison, i64* %array_ptr78, 0 + %82 = insertvalue { i64*, i64 } %81, i64 %array_offset79, 1 + %83 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %80, 2 + %84 = insertvalue { i1, { i64*, i64 }, i64 } %83, { i64*, i64 } %82, 1 + %85 = extractvalue { i1, { i64*, i64 }, i64 } %84, 0 + switch i1 %85, label %86 [ + i1 true, label %87 ] -90: ; preds = %cond_exit_51 - br label %cond_62_case_0 +86: ; preds = %cond_exit_50 + br label %cond_61_case_0 -91: ; preds = %cond_exit_51 - %92 = extractvalue { i1, i64 } %88, 1 - br label %cond_62_case_1 +87: ; preds = %cond_exit_50 + %88 = extractvalue { i1, { i64*, i64 }, i64 } %84, 2 + %89 = extractvalue { i1, { i64*, i64 }, i64 } %84, 1 + br label %cond_61_case_1 -cond_62_case_0: ; preds = %90 - %93 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, 0 - %94 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, 1 - %95 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.4, i32 0, i32 0), i32 %93, i8* %94) +cond_61_case_0: ; preds = %86 + %90 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, 0 + %91 = extractvalue { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, 1 + %92 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.4, i32 0, i32 0), i32 %90, i8* %91) call void @abort() - br label %cond_exit_62 - -cond_62_case_1: ; preds = %91 - br label %cond_exit_62 - -cond_exit_62: ; preds = %cond_62_case_1, %cond_62_case_0 - %"071.0" = phi i64 [ 0, %cond_62_case_0 ], [ %92, %cond_62_case_1 ] + br label %cond_exit_61 + +cond_61_case_1: ; preds = %87 + br label %cond_exit_61 + +cond_exit_61: ; preds = %cond_61_case_1, %cond_61_case_0 + %"081.0" = phi i64 [ 0, %cond_61_case_0 ], [ %88, %cond_61_case_1 ] + %"182.0" = phi { i64*, i64 } [ zeroinitializer, %cond_61_case_0 ], [ %89, %cond_61_case_1 ] + %array_ptr95 = extractvalue { i64*, i64 } %"182.0", 0 + %array_offset96 = extractvalue { i64*, i64 } %"182.0", 1 + %93 = bitcast i64* %array_ptr95 to i8* + call void @free(i8* %93) ret void } +declare i8* @malloc(i64) + declare i32 @printf(i8*, ...) declare void @abort() + +declare void @free(i8*) diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap index 9b294486d..c4b46a02a 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_all_ops@pre-mem2reg@llvm14.snap @@ -20,67 +20,69 @@ define void @_hl.main.1() { alloca_block: %"12_0" = alloca i64, align 8 %"10_0" = alloca i64, align 8 - %"13_0" = alloca [2 x i64], align 8 + %"13_0" = alloca { i64*, i64 }, align 8 %"8_0" = alloca i64, align 8 - %"14_0" = alloca { i1, [2 x i64] }, align 8 - %"0" = alloca { i1, [2 x i64] }, align 8 - %"16_0" = alloca [2 x i64], align 8 - %"08" = alloca [2 x i64], align 8 - %"010" = alloca [2 x i64], align 8 + %"14_0" = alloca { i1, { i64*, i64 } }, align 8 + %"0" = alloca { i1, { i64*, i64 } }, align 8 + %"16_0" = alloca { i64*, i64 }, align 8 + %"08" = alloca { i64*, i64 }, align 8 + %"010" = alloca { i64*, i64 }, align 8 %"21_0" = alloca { i32, i8* }, align 8 - %"18_0" = alloca [2 x i64], align 8 - %"22_0" = alloca [2 x i64], align 8 - %"015" = alloca [2 x i64], align 8 - %"24_0" = alloca [2 x i64], align 8 + %"18_0" = alloca { i64*, i64 }, align 8 + %"22_0" = alloca { i64*, i64 }, align 8 + %"015" = alloca { i64*, i64 }, align 8 + %"24_0" = alloca { i64*, i64 }, align 8 %"26_0" = alloca { i1, i64 }, align 8 - %"020" = alloca { i1, i64 }, align 8 + %"26_1" = alloca { i64*, i64 }, align 8 + %"022" = alloca { i1, i64 }, align 8 + %"1" = alloca { i64*, i64 }, align 8 %"28_0" = alloca i64, align 8 - %"023" = alloca i64, align 8 + %"026" = alloca i64, align 8 %"33_0" = alloca { i32, i8* }, align 8 %"34_0" = alloca i64, align 8 - %"027" = alloca i64, align 8 + %"030" = alloca i64, align 8 %"36_0" = alloca i64, align 8 - %"38_0" = alloca { i1, i64, [2 x i64] }, align 8 - %"033" = alloca { i1, i64, [2 x i64] }, align 8 - %"40_0" = alloca i64, align 8 - %"40_1" = alloca [2 x i64], align 8 - %"036" = alloca i64, align 8 - %"1" = alloca [2 x i64], align 8 - %"039" = alloca i64, align 8 - %"140" = alloca [2 x i64], align 8 - %"45_0" = alloca { i32, i8* }, align 8 - %"42_0" = alloca i64, align 8 - %"42_1" = alloca [2 x i64], align 8 - %"46_0" = alloca i64, align 8 - %"46_1" = alloca [2 x i64], align 8 - %"048" = alloca i64, align 8 - %"149" = alloca [2 x i64], align 8 - %"48_0" = alloca i64, align 8 - %"48_1" = alloca [2 x i64], align 8 - %"50_0" = alloca { i1, i64, [1 x i64] }, align 8 - %"51_0" = alloca i64, align 8 - %"51_1" = alloca [1 x i64], align 8 - %"056" = alloca i64, align 8 - %"157" = alloca [1 x i64], align 8 - %"56_0" = alloca { i32, i8* }, align 8 - %"57_0" = alloca i64, align 8 - %"57_1" = alloca [1 x i64], align 8 - %"063" = alloca i64, align 8 - %"164" = alloca [1 x i64], align 8 - %"59_0" = alloca i64, align 8 - %"59_1" = alloca [1 x i64], align 8 - %"61_0" = alloca { i1, i64 }, align 8 - %"62_0" = alloca i64, align 8 - %"62_1" = alloca [0 x i64], align 8 + %"38_0" = alloca { i1, { i64*, i64 }, i64 }, align 8 + %"038" = alloca { i1, { i64*, i64 }, i64 }, align 8 + %"39_0" = alloca i64, align 8 + %"39_1" = alloca { i64*, i64 }, align 8 + %"041" = alloca i64, align 8 + %"142" = alloca { i64*, i64 }, align 8 + %"045" = alloca i64, align 8 + %"146" = alloca { i64*, i64 }, align 8 + %"44_0" = alloca { i32, i8* }, align 8 + %"41_0" = alloca i64, align 8 + %"41_1" = alloca { i64*, i64 }, align 8 + %"45_0" = alloca i64, align 8 + %"45_1" = alloca { i64*, i64 }, align 8 + %"054" = alloca i64, align 8 + %"155" = alloca { i64*, i64 }, align 8 + %"47_0" = alloca i64, align 8 + %"47_1" = alloca { i64*, i64 }, align 8 + %"49_0" = alloca { i1, { i64*, i64 }, i64 }, align 8 + %"50_0" = alloca i64, align 8 + %"50_1" = alloca { i64*, i64 }, align 8 + %"064" = alloca i64, align 8 + %"165" = alloca { i64*, i64 }, align 8 + %"55_0" = alloca { i32, i8* }, align 8 + %"56_0" = alloca i64, align 8 + %"56_1" = alloca { i64*, i64 }, align 8 %"071" = alloca i64, align 8 - %"172" = alloca [0 x i64], align 8 - %"67_0" = alloca { i32, i8* }, align 8 - %"68_0" = alloca i64, align 8 - %"68_1" = alloca [0 x i64], align 8 - %"078" = alloca i64, align 8 - %"179" = alloca [0 x i64], align 8 - %"70_0" = alloca i64, align 8 - %"70_1" = alloca [0 x i64], align 8 + %"172" = alloca { i64*, i64 }, align 8 + %"58_0" = alloca i64, align 8 + %"58_1" = alloca { i64*, i64 }, align 8 + %"60_0" = alloca { i1, { i64*, i64 }, i64 }, align 8 + %"61_0" = alloca i64, align 8 + %"61_1" = alloca { i64*, i64 }, align 8 + %"081" = alloca i64, align 8 + %"182" = alloca { i64*, i64 }, align 8 + %"66_0" = alloca { i32, i8* }, align 8 + %"67_0" = alloca i64, align 8 + %"67_1" = alloca { i64*, i64 }, align 8 + %"088" = alloca i64, align 8 + %"189" = alloca { i64*, i64 }, align 8 + %"69_0" = alloca i64, align 8 + %"69_1" = alloca { i64*, i64 }, align 8 br label %entry_block entry_block: ; preds = %alloca_block @@ -88,345 +90,362 @@ entry_block: ; preds = %alloca_block store i64 1, i64* %"10_0", align 4 %"10_01" = load i64, i64* %"10_0", align 4 %"12_02" = load i64, i64* %"12_0", align 4 - %0 = insertvalue [2 x i64] undef, i64 %"10_01", 0 - %1 = insertvalue [2 x i64] %0, i64 %"12_02", 1 - store [2 x i64] %1, [2 x i64]* %"13_0", align 4 + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 %"10_01", i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 %"12_02", i64* %5, align 4 + store { i64*, i64 } %3, { i64*, i64 }* %"13_0", align 8 store i64 0, i64* %"8_0", align 4 - %"13_03" = load [2 x i64], [2 x i64]* %"13_0", align 4 + %"13_03" = load { i64*, i64 }, { i64*, i64 }* %"13_0", align 8 %"8_04" = load i64, i64* %"8_0", align 4 %"10_05" = load i64, i64* %"10_0", align 4 - %2 = icmp ult i64 %"8_04", 2 - %3 = icmp ult i64 %"10_05", 2 - %4 = and i1 %2, %3 - br i1 %4, label %7, label %5 - -5: ; preds = %entry_block - %6 = insertvalue { i1, [2 x i64] } { i1 false, [2 x i64] poison }, [2 x i64] %"13_03", 1 - store { i1, [2 x i64] } %6, { i1, [2 x i64] }* %"0", align 4 - br label %17 - -7: ; preds = %entry_block - %8 = alloca i64, i32 2, align 8 - %9 = bitcast i64* %8 to [2 x i64]* - store [2 x i64] %"13_03", [2 x i64]* %9, align 4 - %10 = getelementptr inbounds i64, i64* %8, i64 %"8_04" - %11 = load i64, i64* %10, align 4 - %12 = getelementptr inbounds i64, i64* %8, i64 %"10_05" - %13 = load i64, i64* %12, align 4 - store i64 %13, i64* %10, align 4 - store i64 %11, i64* %12, align 4 - %14 = bitcast i64* %8 to [2 x i64]* - %15 = load [2 x i64], [2 x i64]* %14, align 4 - %16 = insertvalue { i1, [2 x i64] } { i1 true, [2 x i64] poison }, [2 x i64] %15, 1 - store { i1, [2 x i64] } %16, { i1, [2 x i64] }* %"0", align 4 - br label %17 - -17: ; preds = %5, %7 - %"06" = load { i1, [2 x i64] }, { i1, [2 x i64] }* %"0", align 4 - store { i1, [2 x i64] } %"06", { i1, [2 x i64] }* %"14_0", align 4 - %"14_07" = load { i1, [2 x i64] }, { i1, [2 x i64] }* %"14_0", align 4 - %18 = extractvalue { i1, [2 x i64] } %"14_07", 0 - switch i1 %18, label %19 [ - i1 true, label %21 + %array_ptr = extractvalue { i64*, i64 } %"13_03", 0 + %array_offset = extractvalue { i64*, i64 } %"13_03", 1 + %6 = icmp ult i64 %"8_04", 2 + %7 = icmp ult i64 %"10_05", 2 + %8 = and i1 %6, %7 + br i1 %8, label %11, label %9 + +9: ; preds = %entry_block + %10 = insertvalue { i1, { i64*, i64 } } { i1 false, { i64*, i64 } poison }, { i64*, i64 } %"13_03", 1 + store { i1, { i64*, i64 } } %10, { i1, { i64*, i64 } }* %"0", align 8 + br label %19 + +11: ; preds = %entry_block + %12 = add i64 %"8_04", %array_offset + %13 = add i64 %"10_05", %array_offset + %14 = getelementptr inbounds i64, i64* %array_ptr, i64 %12 + %15 = load i64, i64* %14, align 4 + %16 = getelementptr inbounds i64, i64* %array_ptr, i64 %13 + %17 = load i64, i64* %16, align 4 + store i64 %17, i64* %14, align 4 + store i64 %15, i64* %16, align 4 + %18 = insertvalue { i1, { i64*, i64 } } { i1 true, { i64*, i64 } poison }, { i64*, i64 } %"13_03", 1 + store { i1, { i64*, i64 } } %18, { i1, { i64*, i64 } }* %"0", align 8 + br label %19 + +19: ; preds = %9, %11 + %"06" = load { i1, { i64*, i64 } }, { i1, { i64*, i64 } }* %"0", align 8 + store { i1, { i64*, i64 } } %"06", { i1, { i64*, i64 } }* %"14_0", align 8 + %"14_07" = load { i1, { i64*, i64 } }, { i1, { i64*, i64 } }* %"14_0", align 8 + %20 = extractvalue { i1, { i64*, i64 } } %"14_07", 0 + switch i1 %20, label %21 [ + i1 true, label %23 ] -19: ; preds = %17 - %20 = extractvalue { i1, [2 x i64] } %"14_07", 1 - store [2 x i64] %20, [2 x i64]* %"010", align 4 +21: ; preds = %19 + %22 = extractvalue { i1, { i64*, i64 } } %"14_07", 1 + store { i64*, i64 } %22, { i64*, i64 }* %"010", align 8 br label %cond_16_case_0 -21: ; preds = %17 - %22 = extractvalue { i1, [2 x i64] } %"14_07", 1 - store [2 x i64] %22, [2 x i64]* %"015", align 4 +23: ; preds = %19 + %24 = extractvalue { i1, { i64*, i64 } } %"14_07", 1 + store { i64*, i64 } %24, { i64*, i64 }* %"015", align 8 br label %cond_16_case_1 -cond_16_case_0: ; preds = %19 - %"011" = load [2 x i64], [2 x i64]* %"010", align 4 +cond_16_case_0: ; preds = %21 + %"011" = load { i64*, i64 }, { i64*, i64 }* %"010", align 8 store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @0, i32 0, i32 0) }, { i32, i8* }* %"21_0", align 8 - store [2 x i64] %"011", [2 x i64]* %"18_0", align 4 + store { i64*, i64 } %"011", { i64*, i64 }* %"18_0", align 8 %"21_012" = load { i32, i8* }, { i32, i8* }* %"21_0", align 8 - %"18_013" = load [2 x i64], [2 x i64]* %"18_0", align 4 - %23 = extractvalue { i32, i8* } %"21_012", 0 - %24 = extractvalue { i32, i8* } %"21_012", 1 - %25 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %23, i8* %24) + %"18_013" = load { i64*, i64 }, { i64*, i64 }* %"18_0", align 8 + %25 = extractvalue { i32, i8* } %"21_012", 0 + %26 = extractvalue { i32, i8* } %"21_012", 1 + %27 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template, i32 0, i32 0), i32 %25, i8* %26) call void @abort() - store [2 x i64] zeroinitializer, [2 x i64]* %"22_0", align 4 - %"22_014" = load [2 x i64], [2 x i64]* %"22_0", align 4 - store [2 x i64] %"22_014", [2 x i64]* %"08", align 4 + store { i64*, i64 } zeroinitializer, { i64*, i64 }* %"22_0", align 8 + %"22_014" = load { i64*, i64 }, { i64*, i64 }* %"22_0", align 8 + store { i64*, i64 } %"22_014", { i64*, i64 }* %"08", align 8 br label %cond_exit_16 -cond_16_case_1: ; preds = %21 - %"016" = load [2 x i64], [2 x i64]* %"015", align 4 - store [2 x i64] %"016", [2 x i64]* %"24_0", align 4 - %"24_017" = load [2 x i64], [2 x i64]* %"24_0", align 4 - store [2 x i64] %"24_017", [2 x i64]* %"08", align 4 +cond_16_case_1: ; preds = %23 + %"016" = load { i64*, i64 }, { i64*, i64 }* %"015", align 8 + store { i64*, i64 } %"016", { i64*, i64 }* %"24_0", align 8 + %"24_017" = load { i64*, i64 }, { i64*, i64 }* %"24_0", align 8 + store { i64*, i64 } %"24_017", { i64*, i64 }* %"08", align 8 br label %cond_exit_16 cond_exit_16: ; preds = %cond_16_case_1, %cond_16_case_0 - %"09" = load [2 x i64], [2 x i64]* %"08", align 4 - store [2 x i64] %"09", [2 x i64]* %"16_0", align 4 - %"16_018" = load [2 x i64], [2 x i64]* %"16_0", align 4 + %"09" = load { i64*, i64 }, { i64*, i64 }* %"08", align 8 + store { i64*, i64 } %"09", { i64*, i64 }* %"16_0", align 8 + %"16_018" = load { i64*, i64 }, { i64*, i64 }* %"16_0", align 8 %"8_019" = load i64, i64* %"8_0", align 4 - %26 = icmp ult i64 %"8_019", 2 - br i1 %26, label %28, label %27 - -27: ; preds = %cond_exit_16 - store { i1, i64 } { i1 false, i64 poison }, { i1, i64 }* %"020", align 4 - br label %34 - -28: ; preds = %cond_exit_16 - %29 = alloca i64, i32 2, align 8 - %30 = bitcast i64* %29 to [2 x i64]* - store [2 x i64] %"16_018", [2 x i64]* %30, align 4 - %31 = getelementptr inbounds i64, i64* %29, i64 %"8_019" - %32 = load i64, i64* %31, align 4 - %33 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %32, 1 - store { i1, i64 } %33, { i1, i64 }* %"020", align 4 - br label %34 - -34: ; preds = %27, %28 - %"021" = load { i1, i64 }, { i1, i64 }* %"020", align 4 - store { i1, i64 } %"021", { i1, i64 }* %"26_0", align 4 - %"26_022" = load { i1, i64 }, { i1, i64 }* %"26_0", align 4 - %35 = extractvalue { i1, i64 } %"26_022", 0 - switch i1 %35, label %36 [ - i1 true, label %37 + %array_ptr20 = extractvalue { i64*, i64 } %"16_018", 0 + %array_offset21 = extractvalue { i64*, i64 } %"16_018", 1 + %28 = icmp ult i64 %"8_019", 2 + br i1 %28, label %30, label %29 + +29: ; preds = %cond_exit_16 + store { i1, i64 } { i1 false, i64 poison }, { i1, i64 }* %"022", align 4 + store { i64*, i64 } %"16_018", { i64*, i64 }* %"1", align 8 + br label %35 + +30: ; preds = %cond_exit_16 + %31 = add i64 %"8_019", %array_offset21 + %32 = getelementptr inbounds i64, i64* %array_ptr20, i64 %31 + %33 = load i64, i64* %32, align 4 + %34 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %33, 1 + store { i1, i64 } %34, { i1, i64 }* %"022", align 4 + store { i64*, i64 } %"16_018", { i64*, i64 }* %"1", align 8 + br label %35 + +35: ; preds = %29, %30 + %"023" = load { i1, i64 }, { i1, i64 }* %"022", align 4 + %"124" = load { i64*, i64 }, { i64*, i64 }* %"1", align 8 + store { i1, i64 } %"023", { i1, i64 }* %"26_0", align 4 + store { i64*, i64 } %"124", { i64*, i64 }* %"26_1", align 8 + %"26_025" = load { i1, i64 }, { i1, i64 }* %"26_0", align 4 + %36 = extractvalue { i1, i64 } %"26_025", 0 + switch i1 %36, label %37 [ + i1 true, label %38 ] -36: ; preds = %34 +37: ; preds = %35 br label %cond_28_case_0 -37: ; preds = %34 - %38 = extractvalue { i1, i64 } %"26_022", 1 - store i64 %38, i64* %"027", align 4 +38: ; preds = %35 + %39 = extractvalue { i1, i64 } %"26_025", 1 + store i64 %39, i64* %"030", align 4 br label %cond_28_case_1 -cond_28_case_0: ; preds = %36 +cond_28_case_0: ; preds = %37 store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @1, i32 0, i32 0) }, { i32, i8* }* %"33_0", align 8 - %"33_025" = load { i32, i8* }, { i32, i8* }* %"33_0", align 8 - %39 = extractvalue { i32, i8* } %"33_025", 0 - %40 = extractvalue { i32, i8* } %"33_025", 1 - %41 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %39, i8* %40) + %"33_028" = load { i32, i8* }, { i32, i8* }* %"33_0", align 8 + %40 = extractvalue { i32, i8* } %"33_028", 0 + %41 = extractvalue { i32, i8* } %"33_028", 1 + %42 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.1, i32 0, i32 0), i32 %40, i8* %41) call void @abort() store i64 0, i64* %"34_0", align 4 - %"34_026" = load i64, i64* %"34_0", align 4 - store i64 %"34_026", i64* %"023", align 4 + %"34_029" = load i64, i64* %"34_0", align 4 + store i64 %"34_029", i64* %"026", align 4 br label %cond_exit_28 -cond_28_case_1: ; preds = %37 - %"028" = load i64, i64* %"027", align 4 - store i64 %"028", i64* %"36_0", align 4 - %"36_029" = load i64, i64* %"36_0", align 4 - store i64 %"36_029", i64* %"023", align 4 +cond_28_case_1: ; preds = %38 + %"031" = load i64, i64* %"030", align 4 + store i64 %"031", i64* %"36_0", align 4 + %"36_032" = load i64, i64* %"36_0", align 4 + store i64 %"36_032", i64* %"026", align 4 br label %cond_exit_28 cond_exit_28: ; preds = %cond_28_case_1, %cond_28_case_0 - %"024" = load i64, i64* %"023", align 4 - store i64 %"024", i64* %"28_0", align 4 - %"16_030" = load [2 x i64], [2 x i64]* %"16_0", align 4 - %"10_031" = load i64, i64* %"10_0", align 4 - %"28_032" = load i64, i64* %"28_0", align 4 - %42 = icmp ult i64 %"10_031", 2 - br i1 %42, label %46, label %43 - -43: ; preds = %cond_exit_28 - %44 = insertvalue { i1, i64, [2 x i64] } { i1 false, i64 poison, [2 x i64] poison }, i64 %"28_032", 1 - %45 = insertvalue { i1, i64, [2 x i64] } %44, [2 x i64] %"16_030", 2 - store { i1, i64, [2 x i64] } %45, { i1, i64, [2 x i64] }* %"033", align 4 - br label %55 - -46: ; preds = %cond_exit_28 - %47 = alloca i64, i32 2, align 8 - %48 = bitcast i64* %47 to [2 x i64]* - store [2 x i64] %"16_030", [2 x i64]* %48, align 4 - %49 = getelementptr inbounds i64, i64* %47, i64 %"10_031" + %"027" = load i64, i64* %"026", align 4 + store i64 %"027", i64* %"28_0", align 4 + %"26_133" = load { i64*, i64 }, { i64*, i64 }* %"26_1", align 8 + %"10_034" = load i64, i64* %"10_0", align 4 + %"28_035" = load i64, i64* %"28_0", align 4 + %array_ptr36 = extractvalue { i64*, i64 } %"26_133", 0 + %array_offset37 = extractvalue { i64*, i64 } %"26_133", 1 + %43 = icmp ult i64 %"10_034", 2 + br i1 %43, label %47, label %44 + +44: ; preds = %cond_exit_28 + %45 = insertvalue { i1, { i64*, i64 }, i64 } { i1 false, { i64*, i64 } poison, i64 poison }, i64 %"28_035", 2 + %46 = insertvalue { i1, { i64*, i64 }, i64 } %45, { i64*, i64 } %"26_133", 1 + store { i1, { i64*, i64 }, i64 } %46, { i1, { i64*, i64 }, i64 }* %"038", align 8 + br label %53 + +47: ; preds = %cond_exit_28 + %48 = add i64 %"10_034", %array_offset37 + %49 = getelementptr inbounds i64, i64* %array_ptr36, i64 %48 %50 = load i64, i64* %49, align 4 - store i64 %"28_032", i64* %49, align 4 - %51 = bitcast i64* %47 to [2 x i64]* - %52 = load [2 x i64], [2 x i64]* %51, align 4 - %53 = insertvalue { i1, i64, [2 x i64] } { i1 true, i64 poison, [2 x i64] poison }, i64 %50, 1 - %54 = insertvalue { i1, i64, [2 x i64] } %53, [2 x i64] %52, 2 - store { i1, i64, [2 x i64] } %54, { i1, i64, [2 x i64] }* %"033", align 4 - br label %55 - -55: ; preds = %43, %46 - %"034" = load { i1, i64, [2 x i64] }, { i1, i64, [2 x i64] }* %"033", align 4 - store { i1, i64, [2 x i64] } %"034", { i1, i64, [2 x i64] }* %"38_0", align 4 - %"38_035" = load { i1, i64, [2 x i64] }, { i1, i64, [2 x i64] }* %"38_0", align 4 - %56 = extractvalue { i1, i64, [2 x i64] } %"38_035", 0 - switch i1 %56, label %57 [ - i1 true, label %60 + store i64 %"28_035", i64* %49, align 4 + %51 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %50, 2 + %52 = insertvalue { i1, { i64*, i64 }, i64 } %51, { i64*, i64 } %"26_133", 1 + store { i1, { i64*, i64 }, i64 } %52, { i1, { i64*, i64 }, i64 }* %"038", align 8 + br label %53 + +53: ; preds = %44, %47 + %"039" = load { i1, { i64*, i64 }, i64 }, { i1, { i64*, i64 }, i64 }* %"038", align 8 + store { i1, { i64*, i64 }, i64 } %"039", { i1, { i64*, i64 }, i64 }* %"38_0", align 8 + %"38_040" = load { i1, { i64*, i64 }, i64 }, { i1, { i64*, i64 }, i64 }* %"38_0", align 8 + %54 = extractvalue { i1, { i64*, i64 }, i64 } %"38_040", 0 + switch i1 %54, label %55 [ + i1 true, label %58 ] -57: ; preds = %55 - %58 = extractvalue { i1, i64, [2 x i64] } %"38_035", 1 - %59 = extractvalue { i1, i64, [2 x i64] } %"38_035", 2 - store i64 %58, i64* %"039", align 4 - store [2 x i64] %59, [2 x i64]* %"140", align 4 - br label %cond_40_case_0 - -60: ; preds = %55 - %61 = extractvalue { i1, i64, [2 x i64] } %"38_035", 1 - %62 = extractvalue { i1, i64, [2 x i64] } %"38_035", 2 - store i64 %61, i64* %"048", align 4 - store [2 x i64] %62, [2 x i64]* %"149", align 4 - br label %cond_40_case_1 - -cond_40_case_0: ; preds = %57 - %"041" = load i64, i64* %"039", align 4 - %"142" = load [2 x i64], [2 x i64]* %"140", align 4 - store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, { i32, i8* }* %"45_0", align 8 - store i64 %"041", i64* %"42_0", align 4 - store [2 x i64] %"142", [2 x i64]* %"42_1", align 4 - %"45_043" = load { i32, i8* }, { i32, i8* }* %"45_0", align 8 - %"42_044" = load i64, i64* %"42_0", align 4 - %"42_145" = load [2 x i64], [2 x i64]* %"42_1", align 4 - %63 = extractvalue { i32, i8* } %"45_043", 0 - %64 = extractvalue { i32, i8* } %"45_043", 1 - %65 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.2, i32 0, i32 0), i32 %63, i8* %64) +55: ; preds = %53 + %56 = extractvalue { i1, { i64*, i64 }, i64 } %"38_040", 2 + %57 = extractvalue { i1, { i64*, i64 }, i64 } %"38_040", 1 + store i64 %56, i64* %"045", align 4 + store { i64*, i64 } %57, { i64*, i64 }* %"146", align 8 + br label %cond_39_case_0 + +58: ; preds = %53 + %59 = extractvalue { i1, { i64*, i64 }, i64 } %"38_040", 2 + %60 = extractvalue { i1, { i64*, i64 }, i64 } %"38_040", 1 + store i64 %59, i64* %"054", align 4 + store { i64*, i64 } %60, { i64*, i64 }* %"155", align 8 + br label %cond_39_case_1 + +cond_39_case_0: ; preds = %55 + %"047" = load i64, i64* %"045", align 4 + %"148" = load { i64*, i64 }, { i64*, i64 }* %"146", align 8 + store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @2, i32 0, i32 0) }, { i32, i8* }* %"44_0", align 8 + store i64 %"047", i64* %"41_0", align 4 + store { i64*, i64 } %"148", { i64*, i64 }* %"41_1", align 8 + %"44_049" = load { i32, i8* }, { i32, i8* }* %"44_0", align 8 + %"41_050" = load i64, i64* %"41_0", align 4 + %"41_151" = load { i64*, i64 }, { i64*, i64 }* %"41_1", align 8 + %61 = extractvalue { i32, i8* } %"44_049", 0 + %62 = extractvalue { i32, i8* } %"44_049", 1 + %63 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.2, i32 0, i32 0), i32 %61, i8* %62) call void @abort() - store i64 0, i64* %"46_0", align 4 - store [2 x i64] zeroinitializer, [2 x i64]* %"46_1", align 4 - %"46_046" = load i64, i64* %"46_0", align 4 - %"46_147" = load [2 x i64], [2 x i64]* %"46_1", align 4 - store i64 %"46_046", i64* %"036", align 4 - store [2 x i64] %"46_147", [2 x i64]* %"1", align 4 - br label %cond_exit_40 - -cond_40_case_1: ; preds = %60 - %"050" = load i64, i64* %"048", align 4 - %"151" = load [2 x i64], [2 x i64]* %"149", align 4 - store i64 %"050", i64* %"48_0", align 4 - store [2 x i64] %"151", [2 x i64]* %"48_1", align 4 - %"48_052" = load i64, i64* %"48_0", align 4 - %"48_153" = load [2 x i64], [2 x i64]* %"48_1", align 4 - store i64 %"48_052", i64* %"036", align 4 - store [2 x i64] %"48_153", [2 x i64]* %"1", align 4 - br label %cond_exit_40 - -cond_exit_40: ; preds = %cond_40_case_1, %cond_40_case_0 - %"037" = load i64, i64* %"036", align 4 - %"138" = load [2 x i64], [2 x i64]* %"1", align 4 - store i64 %"037", i64* %"40_0", align 4 - store [2 x i64] %"138", [2 x i64]* %"40_1", align 4 - %"40_154" = load [2 x i64], [2 x i64]* %"40_1", align 4 - %66 = alloca i64, i32 2, align 8 - %67 = bitcast i64* %66 to [2 x i64]* - store [2 x i64] %"40_154", [2 x i64]* %67, align 4 - %68 = getelementptr i64, i64* %66, i32 1 - %69 = load i64, i64* %66, align 4 - %70 = bitcast i64* %68 to [1 x i64]* - %71 = load [1 x i64], [1 x i64]* %70, align 4 - %72 = insertvalue { i1, i64, [1 x i64] } { i1 true, i64 poison, [1 x i64] poison }, i64 %69, 1 - %73 = insertvalue { i1, i64, [1 x i64] } %72, [1 x i64] %71, 2 - store { i1, i64, [1 x i64] } %73, { i1, i64, [1 x i64] }* %"50_0", align 4 - %"50_055" = load { i1, i64, [1 x i64] }, { i1, i64, [1 x i64] }* %"50_0", align 4 - %74 = extractvalue { i1, i64, [1 x i64] } %"50_055", 0 - switch i1 %74, label %75 [ - i1 true, label %76 + store i64 0, i64* %"45_0", align 4 + store { i64*, i64 } zeroinitializer, { i64*, i64 }* %"45_1", align 8 + %"45_052" = load i64, i64* %"45_0", align 4 + %"45_153" = load { i64*, i64 }, { i64*, i64 }* %"45_1", align 8 + store i64 %"45_052", i64* %"041", align 4 + store { i64*, i64 } %"45_153", { i64*, i64 }* %"142", align 8 + br label %cond_exit_39 + +cond_39_case_1: ; preds = %58 + %"056" = load i64, i64* %"054", align 4 + %"157" = load { i64*, i64 }, { i64*, i64 }* %"155", align 8 + store i64 %"056", i64* %"47_0", align 4 + store { i64*, i64 } %"157", { i64*, i64 }* %"47_1", align 8 + %"47_058" = load i64, i64* %"47_0", align 4 + %"47_159" = load { i64*, i64 }, { i64*, i64 }* %"47_1", align 8 + store i64 %"47_058", i64* %"041", align 4 + store { i64*, i64 } %"47_159", { i64*, i64 }* %"142", align 8 + br label %cond_exit_39 + +cond_exit_39: ; preds = %cond_39_case_1, %cond_39_case_0 + %"043" = load i64, i64* %"041", align 4 + %"144" = load { i64*, i64 }, { i64*, i64 }* %"142", align 8 + store i64 %"043", i64* %"39_0", align 4 + store { i64*, i64 } %"144", { i64*, i64 }* %"39_1", align 8 + %"39_160" = load { i64*, i64 }, { i64*, i64 }* %"39_1", align 8 + %array_ptr61 = extractvalue { i64*, i64 } %"39_160", 0 + %array_offset62 = extractvalue { i64*, i64 } %"39_160", 1 + %new_offset = add i64 %array_offset62, 1 + %64 = getelementptr inbounds i64, i64* %array_ptr61, i64 %array_offset62 + %65 = load i64, i64* %64, align 4 + %66 = insertvalue { i64*, i64 } poison, i64* %array_ptr61, 0 + %67 = insertvalue { i64*, i64 } %66, i64 %new_offset, 1 + %68 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %65, 2 + %69 = insertvalue { i1, { i64*, i64 }, i64 } %68, { i64*, i64 } %67, 1 + store { i1, { i64*, i64 }, i64 } %69, { i1, { i64*, i64 }, i64 }* %"49_0", align 8 + %"49_063" = load { i1, { i64*, i64 }, i64 }, { i1, { i64*, i64 }, i64 }* %"49_0", align 8 + %70 = extractvalue { i1, { i64*, i64 }, i64 } %"49_063", 0 + switch i1 %70, label %71 [ + i1 true, label %72 ] -75: ; preds = %cond_exit_40 - br label %cond_51_case_0 - -76: ; preds = %cond_exit_40 - %77 = extractvalue { i1, i64, [1 x i64] } %"50_055", 1 - %78 = extractvalue { i1, i64, [1 x i64] } %"50_055", 2 - store i64 %77, i64* %"063", align 4 - store [1 x i64] %78, [1 x i64]* %"164", align 4 - br label %cond_51_case_1 - -cond_51_case_0: ; preds = %75 - store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, { i32, i8* }* %"56_0", align 8 - %"56_060" = load { i32, i8* }, { i32, i8* }* %"56_0", align 8 - %79 = extractvalue { i32, i8* } %"56_060", 0 - %80 = extractvalue { i32, i8* } %"56_060", 1 - %81 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.3, i32 0, i32 0), i32 %79, i8* %80) +71: ; preds = %cond_exit_39 + br label %cond_50_case_0 + +72: ; preds = %cond_exit_39 + %73 = extractvalue { i1, { i64*, i64 }, i64 } %"49_063", 2 + %74 = extractvalue { i1, { i64*, i64 }, i64 } %"49_063", 1 + store i64 %73, i64* %"071", align 4 + store { i64*, i64 } %74, { i64*, i64 }* %"172", align 8 + br label %cond_50_case_1 + +cond_50_case_0: ; preds = %71 + store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @3, i32 0, i32 0) }, { i32, i8* }* %"55_0", align 8 + %"55_068" = load { i32, i8* }, { i32, i8* }* %"55_0", align 8 + %75 = extractvalue { i32, i8* } %"55_068", 0 + %76 = extractvalue { i32, i8* } %"55_068", 1 + %77 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.3, i32 0, i32 0), i32 %75, i8* %76) call void @abort() - store i64 0, i64* %"57_0", align 4 - store [1 x i64] zeroinitializer, [1 x i64]* %"57_1", align 4 - %"57_061" = load i64, i64* %"57_0", align 4 - %"57_162" = load [1 x i64], [1 x i64]* %"57_1", align 4 - store i64 %"57_061", i64* %"056", align 4 - store [1 x i64] %"57_162", [1 x i64]* %"157", align 4 - br label %cond_exit_51 - -cond_51_case_1: ; preds = %76 - %"065" = load i64, i64* %"063", align 4 - %"166" = load [1 x i64], [1 x i64]* %"164", align 4 - store i64 %"065", i64* %"59_0", align 4 - store [1 x i64] %"166", [1 x i64]* %"59_1", align 4 - %"59_067" = load i64, i64* %"59_0", align 4 - %"59_168" = load [1 x i64], [1 x i64]* %"59_1", align 4 - store i64 %"59_067", i64* %"056", align 4 - store [1 x i64] %"59_168", [1 x i64]* %"157", align 4 - br label %cond_exit_51 - -cond_exit_51: ; preds = %cond_51_case_1, %cond_51_case_0 - %"058" = load i64, i64* %"056", align 4 - %"159" = load [1 x i64], [1 x i64]* %"157", align 4 - store i64 %"058", i64* %"51_0", align 4 - store [1 x i64] %"159", [1 x i64]* %"51_1", align 4 - %"51_169" = load [1 x i64], [1 x i64]* %"51_1", align 4 - %82 = alloca i64, align 8 - %83 = bitcast i64* %82 to [1 x i64]* - store [1 x i64] %"51_169", [1 x i64]* %83, align 4 - %84 = getelementptr i64, i64* %82, i32 0 - %85 = load i64, i64* %84, align 4 - %86 = bitcast i64* %82 to [0 x i64]* - %87 = load [0 x i64], [0 x i64]* %86, align 4 - %88 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %85, 1 - store { i1, i64 } %88, { i1, i64 }* %"61_0", align 4 - %"61_070" = load { i1, i64 }, { i1, i64 }* %"61_0", align 4 - %89 = extractvalue { i1, i64 } %"61_070", 0 - switch i1 %89, label %90 [ - i1 true, label %91 + store i64 0, i64* %"56_0", align 4 + store { i64*, i64 } zeroinitializer, { i64*, i64 }* %"56_1", align 8 + %"56_069" = load i64, i64* %"56_0", align 4 + %"56_170" = load { i64*, i64 }, { i64*, i64 }* %"56_1", align 8 + store i64 %"56_069", i64* %"064", align 4 + store { i64*, i64 } %"56_170", { i64*, i64 }* %"165", align 8 + br label %cond_exit_50 + +cond_50_case_1: ; preds = %72 + %"073" = load i64, i64* %"071", align 4 + %"174" = load { i64*, i64 }, { i64*, i64 }* %"172", align 8 + store i64 %"073", i64* %"58_0", align 4 + store { i64*, i64 } %"174", { i64*, i64 }* %"58_1", align 8 + %"58_075" = load i64, i64* %"58_0", align 4 + %"58_176" = load { i64*, i64 }, { i64*, i64 }* %"58_1", align 8 + store i64 %"58_075", i64* %"064", align 4 + store { i64*, i64 } %"58_176", { i64*, i64 }* %"165", align 8 + br label %cond_exit_50 + +cond_exit_50: ; preds = %cond_50_case_1, %cond_50_case_0 + %"066" = load i64, i64* %"064", align 4 + %"167" = load { i64*, i64 }, { i64*, i64 }* %"165", align 8 + store i64 %"066", i64* %"50_0", align 4 + store { i64*, i64 } %"167", { i64*, i64 }* %"50_1", align 8 + %"50_177" = load { i64*, i64 }, { i64*, i64 }* %"50_1", align 8 + %array_ptr78 = extractvalue { i64*, i64 } %"50_177", 0 + %array_offset79 = extractvalue { i64*, i64 } %"50_177", 1 + %78 = add i64 %array_offset79, 0 + %79 = getelementptr inbounds i64, i64* %array_ptr78, i64 %78 + %80 = load i64, i64* %79, align 4 + %81 = insertvalue { i64*, i64 } poison, i64* %array_ptr78, 0 + %82 = insertvalue { i64*, i64 } %81, i64 %array_offset79, 1 + %83 = insertvalue { i1, { i64*, i64 }, i64 } { i1 true, { i64*, i64 } poison, i64 poison }, i64 %80, 2 + %84 = insertvalue { i1, { i64*, i64 }, i64 } %83, { i64*, i64 } %82, 1 + store { i1, { i64*, i64 }, i64 } %84, { i1, { i64*, i64 }, i64 }* %"60_0", align 8 + %"60_080" = load { i1, { i64*, i64 }, i64 }, { i1, { i64*, i64 }, i64 }* %"60_0", align 8 + %85 = extractvalue { i1, { i64*, i64 }, i64 } %"60_080", 0 + switch i1 %85, label %86 [ + i1 true, label %87 ] -90: ; preds = %cond_exit_51 - br label %cond_62_case_0 - -91: ; preds = %cond_exit_51 - %92 = extractvalue { i1, i64 } %"61_070", 1 - store i64 %92, i64* %"078", align 4 - store [0 x i64] undef, [0 x i64]* %"179", align 4 - br label %cond_62_case_1 - -cond_62_case_0: ; preds = %90 - store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, { i32, i8* }* %"67_0", align 8 - %"67_075" = load { i32, i8* }, { i32, i8* }* %"67_0", align 8 - %93 = extractvalue { i32, i8* } %"67_075", 0 - %94 = extractvalue { i32, i8* } %"67_075", 1 - %95 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.4, i32 0, i32 0), i32 %93, i8* %94) +86: ; preds = %cond_exit_50 + br label %cond_61_case_0 + +87: ; preds = %cond_exit_50 + %88 = extractvalue { i1, { i64*, i64 }, i64 } %"60_080", 2 + %89 = extractvalue { i1, { i64*, i64 }, i64 } %"60_080", 1 + store i64 %88, i64* %"088", align 4 + store { i64*, i64 } %89, { i64*, i64 }* %"189", align 8 + br label %cond_61_case_1 + +cond_61_case_0: ; preds = %86 + store { i32, i8* } { i32 1, i8* getelementptr inbounds ([37 x i8], [37 x i8]* @4, i32 0, i32 0) }, { i32, i8* }* %"66_0", align 8 + %"66_085" = load { i32, i8* }, { i32, i8* }* %"66_0", align 8 + %90 = extractvalue { i32, i8* } %"66_085", 0 + %91 = extractvalue { i32, i8* } %"66_085", 1 + %92 = call i32 (i8*, ...) @printf(i8* getelementptr inbounds ([34 x i8], [34 x i8]* @prelude.panic_template.4, i32 0, i32 0), i32 %90, i8* %91) call void @abort() - store i64 0, i64* %"68_0", align 4 - store [0 x i64] zeroinitializer, [0 x i64]* %"68_1", align 4 - %"68_076" = load i64, i64* %"68_0", align 4 - %"68_177" = load [0 x i64], [0 x i64]* %"68_1", align 4 - store i64 %"68_076", i64* %"071", align 4 - store [0 x i64] %"68_177", [0 x i64]* %"172", align 4 - br label %cond_exit_62 - -cond_62_case_1: ; preds = %91 - %"080" = load i64, i64* %"078", align 4 - %"181" = load [0 x i64], [0 x i64]* %"179", align 4 - store i64 %"080", i64* %"70_0", align 4 - store [0 x i64] %"181", [0 x i64]* %"70_1", align 4 - %"70_082" = load i64, i64* %"70_0", align 4 - %"70_183" = load [0 x i64], [0 x i64]* %"70_1", align 4 - store i64 %"70_082", i64* %"071", align 4 - store [0 x i64] %"70_183", [0 x i64]* %"172", align 4 - br label %cond_exit_62 - -cond_exit_62: ; preds = %cond_62_case_1, %cond_62_case_0 - %"073" = load i64, i64* %"071", align 4 - %"174" = load [0 x i64], [0 x i64]* %"172", align 4 - store i64 %"073", i64* %"62_0", align 4 - store [0 x i64] %"174", [0 x i64]* %"62_1", align 4 - %"62_184" = load [0 x i64], [0 x i64]* %"62_1", align 4 + store i64 0, i64* %"67_0", align 4 + store { i64*, i64 } zeroinitializer, { i64*, i64 }* %"67_1", align 8 + %"67_086" = load i64, i64* %"67_0", align 4 + %"67_187" = load { i64*, i64 }, { i64*, i64 }* %"67_1", align 8 + store i64 %"67_086", i64* %"081", align 4 + store { i64*, i64 } %"67_187", { i64*, i64 }* %"182", align 8 + br label %cond_exit_61 + +cond_61_case_1: ; preds = %87 + %"090" = load i64, i64* %"088", align 4 + %"191" = load { i64*, i64 }, { i64*, i64 }* %"189", align 8 + store i64 %"090", i64* %"69_0", align 4 + store { i64*, i64 } %"191", { i64*, i64 }* %"69_1", align 8 + %"69_092" = load i64, i64* %"69_0", align 4 + %"69_193" = load { i64*, i64 }, { i64*, i64 }* %"69_1", align 8 + store i64 %"69_092", i64* %"081", align 4 + store { i64*, i64 } %"69_193", { i64*, i64 }* %"182", align 8 + br label %cond_exit_61 + +cond_exit_61: ; preds = %cond_61_case_1, %cond_61_case_0 + %"083" = load i64, i64* %"081", align 4 + %"184" = load { i64*, i64 }, { i64*, i64 }* %"182", align 8 + store i64 %"083", i64* %"61_0", align 4 + store { i64*, i64 } %"184", { i64*, i64 }* %"61_1", align 8 + %"61_194" = load { i64*, i64 }, { i64*, i64 }* %"61_1", align 8 + %array_ptr95 = extractvalue { i64*, i64 } %"61_194", 0 + %array_offset96 = extractvalue { i64*, i64 } %"61_194", 1 + %93 = bitcast i64* %array_ptr95 to i8* + call void @free(i8* %93) ret void } +declare i8* @malloc(i64) + declare i32 @printf(i8*, ...) declare void @abort() + +declare void @free(i8*) diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap index 3a718f7f2..0a84535ed 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@llvm14.snap @@ -5,10 +5,20 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -define [2 x i64] @_hl.main.1() { +define { i64*, i64 } @_hl.main.1() { alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - ret [2 x i64] [i64 1, i64 2] + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i32 0 + store i64 1, i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i32 1 + store i64 2, i64* %5, align 4 + ret { i64*, i64 } %3 } + +declare i8* @malloc(i64) diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap index 5befaf3df..a908cbb3d 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_array_value@pre-mem2reg@llvm14.snap @@ -5,16 +5,26 @@ expression: mod_str ; ModuleID = 'test_context' source_filename = "test_context" -define [2 x i64] @_hl.main.1() { +define { i64*, i64 } @_hl.main.1() { alloca_block: - %"0" = alloca [2 x i64], align 8 - %"5_0" = alloca [2 x i64], align 8 + %"0" = alloca { i64*, i64 }, align 8 + %"5_0" = alloca { i64*, i64 }, align 8 br label %entry_block entry_block: ; preds = %alloca_block - store [2 x i64] [i64 1, i64 2], [2 x i64]* %"5_0", align 4 - %"5_01" = load [2 x i64], [2 x i64]* %"5_0", align 4 - store [2 x i64] %"5_01", [2 x i64]* %"0", align 4 - %"02" = load [2 x i64], [2 x i64]* %"0", align 4 - ret [2 x i64] %"02" + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i32 0 + store i64 1, i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i32 1 + store i64 2, i64* %5, align 4 + store { i64*, i64 } %3, { i64*, i64 }* %"5_0", align 8 + %"5_01" = load { i64*, i64 }, { i64*, i64 }* %"5_0", align 8 + store { i64*, i64 } %"5_01", { i64*, i64 }* %"0", align 8 + %"02" = load { i64*, i64 }, { i64*, i64 }* %"0", align 8 + ret { i64*, i64 } %"02" } + +declare i8* @malloc(i64) diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@llvm14.snap new file mode 100644 index 000000000..d03fdb1c3 --- /dev/null +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@llvm14.snap @@ -0,0 +1,45 @@ +--- +source: hugr-llvm/src/extension/collections/array.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define void @_hl.main.1() { +alloca_block: + br label %entry_block + +entry_block: ; preds = %alloca_block + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 1, i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 2, i64* %5, align 4 + %array_ptr = extractvalue { i64*, i64 } %3, 0 + %array_offset = extractvalue { i64*, i64 } %3, 1 + %6 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %7 = bitcast i8* %6 to i64* + %8 = insertvalue { i64*, i64 } poison, i64* %7, 0 + %9 = insertvalue { i64*, i64 } %8, i64 0, 1 + %10 = getelementptr inbounds i64, i64* %array_ptr, i64 %array_offset + call void @llvm.memcpy.p0i64.p0i64.i64(i64* %7, i64* %10, i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2), i1 false) + %array_ptr5 = extractvalue { i64*, i64 } %9, 0 + %11 = bitcast i64* %array_ptr5 to i8* + call void @free(i8* %11) + %array_ptr7 = extractvalue { i64*, i64 } %3, 0 + %12 = bitcast i64* %array_ptr7 to i8* + call void @free(i8* %12) + ret void +} + +declare i8* @malloc(i64) + +; Function Attrs: argmemonly nofree nounwind willreturn +declare void @llvm.memcpy.p0i64.p0i64.i64(i64* noalias nocapture writeonly, i64* noalias nocapture readonly, i64, i1 immarg) #0 + +declare void @free(i8*) + +attributes #0 = { argmemonly nofree nounwind willreturn } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@pre-mem2reg@llvm14.snap new file mode 100644 index 000000000..3fd1b276f --- /dev/null +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_clone@pre-mem2reg@llvm14.snap @@ -0,0 +1,60 @@ +--- +source: hugr-llvm/src/extension/collections/array.rs +expression: mod_str +--- +; ModuleID = 'test_context' +source_filename = "test_context" + +define void @_hl.main.1() { +alloca_block: + %"7_0" = alloca i64, align 8 + %"5_0" = alloca i64, align 8 + %"8_0" = alloca { i64*, i64 }, align 8 + %"9_0" = alloca { i64*, i64 }, align 8 + %"9_1" = alloca { i64*, i64 }, align 8 + br label %entry_block + +entry_block: ; preds = %alloca_block + store i64 2, i64* %"7_0", align 4 + store i64 1, i64* %"5_0", align 4 + %"5_01" = load i64, i64* %"5_0", align 4 + %"7_02" = load i64, i64* %"7_0", align 4 + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 %"5_01", i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 %"7_02", i64* %5, align 4 + store { i64*, i64 } %3, { i64*, i64 }* %"8_0", align 8 + %"8_03" = load { i64*, i64 }, { i64*, i64 }* %"8_0", align 8 + %array_ptr = extractvalue { i64*, i64 } %"8_03", 0 + %array_offset = extractvalue { i64*, i64 } %"8_03", 1 + %6 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %7 = bitcast i8* %6 to i64* + %8 = insertvalue { i64*, i64 } poison, i64* %7, 0 + %9 = insertvalue { i64*, i64 } %8, i64 0, 1 + %10 = getelementptr inbounds i64, i64* %array_ptr, i64 %array_offset + call void @llvm.memcpy.p0i64.p0i64.i64(i64* %7, i64* %10, i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2), i1 false) + store { i64*, i64 } %"8_03", { i64*, i64 }* %"9_0", align 8 + store { i64*, i64 } %9, { i64*, i64 }* %"9_1", align 8 + %"9_14" = load { i64*, i64 }, { i64*, i64 }* %"9_1", align 8 + %array_ptr5 = extractvalue { i64*, i64 } %"9_14", 0 + %11 = bitcast i64* %array_ptr5 to i8* + call void @free(i8* %11) + %"9_06" = load { i64*, i64 }, { i64*, i64 }* %"9_0", align 8 + %array_ptr7 = extractvalue { i64*, i64 } %"9_06", 0 + %12 = bitcast i64* %array_ptr7 to i8* + call void @free(i8* %12) + ret void +} + +declare i8* @malloc(i64) + +; Function Attrs: argmemonly nofree nounwind willreturn +declare void @llvm.memcpy.p0i64.p0i64.i64(i64* noalias nocapture writeonly, i64* noalias nocapture readonly, i64, i1 immarg) #0 + +declare void @free(i8*) + +attributes #0 = { argmemonly nofree nounwind willreturn } diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap index 1c638784d..a975ed043 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@llvm14.snap @@ -10,24 +10,37 @@ alloca_block: br label %entry_block entry_block: ; preds = %alloca_block - %0 = insertvalue [2 x i64] undef, i64 1, 0 - %1 = insertvalue [2 x i64] %0, i64 2, 1 - %2 = icmp ult i64 1, 2 - br i1 %2, label %4, label %3 + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 1, i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 2, i64* %5, align 4 + %array_ptr = extractvalue { i64*, i64 } %3, 0 + %array_offset = extractvalue { i64*, i64 } %3, 1 + %6 = icmp ult i64 1, 2 + br i1 %6, label %8, label %7 -3: ; preds = %entry_block - br label %10 +7: ; preds = %entry_block + br label %13 -4: ; preds = %entry_block - %5 = alloca i64, i32 2, align 8 - %6 = bitcast i64* %5 to [2 x i64]* - store [2 x i64] %1, [2 x i64]* %6, align 4 - %7 = getelementptr inbounds i64, i64* %5, i64 1 - %8 = load i64, i64* %7, align 4 - %9 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %8, 1 - br label %10 +8: ; preds = %entry_block + %9 = add i64 1, %array_offset + %10 = getelementptr inbounds i64, i64* %array_ptr, i64 %9 + %11 = load i64, i64* %10, align 4 + %12 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %11, 1 + br label %13 -10: ; preds = %3, %4 - %"0.0" = phi { i1, i64 } [ %9, %4 ], [ { i1 false, i64 poison }, %3 ] +13: ; preds = %7, %8 + %"0.0" = phi { i1, i64 } [ %12, %8 ], [ { i1 false, i64 poison }, %7 ] + %array_ptr8 = extractvalue { i64*, i64 } %3, 0 + %14 = bitcast i64* %array_ptr8 to i8* + call void @free(i8* %14) ret void } + +declare i8* @malloc(i64) + +declare void @free(i8*) diff --git a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap index 15902b579..6b7dfcffe 100644 --- a/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__array__test__emit_get@pre-mem2reg@llvm14.snap @@ -9,9 +9,11 @@ define void @_hl.main.1() { alloca_block: %"7_0" = alloca i64, align 8 %"5_0" = alloca i64, align 8 - %"8_0" = alloca [2 x i64], align 8 + %"8_0" = alloca { i64*, i64 }, align 8 %"9_0" = alloca { i1, i64 }, align 8 + %"9_1" = alloca { i64*, i64 }, align 8 %"0" = alloca { i1, i64 }, align 8 + %"1" = alloca { i64*, i64 }, align 8 br label %entry_block entry_block: ; preds = %alloca_block @@ -19,30 +21,48 @@ entry_block: ; preds = %alloca_block store i64 1, i64* %"5_0", align 4 %"5_01" = load i64, i64* %"5_0", align 4 %"7_02" = load i64, i64* %"7_0", align 4 - %0 = insertvalue [2 x i64] undef, i64 %"5_01", 0 - %1 = insertvalue [2 x i64] %0, i64 %"7_02", 1 - store [2 x i64] %1, [2 x i64]* %"8_0", align 4 - %"8_03" = load [2 x i64], [2 x i64]* %"8_0", align 4 + %0 = call i8* @malloc(i64 mul (i64 ptrtoint (i64* getelementptr (i64, i64* null, i32 1) to i64), i64 2)) + %1 = bitcast i8* %0 to i64* + %2 = insertvalue { i64*, i64 } poison, i64* %1, 0 + %3 = insertvalue { i64*, i64 } %2, i64 0, 1 + %4 = getelementptr inbounds i64, i64* %1, i64 0 + store i64 %"5_01", i64* %4, align 4 + %5 = getelementptr inbounds i64, i64* %1, i64 1 + store i64 %"7_02", i64* %5, align 4 + store { i64*, i64 } %3, { i64*, i64 }* %"8_0", align 8 + %"8_03" = load { i64*, i64 }, { i64*, i64 }* %"8_0", align 8 %"5_04" = load i64, i64* %"5_0", align 4 - %2 = icmp ult i64 %"5_04", 2 - br i1 %2, label %4, label %3 + %array_ptr = extractvalue { i64*, i64 } %"8_03", 0 + %array_offset = extractvalue { i64*, i64 } %"8_03", 1 + %6 = icmp ult i64 %"5_04", 2 + br i1 %6, label %8, label %7 -3: ; preds = %entry_block +7: ; preds = %entry_block store { i1, i64 } { i1 false, i64 poison }, { i1, i64 }* %"0", align 4 - br label %10 + store { i64*, i64 } %"8_03", { i64*, i64 }* %"1", align 8 + br label %13 -4: ; preds = %entry_block - %5 = alloca i64, i32 2, align 8 - %6 = bitcast i64* %5 to [2 x i64]* - store [2 x i64] %"8_03", [2 x i64]* %6, align 4 - %7 = getelementptr inbounds i64, i64* %5, i64 %"5_04" - %8 = load i64, i64* %7, align 4 - %9 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %8, 1 - store { i1, i64 } %9, { i1, i64 }* %"0", align 4 - br label %10 +8: ; preds = %entry_block + %9 = add i64 %"5_04", %array_offset + %10 = getelementptr inbounds i64, i64* %array_ptr, i64 %9 + %11 = load i64, i64* %10, align 4 + %12 = insertvalue { i1, i64 } { i1 true, i64 poison }, i64 %11, 1 + store { i1, i64 } %12, { i1, i64 }* %"0", align 4 + store { i64*, i64 } %"8_03", { i64*, i64 }* %"1", align 8 + br label %13 -10: ; preds = %3, %4 +13: ; preds = %7, %8 %"05" = load { i1, i64 }, { i1, i64 }* %"0", align 4 + %"16" = load { i64*, i64 }, { i64*, i64 }* %"1", align 8 store { i1, i64 } %"05", { i1, i64 }* %"9_0", align 4 + store { i64*, i64 } %"16", { i64*, i64 }* %"9_1", align 8 + %"9_17" = load { i64*, i64 }, { i64*, i64 }* %"9_1", align 8 + %array_ptr8 = extractvalue { i64*, i64 } %"9_17", 0 + %14 = bitcast i64* %array_ptr8 to i8* + call void @free(i8* %14) ret void } + +declare i8* @malloc(i64) + +declare void @free(i8*) diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 241f53ec0..fa0cfa46c 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -25,6 +25,7 @@ lazy_static = { workspace = true } paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } +strum = { workspace = true } [dev-dependencies] rstest = { workspace = true } diff --git a/hugr-passes/README.md b/hugr-passes/README.md index c2bca2124..4aa9ea884 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -47,4 +47,4 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md \ No newline at end of file + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 83ff71b67..d803b817c 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -11,6 +11,8 @@ mod dead_funcs; pub use dead_funcs::{remove_dead_funcs, RemoveDeadFuncsError, RemoveDeadFuncsPass}; pub mod force_order; mod half_node; +pub mod linearize_array; +pub use linearize_array::LinearizeArrayPass; pub mod lower; pub mod merge_bbs; mod monomorphize; diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs new file mode 100644 index 000000000..56ba7a7d0 --- /dev/null +++ b/hugr-passes/src/linearize_array.rs @@ -0,0 +1,397 @@ +//! Provides [LinearizeArrayPass] which turns 'value_array`s into regular linear `array`s. + +use hugr_core::{ + extension::{ + prelude::Noop, + simple_op::{HasConcrete, MakeRegisteredOp}, + }, + hugr::hugrmut::HugrMut, + ops::NamedOp, + std_extensions::collections::{ + array::{ + array_type_def, array_type_parametric, Array, ArrayKind, ArrayOpDef, ArrayRepeatDef, + ArrayScanDef, ArrayValue, ARRAY_REPEAT_OP_ID, ARRAY_SCAN_OP_ID, + }, + value_array::{self, VArrayFromArrayDef, VArrayToArrayDef, VArrayValue, ValueArray}, + }, + types::Transformable, + Node, +}; +use itertools::Itertools; +use strum::IntoEnumIterator; + +use crate::{ + replace_types::{ + handlers::copy_discard_array, DelegatingLinearizer, NodeTemplate, ReplaceTypesError, + }, + ComposablePass, ReplaceTypes, +}; + +/// A HUGR -> HUGR pass that turns 'value_array`s into regular linear `array`s. +/// +/// # Panics +/// +/// - If the Hugr has inter-graph edges whose type contains `value_array`s +/// - If the Hugr contains [`ArrayOpDef::get`] operations on `value_array`s that +/// contain nested `value_array`s. +#[derive(Clone)] +pub struct LinearizeArrayPass(ReplaceTypes); + +impl Default for LinearizeArrayPass { + fn default() -> Self { + let mut pass = ReplaceTypes::default(); + pass.replace_parametrized_type(ValueArray::type_def(), |args| { + Some(Array::ty_parametric(args[0].clone(), args[1].clone()).unwrap()) + }); + pass.replace_consts_parametrized(ValueArray::type_def(), |v, replacer| { + let v: &VArrayValue = v.value().downcast_ref().unwrap(); + let mut ty = v.get_element_type().clone(); + let mut contents = v.get_contents().iter().cloned().collect_vec(); + ty.transform(replacer).unwrap(); + contents.iter_mut().for_each(|v| { + replacer.change_value(v).unwrap(); + }); + Ok(Some(ArrayValue::new(ty, contents).into())) + }); + for op_def in ArrayOpDef::iter() { + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&op_def.name()).unwrap(), + move |args| { + // `get` is only allowed for copyable elements. Assuming the Hugr was + // valid when we started, the only way for the element to become linear + // is if it used to contain nested `value_array`s. In that case, we + // have to get rid of the `get`. + // TODO: But what should we replace it with? Can't be a `set` since we + // don't have anything to put in. Maybe we need a new `get_copy` op + // that takes a function ptr to copy the element? For now, let's just + // error out and make sure we're not emitting `get`s for nested value + // arrays. + if op_def == ArrayOpDef::get && !args[1].as_type().unwrap().copyable() { + panic!( + "Cannot linearise arrays in this Hugr: \ + Contains a `get` operation on nested value arrays" + ); + } + Some(NodeTemplate::SingleOp( + op_def.instantiate(args).unwrap().into(), + )) + }, + ); + } + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&ARRAY_REPEAT_OP_ID).unwrap(), + |args| { + Some(NodeTemplate::SingleOp( + ArrayRepeatDef::new().instantiate(args).unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION.get_op(&ARRAY_SCAN_OP_ID).unwrap(), + |args| { + Some(NodeTemplate::SingleOp( + ArrayScanDef::new().instantiate(args).unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION + .get_op(&VArrayFromArrayDef::new().name()) + .unwrap(), + |args| { + let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); + Some(NodeTemplate::SingleOp( + Noop::new(array_ty).to_extension_op().unwrap().into(), + )) + }, + ); + pass.replace_parametrized_op( + value_array::EXTENSION + .get_op(&VArrayToArrayDef::new().name()) + .unwrap(), + |args| { + let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); + Some(NodeTemplate::SingleOp( + Noop::new(array_ty).to_extension_op().unwrap().into(), + )) + }, + ); + pass.linearizer() + .register_callback(array_type_def(), copy_discard_array); + Self(pass) + } +} + +impl ComposablePass for LinearizeArrayPass { + type Node = Node; + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr) + } +} + +impl LinearizeArrayPass { + /// Returns a new [`LinearizeArrayPass`] that handles all standard extensions. + pub fn new() -> Self { + Self::default() + } + + /// Allows to configure how to clone and discard arrays that are nested + /// inside opaque extension values. + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { + self.0.linearizer() + } +} + +#[cfg(test)] +mod test { + use hugr_core::builder::ModuleBuilder; + use hugr_core::extension::prelude::{ConstUsize, Noop}; + use hugr_core::ops::handle::NodeHandle; + use hugr_core::ops::{Const, OpType}; + use hugr_core::std_extensions::collections::array::{ + self, array_type, ArrayValue, Direction, FROM, INTO, + }; + use hugr_core::std_extensions::collections::value_array::{ + VArrayFromArray, VArrayRepeat, VArrayScan, VArrayToArray, VArrayValue, + }; + use hugr_core::types::Transformable; + use hugr_core::{ + builder::{Container, DFGBuilder, Dataflow, HugrBuilder}, + extension::prelude::{qb_t, usize_t}, + std_extensions::collections::{ + array::{ + op_builder::{build_all_array_ops, build_all_value_array_ops}, + ArrayRepeat, ArrayScan, + }, + value_array::{self, value_array_type}, + }, + types::{Signature, Type}, + HugrView, + }; + use itertools::Itertools; + use rstest::rstest; + + use crate::{composable::ValidatingPass, ComposablePass}; + + use super::LinearizeArrayPass; + + #[test] + fn all_value_array_ops() { + let sig = Signature::new_endo(Type::EMPTY_TYPEROW); + let mut hugr = build_all_value_array_ops(DFGBuilder::new(sig.clone()).unwrap()) + .finish_hugr() + .unwrap(); + ValidatingPass::new(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let target_hugr = build_all_array_ops(DFGBuilder::new(sig).unwrap()) + .finish_hugr() + .unwrap(); + for (n1, n2) in hugr.nodes().zip_eq(target_hugr.nodes()) { + assert_eq!(hugr.get_optype(n1), target_hugr.get_optype(n2)); + } + } + + #[rstest] + #[case(usize_t(), 2)] + #[case(qb_t(), 2)] + #[case(value_array_type(4, usize_t()), 2)] + fn repeat(#[case] elem_ty: Type, #[case] size: u64) { + let mut builder = ModuleBuilder::new(); + let repeat_decl = builder + .declare( + "foo", + Signature::new(Type::EMPTY_TYPEROW, elem_ty.clone()).into(), + ) + .unwrap(); + let mut f = builder + .define_function( + "bar", + Signature::new(Type::EMPTY_TYPEROW, value_array_type(size, elem_ty.clone())), + ) + .unwrap(); + let repeat_f = f.load_func(&repeat_decl, &[]).unwrap(); + let repeat = f + .add_dataflow_op(VArrayRepeat::new(elem_ty.clone(), size), [repeat_f]) + .unwrap(); + let [arr] = repeat.outputs_arr(); + f.set_outputs([arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new(pass.clone()).run(&mut hugr).unwrap(); + let new_repeat: ArrayRepeat = hugr.get_optype(repeat.node()).cast().unwrap(); + let mut new_elem_ty = elem_ty.clone(); + new_elem_ty.transform(&pass.0).unwrap(); + assert_eq!(new_repeat, ArrayRepeat::new(new_elem_ty, size)); + } + + #[rstest] + #[case(usize_t(), qb_t(), 2)] + #[case(usize_t(), value_array_type(4, usize_t()), 2)] + #[case(value_array_type(4, usize_t()), value_array_type(8, usize_t()), 2)] + fn scan(#[case] src_ty: Type, #[case] tgt_ty: Type, #[case] size: u64) { + let mut builder = ModuleBuilder::new(); + let scan_decl = builder + .declare("foo", Signature::new(src_ty.clone(), tgt_ty.clone()).into()) + .unwrap(); + let mut f = builder + .define_function( + "bar", + Signature::new( + value_array_type(size, src_ty.clone()), + value_array_type(size, tgt_ty.clone()), + ), + ) + .unwrap(); + let [arr] = f.input_wires_arr(); + let scan_f = f.load_func(&scan_decl, &[]).unwrap(); + let scan = f + .add_dataflow_op( + VArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size), + [arr, scan_f], + ) + .unwrap(); + let [arr] = scan.outputs_arr(); + f.set_outputs([arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new(pass.clone()).run(&mut hugr).unwrap(); + let new_scan: ArrayScan = hugr.get_optype(scan.node()).cast().unwrap(); + let mut new_src_ty = src_ty.clone(); + let mut new_tgt_ty = tgt_ty.clone(); + new_src_ty.transform(&pass.0).unwrap(); + new_tgt_ty.transform(&pass.0).unwrap(); + + assert_eq!( + new_scan, + ArrayScan::new(new_src_ty, new_tgt_ty, vec![], size) + ); + } + + #[rstest] + #[case(INTO, usize_t(), 2)] + #[case(FROM, usize_t(), 2)] + #[case(INTO, array_type(4, usize_t()), 2)] + #[case(FROM, array_type(4, usize_t()), 2)] + #[case(INTO, value_array_type(4, usize_t()), 2)] + #[case(FROM, value_array_type(4, usize_t()), 2)] + fn convert(#[case] dir: Direction, #[case] elem_ty: Type, #[case] size: u64) { + let (src, tgt) = match dir { + INTO => ( + value_array_type(size, elem_ty.clone()), + array_type(size, elem_ty.clone()), + ), + FROM => ( + array_type(size, elem_ty.clone()), + value_array_type(size, elem_ty.clone()), + ), + }; + let sig = Signature::new(src, tgt); + let mut builder = DFGBuilder::new(sig).unwrap(); + let [arr] = builder.input_wires_arr(); + let op: OpType = match dir { + INTO => VArrayToArray::new(elem_ty.clone(), size).into(), + FROM => VArrayFromArray::new(elem_ty.clone(), size).into(), + }; + let convert = builder.add_dataflow_op(op, [arr]).unwrap(); + let [arr] = convert.outputs_arr(); + builder.set_outputs(vec![arr]).unwrap(); + let mut hugr = builder.finish_hugr().unwrap(); + + let pass = LinearizeArrayPass::default(); + ValidatingPass::new(pass.clone()).run(&mut hugr).unwrap(); + let new_convert: Noop = hugr.get_optype(convert.node()).cast().unwrap(); + let mut new_elem_ty = elem_ty.clone(); + new_elem_ty.transform(&pass.0).unwrap(); + + assert_eq!(new_convert, Noop::new(array_type(size, new_elem_ty))); + } + + #[rstest] + #[case(value_array_type(2, usize_t()))] + #[case(value_array_type(2, value_array_type(4, usize_t())))] + #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] + fn implicit_clone(#[case] array_ty: Type) { + let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]); + let mut builder = DFGBuilder::new(sig).unwrap(); + let [arr] = builder.input_wires_arr(); + builder.set_outputs(vec![arr, arr]).unwrap(); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + } + + #[rstest] + #[case(value_array_type(2, usize_t()))] + #[case(value_array_type(2, value_array_type(4, usize_t())))] + #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] + fn implicit_discard(#[case] array_ty: Type) { + let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW); + let mut builder = DFGBuilder::new(sig).unwrap(); + builder.set_outputs(vec![]).unwrap(); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + } + + #[test] + fn array_value() { + let mut builder = ModuleBuilder::new(); + let array_v = VArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]); + let c = builder.add_constant(Const::new(array_v.clone().into())); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let new_array_v: &ArrayValue = hugr + .get_optype(c.node()) + .as_const() + .unwrap() + .get_custom_value() + .unwrap(); + + assert_eq!(new_array_v.get_element_type(), array_v.get_element_type()); + assert_eq!(new_array_v.get_contents(), array_v.get_contents()); + } + + #[test] + fn array_value_nested() { + let mut builder = ModuleBuilder::new(); + let array_v_inner = VArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]); + let array_v: array::GenericArrayValue = VArrayValue::new( + value_array_type(1, usize_t()), + vec![array_v_inner.clone().into()], + ); + let c = builder.add_constant(Const::new(array_v.clone().into())); + + let mut hugr = builder.finish_hugr().unwrap(); + ValidatingPass::new(LinearizeArrayPass::default()) + .run(&mut hugr) + .unwrap(); + + let new_array_v: &ArrayValue = hugr + .get_optype(c.node()) + .as_const() + .unwrap() + .get_custom_value() + .unwrap(); + + assert_eq!(new_array_v.get_element_type(), &array_type(1, usize_t())); + assert_eq!( + new_array_v.get_contents()[0], + ArrayValue::new(usize_t(), vec![ConstUsize::new(1).into()]).into() + ); + } +} diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index cfe2c9514..6505b274e 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -335,7 +335,8 @@ mod test { use hugr_core::extension::simple_op::MakeRegisteredOp as _; use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; - use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; + use hugr_core::std_extensions::collections::array::ArrayKind; + use hugr_core::std_extensions::collections::value_array::{VArrayOpDef, ValueArray}; use hugr_core::types::type_param::TypeParam; use itertools::Itertools; @@ -480,23 +481,32 @@ mod test { let mut outer = FunctionBuilder::new( "mainish", Signature::new( - array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) - .unwrap(), + ValueArray::ty_parametric( + sa(n), + ValueArray::ty_parametric(sa(2), usize_t()).unwrap(), + ) + .unwrap(), vec![usize_t(); 2], ), ) .unwrap(); - let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); + let arr2u = || ValueArray::ty_parametric(sa(2), usize_t()).unwrap(); let pf1t = PolyFuncType::new( [TypeParam::max_nat()], - Signature::new(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()), + Signature::new( + ValueArray::ty_parametric(sv(0), arr2u()).unwrap(), + usize_t(), + ), ); let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); let pf2t = PolyFuncType::new( [TypeParam::max_nat(), TypeBound::Copyable.into()], - Signature::new(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)), + Signature::new( + vec![ValueArray::ty_parametric(sv(0), tv(1)).unwrap()], + tv(1), + ), ); let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); @@ -510,10 +520,10 @@ mod test { let pf2 = { let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); - let op_def = collections::array::EXTENSION.get_op("get").unwrap(); + let op_def = collections::value_array::EXTENSION.get_op("get").unwrap(); let op = hugr_core::ops::ExtensionOp::new(op_def.clone(), vec![sv(0), tv(1).into()]) .unwrap(); - let [get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); + let [get, _] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); let [got] = pf2 .build_unwrap_sum(1, SumType::new([vec![], vec![tv(1)]]), get) .unwrap(); @@ -536,7 +546,7 @@ mod test { .call(pf1.handle(), &[sa(n)], outer.input_wires()) .unwrap() .outputs_arr(); - let popleft = ArrayOpDef::pop_left.to_concrete(arr2u(), n); + let popleft = VArrayOpDef::pop_left.to_concrete(arr2u(), n); let ar2 = outer .add_dataflow_op(popleft.clone(), outer.input_wires()) .unwrap(); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 05d0168c8..b5b98e887 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use handlers::list_const; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::std_extensions::collections::list::list_type_def; +use hugr_core::std_extensions::collections::value_array::value_array_type_def; use thiserror::Error; use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; @@ -214,6 +215,7 @@ impl Default for ReplaceTypes { let mut res = Self::new_empty(); res.linearize = DelegatingLinearizer::default(); res.replace_consts_parametrized(array_type_def(), handlers::array_const); + res.replace_consts_parametrized(value_array_type_def(), handlers::value_array_const); res.replace_consts_parametrized(list_type_def(), list_const); res } @@ -590,6 +592,7 @@ impl From<&OpDef> for ParametricOp { mod test { use std::sync::Arc; + use crate::replace_types::handlers::generic_array_const; use hugr_core::builder::{ inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, @@ -600,14 +603,20 @@ mod test { use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::constant::CustomConst; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; - use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr_core::std_extensions::collections::{ - array::{array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, - list::{list_type, list_type_def, ListOp, ListValue}, + use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::std_extensions::collections::array::Array; + use hugr_core::std_extensions::collections::array::{ArrayKind, GenericArrayValue}; + use hugr_core::std_extensions::collections::list::{ + list_type, list_type_def, ListOp, ListValue, }; + use hugr_core::std_extensions::collections::value_array::{ + value_array_type, VArrayOp, VArrayOpDef, VArrayValue, ValueArray, + }; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; @@ -680,7 +689,7 @@ mod test { new: impl Fn(Signature) -> Result, ) -> T { let mut dfb = new(Signature::new( - vec![array_type(64, elem_ty.clone()), i64_t()], + vec![value_array_type(64, elem_ty.clone()), i64_t()], elem_ty.clone(), )) .unwrap(); @@ -689,8 +698,11 @@ mod test { .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) .unwrap() .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + let [opt, _] = dfb + .add_dataflow_op( + VArrayOpDef::get.to_concrete(elem_ty.clone(), 64), + [val, idx], + ) .unwrap() .outputs_arr(); let [res] = dfb @@ -706,7 +718,7 @@ mod test { lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); lw.replace_parametrized_type( pv, - Box::new(|args: &[TypeArg]| Some(array_type(64, just_elem_type(args).clone()))), + Box::new(|args: &[TypeArg]| Some(value_array_type(64, just_elem_type(args).clone()))), ); lw.replace_op( &read_op(ext, bool_t()), @@ -835,10 +847,10 @@ mod test { // The PackedVec> becomes an array let [array_get] = ext_ops .into_iter() - .filter_map(|e| ArrayOp::from_extension_op(e).ok()) + .filter_map(|e| VArrayOp::from_extension_op(e).ok()) .collect_array() .unwrap(); - assert_eq!(array_get, ArrayOpDef::get.to_concrete(i64_t(), 64)); + assert_eq!(array_get, VArrayOpDef::get.to_concrete(i64_t(), 64)); } #[test] @@ -868,7 +880,7 @@ mod test { // 1. Lower List to Array<10, T> UNLESS T is usize_t() or i64_t lowerer.replace_parametrized_type(list_type_def(), |args| { let ty = just_elem_type(args); - (![usize_t(), i64_t()].contains(ty)).then_some(array_type(10, ty.clone())) + (![usize_t(), i64_t()].contains(ty)).then_some(value_array_type(10, ty.clone())) }); { let mut h = backup.clone(); @@ -876,7 +888,7 @@ mod test { let sig = h.signature(h.root()).unwrap(); assert_eq!( sig.input(), - &TypeRow::from(vec![list_type(usize_t()), array_type(10, bool_t())]) + &TypeRow::from(vec![list_type(usize_t()), value_array_type(10, bool_t())]) ); assert_eq!(sig.input(), sig.output()); } @@ -898,7 +910,7 @@ mod test { let sig = h.signature(h.root()).unwrap(); assert_eq!( sig.input(), - &TypeRow::from(vec![list_type(i64_t()), array_type(10, bool_t())]) + &TypeRow::from(vec![list_type(i64_t()), value_array_type(10, bool_t())]) ); assert_eq!(sig.input(), sig.output()); // This will have to update inside the Const @@ -915,7 +927,7 @@ mod test { let mut h = backup; lowerer.replace_parametrized_type( list_type_def(), - Box::new(|args: &[TypeArg]| Some(array_type(4, just_elem_type(args).clone()))), + Box::new(|args: &[TypeArg]| Some(value_array_type(4, just_elem_type(args).clone()))), ); lowerer.replace_consts_parametrized(list_type_def(), |opaq, repl| { // First recursively transform the contents @@ -925,7 +937,7 @@ mod test { let lv = opaq.value().downcast_ref::().unwrap(); Ok(Some( - ArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(), + VArrayValue::new(lv.get_element_type().clone(), lv.get_contents().to_vec()).into(), )) }); lowerer.run(&mut h).unwrap(); @@ -934,7 +946,10 @@ mod test { h.get_optype(pred.node()) .as_load_constant() .map(|lc| lc.constant_type()), - Some(&Type::new_sum(vec![Type::from(array_type(4, i64_t())); 2])) + Some(&Type::new_sum(vec![ + Type::from(value_array_type(4, i64_t())); + 2 + ])) ); } @@ -1023,17 +1038,19 @@ mod test { } #[rstest] - #[case(&[])] - #[case(&[3])] - #[case(&[5,7,11,13,17,19])] - fn array_const(#[case] vals: &[u64]) { - use super::handlers::array_const; - let mut dfb = DFGBuilder::new(inout_sig( - type_row![], - array_type(vals.len() as _, usize_t()), - )) - .unwrap(); - let c = dfb.add_load_value(ArrayValue::new( + #[case(&[], Array)] + #[case(&[], ValueArray)] + #[case(&[3], Array)] + #[case(&[3], ValueArray)] + #[case(&[5,7,11,13,17,19], Array)] + #[case(&[5,7,11,13,17,19], ValueArray)] + fn array_const(#[case] vals: &[u64], #[case] _kind: AK) + where + GenericArrayValue: CustomConst, + { + let mut dfb = + DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap(); + let c = dfb.add_load_value(GenericArrayValue::::new( usize_t(), vals.iter().map(|u| ConstUsize::new(*u).into()), )); @@ -1053,7 +1070,7 @@ mod test { matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) if backup.get_optype(from).is_const() && to == c.node()) ); - repl.replace_consts_parametrized(array_type_def(), array_const); + repl.replace_consts_parametrized(AK::type_def(), generic_array_const::); let mut h = backup; repl.run(&mut h).unwrap(); h.validate().unwrap(); diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 573188340..848c09682 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -3,15 +3,18 @@ use hugr_core::builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{option_type, UnwrapBuilder}; +use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ - array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayValue, + array_type, Array, ArrayClone, ArrayDiscard, ArrayKind, ArrayOpBuilder, GenericArrayOpDef, + GenericArrayRepeat, GenericArrayScan, GenericArrayValue, }; use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::std_extensions::collections::value_array::ValueArray; use hugr_core::type_row; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use itertools::Itertools; @@ -43,14 +46,17 @@ pub fn list_const( Ok(Some(ListValue::new(elem_t, vals).into())) } -/// Handler for [ArrayValue] constants that recursively +/// Handler for [GenericArrayValue] constants that recursively /// [ReplaceTypes::change_value]s the elements of the list. /// Included in [ReplaceTypes::default]. -pub fn array_const( +pub fn generic_array_const( val: &OpaqueValue, repl: &ReplaceTypes, -) -> Result, ReplaceTypesError> { - let Some(av) = val.value().downcast_ref::() else { +) -> Result, ReplaceTypesError> +where + GenericArrayValue: CustomConst, +{ + let Some(av) = val.value().downcast_ref::>() else { return Ok(None); }; let mut elem_t = av.get_element_type().clone(); @@ -63,14 +69,37 @@ pub fn array_const( for v in vals.iter_mut() { repl.change_value(v)?; } - Ok(Some(ArrayValue::new(elem_t, vals).into())) + Ok(Some(GenericArrayValue::::new(elem_t, vals).into())) +} + +/// Handler for [ArrayValue] constants that recursively +/// [ReplaceTypes::change_value]s the elements of the list. +/// Included in [ReplaceTypes::default]. +/// +/// [ArrayValue]: hugr_core::std_extensions::collections::array::ArrayValue +pub fn array_const( + val: &OpaqueValue, + repl: &ReplaceTypes, +) -> Result, ReplaceTypesError> { + generic_array_const::(val, repl) +} + +/// Handler for [VArrayValue] constants that recursively +/// [ReplaceTypes::change_value]s the elements of the list. +/// Included in [ReplaceTypes::default]. +/// +/// [VArrayValue]: hugr_core::std_extensions::collections::value_array::VArrayValue +pub fn value_array_const( + val: &OpaqueValue, + repl: &ReplaceTypes, +) -> Result, ReplaceTypesError> { + generic_array_const::(val, repl) } /// Handler for copying/discarding arrays if their elements have become linear. -/// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. /// -/// [DelegatingLinearizer::default]: super::DelegatingLinearizer::default -pub fn linearize_array( +/// Generic over the concrete array implementation. +pub fn linearize_generic_array( args: &[TypeArg], num_outports: usize, lin: &CallbackHandler, @@ -92,8 +121,8 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n); - let in_type = array_type(*n, ty.clone()); + let array_scan = GenericArrayScan::::new(ty.clone(), Type::UNIT, vec![], *n); + let in_type = AK::ty(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); let [in_array] = dfb.input_wires_arr(); @@ -101,14 +130,18 @@ pub fn linearize_array( hugr: Box::new(map_fn), }); // scan has one output, an array of unit, so just ignore/discard that - dfb.add_dataflow_op(array_scan, [in_array, map_fn]).unwrap(); + let unit_arr = dfb + .add_dataflow_op(array_scan, [in_array, map_fn]) + .unwrap() + .out_wire(0); + AK::build_discard(&mut dfb, Type::UNIT, *n, unit_arr).unwrap(); dfb.finish_hugr_with_outputs([]).unwrap() }))); }; // The num_outports>1 case will simplify, and unify with the previous, when we have a // more general ArrayScan https://github.com/CQCL/hugr/issues/2041. In the meantime: let num_new = num_outports - 1; - let array_ty = array_type(*n, ty.clone()); + let array_ty = AK::ty(*n, ty.clone()); let mut dfb = DFGBuilder::new(inout_sig( array_ty.clone(), vec![array_ty.clone(); num_outports], @@ -126,7 +159,7 @@ pub fn linearize_array( .unwrap(); dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; - let repeats = vec![ArrayRepeat::new(option_ty.clone(), *n); num_new]; + let repeats = vec![GenericArrayRepeat::::new(option_ty.clone(), *n); num_new]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -140,7 +173,7 @@ pub fn linearize_array( // 2. use a scan through the input array, copying the element num_outputs times; // return the first copy, and put each of the other copies into one of the array