diff options
Diffstat (limited to 'wpa-psk/src')
| -rw-r--r-- | wpa-psk/src/lib.rs | 181 | 
1 files changed, 181 insertions, 0 deletions
diff --git a/wpa-psk/src/lib.rs b/wpa-psk/src/lib.rs new file mode 100644 index 0000000..e99418c --- /dev/null +++ b/wpa-psk/src/lib.rs @@ -0,0 +1,181 @@ +//! Compute the WPA-PSK of a Wi-Fi SSID and passphrase. +//! +//! # Example +//! +//! Compute and print the WPA-PSK of a valid SSID and passphrase: +//! +//! ``` +//! # use wpa_psk::{Ssid, Passphrase, wpa_psk, bytes_to_hex}; +//! # fn main() -> Result<(), Box<dyn std::error::Error>> { +//! let ssid = Ssid::try_from("home")?; +//! let passphrase = Passphrase::try_from("0123-4567-89")?; +//! let psk = wpa_psk(&ssid, &passphrase); +//! assert_eq!(bytes_to_hex(&psk), "150c047b6fad724512a17fa431687048ee503d14c1ea87681d4f241beb04f5ee"); +//! # Ok(()) +//! # } +//! ``` +//! +//! Compute the WPA-PSK of possibly invalid raw bytes: +//! +//! ``` +//! # use wpa_psk::{wpa_psk_unchecked, bytes_to_hex}; +//! let ssid = "bar".as_bytes(); +//! let passphrase = "2short".as_bytes(); +//! let psk = wpa_psk_unchecked(&ssid, &passphrase); +//! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437"); +//! ``` + +use std::{error::Error, fmt::Display}; + +use hmac::Hmac; +use pbkdf2::pbkdf2; +use sha1::Sha1; + +/// An SSID consisting of 1 up to 32 arbitrary bytes. +#[derive(Debug)] +pub struct Ssid<'a>(&'a [u8]); + +impl<'a> TryFrom<&'a [u8]> for Ssid<'a> { +    type Error = ValidateSsidError; + +    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> { +        if value.is_empty() { +            Err(ValidateSsidError::TooShort) +        } else if value.len() > 32 { +            Err(ValidateSsidError::TooLong) +        } else { +            Ok(Ssid(value)) +        } +    } +} + +impl<'a> TryFrom<&'a str> for Ssid<'a> { +    type Error = ValidateSsidError; + +    fn try_from(value: &'a str) -> Result<Self, Self::Error> { +        Self::try_from(value.as_bytes()) +    } +} + +#[derive(Debug, PartialEq)] +pub enum ValidateSsidError { +    TooShort, +    TooLong, +} + +impl Error for ValidateSsidError {} + +impl Display for ValidateSsidError { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        let msg = match self { +            ValidateSsidError::TooShort => "SSID must have at least one byte", +            ValidateSsidError::TooLong => "SSID must have at most 32 bytes", +        }; +        write!(f, "{msg}") +    } +} + +/// A passphrase consisting of 8 up to 63 printable ASCII characters. +#[derive(Debug)] +pub struct Passphrase<'a>(&'a [u8]); + +impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> { +    type Error = ValidatePassphraseError; + +    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> { +        if value.len() < 8 { +            Err(ValidatePassphraseError::TooShort) +        } else if value.len() > 63 { +            Err(ValidatePassphraseError::TooLong) +        } else if value.iter().any(|i| !matches!(i, 32u8..=126)) { +            Err(ValidatePassphraseError::InvalidByte) +        } else { +            Ok(Passphrase(value)) +        } +    } +} + +impl<'a> TryFrom<&'a str> for Passphrase<'a> { +    type Error = ValidatePassphraseError; + +    fn try_from(value: &'a str) -> Result<Self, Self::Error> { +        Self::try_from(value.as_bytes()) +    } +} + +impl Display for Passphrase<'_> { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        write!(f, "{}", std::str::from_utf8(self.0).unwrap()) +    } +} + +#[derive(Debug, PartialEq)] +pub enum ValidatePassphraseError { +    TooShort, +    TooLong, +    InvalidByte, +} + +impl Error for ValidatePassphraseError {} + +impl Display for ValidatePassphraseError { +    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +        let msg = match self { +            ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes", +            ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes", +            ValidatePassphraseError::InvalidByte => { +                "passphrase must consist of printable ASCII characters" +            } +        }; +        write!(f, "{msg}") +    } +} + +/// Returns the WPA-PSK of the given SSID and passphrase. +pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] { +    wpa_psk_unchecked(ssid.0, passphrase.0) +} + +/// Unchecked WPA-PSK. +/// See [`wpa_psk`]. +pub fn wpa_psk_unchecked(ssid: &[u8], passphrase: &[u8]) -> [u8; 32] { +    let mut buf = [0u8; 32]; +    pbkdf2::<Hmac<Sha1>>(passphrase, ssid, 4096, &mut buf); +    buf +} + +/// Returns the hexdecimal representation of the given bytes. +pub fn bytes_to_hex(bytes: &[u8]) -> String { +    bytes.iter().map(|b| format!("{:02x}", b)).collect() +} + +#[cfg(test)] +mod tests { +    use super::*; + +    #[test] +    fn special_characters() { +        let ssid = Ssid::try_from("123abcABC.,-").unwrap(); +        let passphrase = Passphrase::try_from("456defDEF *<:D").unwrap(); +        assert_eq!( +            bytes_to_hex(&wpa_psk(&ssid, &passphrase)), +            "8a366e5bc51cd5d8fbbeffacc5f1af23fac30e3ac93cdcc368fafbbf63a1085c" +        ); +    } + +    #[test] +    fn passphrase_too_short() { +        assert_eq!( +            Passphrase::try_from("foobar").unwrap_err(), +            ValidatePassphraseError::TooShort +        ); +    } + +    #[test] +    fn display_passphrase() { +        assert_eq!( +            format!("{}", Passphrase::try_from("foobarbuzz").unwrap()), +            "foobarbuzz" +        ); +    } +}  |