Factored out tree, made Vertex (and thus tree) equal over symmetry, added tests.
This commit is contained in:
parent
2a002bd286
commit
d1dece78fe
81
src/main.rs
81
src/main.rs
|
|
@ -1,88 +1,9 @@
|
||||||
use num_traits::Float;
|
|
||||||
|
|
||||||
mod parser;
|
mod parser;
|
||||||
|
mod tree;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
||||||
/*
|
|
||||||
* NOTES
|
|
||||||
* - each alternative should appear only once in a tree. need to figure out a
|
|
||||||
* way to enforce this.
|
|
||||||
* - looks like this is like quite difficult to do through the type system
|
|
||||||
*/
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Debug)]
|
|
||||||
pub struct Alternative<T: Eq> {
|
|
||||||
name: T
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Debug)]
|
|
||||||
pub struct Edge<T: Eq, U: Float> {
|
|
||||||
weight: U,
|
|
||||||
destination: Tree<T, U>
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Debug)]
|
|
||||||
pub enum Vertex<T: Eq, U: Float> {
|
|
||||||
NonTerminal(Box<Edge<T, U>>, Box<Edge<T, U>>),
|
|
||||||
Terminal(Alternative<T>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(PartialEq, Eq, Debug)]
|
|
||||||
pub struct Tree<T: Eq, U: Float> {
|
|
||||||
root: Vertex<T, U>
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: Eq, U: Float> Tree<T, U> {
|
|
||||||
fn left_edge_weights(&self) -> U {
|
|
||||||
match &self.root {
|
|
||||||
Vertex::NonTerminal(a, _) => {
|
|
||||||
let left = &*a;
|
|
||||||
left.weight
|
|
||||||
+ left.destination.left_edge_weights()
|
|
||||||
+ left.destination.right_edge_weights()
|
|
||||||
},
|
|
||||||
Vertex::Terminal(_) => U::zero()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn right_edge_weights(&self) -> U {
|
|
||||||
match &self.root {
|
|
||||||
Vertex::NonTerminal(_, b) => {
|
|
||||||
let right = &*b;
|
|
||||||
right.weight
|
|
||||||
+ right.destination.left_edge_weights()
|
|
||||||
+ right.destination.right_edge_weights()
|
|
||||||
},
|
|
||||||
Vertex::Terminal(_) => U::zero()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn choice_probability(&self, alternative: &T) -> U {
|
|
||||||
match &self.root {
|
|
||||||
Vertex::Terminal(alt) => {
|
|
||||||
if alt.name == *alternative { U::one() } else { U::zero() }
|
|
||||||
},
|
|
||||||
Vertex::NonTerminal(a, b) => {
|
|
||||||
let left = &*a;
|
|
||||||
let right = &*b;
|
|
||||||
let left_edge_weights = self.left_edge_weights();
|
|
||||||
let right_edge_weights = self.right_edge_weights();
|
|
||||||
let left_choice_probability =
|
|
||||||
left.destination.choice_probability(alternative);
|
|
||||||
let right_choice_probability =
|
|
||||||
right.destination.choice_probability(alternative);
|
|
||||||
(
|
|
||||||
left_edge_weights * left_choice_probability
|
|
||||||
+ right_edge_weights * right_choice_probability
|
|
||||||
)
|
|
||||||
/(left_edge_weights + right_edge_weights)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
unimplemented!();
|
unimplemented!();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@
|
||||||
|
|
||||||
use nom::{branch::alt, bytes::complete::take_while1, character::char, combinator::map, error::Error, number::complete::double, sequence::{delimited, pair}, AsChar, IResult, Parser};
|
use nom::{branch::alt, bytes::complete::take_while1, character::char, combinator::map, error::Error, number::complete::double, sequence::{delimited, pair}, AsChar, IResult, Parser};
|
||||||
|
|
||||||
use crate::{Alternative, Edge, Tree, Vertex};
|
use crate::tree::{Alternative, Edge, Tree, Vertex};
|
||||||
|
|
||||||
fn weighted_edge(input: &str) -> IResult<&str, Edge<String, f64>> {
|
fn weighted_edge(input: &str) -> IResult<&str, Edge<String, f64>> {
|
||||||
let weight = delimited(
|
let weight = delimited(
|
||||||
|
|
|
||||||
20
src/test.rs
20
src/test.rs
|
|
@ -1,6 +1,8 @@
|
||||||
use super::*;
|
use super::tree::*;
|
||||||
|
use super::parser;
|
||||||
use assert_approx_eq::assert_approx_eq;
|
use assert_approx_eq::assert_approx_eq;
|
||||||
|
|
||||||
|
// helper methods for testing parser, return constant heap-allocated value
|
||||||
fn simple_symmetric_tree() -> Tree<String, f64> {
|
fn simple_symmetric_tree() -> Tree<String, f64> {
|
||||||
let a = Alternative {
|
let a = Alternative {
|
||||||
name: "A".to_owned()
|
name: "A".to_owned()
|
||||||
|
|
@ -20,6 +22,7 @@ fn simple_symmetric_tree() -> Tree<String, f64> {
|
||||||
Tree {root: root}
|
Tree {root: root}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// helper methods for testing parser, return constant heap-allocated value
|
||||||
fn simple_asymmetric_tree() -> Tree<String, f64> {
|
fn simple_asymmetric_tree() -> Tree<String, f64> {
|
||||||
let a = Alternative {
|
let a = Alternative {
|
||||||
name: "A".to_owned()
|
name: "A".to_owned()
|
||||||
|
|
@ -39,6 +42,7 @@ fn simple_asymmetric_tree() -> Tree<String, f64> {
|
||||||
Tree {root: root}
|
Tree {root: root}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// helper methods for testing parser, return constant heap-allocated value
|
||||||
fn complex_tree() -> Tree<String, f64> {
|
fn complex_tree() -> Tree<String, f64> {
|
||||||
let a = Alternative {
|
let a = Alternative {
|
||||||
name: "A".to_owned()
|
name: "A".to_owned()
|
||||||
|
|
@ -126,3 +130,17 @@ fn parser_test_3() {
|
||||||
complex_tree()
|
complex_tree()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_symmetry_positive() {
|
||||||
|
let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap();
|
||||||
|
let (_, b) = parser::subtree("([1.0]B[3.0]A)").unwrap();
|
||||||
|
assert_eq!(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_symmetry_negative() {
|
||||||
|
let (_, a) = parser::subtree("([3.0]A[1.0]B)").unwrap();
|
||||||
|
let (_, b) = parser::subtree("([1.0]A[3.0]B)").unwrap();
|
||||||
|
assert_ne!(a, b)
|
||||||
|
}
|
||||||
|
|
|
||||||
95
src/tree.rs
Normal file
95
src/tree.rs
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
use num_traits::Float;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Debug)]
|
||||||
|
pub struct Alternative<T: Eq> {
|
||||||
|
pub name: T
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Debug)]
|
||||||
|
pub struct Edge<T: Eq, U: Float> {
|
||||||
|
pub weight: U,
|
||||||
|
pub destination: Tree<T, U>
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Vertex<T: Eq, U: Float> {
|
||||||
|
NonTerminal(Box<Edge<T, U>>, Box<Edge<T, U>>),
|
||||||
|
Terminal(Alternative<T>),
|
||||||
|
}
|
||||||
|
|
||||||
|
// implement symmetry
|
||||||
|
impl<T: Eq, U: Float> PartialEq for Vertex<T, U>{
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
match self {
|
||||||
|
Vertex::NonTerminal(a, b) => {
|
||||||
|
if let Vertex::NonTerminal(c, d) = other {
|
||||||
|
(a == c && b == d) || (a == d && b == c)
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Vertex::Terminal(a) => {
|
||||||
|
if let Vertex::Terminal(b) = other {
|
||||||
|
a == b
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(PartialEq, Debug)]
|
||||||
|
pub struct Tree<T: Eq, U: Float> {
|
||||||
|
pub root: Vertex<T, U>
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Eq, U: Float> Tree<T, U> {
|
||||||
|
fn left_edge_weights(&self) -> U {
|
||||||
|
match &self.root {
|
||||||
|
Vertex::NonTerminal(a, _) => {
|
||||||
|
let left = &*a;
|
||||||
|
left.weight
|
||||||
|
+ left.destination.left_edge_weights()
|
||||||
|
+ left.destination.right_edge_weights()
|
||||||
|
},
|
||||||
|
Vertex::Terminal(_) => U::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn right_edge_weights(&self) -> U {
|
||||||
|
match &self.root {
|
||||||
|
Vertex::NonTerminal(_, b) => {
|
||||||
|
let right = &*b;
|
||||||
|
right.weight
|
||||||
|
+ right.destination.left_edge_weights()
|
||||||
|
+ right.destination.right_edge_weights()
|
||||||
|
},
|
||||||
|
Vertex::Terminal(_) => U::zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn choice_probability(&self, alternative: &T) -> U {
|
||||||
|
match &self.root {
|
||||||
|
Vertex::Terminal(alt) => {
|
||||||
|
if alt.name == *alternative { U::one() } else { U::zero() }
|
||||||
|
},
|
||||||
|
Vertex::NonTerminal(a, b) => {
|
||||||
|
let left = &*a;
|
||||||
|
let right = &*b;
|
||||||
|
let left_edge_weights = self.left_edge_weights();
|
||||||
|
let right_edge_weights = self.right_edge_weights();
|
||||||
|
let left_choice_probability =
|
||||||
|
left.destination.choice_probability(alternative);
|
||||||
|
let right_choice_probability =
|
||||||
|
right.destination.choice_probability(alternative);
|
||||||
|
(
|
||||||
|
left_edge_weights * left_choice_probability
|
||||||
|
+ right_edge_weights * right_choice_probability
|
||||||
|
)
|
||||||
|
/(left_edge_weights + right_edge_weights)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Loading…
Reference in a new issue