diff --git a/src/kb.rs b/src/kb.rs index 5258c135..3b293451 100644 --- a/src/kb.rs +++ b/src/kb.rs @@ -27,3 +27,58 @@ pub enum Key { PageDown, Char(char), } + +/// Converts a slice of `Key` enum values to a UTF-8 encoded `String`. +///Will add newlines for Key::Enter and delete the last char for Key::BackSpace +/// +/// # Arguments +/// +/// * `keys` - A slice of `Key` enum values representing user input keys. +pub fn keys_to_utf8(keys: &[Key]) -> String { + let mut chars = Vec::new(); + for key in keys { + match key { + Key::Char(c) => chars.push(c), + Key::Backspace => { + chars.pop(); + } + #[cfg(not(windows))] + Key::Enter => chars.push(&'\n'), + #[cfg(windows)] + Key::Enter => { + chars.push(&'\r'); + chars.push(&'\n') + } + key => { + // This may be expanded by keeping track of a cursor which is controlled by the ArrowKeys and changes del and backspace + unimplemented!("Cannot convert key: {:?} to utf8", key) + } + } + } + chars.into_iter().collect::() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_keys_to_utf8() { + let keys = vec![ + Key::Char('H'), + Key::Char('e'), + Key::Char('l'), + Key::Char('l'), + Key::Char('o'), + Key::Enter, + Key::Char('W'), + Key::Char('o'), + Key::Char('r'), + Key::Char('l'), + Key::Char('d'), + Key::Backspace, + ]; + let result = keys_to_utf8(&keys); + assert_eq!(result, "Hello\nWorl"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 1b18afc0..8f79779a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -75,6 +75,7 @@ //! * `ansi-parsing`: adds support for parsing ansi codes (this adds support //! for stripping and taking ansi escape codes into account for length //! calculations). +#![feature(io_error_more)] pub use crate::kb::Key; pub use crate::term::{ diff --git a/src/term.rs b/src/term.rs index 0a402585..216547af 100644 --- a/src/term.rs +++ b/src/term.rs @@ -1,5 +1,5 @@ use std::fmt::{Debug, Display}; -use std::io::{self, Read, Write}; +use std::io::{self, Error, Read, Write}; use std::sync::{Arc, Mutex}; #[cfg(unix)] @@ -7,24 +7,23 @@ use std::os::unix::io::{AsRawFd, RawFd}; #[cfg(windows)] use std::os::windows::io::{AsRawHandle, RawHandle}; +use crate::kb::keys_to_utf8; use crate::{kb::Key, utils::Style}; #[cfg(unix)] trait TermWrite: Write + Debug + AsRawFd + Send {} -#[cfg(unix)] impl TermWrite for T {} #[cfg(unix)] -trait TermRead: Read + Debug + AsRawFd + Send {} -#[cfg(unix)] -impl TermRead for T {} +trait TermRead: Iterator + Debug + AsRawFd + Send {} +impl + Debug + AsRawFd + Send> TermRead for T {} #[cfg(unix)] #[derive(Debug, Clone)] pub struct ReadWritePair { #[allow(unused)] - read: Arc>, - write: Arc>, + reader: Arc>, + writer: Arc>, style: Style, } @@ -36,6 +35,112 @@ pub enum TermTarget { #[cfg(unix)] ReadWritePair(ReadWritePair), } +impl TermTarget { + /// Fills the buffer with bytes read from the input source + /// If backspace was pressed the last character is automatically removed from the buffer + /// If enter is pressed, is appended for windows and for other os + /// + /// # Panics + /// If the buffer is not at least 4 bytes large + /// If the input source is a custom ReadWritePair and any keys other than Char, Enter or BackSpace are encountered + /// + /// # Returns Error + /// If there was an error reading from the input source + /// + /// # Returns ok + /// How many bytes have been read + fn read(&self, mut buf: &mut [u8]) -> io::Result { + assert!(buf.len() >= 4, "The buffer must be at least 4 bytes large because a single character may be 4 bytes large"); + if let TermTarget::ReadWritePair(_) = self { + let mut keys = Vec::new(); + while keys_to_utf8(&keys).as_bytes().len() <= buf.len() - 4 { + keys.push(self.read_single_key()?); + } + let utf8 = keys_to_utf8(&keys); + let bytes = utf8.as_bytes(); + buf.write(bytes) + } else { + io::stdin().read(buf) + } + } + + /// Reads chars until Key::Enter was found + /// This may cause an infinite loop if the line is not terminated + /// + /// # Panics + /// If the input source is a custom ReadWritePair and any keys other than Char, Enter or BackSpace are encountered + /// + /// # Returns error + /// If there was an error reading from the input source + /// + /// # Returns ok + /// The length of the modified buffer string + fn read_line(&self, buf: &mut String) -> io::Result { + if let TermTarget::ReadWritePair(_) = self { + let mut keys = Vec::new(); + loop { + let key = self.read_single_key()?; + if key == Key::Enter { + break; + } + keys.push(key); + } + buf.clear(); + buf.push_str(&keys_to_utf8(&keys)); + Ok(buf.len()) + } else { + io::stdin().read_line(buf) + } + } + + /// Read a line from the input source without showing the inserted characters + /// # Panics + /// If the input source is a custom ReadWritePair and any keys other than Char, Enter or BackSpace are encountered + /// + /// # Returns error + /// If there was an error reading from the input source + /// + /// # Returns ok + /// The read line + fn read_secure(&self) -> io::Result { + if let TermTarget::ReadWritePair(_) = self { + let mut s = String::new(); + self.read_line(&mut s)?; + Ok(s) + } else { + read_secure() + } + } + + /// Reads a single key from the input source + /// + /// # Returns Error (only ReadWritePair) + /// If the end of the input source has been reached + /// If the input source lock was already acquired + /// + /// # Returns Ok + /// The key that was read + fn read_single_key(&self) -> io::Result { + if let TermTarget::ReadWritePair(pair) = self { + if let Ok(mut reader) = pair.reader.lock() { + match reader.next() { + Some(key) => Ok(key), + None => io::Result::Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "Reached end of input source", + )), + } + } else { + io::Result::Err(io::Error::new( + io::ErrorKind::ResourceBusy, + "The reader lock was already acquired by this thread", + )) + } + } else { + read_single_key() + } + } +} #[derive(Debug)] pub struct TermInner { @@ -181,7 +286,7 @@ impl Term { #[cfg(unix)] pub fn read_write_pair(read: R, write: W) -> Term where - R: Read + Debug + AsRawFd + Send + 'static, + R: Iterator + Debug + AsRawFd + Send + 'static, W: Write + Debug + AsRawFd + Send + 'static, { Self::read_write_pair_with_style(read, write, Style::new().for_stderr()) @@ -191,13 +296,13 @@ impl Term { #[cfg(unix)] pub fn read_write_pair_with_style(read: R, write: W, style: Style) -> Term where - R: Read + Debug + AsRawFd + Send + 'static, + R: Iterator + Debug + AsRawFd + Send + 'static, W: Write + Debug + AsRawFd + Send + 'static, { Term::with_inner(TermInner { target: TermTarget::ReadWritePair(ReadWritePair { - read: Arc::new(Mutex::new(read)), - write: Arc::new(Mutex::new(write)), + reader: Arc::new(Mutex::new(read)), + writer: Arc::new(Mutex::new(write)), style, }), buffer: None, @@ -210,7 +315,6 @@ impl Term { match self.inner.target { TermTarget::Stderr => Style::new().for_stderr(), TermTarget::Stdout => Style::new().for_stdout(), - #[cfg(unix)] TermTarget::ReadWritePair(ReadWritePair { ref style, .. }) => style.clone(), } } @@ -275,7 +379,7 @@ impl Term { if !self.is_tty { Ok(Key::Unknown) } else { - read_single_key() + self.inner.target.read_single_key() } } @@ -288,7 +392,7 @@ impl Term { return Ok("".into()); } let mut rv = String::new(); - io::stdin().read_line(&mut rv)?; + self.inner.target.read_line(&mut rv)?; let len = rv.trim_end_matches(&['\r', '\n'][..]).len(); rv.truncate(len); Ok(rv) @@ -340,7 +444,7 @@ impl Term { if !self.is_tty { return Ok("".into()); } - match read_secure() { + match self.inner.target.read_secure() { Ok(rv) => { self.write_line("")?; Ok(rv) @@ -525,8 +629,8 @@ impl Term { io::stderr().flush()?; } #[cfg(unix)] - TermTarget::ReadWritePair(ReadWritePair { ref write, .. }) => { - let mut write = write.lock().unwrap(); + TermTarget::ReadWritePair(ReadWritePair { ref writer, .. }) => { + let mut write = writer.lock().unwrap(); write.write_all(bytes)?; write.flush()?; } @@ -561,8 +665,8 @@ impl AsRawFd for Term { match self.inner.target { TermTarget::Stdout => libc::STDOUT_FILENO, TermTarget::Stderr => libc::STDERR_FILENO, - TermTarget::ReadWritePair(ReadWritePair { ref write, .. }) => { - write.lock().unwrap().as_raw_fd() + TermTarget::ReadWritePair(ReadWritePair { ref writer, .. }) => { + writer.lock().unwrap().as_raw_fd() } } } @@ -614,13 +718,13 @@ impl<'a> Write for &'a Term { impl Read for Term { fn read(&mut self, buf: &mut [u8]) -> io::Result { - io::stdin().read(buf) + self.inner.target.read(buf) } } impl<'a> Read for &'a Term { fn read(&mut self, buf: &mut [u8]) -> io::Result { - io::stdin().read(buf) + self.inner.target.read(buf) } }