From d1dece78fe04230ea1ded095085a3e517cac3b0f Mon Sep 17 00:00:00 2001 From: vorboyvo Date: Wed, 24 Sep 2025 13:21:20 -0400 Subject: [PATCH] Factored out tree, made Vertex (and thus tree) equal over symmetry, added tests. --- src/main.rs | 81 +------------------------------------------ src/parser.rs | 2 +- src/test.rs | 20 ++++++++++- src/tree.rs | 95 +++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 116 insertions(+), 82 deletions(-) create mode 100644 src/tree.rs diff --git a/src/main.rs b/src/main.rs index 686957a..fcf4126 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,88 +1,9 @@ -use num_traits::Float; - mod parser; +mod tree; #[cfg(test)] mod test; -/* - * NOTES - * - each alternative should appear only once in a tree. need to figure out a - * way to enforce this. - * - looks like this is like quite difficult to do through the type system - */ - -#[derive(PartialEq, Eq, Debug)] -pub struct Alternative { - name: T -} - -#[derive(PartialEq, Eq, Debug)] -pub struct Edge { - weight: U, - destination: Tree -} - -#[derive(PartialEq, Eq, Debug)] -pub enum Vertex { - NonTerminal(Box>, Box>), - Terminal(Alternative), -} - -#[derive(PartialEq, Eq, Debug)] -pub struct Tree { - root: Vertex -} - -impl Tree { - fn left_edge_weights(&self) -> U { - match &self.root { - Vertex::NonTerminal(a, _) => { - let left = &*a; - left.weight - + left.destination.left_edge_weights() - + left.destination.right_edge_weights() - }, - Vertex::Terminal(_) => U::zero() - } - } - - fn right_edge_weights(&self) -> U { - match &self.root { - Vertex::NonTerminal(_, b) => { - let right = &*b; - right.weight - + right.destination.left_edge_weights() - + right.destination.right_edge_weights() - }, - Vertex::Terminal(_) => U::zero() - } - } - - pub fn choice_probability(&self, alternative: &T) -> U { - match &self.root { - Vertex::Terminal(alt) => { - if alt.name == *alternative { U::one() } else { U::zero() } - }, - Vertex::NonTerminal(a, b) => { - let left = &*a; - let right = &*b; - let left_edge_weights = self.left_edge_weights(); - let right_edge_weights = self.right_edge_weights(); - let left_choice_probability = - left.destination.choice_probability(alternative); - let right_choice_probability = - right.destination.choice_probability(alternative); - ( - left_edge_weights * left_choice_probability - + right_edge_weights * right_choice_probability - ) - /(left_edge_weights + right_edge_weights) - } - } - } -} - fn main() { unimplemented!(); } diff --git a/src/parser.rs b/src/parser.rs index e26bf4a..ce75f90 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -13,7 +13,7 @@ use nom::{branch::alt, bytes::complete::take_while1, character::char, combinator::map, error::Error, number::complete::double, sequence::{delimited, pair}, AsChar, IResult, Parser}; -use crate::{Alternative, Edge, Tree, Vertex}; +use crate::tree::{Alternative, Edge, Tree, Vertex}; fn weighted_edge(input: &str) -> IResult<&str, Edge> { let weight = delimited( diff --git a/src/test.rs b/src/test.rs index 2a0c96e..5102e59 100644 --- a/src/test.rs +++ b/src/test.rs @@ -1,6 +1,8 @@ -use super::*; +use super::tree::*; +use super::parser; use assert_approx_eq::assert_approx_eq; +// helper methods for testing parser, return constant heap-allocated value fn simple_symmetric_tree() -> Tree { let a = Alternative { name: "A".to_owned() @@ -20,6 +22,7 @@ fn simple_symmetric_tree() -> Tree { Tree {root: root} } +// helper methods for testing parser, return constant heap-allocated value fn simple_asymmetric_tree() -> Tree { let a = Alternative { name: "A".to_owned() @@ -39,6 +42,7 @@ fn simple_asymmetric_tree() -> Tree { Tree {root: root} } +// helper methods for testing parser, return constant heap-allocated value fn complex_tree() -> Tree { let a = Alternative { name: "A".to_owned() @@ -126,3 +130,17 @@ fn parser_test_3() { complex_tree() ) } + +#[test] +fn test_symmetry_positive() { + let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::subtree("([1.0]B[3.0]A)").unwrap(); + assert_eq!(a, b) +} + +#[test] +fn test_symmetry_negative() { + let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap(); + let (_, b) = parser::subtree("([1.0]A[3.0]B)").unwrap(); + assert_ne!(a, b) +} diff --git a/src/tree.rs b/src/tree.rs new file mode 100644 index 0000000..82559db --- /dev/null +++ b/src/tree.rs @@ -0,0 +1,95 @@ +use num_traits::Float; + +#[derive(PartialEq, Eq, Debug)] +pub struct Alternative { + pub name: T +} + +#[derive(PartialEq, Debug)] +pub struct Edge { + pub weight: U, + pub destination: Tree +} + +#[derive(Debug)] +pub enum Vertex { + NonTerminal(Box>, Box>), + Terminal(Alternative), +} + +// implement symmetry +impl PartialEq for Vertex{ + fn eq(&self, other: &Self) -> bool { + match self { + Vertex::NonTerminal(a, b) => { + if let Vertex::NonTerminal(c, d) = other { + (a == c && b == d) || (a == d && b == c) + } else { + false + } + } + Vertex::Terminal(a) => { + if let Vertex::Terminal(b) = other { + a == b + } else { + false + } + } + } + } +} + +#[derive(PartialEq, Debug)] +pub struct Tree { + pub root: Vertex +} + +impl Tree { + fn left_edge_weights(&self) -> U { + match &self.root { + Vertex::NonTerminal(a, _) => { + let left = &*a; + left.weight + + left.destination.left_edge_weights() + + left.destination.right_edge_weights() + }, + Vertex::Terminal(_) => U::zero() + } + } + + fn right_edge_weights(&self) -> U { + match &self.root { + Vertex::NonTerminal(_, b) => { + let right = &*b; + right.weight + + right.destination.left_edge_weights() + + right.destination.right_edge_weights() + }, + Vertex::Terminal(_) => U::zero() + } + } + + pub fn choice_probability(&self, alternative: &T) -> U { + match &self.root { + Vertex::Terminal(alt) => { + if alt.name == *alternative { U::one() } else { U::zero() } + }, + Vertex::NonTerminal(a, b) => { + let left = &*a; + let right = &*b; + let left_edge_weights = self.left_edge_weights(); + let right_edge_weights = self.right_edge_weights(); + let left_choice_probability = + left.destination.choice_probability(alternative); + let right_choice_probability = + right.destination.choice_probability(alternative); + ( + left_edge_weights * left_choice_probability + + right_edge_weights * right_choice_probability + ) + /(left_edge_weights + right_edge_weights) + } + } + } +} +