summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStefan Kreutz <mail@skreutz.com>2022-05-28 13:43:53 +0200
committerStefan Kreutz <mail@skreutz.com>2022-05-28 13:43:53 +0200
commit7b9bff64d6c13cc8ea592c2ef6fec96d7062f0c5 (patch)
tree6122359664bb5d59c29ced39ab58018265de1be0
parent9994ec96fa65f8c9f177fef5522a20d9a4b41547 (diff)
downloadwpa-psk-7b9bff64d6c13cc8ea592c2ef6fec96d7062f0c5.tar
Return typed validation errors
-rw-r--r--src/lib.rs65
1 files changed, 54 insertions, 11 deletions
diff --git a/src/lib.rs b/src/lib.rs
index f50dea0..e99418c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -25,7 +25,7 @@
//! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437");
//! ```
-use std::fmt::Display;
+use std::{error::Error, fmt::Display};
use hmac::Hmac;
use pbkdf2::pbkdf2;
@@ -36,13 +36,13 @@ use sha1::Sha1;
pub struct Ssid<'a>(&'a [u8]);
impl<'a> TryFrom<&'a [u8]> for Ssid<'a> {
- type Error = &'static str;
+ type Error = ValidateSsidError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
if value.is_empty() {
- Err("SSID must have at least one byte")
+ Err(ValidateSsidError::TooShort)
} else if value.len() > 32 {
- Err("SSID must have at most 32 bytes")
+ Err(ValidateSsidError::TooLong)
} else {
Ok(Ssid(value))
}
@@ -50,27 +50,45 @@ impl<'a> TryFrom<&'a [u8]> for Ssid<'a> {
}
impl<'a> TryFrom<&'a str> for Ssid<'a> {
- type Error = &'static str;
+ type Error = ValidateSsidError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::try_from(value.as_bytes())
}
}
+#[derive(Debug, PartialEq)]
+pub enum ValidateSsidError {
+ TooShort,
+ TooLong,
+}
+
+impl Error for ValidateSsidError {}
+
+impl Display for ValidateSsidError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let msg = match self {
+ ValidateSsidError::TooShort => "SSID must have at least one byte",
+ ValidateSsidError::TooLong => "SSID must have at most 32 bytes",
+ };
+ write!(f, "{msg}")
+ }
+}
+
/// A passphrase consisting of 8 up to 63 printable ASCII characters.
#[derive(Debug)]
pub struct Passphrase<'a>(&'a [u8]);
impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> {
- type Error = &'static str;
+ type Error = ValidatePassphraseError;
fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
if value.len() < 8 {
- Err("passphrase must have at least 8 bytes")
+ Err(ValidatePassphraseError::TooShort)
} else if value.len() > 63 {
- Err("passphrase must have at most 63 bytes")
+ Err(ValidatePassphraseError::TooLong)
} else if value.iter().any(|i| !matches!(i, 32u8..=126)) {
- Err("passphrase must consist of printable ASCII characters")
+ Err(ValidatePassphraseError::InvalidByte)
} else {
Ok(Passphrase(value))
}
@@ -78,7 +96,7 @@ impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> {
}
impl<'a> TryFrom<&'a str> for Passphrase<'a> {
- type Error = &'static str;
+ type Error = ValidatePassphraseError;
fn try_from(value: &'a str) -> Result<Self, Self::Error> {
Self::try_from(value.as_bytes())
@@ -91,6 +109,28 @@ impl Display for Passphrase<'_> {
}
}
+#[derive(Debug, PartialEq)]
+pub enum ValidatePassphraseError {
+ TooShort,
+ TooLong,
+ InvalidByte,
+}
+
+impl Error for ValidatePassphraseError {}
+
+impl Display for ValidatePassphraseError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ let msg = match self {
+ ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes",
+ ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes",
+ ValidatePassphraseError::InvalidByte => {
+ "passphrase must consist of printable ASCII characters"
+ }
+ };
+ write!(f, "{msg}")
+ }
+}
+
/// Returns the WPA-PSK of the given SSID and passphrase.
pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] {
wpa_psk_unchecked(ssid.0, passphrase.0)
@@ -125,7 +165,10 @@ mod tests {
#[test]
fn passphrase_too_short() {
- Passphrase::try_from("foobar").unwrap_err();
+ assert_eq!(
+ Passphrase::try_from("foobar").unwrap_err(),
+ ValidatePassphraseError::TooShort
+ );
}
#[test]
Generated by cgit. See skreutz.com for my tech blog and contact information.