Skip to content

Commit

Permalink
organized the API and made the documentation friendly
Browse files Browse the repository at this point in the history
  • Loading branch information
comath committed Feb 10, 2021
1 parent c246b94 commit 94b71d9
Show file tree
Hide file tree
Showing 15 changed files with 449 additions and 233 deletions.
14 changes: 7 additions & 7 deletions goko/src/covertree/builders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl BuilderNode {
}
}

/// A construction object for a covertree. See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// A construction object for a covertree. See [`crate::covertree::CoverTreeParameters`] for docs
#[derive(Debug)]
pub struct CoverTreeBuilder {
pub(crate) scale_base: f32,
Expand Down Expand Up @@ -347,32 +347,32 @@ impl CoverTreeBuilder {
}
}

/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_scale_base(&mut self, x: f32) -> &mut Self {
self.scale_base = x;
self
}
/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_leaf_cutoff(&mut self, x: usize) -> &mut Self {
self.leaf_cutoff = x;
self
}
/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_min_res_index(&mut self, x: i32) -> &mut Self {
self.min_res_index = x;
self
}
/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_use_singletons(&mut self, x: bool) -> &mut Self {
self.use_singletons = x;
self
}
/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_verbosity(&mut self, x: u32) -> &mut Self {
self.verbosity = x;
self
}
/// See [`crate::covertree::tree::CoverTreeParameters`] for docs
/// See [`crate::covertree::CoverTreeParameters`] for docs
pub fn set_rng_seed(&mut self, x: u64) -> &mut Self {
self.rng_seed = Some(x);
self
Expand Down
7 changes: 4 additions & 3 deletions goko/src/covertree/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ use std::iter::Rev;
use std::ops::Range;
use std::slice::Iter;
use std::ops::Deref;
use serde::{Deserialize, Serialize};

use plugins::labels::*;

/// When 2 spheres overlap under a node, and there is a point in the overlap we have to decide
/// to which sphere it belongs. As we create the nodes in a particular sequence, we can assign them
/// to the first to be created or we can assign it to the nearest.
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
pub enum PartitionType {
/// Conflicts assigning a point to several eligible nodes are assigned to the nearest node.
Nearest,
Expand All @@ -80,14 +81,14 @@ pub struct CoverTreeParameters<D: PointCloud> {
pub use_singletons: bool,
/// The partition type of the tree
pub partition_type: PartitionType,
/// The point cloud this tree references
pub point_cloud: Arc<D>,
/// This should be replaced by a logging solution
pub verbosity: u32,
/// The seed to use for deterministic trees. This is xor-ed with the point index to create a seed for `rand::rngs::SmallRng`.
///
/// Pass in None if you want to use the host os's entropy instead.
pub rng_seed: Option<u64>,
/// The point cloud this tree references
pub point_cloud: Arc<D>,
/// This is where the base plugins are are stored.
pub plugins: RwLock<TreePluginSet>,
}
Expand Down
2 changes: 1 addition & 1 deletion goko/src/plugins/discrete/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ impl Dirichlet {
None
}

/// from https://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/
/// from <https://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/>
/// We assume that the Dirichlet distribution passed into this one is conditioned on this one! It assumes they have the same keys!
pub fn kl_divergence(&self, other: &Dirichlet) -> Option<f64> {
let my_total = self.total();
Expand Down
9 changes: 4 additions & 5 deletions pointcloud/src/base_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@ pub trait PointRef: Send + Sync {
fn dense_iter(&self) -> Self::DenseIter;
}

/// Metric trait. Done as a trait so that it's easy to switch out.
/// Metric trait. Done as a trait so that it's easy to switch out.
///
/// Use a specific T. Don't implement it for a generic [S], but for [f32] or [u8], as you can use SIMD.
/// This library uses `packed_simd`.
/// Implement this then benchmark it to hell, this is the core loop of everything.
pub trait Metric<T: ?Sized>: Send + Sync + 'static {
/// Distance calculator. Optimize the hell out of this if you're implementing it.
fn dist(x: &T, y: &T) -> f32;
Expand All @@ -54,7 +53,7 @@ pub trait PointCloud: Send + Sync + 'static {
/// The metric this pointcloud is bound to. Think L2
type Metric: Metric<Self::Point>;
/// Name type, could be a string or a
type Name: Sized + Clone + Eq;
type Name: Sized + Clone + Eq + Serialize;
/// The label type.
/// Summary of a set of labels
type Label: ?Sized;
Expand Down Expand Up @@ -513,7 +512,7 @@ impl<D: PointCloud, L: LabelSet> PointCloud for SimpleLabeledCloud<D, L> {
/// Enables the points in the underlying cloud to be named with strings.
pub trait NamedSet: Send + Sync + 'static {
/// Name type, could be a string or a
type Name: Sized + Clone + Eq;
type Name: Sized + Clone + Eq + Serialize;
/// Number of elements in this name set
fn len(&self) -> usize;
/// If there are no elements left in this name set
Expand Down
12 changes: 6 additions & 6 deletions pointcloud/src/summaries/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ impl Summary for CategorySummary {
/// Summary of vectors
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct VecSummary {
/// First moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// First moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment1: Vec<f32>,
/// Second moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// Second moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment2: Vec<f32>,
/// The count of the number of labels included
pub count: usize,
Expand Down Expand Up @@ -117,9 +117,9 @@ impl Summary for VecSummary {
/// Summary of a bunch of underlying floats
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct FloatSummary {
/// First moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// First moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment1: f64,
/// Second moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// Second moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment2: f64,
/// The count of the number of labels included
pub count: usize,
Expand Down Expand Up @@ -147,9 +147,9 @@ impl Summary for FloatSummary {
/// Summary of a bunch of underlying integers, more accurate for int than the float summary
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
pub struct IntSummary {
/// First moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// First moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment1: i64,
/// Second moment, see https://en.wikipedia.org/wiki/Moment_(mathematics)
/// Second moment, see <https://en.wikipedia.org/wiki/Moment_(mathematics)>
pub moment2: i64,
/// The count of the number of labels included
pub count: usize,
Expand Down
1 change: 1 addition & 0 deletions serve_goko/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
hi
6 changes: 3 additions & 3 deletions serve_goko/examples/mnist_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use pointcloud::*;
use std::path::Path;
extern crate serve_goko;
use serve_goko::*;
use serve_goko::parser::MsgPackDense;

use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use hyper::Server;

fn build_tree() -> CoverTreeWriter<DefaultLabeledCloud<L2>> {
let file_name = "../data/mnist_complex.yml";
Expand All @@ -23,7 +23,7 @@ pub async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
pretty_env_logger::init();
let ct_writer = build_tree();

let goko_server = MakeGokoService::new(MsgPackDense::new(), ct_writer);
let goko_server = MakeGokoHttpService::new(MsgPackDense::new(), ct_writer);

let addr = ([127, 0, 0, 1], 3030).into();

Expand Down
75 changes: 75 additions & 0 deletions serve_goko/src/api/knn.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use goko::CoverTreeReader;
use pointcloud::*;

use serde::{Deserialize, Serialize};
use std::ops::Deref;

use goko::errors::GokoError;

use super::{Process,NamedDistance};

/// Response: [`KnnResponse`]
#[derive(Deserialize, Serialize)]
pub struct KnnRequest<T> {
pub k: usize,
pub point: T,
}

/// Request: [`KnnRequest`]
#[derive(Deserialize, Serialize)]
pub struct KnnResponse<N> {
pub knn: Vec<NamedDistance<N>>,
}

impl<D: PointCloud, T: Deref<Target = D::Point> + Send + Sync> Process<D> for KnnRequest<T> {
type Response = KnnResponse<D::Name>;
type Error = GokoError;
fn process(self, reader: &CoverTreeReader<D>) -> Result<Self::Response, Self::Error> {
let knn = reader.knn(&self.point, self.k)?;
let pc = &reader.parameters().point_cloud;
let resp: Result<Vec<NamedDistance<D::Name>>, GokoError> = knn
.iter()
.map(|(distance, pi)| {
Ok(NamedDistance {
name: pc.name(*pi)?,
distance: *distance,
})
})
.collect();

Ok(KnnResponse { knn: resp? })
}
}

/// Response: [`RoutingKnnResponse`]
#[derive(Deserialize, Serialize)]
pub struct RoutingKnnRequest<T> {
pub k: usize,
pub point: T,
}

/// Request: [`RoutingKnnRequest`]
#[derive(Deserialize, Serialize)]
pub struct RoutingKnnResponse<N> {
pub routing_knn: Vec<NamedDistance<N>>,
}

impl<D: PointCloud, T: Deref<Target = D::Point> + Send + Sync> Process<D> for RoutingKnnRequest<T> {
type Response = RoutingKnnResponse<D::Name>;
type Error = GokoError;
fn process(self, reader: &CoverTreeReader<D>) -> Result<Self::Response, Self::Error> {
let knn = reader.routing_knn(&self.point, self.k)?;
let pc = &reader.parameters().point_cloud;
let resp: Result<Vec<NamedDistance<D::Name>>, GokoError> = knn
.iter()
.map(|(distance, pi)| {
Ok(NamedDistance {
name: pc.name(*pi)?,
distance: *distance,
})
})
.collect();

Ok(RoutingKnnResponse { routing_knn: resp? })
}
}
Loading

0 comments on commit 94b71d9

Please sign in to comment.