//! Compute the WPA-PSK of a Wi-Fi SSID and passphrase. //! //! # Example //! //! Compute and print the WPA-PSK of a valid SSID and passphrase: //! //! ``` //! # use wpa_psk::{Ssid, Passphrase, wpa_psk, bytes_to_hex}; //! # fn main() -> Result<(), Box> { //! let ssid = Ssid::try_from("home")?; //! let passphrase = Passphrase::try_from("0123-4567-89")?; //! let psk = wpa_psk(&ssid, &passphrase); //! assert_eq!(bytes_to_hex(&psk), "150c047b6fad724512a17fa431687048ee503d14c1ea87681d4f241beb04f5ee"); //! # Ok(()) //! # } //! ``` //! //! Compute the WPA-PSK of possibly invalid raw bytes: //! //! ``` //! # use wpa_psk::{wpa_psk_unchecked, bytes_to_hex}; //! let ssid = "bar".as_bytes(); //! let passphrase = "2short".as_bytes(); //! let psk = wpa_psk_unchecked(&ssid, &passphrase); //! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437"); //! ``` use std::{error::Error, fmt::Display}; use hmac::Hmac; use pbkdf2::pbkdf2; use sha1::Sha1; /// An SSID consisting of 1 up to 32 arbitrary bytes. #[derive(Debug)] pub struct Ssid<'a>(&'a [u8]); impl<'a> TryFrom<&'a [u8]> for Ssid<'a> { type Error = ValidateSsidError; fn try_from(value: &'a [u8]) -> Result { if value.is_empty() { Err(ValidateSsidError::TooShort) } else if value.len() > 32 { Err(ValidateSsidError::TooLong) } else { Ok(Ssid(value)) } } } impl<'a> TryFrom<&'a str> for Ssid<'a> { type Error = ValidateSsidError; fn try_from(value: &'a str) -> Result { Self::try_from(value.as_bytes()) } } #[derive(Debug, PartialEq)] pub enum ValidateSsidError { TooShort, TooLong, } impl Error for ValidateSsidError {} impl Display for ValidateSsidError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let msg = match self { ValidateSsidError::TooShort => "SSID must have at least one byte", ValidateSsidError::TooLong => "SSID must have at most 32 bytes", }; write!(f, "{msg}") } } /// A passphrase consisting of 8 up to 63 printable ASCII characters. #[derive(Debug)] pub struct Passphrase<'a>(&'a [u8]); impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> { type Error = ValidatePassphraseError; fn try_from(value: &'a [u8]) -> Result { if value.len() < 8 { Err(ValidatePassphraseError::TooShort) } else if value.len() > 63 { Err(ValidatePassphraseError::TooLong) } else if value.iter().any(|i| !matches!(i, 32u8..=126)) { Err(ValidatePassphraseError::InvalidByte) } else { Ok(Passphrase(value)) } } } impl<'a> TryFrom<&'a str> for Passphrase<'a> { type Error = ValidatePassphraseError; fn try_from(value: &'a str) -> Result { Self::try_from(value.as_bytes()) } } impl Display for Passphrase<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", std::str::from_utf8(self.0).unwrap()) } } #[derive(Debug, PartialEq)] pub enum ValidatePassphraseError { TooShort, TooLong, InvalidByte, } impl Error for ValidatePassphraseError {} impl Display for ValidatePassphraseError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let msg = match self { ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes", ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes", ValidatePassphraseError::InvalidByte => { "passphrase must consist of printable ASCII characters" } }; write!(f, "{msg}") } } /// Returns the WPA-PSK of the given SSID and passphrase. pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] { wpa_psk_unchecked(ssid.0, passphrase.0) } /// Unchecked WPA-PSK. /// See [`wpa_psk`]. pub fn wpa_psk_unchecked(ssid: &[u8], passphrase: &[u8]) -> [u8; 32] { let mut buf = [0u8; 32]; pbkdf2::>(passphrase, ssid, 4096, &mut buf); buf } /// Returns the hexdecimal representation of the given bytes. pub fn bytes_to_hex(bytes: &[u8]) -> String { bytes.iter().map(|b| format!("{:02x}", b)).collect() } #[cfg(test)] mod tests { use super::*; #[test] fn special_characters() { let ssid = Ssid::try_from("123abcABC.,-").unwrap(); let passphrase = Passphrase::try_from("456defDEF *<:D").unwrap(); assert_eq!( bytes_to_hex(&wpa_psk(&ssid, &passphrase)), "8a366e5bc51cd5d8fbbeffacc5f1af23fac30e3ac93cdcc368fafbbf63a1085c" ); } #[test] fn passphrase_too_short() { assert_eq!( Passphrase::try_from("foobar").unwrap_err(), ValidatePassphraseError::TooShort ); } #[test] fn display_passphrase() { assert_eq!( format!("{}", Passphrase::try_from("foobarbuzz").unwrap()), "foobarbuzz" ); } }