Skip to content

Commit

Permalink
Fix covariance and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vks committed Dec 14, 2023
1 parent 8ec2dd6 commit 339fa05
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,24 @@ use serde_derive::{Deserialize, Serialize};

use crate::Merge;

/// Estimate the arithmetic mean and the covariance of a sequence of number pairs
/// Estimate the arithmetic means and the covariance of a sequence of number pairs
/// ("population").
///
/// Because the variances are calculated as well, this can be used to calculate the Pearson
/// correlation coefficient.
///
///
/// ## Example
///
/// ```
/// use average::Covariance;
///
/// let a: Covariance = [(1., 5.), (2., 4.), (3., 3.), (4., 2.), (5., 1.)].iter().cloned().collect();
/// assert_eq!(a.mean_x(), 3.);
/// assert_eq!(a.mean_y(), 3.);
/// assert_eq!(a.population_covariance(), -2.5);
/// assert_eq!(a.sample_covariance(), -2.0);
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Covariance {
Expand Down Expand Up @@ -40,13 +56,13 @@ impl Covariance {
let delta_x = x - self.avg_x;
let delta_y = y - self.avg_y;

self.avg_x += delta_x / self.n.to_f64().unwrap();
self.avg_x += delta_x / n;
self.sum_x_2 += delta_x * delta_x * n * (n - 1.);

self.avg_y += delta_y / self.n.to_f64().unwrap();
self.avg_y += delta_y / n;
self.sum_y_2 += delta_y * delta_y * n * (n - 1.);

self.sum_prod += delta_x * delta_y;
self.sum_prod += delta_x * (y - self.avg_y);
}

/// Calculate the population covariance of the sample.
Expand All @@ -56,7 +72,7 @@ impl Covariance {
/// Returns NaN for an empty sample.
#[inline]
pub fn population_covariance(&self) -> f64 {
if self.n < 2 {
if self.n < 1 {
return f64::NAN;
}
self.sum_prod / self.n.to_f64().unwrap()
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
#![forbid(missing_docs)]
#![forbid(missing_debug_implementations)]
#![cfg_attr(feature = "nightly", feature(generic_const_exprs))]
#[cfg(feature = "std")] extern crate std;

#[macro_use]
mod macros;
Expand Down
43 changes: 43 additions & 0 deletions tests/integration/covariance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use average::Covariance;

#[test]
fn simple() {
let mut cov = Covariance::new();
assert!(cov.mean_x().is_nan());
assert!(cov.mean_y().is_nan());
assert!(cov.population_covariance().is_nan());
assert!(cov.sample_covariance().is_nan());
assert!(cov.population_pearson().is_nan());
assert!(cov.sample_pearson().is_nan());

cov.add(1., 5.);
assert_eq!(cov.mean_x(), 1.);
assert_eq!(cov.mean_y(), 5.);
assert_eq!(cov.population_covariance(), 0.);
assert!(cov.sample_covariance().is_nan());
// TODO: pearson

cov.add(2., 4.);
assert_eq!(cov.mean_x(), 1.5);
assert_eq!(cov.mean_y(), 4.5);
assert_eq!(cov.population_covariance(), -0.25);
assert_eq!(cov.sample_covariance(), -0.5);

cov.add(3., 3.);
assert_eq!(cov.mean_x(), 2.);
assert_eq!(cov.mean_y(), 4.);
assert_eq!(cov.population_covariance(), -2./3.);
assert_eq!(cov.sample_covariance(), -1.);

cov.add(4., 2.);
assert_eq!(cov.mean_x(), 2.5);
assert_eq!(cov.mean_y(), 3.5);
assert_eq!(cov.population_covariance(), -1.25);
assert_eq!(cov.sample_covariance(), -5./3.);

cov.add(5., 1.);
assert_eq!(cov.mean_x(), 3.);
assert_eq!(cov.mean_y(), 3.);
assert_eq!(cov.population_covariance(), -2.0);
assert_eq!(cov.sample_covariance(), -2.5);
}
1 change: 1 addition & 0 deletions tests/integration/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ mod skewness;
#[cfg(feature = "std")]
mod streaming_stats;
mod weighted_mean;
mod covariance;

0 comments on commit 339fa05

Please sign in to comment.