Skip to content

Commit

Permalink
tokens: Make check_token async
Browse files Browse the repository at this point in the history
This will allow us to check the database for revocations.
  • Loading branch information
jameswestman authored and barthalion committed Jul 19, 2023
1 parent d7e4bf8 commit 5ba76ec
Showing 1 changed file with 93 additions and 77 deletions.
170 changes: 93 additions & 77 deletions src/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ use actix_web::error::Error;
use actix_web::http::header::{HeaderValue, AUTHORIZATION};
use actix_web::{HttpMessage, HttpRequest, Result};
use futures::future::{ok, Either, FutureResult};
use futures::{Future, Poll};
use futures::{Future, IntoFuture, Poll};
use futures3::TryFutureExt;
use jwt::{decode, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::fmt::Display;
use std::rc::Rc;

Expand Down Expand Up @@ -196,49 +198,47 @@ pub struct Inner {
optional: bool,
}

impl Inner {
fn parse_authorization(&self, header: &HeaderValue) -> Result<String, ApiError> {
// "Bearer *" length
if header.len() < 8 {
return Err(ApiError::InvalidToken(
"Header length too short".to_string(),
));
}
fn parse_authorization(header: &HeaderValue) -> Result<String, ApiError> {
// "Bearer *" length
if header.len() < 8 {
return Err(ApiError::InvalidToken(
"Header length too short".to_string(),
));
}

let mut parts = header
.to_str()
.map_err(|_| ApiError::InvalidToken("Cannot convert header to string".to_string()))?
.splitn(2, ' ');
match parts.next() {
Some(scheme) if scheme == "Bearer" => (),
_ => {
return Err(ApiError::InvalidToken(
"Token scheme is not Bearer".to_string(),
))
}
let mut parts = header
.to_str()
.map_err(|_| ApiError::InvalidToken("Cannot convert header to string".to_string()))?
.splitn(2, ' ');
match parts.next() {
Some(scheme) if scheme == "Bearer" => (),
_ => {
return Err(ApiError::InvalidToken(
"Token scheme is not Bearer".to_string(),
))
}
}

let token = parts
.next()
.ok_or_else(|| ApiError::InvalidToken("No token value in header".to_string()))?;
let token = parts
.next()
.ok_or_else(|| ApiError::InvalidToken("No token value in header".to_string()))?;

Ok(token.to_string())
}
Ok(token.to_string())
}

fn validate_claims(&self, token: String) -> Result<Claims, ApiError> {
let validation = Validation::default();
fn validate_claims(secret: Vec<u8>, token: String) -> Result<Claims, ApiError> {
let validation = Validation::default();

let token_data = match decode::<Claims>(
&token,
&DecodingKey::from_secret(self.secret.as_ref()),
&validation,
) {
Ok(c) => c,
Err(_err) => return Err(ApiError::InvalidToken("Invalid token claims".to_string())),
};
let token_data = match decode::<Claims>(
&token,
&DecodingKey::from_secret(secret.as_ref()),
&validation,
) {
Ok(c) => c,
Err(_err) => return Err(ApiError::InvalidToken("Invalid token claims".to_string())),
};

Ok(token_data.claims)
}
Ok(token_data.claims)
}

pub struct TokenParser(Rc<Inner>);
Expand Down Expand Up @@ -273,76 +273,92 @@ where

fn new_transform(&self, service: S) -> Self::Future {
ok(TokenParserMiddleware {
service,
service: Rc::new(RefCell::new(service)),
inner: self.0.clone(),
})
}
}

/// TokenParser middleware
pub struct TokenParserMiddleware<S> {
service: S,
service: Rc<RefCell<S>>,
inner: Rc<Inner>,
}

impl<S> TokenParserMiddleware<S> {
fn check_token(&self, req: &ServiceRequest) -> Result<Option<Claims>, ApiError> {
let header = match req.headers().get(AUTHORIZATION) {
Some(h) => h,
None => {
if self.inner.optional {
return Ok(None);
}
return Err(ApiError::InvalidToken(
"No Authorization header".to_string(),
));
fn get_token(optional: bool, req: &ServiceRequest) -> Result<Option<String>, ApiError> {
let header = match req.headers().get(AUTHORIZATION) {
Some(h) => h,
None => {
if optional {
return Ok(None);
}
};
let token = self.inner.parse_authorization(header)?;
let claims = self.inner.validate_claims(token)?;
Ok(Some(claims))
}
return Err(ApiError::InvalidToken(
"No Authorization header".to_string(),
));
}
};
let token = parse_authorization(header)?;
Ok(Some(token))
}

async fn check_token_async(secret: Vec<u8>, token: String) -> Result<Claims, ApiError> {
let claims = validate_claims(secret, token)?;
Ok(claims)
}

fn check_token(
secret: Vec<u8>,
token: String,
) -> impl futures::Future<Item = Claims, Error = ApiError> {
Box::pin(check_token_async(secret, token)).compat()
}

impl<S, B> Service for TokenParserMiddleware<S>
where
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S: Service<Request = ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Request = ServiceRequest;
type Response = ServiceResponse<B>;
type Error = Error;
#[allow(clippy::type_complexity)]
type Future = Either<
//S::Future,
Box<dyn Future<Item = Self::Response, Error = Self::Error>>,
FutureResult<Self::Response, Self::Error>,
>;
type Future = Box<dyn Future<Item = Self::Response, Error = Self::Error>>;

fn poll_ready(&mut self) -> Poll<(), Self::Error> {
self.service.poll_ready()
self.service.borrow_mut().poll_ready()
}

fn call(&mut self, req: ServiceRequest) -> Self::Future {
let maybe_claims = match self.check_token(&req) {
Err(e) => return Either::B(ok(req.error_response(e))),
Ok(c) => c,
};
let srv = self.service.clone();
let secret = self.inner.secret.clone();

let c = maybe_claims.clone();
let token = get_token(self.inner.optional, &req)
.into_future()
.and_then(|token| token.map(|t| check_token(secret, t)));

if let Some(claims) = maybe_claims {
req.extensions_mut().insert(claims);
}
let fut = token.then(move |maybe_claims| {
let maybe_claims = match maybe_claims {
Err(e) => return Either::B(ok(req.error_response(e))),
Ok(c) => c,
};

Either::A(Box::new(self.service.call(req).and_then(move |resp| {
if resp.status() == 401 || resp.status() == 403 {
if let Some(ref claims) = c {
log::info!("Presented claims: {:?}", claims);
}
let c = maybe_claims.clone();

if let Some(claims) = maybe_claims {
req.extensions_mut().insert(claims);
}
Ok(resp)
})))

Either::A(Box::new(srv.borrow_mut().call(req).and_then(move |resp| {
if resp.status() == 401 || resp.status() == 403 {
if let Some(ref claims) = c {
log::info!("Presented claims: {:?}", claims);
}
}
Ok(resp)
})))
});

Box::new(fut)
}
}

0 comments on commit 5ba76ec

Please sign in to comment.