Test round trip

This commit is contained in:
asonix 2019-09-11 01:24:51 -05:00
parent aefb08e627
commit c73da59045
3 changed files with 226 additions and 52 deletions

View file

@ -5,6 +5,7 @@ use crate::{
SIGNATURE_FIELD, SIGNATURE_FIELD,
}; };
#[derive(Debug)]
pub struct Signed { pub struct Signed {
signature: String, signature: String,
sig_headers: Vec<String>, sig_headers: Vec<String>,
@ -13,6 +14,7 @@ pub struct Signed {
key_id: String, key_id: String,
} }
#[derive(Debug)]
pub struct Unsigned { pub struct Unsigned {
pub(crate) signing_string: String, pub(crate) signing_string: String,
pub(crate) sig_headers: Vec<String>, pub(crate) sig_headers: Vec<String>,

View file

@ -1,16 +1,16 @@
use chrono::{DateTime, Duration, Utc}; use chrono::{DateTime, Duration, Utc};
use std::collections::BTreeMap; use std::{collections::BTreeMap, error::Error, fmt};
pub mod create; pub mod create;
pub mod verify; pub mod verify;
use self::{ use self::{
create::Unsigned, create::Unsigned,
verify::{Unvalidated, ValidateError}, verify::{ParseSignatureError, ParsedHeader, Unverified, ValidateError},
}; };
const REQUEST_TARGET: &'static str = "(request-target)"; const REQUEST_TARGET: &'static str = "(request-target)";
const CREATED: &'static str = "(crated)"; const CREATED: &'static str = "(created)";
const EXPIRES: &'static str = "(expires)"; const EXPIRES: &'static str = "(expires)";
const KEY_ID_FIELD: &'static str = "keyId"; const KEY_ID_FIELD: &'static str = "keyId";
@ -21,22 +21,32 @@ const EXPIRES_FIELD: &'static str = "expires";
const HEADERS_FIELD: &'static str = "headers"; const HEADERS_FIELD: &'static str = "headers";
const SIGNATURE_FIELD: &'static str = "signature"; const SIGNATURE_FIELD: &'static str = "signature";
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct Config { pub struct Config {
pub expires: Duration, pub expires_after: Duration,
}
#[derive(Debug)]
pub enum VerifyError {
Validate(ValidateError),
Parse(ParseSignatureError),
} }
impl Config { impl Config {
pub fn normalize( pub fn begin_sign(
&self, &self,
method: &str, method: &str,
path_and_query: &str, path_and_query: &str,
headers: &mut BTreeMap<String, String>, headers: BTreeMap<String, String>,
) -> Unsigned { ) -> Unsigned {
let sig_headers = build_headers_list(headers); let mut headers = headers
.into_iter()
.map(|(k, v)| (k.to_lowercase(), v))
.collect();
let sig_headers = build_headers_list(&headers);
let created = Utc::now(); let created = Utc::now();
let expires = created + self.expires; let expires = created + self.expires_after;
let signing_string = build_signing_string( let signing_string = build_signing_string(
method, method,
@ -44,7 +54,7 @@ impl Config {
Some(created), Some(created),
Some(expires), Some(expires),
&sig_headers, &sig_headers,
headers, &mut headers,
); );
Unsigned { Unsigned {
@ -55,24 +65,26 @@ impl Config {
} }
} }
pub fn validate<F, T>(&self, unvalidated: Unvalidated, f: F) -> Result<T, ValidateError> pub fn begin_verify(
where &self,
F: FnOnce(&[u8], &str) -> T, method: &str,
{ path_and_query: &str,
if let Some(expires) = unvalidated.expires { headers: BTreeMap<String, String>,
if expires < unvalidated.parsed_at { ) -> Result<Unverified, VerifyError> {
return Err(ValidateError::Expired); let mut headers: BTreeMap<String, String> = headers
} .into_iter()
} .map(|(k, v)| (k.to_lowercase().to_owned(), v))
if let Some(created) = unvalidated.created { .collect();
if created + self.expires < unvalidated.parsed_at {
return Err(ValidateError::Expired);
}
}
let v = base64::decode(&unvalidated.signature).map_err(|_| ValidateError::Decode)?; let header = headers
.remove("authorization")
.or_else(|| headers.remove("signature"))
.ok_or(ValidateError::Missing)?;
Ok((f)(&v, &unvalidated.signing_string)) let parsed_header: ParsedHeader = header.parse()?;
let unvalidated = parsed_header.into_unvalidated(method, path_and_query, &mut headers);
Ok(unvalidated.validate(self.expires_after)?)
} }
} }
@ -117,18 +129,86 @@ fn build_signing_string(
signing_string signing_string
} }
impl fmt::Display for VerifyError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
VerifyError::Validate(ref e) => fmt::Display::fmt(e, f),
VerifyError::Parse(ref e) => fmt::Display::fmt(e, f),
}
}
}
impl Error for VerifyError {
fn description(&self) -> &str {
match *self {
VerifyError::Validate(ref e) => e.description(),
VerifyError::Parse(ref e) => e.description(),
}
}
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
VerifyError::Validate(ref e) => Some(e),
VerifyError::Parse(ref e) => Some(e),
}
}
}
impl From<ValidateError> for VerifyError {
fn from(v: ValidateError) -> Self {
VerifyError::Validate(v)
}
}
impl From<ParseSignatureError> for VerifyError {
fn from(p: ParseSignatureError) -> Self {
VerifyError::Parse(p)
}
}
impl Default for Config { impl Default for Config {
fn default() -> Self { fn default() -> Self {
Config { Config {
expires: Duration::seconds(10), expires_after: Duration::seconds(10),
} }
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::Config;
use std::collections::BTreeMap;
fn prepare_headers() -> BTreeMap<String, String> {
let mut headers = BTreeMap::new();
headers.insert(
"Content-Type".to_owned(),
"application/activity+json".to_owned(),
);
headers
}
#[test] #[test]
fn it_works() { fn round_trip() {
assert_eq!(2 + 2, 4); let headers = prepare_headers();
let config = Config::default();
let authorization_header = config
.begin_sign("GET", "/foo?bar=baz", headers)
.sign("hi".to_owned(), |s| {
Ok(s.as_bytes().to_vec()) as Result<_, std::io::Error>
})
.unwrap()
.authorization_header();
let mut headers = prepare_headers();
headers.insert("Authorization".to_owned(), authorization_header);
let verified = config
.begin_verify("GET", "/foo?bar=baz", headers)
.unwrap()
.verify(|bytes, string| string.as_bytes() == bytes);
assert!(verified);
} }
} }

View file

@ -1,4 +1,4 @@
use chrono::{DateTime, TimeZone, Utc}; use chrono::{DateTime, Duration, TimeZone, Utc};
use std::{ use std::{
collections::{BTreeMap, HashMap}, collections::{BTreeMap, HashMap},
error::Error, error::Error,
@ -11,6 +11,15 @@ use crate::{
KEY_ID_FIELD, SIGNATURE_FIELD, KEY_ID_FIELD, SIGNATURE_FIELD,
}; };
#[derive(Debug)]
pub struct Unverified {
key_id: String,
signature: Vec<u8>,
algorithm: Option<Algorithm>,
signing_string: String,
}
#[derive(Debug)]
pub struct Unvalidated { pub struct Unvalidated {
pub(crate) key_id: String, pub(crate) key_id: String,
pub(crate) signature: String, pub(crate) signature: String,
@ -21,6 +30,7 @@ pub struct Unvalidated {
pub(crate) signing_string: String, pub(crate) signing_string: String,
} }
#[derive(Debug)]
pub struct ParsedHeader { pub struct ParsedHeader {
signature: String, signature: String,
key_id: String, key_id: String,
@ -56,6 +66,7 @@ pub enum Algorithm {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum ValidateError { pub enum ValidateError {
Missing,
Expired, Expired,
Decode, Decode,
} }
@ -63,7 +74,7 @@ pub enum ValidateError {
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ParseSignatureError(&'static str); pub struct ParseSignatureError(&'static str);
impl Unvalidated { impl Unverified {
pub fn key_id(&self) -> &str { pub fn key_id(&self) -> &str {
&self.key_id &self.key_id
} }
@ -71,10 +82,41 @@ impl Unvalidated {
pub fn algorithm(&self) -> Option<&Algorithm> { pub fn algorithm(&self) -> Option<&Algorithm> {
self.algorithm.as_ref() self.algorithm.as_ref()
} }
pub fn verify<F, T>(&self, f: F) -> T
where
F: FnOnce(&[u8], &str) -> T,
{
(f)(&self.signature, &self.signing_string)
}
}
impl Unvalidated {
pub fn validate(self, expires_after: Duration) -> Result<Unverified, ValidateError> {
if let Some(expires) = self.expires {
if expires < self.parsed_at {
return Err(ValidateError::Expired);
}
}
if let Some(created) = self.created {
if created + expires_after < self.parsed_at {
return Err(ValidateError::Expired);
}
}
let signature = base64::decode(&self.signature).map_err(|_| ValidateError::Decode)?;
Ok(Unverified {
key_id: self.key_id,
algorithm: self.algorithm,
signing_string: self.signing_string,
signature,
})
}
} }
impl ParsedHeader { impl ParsedHeader {
pub fn to_unvalidated( pub fn into_unvalidated(
self, self,
method: &str, method: &str,
path_and_query: &str, path_and_query: &str,
@ -113,7 +155,7 @@ impl FromStr for ParsedHeader {
if let Some(key) = i.next() { if let Some(key) = i.next() {
if let Some(value) = i.next() { if let Some(value) = i.next() {
return Some((key.to_owned(), value.to_owned())); return Some((key.to_owned(), value.trim_matches('"').to_owned()));
} }
} }
None None
@ -131,7 +173,7 @@ impl FromStr for ParsedHeader {
.remove(HEADERS_FIELD) .remove(HEADERS_FIELD)
.map(|h| h.split_whitespace().map(|s| s.to_owned()).collect()) .map(|h| h.split_whitespace().map(|s| s.to_owned()).collect())
.unwrap_or_else(|| vec![CREATED.to_owned()]), .unwrap_or_else(|| vec![CREATED.to_owned()]),
algorithm: hm.remove(ALGORITHM_FIELD).map(Algorithm::from), algorithm: hm.remove(ALGORITHM_FIELD).map(|s| Algorithm::from(s)),
created: parse_time(&mut hm, CREATED_FIELD)?, created: parse_time(&mut hm, CREATED_FIELD)?,
expires: parse_time(&mut hm, EXPIRES_FIELD)?, expires: parse_time(&mut hm, EXPIRES_FIELD)?,
parsed_at: Utc::now(), parsed_at: Utc::now(),
@ -188,32 +230,82 @@ impl From<&str> for Algorithm {
} }
} }
impl fmt::Display for ValidateError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ValidateError::Expired => write!(f, "Http Signature is expired"),
ValidateError::Decode => write!(f, "Http Signature could not be decoded"),
}
}
}
impl fmt::Display for ParseSignatureError { impl fmt::Display for ParseSignatureError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Error when parsing {} from Http Signature", self.0) write!(f, "Error when parsing {} from Http Signature", self.0)
} }
} }
impl Error for ValidateError {
fn description(&self) -> &'static str {
match *self {
ValidateError::Expired => "Http Signature is expired",
ValidateError::Decode => "Http Signature could not be decoded",
}
}
}
impl Error for ParseSignatureError { impl Error for ParseSignatureError {
fn description(&self) -> &'static str { fn description(&self) -> &'static str {
"There was an error parsing the Http Signature" "There was an error parsing the Http Signature"
} }
} }
impl fmt::Display for ValidateError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ValidateError::Missing => write!(f, "Http Signature is missing"),
ValidateError::Expired => write!(f, "Http Signature is expired"),
ValidateError::Decode => write!(f, "Http Signature could not be decoded"),
}
}
}
impl Error for ValidateError {
fn description(&self) -> &'static str {
match *self {
ValidateError::Missing => "Http Signature is missing",
ValidateError::Expired => "Http Signature is expired",
ValidateError::Decode => "Http Signature could not be decoded",
}
}
}
#[cfg(test)]
mod tests {
use chrono::Utc;
use super::ParsedHeader;
#[test]
fn parses_header_succesfully_1() {
let time1 = Utc::now().timestamp();
let time2 = Utc::now().timestamp();
let h = format!(r#"Signature keyId="my-key-id",algorithm="hs2019",created="{}",expires="{}",headers="(request-target) (created) (expires) date content-type",signature="blah blah blah""#, time1, time2);
parse_signature(&h)
}
#[test]
fn parses_header_succesfully_2() {
let time1 = Utc::now().timestamp();
let time2 = Utc::now().timestamp();
let h = format!(r#"Signature keyId="my-key-id",algorithm="rsa-sha256",created="{}",expires="{}",signature="blah blah blah""#, time1, time2);
parse_signature(&h)
}
#[test]
fn parses_header_succesfully_3() {
let time1 = Utc::now().timestamp();
let h = format!(r#"Signature keyId="my-key-id",algorithm="rsa-sha256",created="{}",headers="(request-target) (created) date content-type",signature="blah blah blah""#, time1);
parse_signature(&h)
}
#[test]
fn parses_header_succesfully_4() {
let h = r#"Signature keyId="my-key-id",algorithm="rsa-sha256",headers="(request-target) date content-type",signature="blah blah blah""#;
parse_signature(h)
}
fn parse_signature(s: &str) {
let ph: ParsedHeader = s.parse().unwrap();
println!("{:?}", ph);
}
}