use num_traits::Float; /* * NOTES * - each alternative should appear only once in a tree. need to figure out a * way to enforce this. * - looks like this is like quite difficult to do through the type system */ #[derive(PartialEq, Eq)] struct Alternative { name: T } struct Edge { weight: U, destination: Tree } enum Vertex { NonTerminal(Box>, Box>), Terminal(Alternative), } struct Tree { root: Vertex } impl Tree { fn left_edge_weights(&self) -> U { match &self.root { Vertex::NonTerminal(a, _) => { let left = &*a; left.weight + left.destination.left_edge_weights() + left.destination.right_edge_weights() }, Vertex::Terminal(_) => U::zero() } } fn right_edge_weights(&self) -> U { match &self.root { Vertex::NonTerminal(_, b) => { let right = &*b; right.weight + right.destination.left_edge_weights() + right.destination.right_edge_weights() }, Vertex::Terminal(_) => U::zero() } } pub fn choice_probability(&self, alternative: &T) -> U { match &self.root { Vertex::Terminal(alt) => { if alt.name == *alternative { U::one() } else { U::zero() } }, Vertex::NonTerminal(a, b) => { let left = &*a; let right = &*b; let left_edge_weights = self.left_edge_weights(); let right_edge_weights = self.right_edge_weights(); let left_choice_probability = left.destination.choice_probability(alternative); let right_choice_probability = right.destination.choice_probability(alternative); ( left_edge_weights * left_choice_probability + right_edge_weights * right_choice_probability ) /(left_edge_weights + right_edge_weights) } } } } fn main() { unimplemented!(); } #[cfg(test)] mod tests { use super::*; use assert_approx_eq::assert_approx_eq; // A test for the simplest symmetric case #[test] fn choice_probability_test_1() { let a = Alternative {name: "A"}; let b = Alternative {name: "B"}; let edge_a = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(a)} }; let edge_b = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(b)} }; let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); let tree = Tree {root: root}; assert_eq!(tree.choice_probability(&"A"), 0.5); } // A test for the simplest asymmetric case #[test] fn choice_probability_test_2() { let a = Alternative {name: "A"}; let b = Alternative {name: "B"}; let edge_a = Edge {weight: 3.0, destination: Tree {root: Vertex::Terminal(a)}}; let edge_b = Edge {weight: 1.0, destination: Tree {root: Vertex::Terminal(b)}}; let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); let tree = Tree {root: root}; assert_eq!(tree.choice_probability(&"A"), 0.75); } // A test for depth higher than 1 #[test] fn choice_probability_test_3() { let a = Alternative {name: "A"}; let b = Alternative {name: "B"}; let c = Alternative {name: "C"}; let edge_a = Edge { weight: 2.5, destination: Tree {root: Vertex::Terminal(a)} }; let edge_b = Edge { weight: 1.0, destination: Tree {root: Vertex::Terminal(b)} }; let edge_ab = Edge{ weight: 0.5, destination: Tree { root: Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)) } }; let edge_c = Edge{ weight: 1.0, destination: Tree {root: Vertex::Terminal(c)} }; let root = Vertex::NonTerminal(Box::new(edge_ab), Box::new(edge_c)); let tree = Tree {root: root}; assert_approx_eq!( tree.choice_probability(&"A"), (2.5/(2.5+1.0))*((2.5+1.0+0.5)/(2.5+1.0+0.5+1.0)) ); } }