diff --git a/src/parser.rs b/src/parser.rs index ce75f90..13f9007 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -15,23 +15,26 @@ use nom::{branch::alt, bytes::complete::take_while1, character::char, combinator use crate::tree::{Alternative, Edge, Tree, Vertex}; +fn alternative(input: &str) -> IResult<&str, Tree> { + map( + take_while1::<_, &str, Error<&str>>(AsChar::is_alpha), + |a| Tree:: { + root: Vertex::Terminal(Alternative{name: a.to_owned()}) + } + ).parse(input) +} + fn weighted_edge(input: &str) -> IResult<&str, Edge> { let weight = delimited( char::<&str, Error<&str>>('['), double, char(']') ); - let alternative = map( - take_while1::<_, &str, Error<&str>>(AsChar::is_alpha), - |a| Tree:: { - root: Vertex::Terminal(Alternative{name: a.to_owned()}) - } - ); map( - pair(weight, alt((alternative, subtree))), + pair(weight, tree), |(a, b)| Edge{ weight: a, - destination: b + destination: Box::new(b) } ).parse(input) } @@ -40,8 +43,12 @@ pub fn subtree(input: &str) -> IResult<&str, Tree> { let inner = map( pair(weighted_edge, weighted_edge), |(a, b)| Tree{ - root: Vertex::NonTerminal(Box::new(a), Box::new(b)) + root: Vertex::NonTerminal(a, b) } ); delimited(char('('), inner, char(')')).parse(input) } + +pub fn tree(input: &str) -> IResult<&str, Tree> { + alt((alternative, subtree)).parse(input) +} diff --git a/src/test.rs b/src/test.rs index d3dd396..5d6128b 100644 --- a/src/test.rs +++ b/src/test.rs @@ -12,13 +12,13 @@ fn simple_symmetric_tree() -> Tree { }; let edge_a = Edge { weight: 1.0, - destination: Tree {root: Vertex::Terminal(a)} + destination: Box::new(Tree {root: Vertex::Terminal(a)}) }; let edge_b = Edge { weight: 1.0, - destination: Tree {root: Vertex::Terminal(b)} + destination: Box::new(Tree {root: Vertex::Terminal(b)}) }; - let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); + let root = Vertex::NonTerminal(edge_a, edge_b); Tree {root: root} } @@ -32,13 +32,13 @@ fn simple_asymmetric_tree() -> Tree { }; let edge_a = Edge { weight: 3.0, - destination: Tree {root: Vertex::Terminal(a)} + destination: Box::new(Tree {root: Vertex::Terminal(a)}) }; let edge_b = Edge { weight: 1.0, - destination: Tree {root: Vertex::Terminal(b)} + destination: Box::new(Tree {root: Vertex::Terminal(b)}) }; - let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); + let root = Vertex::NonTerminal(edge_a, edge_b); Tree {root: root} } @@ -55,23 +55,23 @@ fn complex_tree() -> Tree { }; let edge_a = Edge { weight: 2.5, - destination: Tree {root: Vertex::Terminal(a)} + destination: Box::new(Tree {root: Vertex::Terminal(a)}) }; let edge_b = Edge { weight: 1.0, - destination: Tree {root: Vertex::Terminal(b)} + destination: Box::new(Tree {root: Vertex::Terminal(b)}) }; let edge_ab = Edge{ weight: 0.5, - destination: Tree { - root: Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)) - } + destination: Box::new(Tree { + root: Vertex::NonTerminal(edge_a, edge_b) + }) }; let edge_c = Edge{ weight: 1.0, - destination: Tree {root: Vertex::Terminal(c)} + destination: Box::new(Tree {root: Vertex::Terminal(c)}) }; - let root = Vertex::NonTerminal(Box::new(edge_ab), Box::new(edge_c)); + let root = Vertex::NonTerminal(edge_ab, edge_c); Tree {root: root} } @@ -102,7 +102,7 @@ fn choice_probability_test_3() { fn parser_test_1() { assert_eq!( { - let (_, b) = parser::subtree("([1.0]A[1.0]B)").unwrap(); + let (_, b) = parser::tree("([1.0]A[1.0]B)").unwrap(); b }, simple_symmetric_tree() @@ -113,7 +113,7 @@ fn parser_test_1() { fn parser_test_2() { assert_eq!( { - let (_, b) = parser::subtree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::tree("([3.0]A[1.0]B)").unwrap(); b }, simple_asymmetric_tree() @@ -124,7 +124,7 @@ fn parser_test_2() { fn parser_test_3() { assert_eq!( { - let (_, b) = parser::subtree("([0.5]([2.5]A[1.0]B)[1.0]C)").unwrap(); + let (_, b) = parser::tree("([0.5]([2.5]A[1.0]B)[1.0]C)").unwrap(); b }, complex_tree() @@ -133,21 +133,44 @@ fn parser_test_3() { #[test] fn test_symmetry_positive() { - let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap(); - let (_, b) = parser::subtree("([1.0]B[3.0]A)").unwrap(); + let (_, a) = parser::tree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::tree("([1.0]B[3.0]A)").unwrap(); assert_eq!(a, b) } #[test] fn test_symmetry_negative() { - let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap(); - let (_, b) = parser::subtree("([1.0]A[3.0]B)").unwrap(); + let (_, a) = parser::tree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::tree("([1.0]A[3.0]B)").unwrap(); assert_ne!(a, b) } #[test] fn test_symmetry_complex() { - let (_, a) = parser::subtree("([1.0]C[0.5]([1.0]B[2.5]A))").unwrap(); - let (_, b) = parser::subtree("([0.5]([2.5]A[1.0]B)[1.0]C)").unwrap(); + let (_, a) = parser::tree("([1.0]C[0.5]([1.0]B[2.5]A))").unwrap(); + let (_, b) = parser::tree("([0.5]([2.5]A[1.0]B)[1.0]C)").unwrap(); assert_eq!(a, b) } + +#[test] +fn test_menu_simple() { + let (_, a) = parser::tree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::tree("A").unwrap(); + let Some((a_menu_tree, _)) = a.menu_tree(&["A".to_owned()]) else { + panic!("Failed to get tree from menu_tree output") + }; + assert_eq!(a_menu_tree, b) +} + +#[test] +fn test_menu_complex() { + let (_, a) = parser::tree("([1.0]C[0.5]([1.0]B[2.5]A))").unwrap(); + let (_, b) = parser::tree("([1.0]C[1.5]B)").unwrap(); + let Some((a_menu_tree, _)) = a.menu_tree(&[ + "B".to_owned(), "C".to_owned() + ]) else { + panic!("Failed to get tree from menu_tree output") + }; + assert_eq!(a_menu_tree, b) +} + diff --git a/src/tree.rs b/src/tree.rs index 3cb29f7..13db604 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -1,24 +1,24 @@ use num_traits::Float; -#[derive(PartialEq, Eq, Debug)] -pub struct Alternative { +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct Alternative { pub name: T } -#[derive(PartialEq, Debug)] -pub struct Edge { +#[derive(PartialEq, Clone, Debug)] +pub struct Edge { pub weight: U, - pub destination: Tree + pub destination: Box> } -#[derive(Debug)] -pub enum Vertex { - NonTerminal(Box>, Box>), +#[derive(Clone, Debug)] +pub enum Vertex { + NonTerminal(Edge, Edge), Terminal(Alternative), } // implement symmetry -impl PartialEq for Vertex{ +impl PartialEq for Vertex{ fn eq(&self, other: &Self) -> bool { match self { Vertex::NonTerminal(a, b) => { @@ -39,12 +39,16 @@ impl PartialEq for Vertex{ } } -#[derive(PartialEq, Debug)] -pub struct Tree { +#[derive(PartialEq, Clone, Debug)] +pub struct Tree { pub root: Vertex } -impl Tree { +pub enum TreeError { + EmptyTree, +} + +impl Tree { fn left_edge_weights(&self) -> U { match &self.root { Vertex::NonTerminal(a, _) => { @@ -91,16 +95,61 @@ impl Tree { } } } + + pub fn menu_tree(&self, alternatives: &[T]) -> Option<(Tree, U)> { + match &self.root { + Vertex::Terminal(alt) => { + if alternatives.contains(&alt.name) { + Some((self.clone(), U::zero())) + } else { + None + } + }, + Vertex::NonTerminal(a, b) => { + let a_menu = a.destination.menu_tree(alternatives); + let b_menu = b.destination.menu_tree(alternatives); + match (a_menu, b_menu) { + (Some((a_tree, a_weight)), Some((b_tree, b_weight))) => { + Some( + (Tree { + root: Vertex::NonTerminal( + Edge { + weight: a.weight + a_weight, + destination: Box::new(a_tree) + }, + Edge { + weight: b.weight + b_weight, + destination: Box::new(b_tree) + } + ) + }, U::zero()) + ) + }, + (Some((a_tree, a_weight)), None) => { + Some( + (a_tree, a.weight + a_weight) + ) + }, + (None, Some((b_tree, b_weight))) => { + Some( + (b_tree, b.weight + b_weight) + ) + }, + (None, None) => None + } + }, + } + } } -pub struct SimilarityRelation { +pub struct SimilarityRelation { left: T, right: T, with_respect_to: T } // implement symmetry -impl PartialEq for SimilarityRelation { +impl PartialEq for SimilarityRelation { fn eq(&self, other: &Self) -> bool { ( (self.left == other.left && self.right == other.right)