From 7b9bff64d6c13cc8ea592c2ef6fec96d7062f0c5 Mon Sep 17 00:00:00 2001 From: Stefan Kreutz Date: Sat, 28 May 2022 13:43:53 +0200 Subject: Return typed validation errors --- src/lib.rs | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 11 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index f50dea0..e99418c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,7 +25,7 @@ //! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437"); //! ``` -use std::fmt::Display; +use std::{error::Error, fmt::Display}; use hmac::Hmac; use pbkdf2::pbkdf2; @@ -36,13 +36,13 @@ use sha1::Sha1; pub struct Ssid<'a>(&'a [u8]); impl<'a> TryFrom<&'a [u8]> for Ssid<'a> { - type Error = &'static str; + type Error = ValidateSsidError; fn try_from(value: &'a [u8]) -> Result { if value.is_empty() { - Err("SSID must have at least one byte") + Err(ValidateSsidError::TooShort) } else if value.len() > 32 { - Err("SSID must have at most 32 bytes") + Err(ValidateSsidError::TooLong) } else { Ok(Ssid(value)) } @@ -50,27 +50,45 @@ impl<'a> TryFrom<&'a [u8]> for Ssid<'a> { } impl<'a> TryFrom<&'a str> for Ssid<'a> { - type Error = &'static str; + type Error = ValidateSsidError; fn try_from(value: &'a str) -> Result { Self::try_from(value.as_bytes()) } } +#[derive(Debug, PartialEq)] +pub enum ValidateSsidError { + TooShort, + TooLong, +} + +impl Error for ValidateSsidError {} + +impl Display for ValidateSsidError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let msg = match self { + ValidateSsidError::TooShort => "SSID must have at least one byte", + ValidateSsidError::TooLong => "SSID must have at most 32 bytes", + }; + write!(f, "{msg}") + } +} + /// A passphrase consisting of 8 up to 63 printable ASCII characters. #[derive(Debug)] pub struct Passphrase<'a>(&'a [u8]); impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> { - type Error = &'static str; + type Error = ValidatePassphraseError; fn try_from(value: &'a [u8]) -> Result { if value.len() < 8 { - Err("passphrase must have at least 8 bytes") + Err(ValidatePassphraseError::TooShort) } else if value.len() > 63 { - Err("passphrase must have at most 63 bytes") + Err(ValidatePassphraseError::TooLong) } else if value.iter().any(|i| !matches!(i, 32u8..=126)) { - Err("passphrase must consist of printable ASCII characters") + Err(ValidatePassphraseError::InvalidByte) } else { Ok(Passphrase(value)) } @@ -78,7 +96,7 @@ impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> { } impl<'a> TryFrom<&'a str> for Passphrase<'a> { - type Error = &'static str; + type Error = ValidatePassphraseError; fn try_from(value: &'a str) -> Result { Self::try_from(value.as_bytes()) @@ -91,6 +109,28 @@ impl Display for Passphrase<'_> { } } +#[derive(Debug, PartialEq)] +pub enum ValidatePassphraseError { + TooShort, + TooLong, + InvalidByte, +} + +impl Error for ValidatePassphraseError {} + +impl Display for ValidatePassphraseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let msg = match self { + ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes", + ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes", + ValidatePassphraseError::InvalidByte => { + "passphrase must consist of printable ASCII characters" + } + }; + write!(f, "{msg}") + } +} + /// Returns the WPA-PSK of the given SSID and passphrase. pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] { wpa_psk_unchecked(ssid.0, passphrase.0) @@ -125,7 +165,10 @@ mod tests { #[test] fn passphrase_too_short() { - Passphrase::try_from("foobar").unwrap_err(); + assert_eq!( + Passphrase::try_from("foobar").unwrap_err(), + ValidatePassphraseError::TooShort + ); } #[test] -- cgit v1.2.3