diff --git a/awkernel_lib/src/net.rs b/awkernel_lib/src/net.rs index b8446a086..e8eef3453 100644 --- a/awkernel_lib/src/net.rs +++ b/awkernel_lib/src/net.rs @@ -14,12 +14,6 @@ use self::{ net_device::{LinkStatus, NetCapabilities, NetDevice}, }; -#[cfg(not(feature = "std"))] -use self::tcp::TcpPort; - -#[cfg(not(feature = "std"))] -use alloc::collections::BTreeSet; - #[cfg(not(feature = "std"))] use alloc::{string::String, vec::Vec}; @@ -34,6 +28,7 @@ pub mod ip_addr; pub mod ipv6; pub mod multicast; pub mod net_device; +mod port_alloc; pub mod tcp; pub mod tcp_listener; pub mod tcp_stream; @@ -132,30 +127,6 @@ impl Display for IfStatus { static NET_MANAGER: RwLock = RwLock::new(NetManager { interfaces: BTreeMap::new(), interface_id: 0, - - #[cfg(not(feature = "std"))] - udp_ports_ipv4: BTreeSet::new(), - - #[cfg(not(feature = "std"))] - udp_port_ipv4_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - udp_ports_ipv6: BTreeSet::new(), - - #[cfg(not(feature = "std"))] - udp_port_ipv6_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv4: BTreeMap::new(), - - #[cfg(not(feature = "std"))] - tcp_port_ipv4_ephemeral: u16::MAX >> 2, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv6: BTreeMap::new(), - - #[cfg(not(feature = "std"))] - tcp_port_ipv6_ephemeral: u16::MAX >> 2, }); static IRQ_WAKERS: Mutex> = Mutex::new(BTreeMap::new()); @@ -164,208 +135,6 @@ static POLL_WAKERS: Mutex> = Mutex::new(BTreeMap::new()) pub struct NetManager { interfaces: BTreeMap>, interface_id: u64, - - #[cfg(not(feature = "std"))] - udp_ports_ipv4: BTreeSet, - - #[cfg(not(feature = "std"))] - udp_port_ipv4_ephemeral: u16, - - #[cfg(not(feature = "std"))] - udp_ports_ipv6: BTreeSet, - - #[cfg(not(feature = "std"))] - udp_port_ipv6_ephemeral: u16, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv4: BTreeMap, - - #[cfg(not(feature = "std"))] - tcp_port_ipv4_ephemeral: u16, - - #[cfg(not(feature = "std"))] - tcp_ports_ipv6: BTreeMap, - - #[cfg(not(feature = "std"))] - tcp_port_ipv6_ephemeral: u16, -} - -impl NetManager { - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_udp_ipv4(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.udp_port_ipv4_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - if !self.udp_ports_ipv4.contains(&port) { - self.udp_ports_ipv4.insert(port); - self.udp_port_ipv4_ephemeral = port; - ephemeral_port = Some(port); - break; - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn set_port_in_use_udp_ipv4(&mut self, port: u16) { - self.udp_ports_ipv4.insert(port); - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_udp_ipv4(&mut self, port: u16) -> bool { - self.udp_ports_ipv4.contains(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn free_port_udp_ipv4(&mut self, port: u16) { - self.udp_ports_ipv4.remove(&port); - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_udp_ipv6(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.udp_port_ipv6_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - if !self.udp_ports_ipv6.contains(&port) { - self.udp_ports_ipv6.insert(port); - self.udp_port_ipv4_ephemeral = port; - ephemeral_port = Some(port); - break; - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn set_port_in_use_udp_ipv6(&mut self, port: u16) { - self.udp_ports_ipv6.insert(port); - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_udp_ipv6(&mut self, port: u16) -> bool { - self.udp_ports_ipv6.contains(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn free_port_udp_ipv6(&mut self, port: u16) { - self.udp_ports_ipv6.remove(&port); - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_tcp_ipv4(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.tcp_port_ipv4_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - let entry = self.tcp_ports_ipv4.entry(i); - - match entry { - Entry::Occupied(_) => (), - Entry::Vacant(e) => { - e.insert(1); - ephemeral_port = Some(TcpPort::new(port, true)); - self.tcp_port_ipv4_ephemeral = port; - break; - } - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_tcp_ipv4(&mut self, port: u16) -> bool { - self.tcp_ports_ipv4.contains_key(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn port_in_use_tcp_ipv4(&mut self, port: u16) -> TcpPort { - if let Some(e) = self.tcp_ports_ipv4.get_mut(&port) { - *e += 1; - } else { - self.tcp_ports_ipv4.insert(port, 1); - } - - TcpPort::new(port, true) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn decrement_port_in_use_tcp_ipv4(&mut self, port: u16) { - if let Some(e) = self.tcp_ports_ipv4.get_mut(&port) { - *e -= 1; - if *e == 0 { - self.tcp_ports_ipv4.remove(&port); - } - } - } - - #[cfg(not(feature = "std"))] - fn get_ephemeral_port_tcp_ipv6(&mut self) -> Option { - let mut ephemeral_port = None; - for i in 0..(u16::MAX >> 2) { - let port = self.tcp_port_ipv6_ephemeral.wrapping_add(i); - let port = if port == 0 { u16::MAX >> 2 } else { port }; - - let entry = self.tcp_ports_ipv6.entry(i); - - match entry { - Entry::Occupied(_) => (), - Entry::Vacant(e) => { - e.insert(1); - ephemeral_port = Some(TcpPort::new(port, false)); - self.tcp_port_ipv6_ephemeral = port; - break; - } - } - } - - ephemeral_port - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn is_port_in_use_tcp_ipv6(&mut self, port: u16) -> bool { - self.tcp_ports_ipv6.contains_key(&port) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn port_in_use_tcp_ipv6(&mut self, port: u16) -> TcpPort { - if let Some(e) = self.tcp_ports_ipv6.get_mut(&port) { - *e += 1; - } else { - self.tcp_ports_ipv6.insert(port, 1); - } - - TcpPort::new(port, true) - } - - #[cfg(not(feature = "std"))] - #[inline(always)] - fn decrement_port_in_use_tcp_ipv6(&mut self, port: u16) { - if let Some(e) = self.tcp_ports_ipv6.get_mut(&port) { - *e -= 1; - if *e == 0 { - self.tcp_ports_ipv6.remove(&port); - } - } - } } pub fn get_interface(interface_id: u64) -> Result { diff --git a/awkernel_lib/src/net/port_alloc.rs b/awkernel_lib/src/net/port_alloc.rs new file mode 100644 index 000000000..57d427bde --- /dev/null +++ b/awkernel_lib/src/net/port_alloc.rs @@ -0,0 +1,223 @@ +#![cfg(not(feature = "std"))] + +use alloc::collections::{btree_map::Entry, BTreeMap, BTreeSet}; + +use crate::sync::{mcs::MCSNode, mutex::Mutex}; + +use super::tcp::TcpPort; + +struct TcpPortsInner { + map: BTreeMap, + cursor: u16, +} + +struct UdpPortsInner { + set: BTreeSet, + cursor: u16, +} + +pub(super) struct PortAllocator { + tcp_ipv4: Mutex, + tcp_ipv6: Mutex, + udp_ipv4: Mutex, + udp_ipv6: Mutex, +} + +pub(super) static PORT_ALLOC: PortAllocator = PortAllocator::new(); + +impl PortAllocator { + pub(super) const fn new() -> Self { + Self { + tcp_ipv4: Mutex::new(TcpPortsInner { + map: BTreeMap::new(), + cursor: u16::MAX >> 2, + }), + tcp_ipv6: Mutex::new(TcpPortsInner { + map: BTreeMap::new(), + cursor: u16::MAX >> 2, + }), + udp_ipv4: Mutex::new(UdpPortsInner { + set: BTreeSet::new(), + cursor: u16::MAX >> 2, + }), + udp_ipv6: Mutex::new(UdpPortsInner { + set: BTreeSet::new(), + cursor: u16::MAX >> 2, + }), + } + } + + /// Allocate an ephemeral TCP IPv4 port. + pub(super) fn get_ephemeral_tcp_ipv4(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + return Some(TcpPort::new(port, true)); + } + } + None + } + + /// Claim a specific TCP IPv4 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_tcp_ipv4(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + Some(TcpPort::new(port, true)) + } else { + None + } + } + + /// Increment the reference count for a TCP IPv4 port (used by `TcpListener::accept`). + pub(super) fn increment_ref_tcp_ipv4(&self, port: u16) -> TcpPort { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e += 1; + } else { + ports.map.insert(port, 1); + } + TcpPort::new(port, true) + } + + /// Decrement the reference count for a TCP IPv4 port, freeing it when it reaches zero. + pub(super) fn decrement_ref_tcp_ipv4(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv4.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e -= 1; + if *e == 0 { + ports.map.remove(&port); + } + } + } + + /// Allocate an ephemeral TCP IPv6 port. + pub(super) fn get_ephemeral_tcp_ipv6(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + return Some(TcpPort::new(port, false)); + } + } + None + } + + /// Claim a specific TCP IPv6 port. Returns `None` if the port is already in use. + pub(super) fn try_claim_tcp_ipv6(&self, port: u16) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Entry::Vacant(e) = ports.map.entry(port) { + e.insert(1); + Some(TcpPort::new(port, false)) + } else { + None + } + } + + /// Increment the reference count for a TCP IPv6 port. + pub(super) fn increment_ref_tcp_ipv6(&self, port: u16) -> TcpPort { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e += 1; + } else { + ports.map.insert(port, 1); + } + TcpPort::new(port, false) + } + + /// Decrement the reference count for a TCP IPv6 port, freeing it when it reaches zero. + pub(super) fn decrement_ref_tcp_ipv6(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.tcp_ipv6.lock(&mut node); + if let Some(e) = ports.map.get_mut(&port) { + *e -= 1; + if *e == 0 { + ports.map.remove(&port); + } + } + } + + /// Allocate an ephemeral UDP IPv4 port. + pub(super) fn get_ephemeral_udp_ipv4(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if ports.set.insert(port) { + return Some(port); + } + } + None + } + + /// Claim a specific UDP IPv4 port. Returns `false` if the port is already in use. + pub(super) fn try_claim_udp_ipv4(&self, port: u16) -> bool { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + ports.set.insert(port) + } + + /// Free a UDP IPv4 port. + pub(super) fn free_udp_ipv4(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv4.lock(&mut node); + ports.set.remove(&port); + } + + /// Allocate an ephemeral UDP IPv6 port. + pub(super) fn get_ephemeral_udp_ipv6(&self) -> Option { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + for _ in 0..(u16::MAX >> 2) { + ports.cursor = ports.cursor.wrapping_add(1); + let port = if ports.cursor == 0 { + u16::MAX >> 2 + } else { + ports.cursor + }; + if ports.set.insert(port) { + return Some(port); + } + } + None + } + + /// Claim a specific UDP IPv6 port. Returns `false` if the port is already in use. + pub(super) fn try_claim_udp_ipv6(&self, port: u16) -> bool { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + ports.set.insert(port) + } + + /// Free a UDP IPv6 port. + pub(super) fn free_udp_ipv6(&self, port: u16) { + let mut node = MCSNode::new(); + let mut ports = self.udp_ipv6.lock(&mut node); + ports.set.remove(&port); + } +} diff --git a/awkernel_lib/src/net/tcp.rs b/awkernel_lib/src/net/tcp.rs index 0fc0f5cff..2aed89cd4 100644 --- a/awkernel_lib/src/net/tcp.rs +++ b/awkernel_lib/src/net/tcp.rs @@ -19,13 +19,14 @@ impl TCPHdr { } } -#[allow(dead_code)] +#[cfg(not(feature = "std"))] #[derive(Debug)] pub struct TcpPort { port: u16, is_ipv4: bool, } +#[cfg(not(feature = "std"))] impl TcpPort { pub fn new(port: u16, is_ipv4: bool) -> Self { Self { port, is_ipv4 } @@ -37,16 +38,13 @@ impl TcpPort { } } +#[cfg(not(feature = "std"))] impl Drop for TcpPort { fn drop(&mut self) { - #[cfg(not(feature = "std"))] - { - let mut net_manager = super::NET_MANAGER.write(); - if self.is_ipv4 { - net_manager.decrement_port_in_use_tcp_ipv4(self.port); - } else { - net_manager.decrement_port_in_use_tcp_ipv6(self.port); - } + if self.is_ipv4 { + super::port_alloc::PORT_ALLOC.decrement_ref_tcp_ipv4(self.port); + } else { + super::port_alloc::PORT_ALLOC.decrement_ref_tcp_ipv6(self.port); } } } diff --git a/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs b/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs index 8d548e368..030c7a6b5 100644 --- a/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs +++ b/awkernel_lib/src/net/tcp_listener/tcp_listener_no_std.rs @@ -6,7 +6,8 @@ use crate::sync::mcs::MCSNode; use alloc::{vec, vec::Vec}; use crate::net::{ - ip_addr::IpAddr, tcp::TcpPort, tcp_stream::TcpStream, NetManagerError, NET_MANAGER, + ip_addr::IpAddr, port_alloc::PORT_ALLOC, tcp::TcpPort, tcp_stream::TcpStream, NetManagerError, + NET_MANAGER, }; use super::SockTcpListener; @@ -30,14 +31,15 @@ impl SockTcpListener for TcpListener { tx_buffer_size: usize, backlogs: usize, ) -> Result { - let mut net_manager = NET_MANAGER.write(); - // Find the interface that has the specified address. - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)? - .clone(); + let if_net = { + let net_manager = NET_MANAGER.read(); + net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)? + .clone() + }; let port = if let Some(port) = port { if port == 0 { @@ -45,34 +47,26 @@ impl SockTcpListener for TcpListener { } if addr.is_ipv4() { - // Check if the specified port is available. - if net_manager.is_port_in_use_tcp_ipv4(port) { - return Err(NetManagerError::PortInUse); - } - - net_manager.port_in_use_tcp_ipv4(port) + PORT_ALLOC + .try_claim_tcp_ipv4(port) + .ok_or(NetManagerError::PortInUse)? } else { - // Check if the specified port is available. - if net_manager.is_port_in_use_tcp_ipv6(port) { - return Err(NetManagerError::PortInUse); - } - - net_manager.port_in_use_tcp_ipv6(port) + PORT_ALLOC + .try_claim_tcp_ipv6(port) + .ok_or(NetManagerError::PortInUse)? } } else if addr.is_ipv4() { // Find an ephemeral port. - net_manager - .get_ephemeral_port_tcp_ipv4() + PORT_ALLOC + .get_ephemeral_tcp_ipv4() .ok_or(NetManagerError::NoAvailablePort)? } else { // Find an ephemeral port. - net_manager - .get_ephemeral_port_tcp_ipv6() + PORT_ALLOC + .get_ephemeral_tcp_ipv6() .ok_or(NetManagerError::NoAvailablePort)? }; - drop(net_manager); - let mut handles = Vec::new(); for _ in 0..backlogs { @@ -98,13 +92,10 @@ impl SockTcpListener for TcpListener { fn accept(&mut self, waker: &core::task::Waker) -> Result, NetManagerError> { // If there is a connected socket, return it. if let Some(handle) = self.connected_sockets.pop_front() { - let port = { - let mut net_manager = NET_MANAGER.write(); - if self.addr.is_ipv4() { - net_manager.port_in_use_tcp_ipv4(self.port.port()) - } else { - net_manager.port_in_use_tcp_ipv6(self.port.port()) - } + let port = if self.addr.is_ipv4() { + PORT_ALLOC.increment_ref_tcp_ipv4(self.port.port()) + } else { + PORT_ALLOC.increment_ref_tcp_ipv6(self.port.port()) }; return Ok(Some(TcpStream { handle, @@ -171,13 +162,10 @@ impl SockTcpListener for TcpListener { // If there is a connected socket, return it. if let Some(handle) = self.connected_sockets.pop_front() { - let port = { - let mut net_manager = NET_MANAGER.write(); - if self.addr.is_ipv4() { - net_manager.port_in_use_tcp_ipv4(self.port.port()) - } else { - net_manager.port_in_use_tcp_ipv6(self.port.port()) - } + let port = if self.addr.is_ipv4() { + PORT_ALLOC.increment_ref_tcp_ipv4(self.port.port()) + } else { + PORT_ALLOC.increment_ref_tcp_ipv6(self.port.port()) }; if_net.poll_tx_only(crate::cpu::raw_cpu_id() & (if_net.net_device.num_queues() - 1)); diff --git a/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs b/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs index f0e1dd6d3..3011f3099 100644 --- a/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs +++ b/awkernel_lib/src/net/tcp_stream/tcp_stream_no_std.rs @@ -1,4 +1,6 @@ -use crate::net::{ip_addr::IpAddr, tcp::TcpPort, NetManagerError, NET_MANAGER}; +use crate::net::{ + ip_addr::IpAddr, port_alloc::PORT_ALLOC, tcp::TcpPort, NetManagerError, NET_MANAGER, +}; use super::{SockTcpStream, TcpResult}; @@ -90,26 +92,25 @@ impl SockTcpStream for TcpStream { tx_buffer_size: usize, waker: &core::task::Waker, ) -> Result { - let mut net_manager = NET_MANAGER.write(); - - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)?; - let if_net = if_net.clone(); + let if_net = { + let net_manager = NET_MANAGER.read(); + net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)? + .clone() + }; let local_port = if remote_addr.is_ipv4() { - net_manager - .get_ephemeral_port_tcp_ipv4() + PORT_ALLOC + .get_ephemeral_tcp_ipv4() .ok_or(NetManagerError::NoAvailablePort)? } else { - net_manager - .get_ephemeral_port_tcp_ipv6() + PORT_ALLOC + .get_ephemeral_tcp_ipv6() .ok_or(NetManagerError::NoAvailablePort)? }; - drop(net_manager); - let rx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; rx_buffer_size]); let tx_buffer = smoltcp::socket::tcp::SocketBuffer::new(vec![0; tx_buffer_size]); diff --git a/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs b/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs index bf332ac7d..b2fca5937 100644 --- a/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs +++ b/awkernel_lib/src/net/udp_socket/udp_socket_no_std.rs @@ -1,6 +1,6 @@ use core::net::Ipv4Addr; -use crate::net::{ip_addr::IpAddr, NET_MANAGER}; +use crate::net::{ip_addr::IpAddr, port_alloc::PORT_ALLOC, NET_MANAGER}; use awkernel_sync::{mcs::MCSNode, mutex::Mutex}; use super::{NetManagerError, SockUdp}; @@ -32,7 +32,15 @@ impl super::SockUdp for UdpSocket { rx_buffer_size: usize, tx_buffer_size: usize, ) -> Result { - let mut net_manager = NET_MANAGER.write(); + // Find the interface that has the specified address. + let if_net = { + let net_manager = NET_MANAGER.read(); + net_manager + .interfaces + .get(&interface_id) + .ok_or(NetManagerError::InvalidInterfaceID)? + .clone() + }; let is_ipv4; let port = if let Some(port) = port { @@ -40,48 +48,35 @@ impl super::SockUdp for UdpSocket { return Err(NetManagerError::InvalidPort); } - // Check if the specified port is available. + // Check if the specified port is available and claim it atomically. if addr.is_ipv4() { - if net_manager.is_port_in_use_udp_ipv4(port) { + if !PORT_ALLOC.try_claim_udp_ipv4(port) { return Err(NetManagerError::PortInUse); } - is_ipv4 = true; - net_manager.set_port_in_use_udp_ipv4(port); port } else { - if net_manager.is_port_in_use_udp_ipv6(port) { + if !PORT_ALLOC.try_claim_udp_ipv6(port) { return Err(NetManagerError::PortInUse); } - is_ipv4 = false; - net_manager.set_port_in_use_udp_ipv6(port); port } } else { // Find an ephemeral port. if addr.is_ipv4() { is_ipv4 = true; - net_manager - .get_ephemeral_port_udp_ipv4() + PORT_ALLOC + .get_ephemeral_udp_ipv4() .ok_or(NetManagerError::PortInUse)? } else { is_ipv4 = false; - net_manager - .get_ephemeral_port_udp_ipv6() + PORT_ALLOC + .get_ephemeral_udp_ipv6() .ok_or(NetManagerError::PortInUse)? } }; - // Find the interface that has the specified address. - let if_net = net_manager - .interfaces - .get(&interface_id) - .ok_or(NetManagerError::InvalidInterfaceID)? - .clone(); - - drop(net_manager); - // Create a UDP socket. use smoltcp::socket::udp; let udp_rx_buffer = udp::PacketBuffer::new( @@ -331,16 +326,17 @@ impl Drop for UdpSocket { } } - let mut net_manager = NET_MANAGER.write(); + { + let net_manager = NET_MANAGER.read(); + if let Some(if_net) = net_manager.interfaces.get(&self.interface_id) { + if_net.socket_set.write().remove(self.handle); + } + } if self.is_ipv4 { - net_manager.free_port_udp_ipv4(self.port); + PORT_ALLOC.free_udp_ipv4(self.port); } else { - net_manager.free_port_udp_ipv6(self.port); - } - - if let Some(if_net) = net_manager.interfaces.get(&self.interface_id) { - if_net.socket_set.write().remove(self.handle); + PORT_ALLOC.free_udp_ipv6(self.port); } } }