use num_traits::Float; mod parser; /* * NOTES * - each alternative should appear only once in a tree. need to figure out a * way to enforce this. * - looks like this is like quite difficult to do through the type system */ #[derive(PartialEq, Eq, Debug)] pub struct Alternative { name: T } #[derive(PartialEq, Eq, Debug)] pub struct Edge { weight: U, destination: Tree } #[derive(PartialEq, Eq, Debug)] pub enum Vertex { NonTerminal(Box>, Box>), Terminal(Alternative), } #[derive(PartialEq, Eq, Debug)] pub struct Tree { root: Vertex } impl Tree { fn left_edge_weights(&self) -> U { match &self.root { Vertex::NonTerminal(a, _) => { let left = &*a; left.weight + left.destination.left_edge_weights() + left.destination.right_edge_weights() }, Vertex::Terminal(_) => U::zero() } } fn right_edge_weights(&self) -> U { match &self.root { Vertex::NonTerminal(_, b) => { let right = &*b; right.weight + right.destination.left_edge_weights() + right.destination.right_edge_weights() }, Vertex::Terminal(_) => U::zero() } } pub fn choice_probability(&self, alternative: &T) -> U { match &self.root { Vertex::Terminal(alt) => { if alt.name == *alternative { U::one() } else { U::zero() } }, Vertex::NonTerminal(a, b) => { let left = &*a; let right = &*b; let left_edge_weights = self.left_edge_weights(); let right_edge_weights = self.right_edge_weights(); let left_choice_probability = left.destination.choice_probability(alternative); let right_choice_probability = right.destination.choice_probability(alternative); ( left_edge_weights * left_choice_probability + right_edge_weights * right_choice_probability ) /(left_edge_weights + right_edge_weights) } } } } fn main() { unimplemented!(); } #[cfg(test)] mod tests { use super::*; use assert_approx_eq::assert_approx_eq; fn simple_symmetric_tree() -> Tree { let a = Alternative { name: "A".to_owned() }; let b = Alternative { name: "B".to_owned() }; let edge_a = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(a)} }; let edge_b = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(b)} }; let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); Tree {root: root} } fn simple_asymmetric_tree() -> Tree { let a = Alternative { name: "A".to_owned() }; let b = Alternative { name: "B".to_owned() }; let edge_a = Edge { weight: 3.0, destination: Tree {root: Vertex::Terminal(a)} }; let edge_b = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(b)} }; let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); Tree {root: root} } fn complex_tree() -> Tree { let a = Alternative { name: "A".to_owned() }; let b = Alternative { name: "B".to_owned() }; let c = Alternative { name: "C".to_owned() }; let edge_a = Edge { weight: 2.5, destination: Tree {root: Vertex::Terminal(a)} }; let edge_b = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(b)} }; let edge_ab = Edge{ weight: 0.5, destination: Tree { root: Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)) } }; let edge_c = Edge{ weight: 1.0, destination: Tree {root: Vertex::Terminal(c)} }; let root = Vertex::NonTerminal(Box::new(edge_ab), Box::new(edge_c)); Tree {root: root} } // A test for the simplest symmetric case #[test] fn choice_probability_test_1() { assert_eq!(simple_symmetric_tree().choice_probability(&("A".to_owned())), 0.5); } // A test for the simplest asymmetric case #[test] fn choice_probability_test_2() { assert_eq!(simple_asymmetric_tree().choice_probability(&("A".to_owned())), 0.75); } // A test for depth higher than 1 #[test] fn choice_probability_test_3() { assert_approx_eq!( complex_tree().choice_probability(&("A".to_owned())), (2.5/(2.5+1.0))*((2.5+1.0+0.5)/(2.5+1.0+0.5+1.0)) ); } // A test for parsing the simplest symmetric case #[test] fn parser_test_1() { assert_eq!( { let (_, b) = parser::subtree("([1.0]A[1.0]B)").unwrap(); b }, simple_symmetric_tree() ) } #[test] fn parser_test_2() { assert_eq!( { let (_, b) = parser::subtree("([3.0]A[1.0]B)").unwrap(); b }, simple_asymmetric_tree() ) } #[test] fn parser_test_3() { assert_eq!( { let (_, b) = parser::subtree("([0.5]([2.5]A[1.0]B)[1.0]C)").unwrap(); b }, complex_tree() ) } }