Skip to content

Commit

Permalink
Merge pull request #249 from jamesmcm/port_forwarding_refactoring
Browse files Browse the repository at this point in the history
Refactor shared code in port forwarding into traits
  • Loading branch information
jamesmcm committed Feb 29, 2024
2 parents 100b9a7 + 5c4cbdd commit 61a5658
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 92 deletions.
15 changes: 7 additions & 8 deletions src/exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ use vopono_core::config::providers::{UiClient, VpnProvider};
use vopono_core::config::vpn::{verify_auth, Protocol};
use vopono_core::network::application_wrapper::ApplicationWrapper;
use vopono_core::network::firewall::Firewall;
use vopono_core::network::natpmpc::Natpmpc;
use vopono_core::network::netns::NetworkNamespace;
use vopono_core::network::network_interface::{get_active_interfaces, NetworkInterface};
use vopono_core::network::piapf::Piapf;
use vopono_core::network::port_forwarding::natpmpc::Natpmpc;
use vopono_core::network::port_forwarding::piapf::Piapf;
use vopono_core::network::port_forwarding::Forwarder;
use vopono_core::network::shadowsocks::uses_shadowsocks;
use vopono_core::network::sysctl::SysCtl;
use vopono_core::network::Forwarder;
use vopono_core::util::vopono_dir;
use vopono_core::util::{get_config_file_protocol, get_config_from_alias};
use vopono_core::util::{get_existing_namespaces, get_target_subnet};
Expand Down Expand Up @@ -154,7 +154,6 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
};

// Custom port forwarding (implementation to use for --custom-config)
// TODO: Allow fully custom handling separate callback script?
let custom_port_forwarding: Option<VpnProvider> = command
.custom_port_forwarding
.map(|x| x.to_variant())
Expand All @@ -165,7 +164,7 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
.ok()
});
if custom_port_forwarding.is_some() && custom_config.is_none() {
warn!("Custom port forwarding implementation is set, but not using custom provider config file. custom-port-forwarding setting will be ignored");
error!("Custom port forwarding implementation is set, but not using custom provider config file. custom-port-forwarding setting will be ignored");
}

// Create netns only
Expand Down Expand Up @@ -622,17 +621,17 @@ pub fn exec(command: ExecCommand, uiclient: &dyn UiClient) -> anyhow::Result<()>
Some(VpnProvider::ProtonVPN) => {
vopono_core::util::open_hosts(
&ns,
vec![vopono_core::network::natpmpc::PROTONVPN_GATEWAY],
vec![vopono_core::network::port_forwarding::natpmpc::PROTONVPN_GATEWAY],
firewall,
)?;
Some(Box::new(Natpmpc::new(&ns, callback.as_ref())?))
}
Some(p) => {
warn!("Port forwarding not supported for the selected provider: {} - ignoring --port-forwarding", p);
error!("Port forwarding not supported for the selected provider: {} - ignoring --port-forwarding", p);
None
}
None => {
warn!("--port-forwarding set but --custom-port-forwarding provider not provided for --custom-config usage. Ignoring --port-forwarding");
error!("--port-forwarding set but --custom-port-forwarding provider not provided for --custom-config usage. Ignoring --port-forwarding");
None
}
}
Expand Down
2 changes: 1 addition & 1 deletion vopono_core/src/network/application_wrapper.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;

use super::{netns::NetworkNamespace, Forwarder};
use super::{netns::NetworkNamespace, port_forwarding::Forwarder};
use crate::util::get_all_running_process_names;
use log::warn;

Expand Down
7 changes: 1 addition & 6 deletions vopono_core/src/network/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@ pub mod application_wrapper;
pub mod dns_config;
pub mod firewall;
pub mod host_masquerade;
pub mod natpmpc;
pub mod netns;
pub mod network_interface;
pub mod openconnect;
pub mod openfortivpn;
pub mod openvpn;
pub mod piapf;
pub mod port_forwarding;
pub mod shadowsocks;
pub mod sysctl;
pub mod veth_pair;
pub mod warp;
pub mod wireguard;

pub trait Forwarder {
fn forwarded_port(&self) -> u16;
}
73 changes: 73 additions & 0 deletions vopono_core/src/network/port_forwarding/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use std::sync::mpsc::Receiver;

use super::netns::NetworkNamespace;

pub mod natpmpc;
pub mod piapf;

pub trait Forwarder {
fn forwarded_port(&self) -> u16;
}

/// ThreadParams must implement these methods
pub trait ThreadParameters {
fn get_callback_command(&self) -> Option<String>;
fn get_loop_delay(&self) -> u64;
fn get_netns_name(&self) -> String;
}

pub trait ThreadLoopForwarder: Forwarder {
/// Implementation defines parameter struct passed to loop on thread
type ThreadParams: ThreadParameters;

/// Implementation defines how to refresh port
fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16>;

/// Provided common implementation for thread loop
fn thread_loop(params: Self::ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(params.get_loop_delay()));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => {
log::debug!("Thread refreshed port: {p}");
Self::callback_command(&params, p);
}
}
}
}
}

fn callback_command(params: &Self::ThreadParams, port: u16) -> Option<anyhow::Result<String>> {
params.get_callback_command().map(|callback_command|
{
let refresh_response = NetworkNamespace::exec_with_output(
&params.get_netns_name(),
&[&callback_command, &port.to_string()],
)?;
if !refresh_response.status.success() {
log::error!(
"Port forwarding callback script was unsuccessful!: stdout: {:?}, stderr: {:?}, exit code: {}",
String::from_utf8(refresh_response.stdout),
String::from_utf8(refresh_response.stderr),
refresh_response.status
);
Err(anyhow::anyhow!("Port forwarding callback script failed"))
} else if let Ok(out) = String::from_utf8(refresh_response.stdout) {
println!("{}", out);
Ok(out)
} else {
Ok("Callback script succeeded but stdout was not valid UTF8".to_string())
}
}
)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use anyhow::Context;
use regex::Regex;
use std::sync::mpsc::{self, Receiver};
use std::sync::mpsc;
use std::{
net::{IpAddr, Ipv4Addr},
sync::mpsc::Sender,
thread::JoinHandle,
};

use super::netns::NetworkNamespace;
use super::Forwarder;
use super::{Forwarder, ThreadLoopForwarder, ThreadParameters};
use crate::network::netns::NetworkNamespace;

// TODO: Move this to ProtonVPN provider
pub const PROTONVPN_GATEWAY: IpAddr = IpAddr::V4(Ipv4Addr::new(10, 2, 0, 1));
Expand All @@ -20,11 +20,23 @@ pub struct Natpmpc {
send_channel: Sender<bool>,
}

struct ThreadParams {
pub struct ThreadParamsImpl {
pub netns_name: String,
pub callback: Option<String>,
}

impl ThreadParameters for ThreadParamsImpl {
fn get_callback_command(&self) -> Option<String> {
self.callback.clone()
}
fn get_loop_delay(&self) -> u64 {
45
}
fn get_netns_name(&self) -> String {
self.netns_name.clone()
}
}

impl Natpmpc {
pub fn new(ns: &NetworkNamespace, callback: Option<&String>) -> anyhow::Result<Self> {
let gateway_str = PROTONVPN_GATEWAY.to_string();
Expand All @@ -49,11 +61,13 @@ impl Natpmpc {
anyhow::bail!("natpmpc failed - likely that this server does not support port forwarding, please choose another server")
}

let params = ThreadParams {
let params = ThreadParamsImpl {
netns_name: ns.name.clone(),
callback: callback.cloned(),
};

let port = Self::refresh_port(&params)?;
Self::callback_command(&params, port);

let (send, recv) = mpsc::channel::<bool>();

Expand All @@ -66,9 +80,12 @@ impl Natpmpc {
send_channel: send,
})
}
}

// TODO: Refactor these two methods into Trait shared with piapf.rs
fn refresh_port(params: &ThreadParams) -> anyhow::Result<u16> {
impl ThreadLoopForwarder for Natpmpc {
type ThreadParams = ThreadParamsImpl;

fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16> {
let gateway_str = PROTONVPN_GATEWAY.to_string();
// TODO: Cache regex
let re = Regex::new(r"Mapped public port (?P<port>\d{1,5}) protocol").unwrap();
Expand Down Expand Up @@ -102,48 +119,8 @@ impl Natpmpc {
"natpmpc assigned UDP port: {udp_port} did not equal TCP port: {tcp_port}"
)
}

if let Some(cb) = &params.callback {
let refresh_response = NetworkNamespace::exec_with_output(
&params.netns_name,
&[cb, &udp_port.to_string()],
)?;
if !refresh_response.status.success() {
log::error!(
"Port forwarding callback script was unsuccessful!: stdout: {:?}, stderr: {:?}, exit code: {}",
String::from_utf8(refresh_response.stdout),
String::from_utf8(refresh_response.stderr),
refresh_response.status
);
} else if let Ok(out) = String::from_utf8(refresh_response.stdout) {
println!("{}", out);
}
}

Ok(udp_port)
}

// Spawn thread to repeat above every 45 seconds
fn thread_loop(params: ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(45));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => log::debug!("Thread refreshed port: {p}"),
}

// TODO: Communicate port change via channel?
}
}
}
}

impl Drop for Natpmpc {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use base64::prelude::*;
use regex::Regex;
use std::sync::mpsc::{self, Receiver};
use std::sync::mpsc::{self};
use std::{sync::mpsc::Sender, thread::JoinHandle};
use which::which;

use super::netns::NetworkNamespace;
use super::Forwarder;
use super::{Forwarder, ThreadLoopForwarder, ThreadParameters};
use crate::network::netns::NetworkNamespace;

use crate::config::providers::pia::PrivateInternetAccess;
use crate::config::providers::OpenVpnProvider;
Expand All @@ -18,7 +18,7 @@ pub struct Piapf {
send_channel: Sender<bool>,
}

struct ThreadParams {
pub struct ThreadParamsImpl {
pub port: u16,
pub netns_name: String,
pub signature: String,
Expand All @@ -29,6 +29,20 @@ struct ThreadParams {
pub callback: Option<String>,
}

impl ThreadParameters for ThreadParamsImpl {
fn get_callback_command(&self) -> Option<String> {
self.callback.clone()
}

fn get_loop_delay(&self) -> u64 {
60 * 15
}

fn get_netns_name(&self) -> String {
self.netns_name.clone()
}
}

impl Piapf {
pub fn new(
ns: &NetworkNamespace,
Expand Down Expand Up @@ -147,7 +161,7 @@ impl Piapf {
.as_u16()
.expect("getSignature response missing port");

let params = ThreadParams {
let params = ThreadParamsImpl {
netns_name: ns.name.clone(),
hostname: vpn_hostname,
gateway: vpn_gateway,
Expand All @@ -157,7 +171,8 @@ impl Piapf {
port,
callback: callback.cloned(),
};
Self::refresh_port(&params)?;
let port = Self::refresh_port(&params)?;
Self::callback_command(&params, port);
let (send, recv) = mpsc::channel::<bool>();
let handle = std::thread::spawn(move || Self::thread_loop(params, recv));

Expand All @@ -168,9 +183,12 @@ impl Piapf {
send_channel: send,
})
}
}

impl ThreadLoopForwarder for Piapf {
type ThreadParams = ThreadParamsImpl;

// TODO: Refactor methods below into Trait
fn refresh_port(params: &ThreadParams) -> anyhow::Result<u16> {
fn refresh_port(params: &Self::ThreadParams) -> anyhow::Result<u16> {
let bind_response = NetworkNamespace::exec_with_output(
&params.netns_name,
&[
Expand Down Expand Up @@ -222,28 +240,6 @@ impl Piapf {

Ok(params.port)
}

// Spawn thread to repeat above every 15 minutes
fn thread_loop(params: ThreadParams, recv: Receiver<bool>) {
loop {
let resp = recv.recv_timeout(std::time::Duration::from_secs(60 * 15));
if resp.is_ok() {
log::debug!("Thread exiting...");
return;
} else {
let port = Self::refresh_port(&params);
match port {
Err(e) => {
log::error!("Thread failed to refresh port: {e:?}");
return;
}
Ok(p) => log::debug!("Thread refreshed port: {p}"),
}

// TODO: Communicate port change via channel?
}
}
}
}

impl Drop for Piapf {
Expand Down

0 comments on commit 61a5658

Please sign in to comment.