Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions awkernel_drivers/src/pcie/virtio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub enum VirtioDriverErr {
NoVirtqueue,
InvalidQueueSize,
DMAPool,
InvalidPacket,
}

impl From<VirtioDriverErr> for PCIeDeviceErr {
Expand All @@ -46,6 +47,7 @@ impl From<VirtioDriverErr> for PCIeDeviceErr {
VirtioDriverErr::NoVirtqueue => PCIeDeviceErr::InitFailure,
VirtioDriverErr::InvalidQueueSize => PCIeDeviceErr::InitFailure,
VirtioDriverErr::DMAPool => PCIeDeviceErr::InitFailure,
VirtioDriverErr::InvalidPacket => PCIeDeviceErr::CommandFailure,
}
}
}
105 changes: 87 additions & 18 deletions awkernel_drivers/src/pcie/virtio/virtio_net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,23 @@ use awkernel_lib::{
addr::Addr,
dma_pool::DMAPool,
interrupt::IRQ,
net::net_device::{
EtherFrameBuf, EtherFrameRef, LinkStatus, NetCapabilities, NetDevError, NetDevice, NetFlags,
net::{
ether::{extract_headers, EtherHeader, EtherVlanHeader, NetworkHdr, TransportHdr},
ipv6::Ip6Hdr,
net_device::{
EtherFrameBuf, EtherFrameRef, LinkStatus, NetCapabilities, NetDevError, NetDevice,
NetFlags, PacketHeaderFlags,
},
tcp::TCPHdr,
udp::UDPHdr,
},
paging::PAGESIZE,
sync::{
mutex::{MCSNode, Mutex},
rwlock::RwLock,
},
};
use memoffset::offset_of;

const DEVICE_SHORT_NAME: &str = "virtio-net";

Expand All @@ -38,6 +46,7 @@ const RECV_QUEUE_SIZE: usize = 32; // To Be Determined
const VIRTIO_NET_ID: u16 = 0x1041;

// device-specific feature bits
const VIRTIO_NET_F_CSUM: u64 = 1 << 0;
const VIRTIO_NET_F_MAC: u64 = 1 << 5;
const VIRTIO_NET_F_STATUS: u64 = 1 << 16;
const VIRTIO_NET_F_SPEED_DUPLEX: u64 = 1 << 63;
Expand Down Expand Up @@ -274,6 +283,59 @@ impl Virtq {
}
}

fn vio_tx_offload(&mut self, frame: &EtherFrameRef) -> Result<VirtioNetHdr, VirtioDriverErr> {
let mut hdr = VirtioNetHdr::default();

let has_tcp_csum_out = frame.csum_flags.contains(PacketHeaderFlags::TCP_CSUM_OUT);
let has_udp_csum_out = frame.csum_flags.contains(PacketHeaderFlags::UDP_CSUM_OUT);
if !has_tcp_csum_out && !has_udp_csum_out {
return Ok(hdr);
}

let ext = extract_headers(frame.data).or(Err(VirtioDriverErr::InvalidPacket))?;

// Consistency check
match ext.network {
NetworkHdr::Ipv4(_) => (),
NetworkHdr::Ipv6(_) => (),
_ => return Ok(hdr),
}
match ext.transport {
TransportHdr::Tcp(_) => {
if !has_tcp_csum_out {
return Ok(hdr);
}
}
TransportHdr::Udp(_) => {
if !has_udp_csum_out {
return Ok(hdr);
}
}
_ => return Ok(hdr),
}

hdr.csum_start = match ext.ether_vlan {
Some(_) => core::mem::size_of::<EtherVlanHeader>() as u16,
None => core::mem::size_of::<EtherHeader>() as u16,
};

hdr.csum_start += match ext.network {
NetworkHdr::Ipv4(ip) => ip.header_len() as u16,
NetworkHdr::Ipv6(_) => core::mem::size_of::<Ip6Hdr>() as u16,
_ => 0,
};

hdr.csum_offset = match ext.transport {
TransportHdr::Tcp(_) => offset_of!(TCPHdr, th_sum) as u16,
TransportHdr::Udp(_) => offset_of!(UDPHdr, uh_sum) as u16,
_ => 0,
};

hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM;

Ok(hdr)
}

fn vio_tx_dequeue(&mut self) -> u16 {
let mut freed = 0;
while let Some((slot, _len)) = self.virtio_dequeue() {
Expand All @@ -295,25 +357,25 @@ impl Virtq {
self.vio_txeof();
}

fn vio_encap(&mut self, slot: usize, frame: &EtherFrameRef) -> usize {
let len = frame.data.len();
fn vio_encap(&mut self, slot: usize, frame: &EtherFrameRef, header: &VirtioNetHdr) -> usize {
let buf = self.data_buf.as_mut();
let dst = &mut buf[slot].as_mut_ptr();
let header_len = core::mem::size_of::<VirtioNetHdr>();
let data_len = frame.data.len();
unsafe {
// TODO: handle VirtIO-net header
// For now, we just skip the header by dst.add(header_len)
core::ptr::copy_nonoverlapping(frame.data.as_ptr(), dst.add(header_len), len);
core::ptr::copy_nonoverlapping(header as *const _ as *const u8, dst.add(0), header_len);
core::ptr::copy_nonoverlapping(frame.data.as_ptr(), dst.add(header_len), data_len);
}

header_len + len
header_len + data_len
}

fn vio_start(&mut self, frame: &EtherFrameRef) {
fn vio_start(&mut self, frame: &EtherFrameRef) -> Result<(), VirtioDriverErr> {
self.vio_tx_dequeue();

if let Some(slot) = self.virtio_enqueue_prep() {
let len = self.vio_encap(slot, frame);
let header = self.vio_tx_offload(frame)?;
let len = self.vio_encap(slot, frame, &header);
self.virtio_enqueue_reserve(slot);
self.virtio_enqueue(slot, len, true);
self.virtio_enqueue_commit(slot);
Expand All @@ -322,10 +384,13 @@ impl Virtq {
if self.virtio_start_vq_intr() {
self.vio_tx_dequeue();
}

Ok(())
}
}

/// Packet header structure
#[derive(Default)]
#[repr(C, packed)]
struct VirtioNetHdr {
flags: u8,
Expand All @@ -337,13 +402,7 @@ struct VirtioNetHdr {
num_buffers: u16, // only present if VIRTIO_NET_F_MRG_RXBUF is negotiated
}

const _VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1;
const _VIRTIO_NET_HDR_F_DATA_VALID: u8 = 2;
const _VIRTIO_NET_HDR_GSO_NONE: u8 = 0;
const _VIRTIO_NET_HDR_GSO_TCPV4: u8 = 1;
const _VIRTIO_NET_HDR_GSO_UDP: u8 = 3;
const _VIRTIO_NET_HDR_GSO_TCPV6: u8 = 4;
const _VIRTIO_NET_HDR_GSO_ECN: u8 = 0x80;
const VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1;

pub fn match_device(vendor: u16, id: u16) -> bool {
vendor == pcie_id::VIRTIO_VENDOR_ID && id == VIRTIO_NET_ID
Expand Down Expand Up @@ -456,6 +515,7 @@ impl VirtioNetInner {
self.driver_features |= VIRTIO_NET_F_MAC;
self.driver_features |= VIRTIO_NET_F_STATUS;
self.driver_features |= VIRTIO_NET_F_SPEED_DUPLEX;
self.driver_features |= VIRTIO_NET_F_CSUM;

self.virtio_pci_negotiate_features()?;

Expand All @@ -468,6 +528,15 @@ impl VirtioNetInner {
self.capabilities = NetCapabilities::empty();
self.flags = NetFlags::BROADCAST | NetFlags::SIMPLEX | NetFlags::MULTICAST;

if self.virtio_has_feature(VIRTIO_NET_F_CSUM) {
self.capabilities |= NetCapabilities::CSUM_UDPv4;

// NOTE: we currently only support UDPv4
// self.capabilities |= NetCapabilities::CSUM_TCPv4;
// self.capabilities |= NetCapabilities::CSUM_TCPv6;
// self.capabilities |= NetCapabilities::CSUM_UDPv6;
}

let num_queues = 1; // TODO: support multiple queues
for i in 0..num_queues {
let mut rx = self.virtio_alloc_vq(2 * i)?;
Expand Down Expand Up @@ -955,7 +1024,7 @@ impl NetDevice for VirtioNet {
let inner = self.inner.read();
let mut node = MCSNode::new();
let mut tx = inner.virtqueues[que_id].tx.lock(&mut node);
tx.vio_start(&data);
tx.vio_start(&data).or(Err(NetDevError::DeviceError))?;
}

let tx_vq_index = (que_id * 2 + 1) as u16;
Expand Down