Initial commit.
This commit is contained in:
commit
6ae8385f7d
32
Cargo.lock
generated
Normal file
32
Cargo.lock
generated
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "assert_approx_eq"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3c07dab4369547dbe5114677b33fbbf724971019f3818172d59a97a61c774ffd"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "code"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"assert_approx_eq",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
8
Cargo.toml
Normal file
8
Cargo.toml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
[package]
|
||||
name = "code"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
assert_approx_eq = "1.1.0"
|
||||
num-traits = "0.2"
|
||||
61
flake.lock
Normal file
61
flake.lock
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
{
|
||||
"nodes": {
|
||||
"flake-utils": {
|
||||
"inputs": {
|
||||
"systems": "systems"
|
||||
},
|
||||
"locked": {
|
||||
"lastModified": 1731533236,
|
||||
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "numtide",
|
||||
"repo": "flake-utils",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1753694789,
|
||||
"narHash": "sha256-cKgvtz6fKuK1Xr5LQW/zOUiAC0oSQoA9nOISB0pJZqM=",
|
||||
"owner": "nixos",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "dc9637876d0dcc8c9e5e22986b857632effeb727",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nixos",
|
||||
"ref": "nixos-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"flake-utils": "flake-utils",
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
},
|
||||
"systems": {
|
||||
"locked": {
|
||||
"lastModified": 1681028828,
|
||||
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "nix-systems",
|
||||
"repo": "default",
|
||||
"type": "github"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
29
flake.nix
Normal file
29
flake.nix
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
{
|
||||
description = "A very basic flake";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
outputs =
|
||||
{
|
||||
self,
|
||||
nixpkgs,
|
||||
flake-utils,
|
||||
}:
|
||||
flake-utils.lib.eachDefaultSystem (
|
||||
system:
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${system};
|
||||
in
|
||||
{
|
||||
devShell = pkgs.mkShell {
|
||||
buildInputs = with pkgs; [
|
||||
cargo rustc rust-analyzer clippy
|
||||
];
|
||||
};
|
||||
formatter = pkgs.nixfmt-rfc-style;
|
||||
}
|
||||
);
|
||||
}
|
||||
148
src/main.rs
Normal file
148
src/main.rs
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
use num_traits::Float;
|
||||
|
||||
/*
|
||||
* NOTES
|
||||
* - each alternative should appear only once in a tree. need to figure out a
|
||||
* way to enforce this.
|
||||
* - looks like this is like quite difficult to do through the type system
|
||||
*/
|
||||
|
||||
#[derive(PartialEq, Eq)]
|
||||
struct Alternative<T: Eq> {
|
||||
name: T
|
||||
}
|
||||
|
||||
struct Edge<T: Eq, U: Float> {
|
||||
weight: U,
|
||||
destination: Tree<T, U>
|
||||
}
|
||||
|
||||
enum Vertex<T: Eq, U: Float> {
|
||||
NonTerminal(Box<Edge<T, U>>, Box<Edge<T, U>>),
|
||||
Terminal(Alternative<T>),
|
||||
}
|
||||
|
||||
struct Tree<T: Eq, U: Float> {
|
||||
root: Vertex<T, U>
|
||||
}
|
||||
|
||||
impl<T: Eq, U: Float> Tree<T, U> {
|
||||
fn left_edge_weights(&self) -> U {
|
||||
match &self.root {
|
||||
Vertex::NonTerminal(a, _) => {
|
||||
let left = &*a;
|
||||
left.weight
|
||||
+ left.destination.left_edge_weights()
|
||||
+ left.destination.right_edge_weights()
|
||||
},
|
||||
Vertex::Terminal(_) => U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
fn right_edge_weights(&self) -> U {
|
||||
match &self.root {
|
||||
Vertex::NonTerminal(_, b) => {
|
||||
let right = &*b;
|
||||
right.weight
|
||||
+ right.destination.left_edge_weights()
|
||||
+ right.destination.right_edge_weights()
|
||||
},
|
||||
Vertex::Terminal(_) => U::zero()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn choice_probability(&self, alternative: &T) -> U {
|
||||
match &self.root {
|
||||
Vertex::Terminal(alt) => {
|
||||
if alt.name == *alternative { U::one() } else { U::zero() }
|
||||
},
|
||||
Vertex::NonTerminal(a, b) => {
|
||||
let left = &*a;
|
||||
let right = &*b;
|
||||
let left_edge_weights = self.left_edge_weights();
|
||||
let right_edge_weights = self.right_edge_weights();
|
||||
let left_choice_probability =
|
||||
left.destination.choice_probability(alternative);
|
||||
let right_choice_probability =
|
||||
right.destination.choice_probability(alternative);
|
||||
(
|
||||
left_edge_weights * left_choice_probability
|
||||
+ right_edge_weights * right_choice_probability
|
||||
)
|
||||
/(left_edge_weights + right_edge_weights)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
unimplemented!();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use assert_approx_eq::assert_approx_eq;
|
||||
|
||||
// A test for the simplest symmetric case
|
||||
#[test]
|
||||
fn choice_probability_test_1() {
|
||||
let a = Alternative {name: "A"};
|
||||
let b = Alternative {name: "B"};
|
||||
let edge_a = Edge {
|
||||
weight: 1.0,
|
||||
destination: Tree {root: Vertex::Terminal(a)}
|
||||
};
|
||||
let edge_b = Edge {
|
||||
weight: 1.0,
|
||||
destination: Tree {root: Vertex::Terminal(b)}
|
||||
};
|
||||
let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b));
|
||||
let tree = Tree {root: root};
|
||||
assert_eq!(tree.choice_probability(&"A"), 0.5);
|
||||
}
|
||||
|
||||
// A test for the simplest asymmetric case
|
||||
#[test]
|
||||
fn choice_probability_test_2() {
|
||||
let a = Alternative {name: "A"};
|
||||
let b = Alternative {name: "B"};
|
||||
let edge_a = Edge {weight: 3.0, destination: Tree {root: Vertex::Terminal(a)}};
|
||||
let edge_b = Edge {weight: 1.0, destination: Tree {root: Vertex::Terminal(b)}};
|
||||
let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b));
|
||||
let tree = Tree {root: root};
|
||||
assert_eq!(tree.choice_probability(&"A"), 0.75);
|
||||
}
|
||||
|
||||
// A test for depth higher than 1
|
||||
#[test]
|
||||
fn choice_probability_test_3() {
|
||||
let a = Alternative {name: "A"};
|
||||
let b = Alternative {name: "B"};
|
||||
let c = Alternative {name: "C"};
|
||||
let edge_a = Edge {
|
||||
weight: 2.5,
|
||||
destination: Tree {root: Vertex::Terminal(a)}
|
||||
};
|
||||
let edge_b = Edge {
|
||||
weight: 1.0,
|
||||
destination: Tree {root: Vertex::Terminal(b)}
|
||||
};
|
||||
let edge_ab = Edge{
|
||||
weight: 0.5,
|
||||
destination: Tree {
|
||||
root: Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b))
|
||||
}
|
||||
};
|
||||
let edge_c = Edge{
|
||||
weight: 1.0,
|
||||
destination: Tree {root: Vertex::Terminal(c)}
|
||||
};
|
||||
let root = Vertex::NonTerminal(Box::new(edge_ab), Box::new(edge_c));
|
||||
let tree = Tree {root: root};
|
||||
assert_approx_eq!(
|
||||
tree.choice_probability(&"A"),
|
||||
(2.5/(2.5+1.0))*((2.5+1.0+0.5)/(2.5+1.0+0.5+1.0))
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue