Skip to content

Commit

Permalink
improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
thetek42 committed Dec 20, 2023
1 parent 54ae7e0 commit 1256108
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 50 deletions.
6 changes: 2 additions & 4 deletions examples/wps_async.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
//! Example of using async wifi.
//!
//! Add your own ssid and password
//! Example of using async WPS.
use embedded_svc::wifi::Configuration;
use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration};

use esp_idf_svc::hal::prelude::Peripherals;
use esp_idf_svc::log::EspLogger;
Expand Down
15 changes: 9 additions & 6 deletions src/private/cstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use core::str::Utf8Error;
use crate::sys::{EspError, ESP_ERR_INVALID_SIZE};

#[cfg(feature = "alloc")]
pub fn set_str(buf: &mut [u8], s: &str) -> Result<(), crate::sys::EspError> {
pub fn set_str(buf: &mut [u8], s: &str) -> Result<(), EspError> {
assert!(s.len() < buf.len());
let cs = to_cstring_arg(s)?;
let ss: &[u8] = cs.as_bytes_with_nul();
Expand All @@ -19,11 +19,14 @@ pub fn set_str(buf: &mut [u8], s: &str) -> Result<(), crate::sys::EspError> {
Ok(())
}

#[cfg(feature = "alloc")]
pub fn set_cchar_str(buf: &mut [c_char], s: &str) -> Result<(), crate::sys::EspError> {
let buf_u8 = unsafe { buf.as_mut_ptr() as *mut u8 };
let slice = unsafe { core::slice::from_raw_parts_mut(buf_u8, buf.len()) };
set_str(slice, s)
pub fn set_cchar_slice(buf: &mut [c_char], s: &str) -> Result<(), EspError> {
if s.len() > buf.len() {
return Err(EspError::from_infallible::<ESP_ERR_INVALID_SIZE>());
}
let s_cchar = unsafe { s.as_bytes().as_ptr() as *const c_char };
let s_slice = unsafe { core::slice::from_raw_parts(s_cchar, s.len()) };
buf[..s.len()].copy_from_slice(s_slice);
Ok(())
}

pub unsafe fn from_cstr_ptr<'a>(ptr: *const c_char) -> &'a str {
Expand Down
98 changes: 58 additions & 40 deletions src/wifi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,19 @@ where
/// layer as well - should be preferred. Using `WifiDriver` directly is beneficial
/// only when one would like to utilize a custom, non-STD network stack like `smoltcp`.
pub struct WifiDriver<'d> {
status: Arc<mutex::Mutex<(WifiEvent, WifiEvent, Option<WpsEvent>)>>,
status: Arc<mutex::Mutex<WifiDriverStatus>>,
_subscription: EspSubscription<'static, System>,
#[cfg(all(feature = "alloc", esp_idf_comp_nvs_flash_enabled))]
_nvs: Option<EspDefaultNvsPartition>,
_p: PhantomData<&'d mut ()>,
}

struct WifiDriverStatus {
pub sta: WifiEvent,
pub ap: WifiEvent,
pub wps: Option<WpsEvent>,
}

impl<'d> WifiDriver<'d> {
#[cfg(all(feature = "alloc", esp_idf_comp_nvs_flash_enabled))]
pub fn new<M: WifiModemPeripheral>(
Expand Down Expand Up @@ -460,33 +466,37 @@ impl<'d> WifiDriver<'d> {
sysloop: &EspEventLoop<System>,
) -> Result<
(
Arc<mutex::Mutex<(WifiEvent, WifiEvent, Option<WpsEvent>)>>,
Arc<mutex::Mutex<WifiDriverStatus>>,
EspSubscription<'static, System>,
),
EspError,
> {
let status = Arc::new(mutex::Mutex::wrap(
mutex::RawMutex::new(),
(WifiEvent::StaStopped, WifiEvent::ApStopped, None),
WifiDriverStatus {
sta: WifiEvent::StaStopped,
ap: WifiEvent::ApStopped,
wps: None,
},
));
let s_status = status.clone();

let subscription = sysloop.subscribe(move |event: &WifiEvent| {
let mut guard = s_status.lock();

match event {
WifiEvent::ApStarted => guard.1 = WifiEvent::ApStarted,
WifiEvent::ApStopped => guard.1 = WifiEvent::ApStopped,
WifiEvent::StaStarted => guard.0 = WifiEvent::StaStarted,
WifiEvent::StaStopped => guard.0 = WifiEvent::StaStopped,
WifiEvent::StaConnected => guard.0 = WifiEvent::StaConnected,
WifiEvent::StaDisconnected => guard.0 = WifiEvent::StaDisconnected,
WifiEvent::ScanDone => guard.0 = WifiEvent::ScanDone,
WifiEvent::ApStarted => guard.ap = WifiEvent::ApStarted,
WifiEvent::ApStopped => guard.ap = WifiEvent::ApStopped,
WifiEvent::StaStarted => guard.sta = WifiEvent::StaStarted,
WifiEvent::StaStopped => guard.sta = WifiEvent::StaStopped,
WifiEvent::StaConnected => guard.sta = WifiEvent::StaConnected,
WifiEvent::StaDisconnected => guard.sta = WifiEvent::StaDisconnected,
WifiEvent::ScanDone => guard.sta = WifiEvent::ScanDone,
WifiEvent::StaWpsSuccess(_)
| WifiEvent::StaWpsFailed
| WifiEvent::StaWpsTimeout
| WifiEvent::StaWpsPin(_)
| WifiEvent::StaWpsPbcOverlap => guard.2 = Some(event.try_into().unwrap()),
| WifiEvent::StaWpsPbcOverlap => guard.wps = Some(event.try_into().unwrap()),
_ => (),
};
})?;
Expand Down Expand Up @@ -621,20 +631,23 @@ impl<'d> WifiDriver<'d> {
}

pub fn is_ap_started(&self) -> Result<bool, EspError> {
Ok(self.status.lock().1 == WifiEvent::ApStarted)
Ok(self.status.lock().ap == WifiEvent::ApStarted)
}

pub fn is_sta_started(&self) -> Result<bool, EspError> {
let guard = self.status.lock();

Ok(guard.0 == WifiEvent::StaStarted
|| guard.0 == WifiEvent::StaConnected
|| guard.0 == WifiEvent::ScanDone
|| guard.0 == WifiEvent::StaDisconnected)
Ok(matches!(
guard.sta,
WifiEvent::StaStarted
| WifiEvent::StaConnected
| WifiEvent::ScanDone
| WifiEvent::StaDisconnected
))
}

pub fn is_sta_connected(&self) -> Result<bool, EspError> {
Ok(self.status.lock().0 == WifiEvent::StaConnected)
Ok(self.status.lock().sta == WifiEvent::StaConnected)
}

pub fn is_started(&self) -> Result<bool, EspError> {
Expand All @@ -660,15 +673,15 @@ impl<'d> WifiDriver<'d> {
} else {
let guard = self.status.lock();

Ok((!ap_enabled || guard.1 == WifiEvent::ApStarted)
&& (!sta_enabled || guard.0 == WifiEvent::StaConnected))
Ok((!ap_enabled || guard.ap == WifiEvent::ApStarted)
&& (!sta_enabled || guard.sta == WifiEvent::StaConnected))
}
}

pub fn is_scan_done(&self) -> Result<bool, EspError> {
let guard = self.status.lock();

Ok(guard.0 == WifiEvent::ScanDone)
Ok(guard.sta == WifiEvent::ScanDone)
}

#[allow(non_upper_case_globals)]
Expand Down Expand Up @@ -1077,32 +1090,33 @@ impl<'d> WifiDriver<'d> {
pub fn start_wps(&mut self, config: &WpsConfig) -> Result<(), EspError> {
let config = Newtype::<esp_wps_config_t>::try_from(config)?;

self.set_configuration(&Configuration::None)?;
esp!(unsafe { esp_wifi_set_mode(wifi_mode_t_WIFI_MODE_STA) })?;

if !self.is_started()? {
self.start()?;
match self.get_configuration()? {
Configuration::None => esp!(unsafe { esp_wifi_set_mode(wifi_mode_t_WIFI_MODE_STA) })?,
Configuration::AccessPoint(_) => {
esp!(unsafe { esp_wifi_set_mode(wifi_mode_t_WIFI_MODE_APSTA) })?
}
_ => (),
}

esp!(unsafe { esp_wifi_wps_enable(&config.0 as *const _) })?;
esp!(unsafe { esp_wifi_wps_start(0) })?;

self.status.lock().2 = Some(WpsEvent::Active);
self.status.lock().wps = Some(WpsEvent::Active);

Ok(())
}

pub fn is_wps_active(&self) -> Result<bool, EspError> {
let status = self.status.lock();
Ok(matches!(status.2, Some(WpsEvent::Active)))
Ok(matches!(status.wps, Some(WpsEvent::Active)))
}

/// Gets the WPS status as a [`WPS Event`] and disables WPS.
pub fn take_wps_event(&mut self) -> Result<WpsEvent, EspError> {
fn take_wps_event(&mut self) -> Result<WpsEvent, EspError> {
let mut status = self.status.lock();
if status.2.is_some() && !matches!(status.2.as_ref().unwrap(), WpsEvent::Active) {
if status.wps.is_some() && !matches!(status.wps.as_ref().unwrap(), WpsEvent::Active) {
esp!(unsafe { esp_wifi_wps_disable() })?;
Ok(status.2.take().unwrap())
Ok(status.wps.take().unwrap())
} else {
Err(EspError::from_infallible::<ESP_ERR_INVALID_STATE>())
}
Expand Down Expand Up @@ -1766,6 +1780,8 @@ pub struct WpsCredentials {
pub passphrase: heapless::String<64>,
}

const MAX_WPS_AP_CRED_USIZE: usize = MAX_WPS_AP_CRED as usize;

#[derive(Clone, Debug, Eq, PartialEq)]
pub enum WifiEvent {
Ready,
Expand All @@ -1779,7 +1795,7 @@ pub enum WifiEvent {
StaAuthmodeChanged,
StaBssRssiLow,
StaBeaconTimeout,
StaWpsSuccess(Option<Arc<heapless::Vec<WpsCredentials, 3>>>),
StaWpsSuccess(Option<Arc<heapless::Vec<WpsCredentials, MAX_WPS_AP_CRED_USIZE>>>),
StaWpsFailed,
StaWpsTimeout,
StaWpsPin(heapless::String<8>),
Expand Down Expand Up @@ -1832,9 +1848,11 @@ impl EspTypedEventDeserializer<WifiEvent> for WifiEvent {
for i in 0..payload.ap_cred_cnt {
let creds = &payload.ap_cred[i as usize];
let Ok(ssid) = core::str::from_utf8(&creds.ssid) else {
log::warn!("Received a non-UTF-8 SSID via WPS");
continue;
};
let Ok(passphrase) = core::str::from_utf8(&creds.passphrase) else {
log::warn!("Received a non-UTF-8 passphrase via WPS");
continue;
};
let creds = WpsCredentials {
Expand All @@ -1852,10 +1870,10 @@ impl EspTypedEventDeserializer<WifiEvent> for WifiEvent {
WifiEvent::StaWpsTimeout
} else if event_id == wifi_event_t_WIFI_EVENT_STA_WPS_ER_PIN {
let payload = unsafe { (data.payload as *const wifi_event_sta_wps_er_pin_t).as_ref() };
let pin = payload.unwrap().pin_code;
let pin_str = core::str::from_utf8(&pin).unwrap_or("bad pin");
let pin_str = heapless::String::from(pin_str);
WifiEvent::StaWpsPin(pin_str)
let pin = payload
.and_then(|x| core::str::from_utf8(&x.pin_code).ok())
.unwrap_or("bad pin");
WifiEvent::StaWpsPin(heapless::String::from(pin))
} else if event_id == wifi_event_t_WIFI_EVENT_STA_WPS_ER_PBC_OVERLAP {
WifiEvent::StaWpsPbcOverlap
} else if event_id == wifi_event_t_WIFI_EVENT_AP_START {
Expand Down Expand Up @@ -2442,10 +2460,10 @@ impl TryFrom<&WpsFactoryInfo<'_>> for Newtype<wps_factory_information_t> {
device_name: [0; 33],
});

set_cchar_str(&mut result.0.manufacturer, info.manufacturer)?;
set_cchar_str(&mut result.0.model_number, info.model_number)?;
set_cchar_str(&mut result.0.model_name, info.model_name)?;
set_cchar_str(&mut result.0.device_name, info.device_name)?;
set_cchar_slice(&mut result.0.manufacturer, info.manufacturer)?;
set_cchar_slice(&mut result.0.model_number, info.model_number)?;
set_cchar_slice(&mut result.0.model_name, info.model_name)?;
set_cchar_slice(&mut result.0.device_name, info.device_name)?;

Ok(result)
}
Expand Down Expand Up @@ -2483,7 +2501,7 @@ impl WpsType {
#[derive(Debug)]
pub enum WpsEvent {
Active,
Success(Option<Arc<heapless::Vec<WpsCredentials, 3>>>),
Success(Option<Arc<heapless::Vec<WpsCredentials, MAX_WPS_AP_CRED_USIZE>>>),
Failure,
Timeout,
Pin(heapless::String<8>),
Expand Down

0 comments on commit 1256108

Please sign in to comment.