From 6ae8385f7d9948874eebf9b9dd04c6ab92f6b4d2 Mon Sep 17 00:00:00 2001 From: vorboyvo Date: Wed, 10 Sep 2025 19:22:29 -0400 Subject: [PATCH] Initial commit. --- Cargo.lock | 32 ++++++++++++ Cargo.toml | 8 +++ flake.lock | 61 ++++++++++++++++++++++ flake.nix | 29 ++++++++++ src/main.rs | 148 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 278 insertions(+) create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 src/main.rs diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..fb83e0f --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,32 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "assert_approx_eq" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c07dab4369547dbe5114677b33fbbf724971019f3818172d59a97a61c774ffd" + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "code" +version = "0.1.0" +dependencies = [ + "assert_approx_eq", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..062a56c --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "code" +version = "0.1.0" +edition = "2024" + +[dependencies] +assert_approx_eq = "1.1.0" +num-traits = "0.2" diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..c3165d8 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1753694789, + "narHash": "sha256-cKgvtz6fKuK1Xr5LQW/zOUiAC0oSQoA9nOISB0pJZqM=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "dc9637876d0dcc8c9e5e22986b857632effeb727", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..b86984a --- /dev/null +++ b/flake.nix @@ -0,0 +1,29 @@ +{ + description = "A very basic flake"; + + inputs = { + nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + devShell = pkgs.mkShell { + buildInputs = with pkgs; [ + cargo rustc rust-analyzer clippy + ]; + }; + formatter = pkgs.nixfmt-rfc-style; + } + ); +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c803791 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,148 @@ +use num_traits::Float; + +/* + * NOTES + * - each alternative should appear only once in a tree. need to figure out a + * way to enforce this. + * - looks like this is like quite difficult to do through the type system + */ + +#[derive(PartialEq, Eq)] +struct Alternative { + name: T +} + +struct Edge { + weight: U, + destination: Tree +} + +enum Vertex { + NonTerminal(Box>, Box>), + Terminal(Alternative), +} + +struct Tree { + root: Vertex +} + +impl Tree { + fn left_edge_weights(&self) -> U { + match &self.root { + Vertex::NonTerminal(a, _) => { + let left = &*a; + left.weight + + left.destination.left_edge_weights() + + left.destination.right_edge_weights() + }, + Vertex::Terminal(_) => U::zero() + } + } + + fn right_edge_weights(&self) -> U { + match &self.root { + Vertex::NonTerminal(_, b) => { + let right = &*b; + right.weight + + right.destination.left_edge_weights() + + right.destination.right_edge_weights() + }, + Vertex::Terminal(_) => U::zero() + } + } + + pub fn choice_probability(&self, alternative: &T) -> U { + match &self.root { + Vertex::Terminal(alt) => { + if alt.name == *alternative { U::one() } else { U::zero() } + }, + Vertex::NonTerminal(a, b) => { + let left = &*a; + let right = &*b; + let left_edge_weights = self.left_edge_weights(); + let right_edge_weights = self.right_edge_weights(); + let left_choice_probability = + left.destination.choice_probability(alternative); + let right_choice_probability = + right.destination.choice_probability(alternative); + ( + left_edge_weights * left_choice_probability + + right_edge_weights * right_choice_probability + ) + /(left_edge_weights + right_edge_weights) + } + } + } +} + +fn main() { + unimplemented!(); +} + +#[cfg(test)] +mod tests { + use super::*; + use assert_approx_eq::assert_approx_eq; + + // A test for the simplest symmetric case + #[test] + fn choice_probability_test_1() { + let a = Alternative {name: "A"}; + let b = Alternative {name: "B"}; + let edge_a = Edge { + weight: 1.0, + destination: Tree {root: Vertex::Terminal(a)} + }; + let edge_b = Edge { + weight: 1.0, + destination: Tree {root: Vertex::Terminal(b)} + }; + let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); + let tree = Tree {root: root}; + assert_eq!(tree.choice_probability(&"A"), 0.5); + } + + // A test for the simplest asymmetric case + #[test] + fn choice_probability_test_2() { + let a = Alternative {name: "A"}; + let b = Alternative {name: "B"}; + let edge_a = Edge {weight: 3.0, destination: Tree {root: Vertex::Terminal(a)}}; + let edge_b = Edge {weight: 1.0, destination: Tree {root: Vertex::Terminal(b)}}; + let root = Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)); + let tree = Tree {root: root}; + assert_eq!(tree.choice_probability(&"A"), 0.75); + } + + // A test for depth higher than 1 + #[test] + fn choice_probability_test_3() { + let a = Alternative {name: "A"}; + let b = Alternative {name: "B"}; + let c = Alternative {name: "C"}; + let edge_a = Edge { + weight: 2.5, + destination: Tree {root: Vertex::Terminal(a)} + }; + let edge_b = Edge { + weight: 1.0, + destination: Tree {root: Vertex::Terminal(b)} + }; + let edge_ab = Edge{ + weight: 0.5, + destination: Tree { + root: Vertex::NonTerminal(Box::new(edge_a), Box::new(edge_b)) + } + }; + let edge_c = Edge{ + weight: 1.0, + destination: Tree {root: Vertex::Terminal(c)} + }; + let root = Vertex::NonTerminal(Box::new(edge_ab), Box::new(edge_c)); + let tree = Tree {root: root}; + assert_approx_eq!( + tree.choice_probability(&"A"), + (2.5/(2.5+1.0))*((2.5+1.0+0.5)/(2.5+1.0+0.5+1.0)) + ); + } +}