diff --git a/client.go b/client.go index 349b5e9..2d42373 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ type Client struct { port int list bool send bool + burnNum int receiveNum int authToken string } @@ -50,7 +51,8 @@ func (c *Client) Connect() error { enc := gob.NewEncoder(&secureConnection) dec := gob.NewDecoder(&secureConnection) - if c.list { + switch { + case c.list: log.Debugf("requesting file list") err := c.connectToServer(secure.OperationTypeList, enc, dec) @@ -75,8 +77,7 @@ func (c *Client) Connect() error { fmt.Printf("total: %d files\n", numFiles) conn.Close() log.Debugf("done listing") - - } else if c.receiveNum >= 0 { + case c.receiveNum >= 0: log.Debugf("receiving file %d", c.receiveNum) err := c.connectToServer(secure.OperationTypeReceive, enc, dec) @@ -98,7 +99,8 @@ func (c *Client) Connect() error { panic(err) } - if res.Status == secure.ReceiveDataStartResponseOK { + switch res.Status { + case secure.ReceiveDataStartResponseOK: for { res := secure.PacketReceiveDataNext{} err = dec.Decode(&res) @@ -111,14 +113,14 @@ func (c *Client) Connect() error { } } log.Debugf("finished") - } else if res.Status == secure.ReceiveDataStartResponseNotFound { + case secure.ReceiveDataStartResponseNotFound: log.Error("ngf not found") - } else { + default: panic("unknown status") } conn.Close() - } else if c.send { + case c.send: // send mode err := c.connectToServer(secure.OperationTypeSend, enc, dec) @@ -169,16 +171,45 @@ func (c *Client) Connect() error { log.Debugf("Sent %s in %d chunks", humanize.Bytes(uint64(nBytes)), nChunks) conn.Close() + case c.burnNum >= 0: + log.Debugf("burning file %d", c.burnNum) + + err := c.connectToServer(secure.OperationTypeBurn, enc, dec) + if err != nil { + return fmt.Errorf("could not connect and auth: %v", err) + } + + req := secure.PacketBurnRequest{ + Id: uint32(c.burnNum), + } + err = enc.Encode(req) + if err != nil { + panic(err) + } + // expect a response telling us if we can go ahead + res := secure.PacketBurnResponse{} + err = dec.Decode(&res) + if err != nil { + panic(err) + } + + switch res.Status { + case secure.BurnResponseOK: + log.Debugf("finished") + case secure.BurnResponseNotFound: + log.Error("ngf not found") + default: + panic("unknown status") + } - } else { + conn.Close() + default: panic("no client mode set") } return nil - } func (c *Client) connectToServer(op secure.OperationTypeEnum, enc *gob.Encoder, dec *gob.Decoder) error { - // list mode startPacket := secure.PacketStartRequest{ OperationType: op, diff --git a/main.go b/main.go index e8d86fc..977b1a2 100644 --- a/main.go +++ b/main.go @@ -13,41 +13,39 @@ import ( "github.com/spf13/viper" ) -var CurrentVersion = "v0.0.4" +var CurrentVersion = "v0.0.5" -const ProtocolVersion = "1.1" +const ProtocolVersion = "1.2" -type PasteValue struct { - PasteRequired bool - PasteNumber uint +type ListValue struct { + Required bool + Number uint } -func (v *PasteValue) String() string { - if v.PasteRequired { - return fmt.Sprintf("YES: %d", v.PasteNumber) +func (v *ListValue) String() string { + if v.Required { + return fmt.Sprintf("YES: %d", v.Number) } return "0" } -func (v *PasteValue) Set(s string) error { - v.PasteRequired = true +func (v *ListValue) Set(s string) error { + v.Required = true num, err := strconv.ParseUint(s, 10, 64) if err != nil { return err } - v.PasteNumber = uint(num) + v.Number = uint(num) return nil } -func (v *PasteValue) Type() string { +func (v *ListValue) Type() string { return "int" - } func getAuthTokenFromTerminal() string { - tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0755) - + tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0o755) if err != nil { log.Printf("cannot open /dev/tty to read authtoken: %v", err) return "" @@ -59,7 +57,9 @@ func getAuthTokenFromTerminal() string { log.Printf("cannot set /dev/tty to raw mode: %v", err) return "" } - defer term.Restore(fd, oldState) + defer func() { + _ = term.Restore(fd, oldState) + }() t := term.NewTerminal(tty, "") pass, err := t.ReadPassword("Enter auth token: ") @@ -76,12 +76,16 @@ func main() { // client mode flags isList := flag.BoolP("list", "l", false, "Returns a list of current items on the server") - isSend := flag.BoolP("copy", "c", false, "sending stdin to netgiv server (copy)") + isSend := flag.BoolP("copy", "c", false, "send stdin to netgiv server (copy)") - pasteFlag := PasteValue{} - flag.VarP(&pasteFlag, "paste", "p", "receive from netgiv server to stdout (paste), with optional number (see --list)") + pasteFlag := ListValue{} + flag.VarP(&pasteFlag, "paste", "p", "receive from netgiv server to stdout (paste), with optional id (see --list)") flag.Lookup("paste").NoOptDefVal = "0" + burnFlag := ListValue{} + flag.VarP(&burnFlag, "burn", "b", "burn (remove/delete) the item on the netgiv server, with optional id (see --list)") + flag.Lookup("burn").NoOptDefVal = "0" + debug := flag.Bool("debug", false, "turn on debug logging") flag.String("address", "", "IP address/hostname of the netgiv server") @@ -93,12 +97,18 @@ func main() { flag.Parse() - receiveNum := int(pasteFlag.PasteNumber) - if !pasteFlag.PasteRequired { + receiveNum := int(pasteFlag.Number) + if !pasteFlag.Required { receiveNum = -1 } - viper.AddConfigPath("$HOME/.netgiv/") // call multiple times to add many search paths + burnNum := int(burnFlag.Number) + if !burnFlag.Required { + burnNum = -1 + } + + viper.AddConfigPath("$HOME/.netgiv/") + viper.AddConfigPath("$HOME/.config/netgiv/") // calling multiple times adds to search paths viper.SetConfigType("yaml") viper.SetDefault("port", 4512) @@ -112,11 +122,10 @@ func main() { } } - flag.Parse() - viper.BindPFlags(flag.CommandLine) + _ = viper.BindPFlags(flag.CommandLine) viper.SetEnvPrefix("NETGIV") - viper.BindEnv("authtoken") + _ = viper.BindEnv("authtoken") // pull the various things into local variables port := viper.GetInt("port") // retrieve value from viper @@ -170,11 +179,12 @@ environment variable. This may be preferable in some environments. log.Fatal("an address must be provided on the command line, or configuration") } + log.Debugf("protocol version: %s", ProtocolVersion) if *isServer { s := Server{port: port, authToken: authtoken} s.Run() } else { - if !*isList && !*isSend && receiveNum == -1 { + if !*isList && !*isSend && burnNum == -1 && receiveNum == -1 { // try to work out the intent based on whether or not stdin/stdout // are ttys stdinTTY := isatty.IsTerminal(os.Stdin.Fd()) @@ -193,7 +203,7 @@ environment variable. This may be preferable in some environments. } - c := Client{port: port, address: address, list: *isList, send: *isSend, receiveNum: receiveNum, authToken: authtoken} + c := Client{port: port, address: address, list: *isList, send: *isSend, burnNum: burnNum, receiveNum: receiveNum, authToken: authtoken} err := c.Connect() if err != nil { fmt.Print(err) diff --git a/secure/secure.go b/secure/secure.go index c7fbb15..a589910 100644 --- a/secure/secure.go +++ b/secure/secure.go @@ -129,7 +129,7 @@ func (s *SecureConnection) Write(p []byte) (int, error) { var nonce [24]byte // Create a new nonce for each message sent - rand.Read(nonce[:]) + _, _ = rand.Read(nonce[:]) encryptedMessage := box.SealAfterPrecomputation(nil, p, &nonce, s.SharedKey) sm := SecureMessage{Msg: encryptedMessage, Nonce: nonce} @@ -145,10 +145,10 @@ func Handshake(conn *net.TCPConn) *[32]byte { publicKey, privateKey, _ := box.GenerateKey(rand.Reader) - conn.Write(publicKey[:]) + _, _ = conn.Write(publicKey[:]) peerKeyArray := make([]byte, 32) - conn.Read(peerKeyArray) + _, _ = conn.Read(peerKeyArray) copy(peerKey[:], peerKeyArray) box.Precompute(&sharedKey, &peerKey, privateKey) @@ -162,10 +162,11 @@ const ( OperationTypeSend OperationTypeEnum = iota OperationTypeList OperationTypeReceive + OperationTypeBurn ) // PacketStartRequest is sent from the client to the server at the beginning -// to authenticate and annonce the requested particular operation +// to authenticate and announce the requested particular operation type PacketStartRequest struct { OperationType OperationTypeEnum ClientName string @@ -233,3 +234,20 @@ type PacketListData struct { Timestamp time.Time Kind string } + +type PacketBurnRequest struct { + Id uint32 +} + +type PacketBurnResponse struct { + Status PacketBurnResponseEnum +} + +type PacketBurnResponseEnum byte + +const ( + // File has been deleted + BurnResponseOK PacketBurnResponseEnum = iota + // No such file by index + BurnResponseNotFound +) diff --git a/secure/secure_test.go b/secure/secure_test.go index 0d8c620..825ea1c 100644 --- a/secure/secure_test.go +++ b/secure/secure_test.go @@ -13,7 +13,8 @@ func TestBasic(t *testing.T) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -23,7 +24,8 @@ func TestBasic(t *testing.T) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -45,7 +47,7 @@ func TestBasic(t *testing.T) { for _, b := range testData { go func() { - srcSecConn.Write(b) + _, _ = srcSecConn.Write(b) }() time.Sleep(time.Second) @@ -70,7 +72,8 @@ func TestPacketBasic(t *testing.T) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -80,7 +83,8 @@ func TestPacketBasic(t *testing.T) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -97,10 +101,12 @@ func TestPacketBasic(t *testing.T) { ProtocolVersion: "1.1", AuthToken: "abc123", } - go func() { enc.Encode(packet) }() + go func() { + _ = enc.Encode(packet) + }() recvPacket := PacketStartRequest{} - dec.Decode(&recvPacket) + _ = dec.Decode(&recvPacket) if recvPacket.OperationType != OperationTypeReceive { t.Error("bad OperationType") @@ -117,7 +123,6 @@ func TestPacketBasic(t *testing.T) { if recvPacket.ProtocolVersion != "1.1" { t.Error("bad ProtocolVersion") } - } func BenchmarkPPS(b *testing.B) { @@ -125,7 +130,8 @@ func BenchmarkPPS(b *testing.B) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -135,7 +141,8 @@ func BenchmarkPPS(b *testing.B) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -153,12 +160,11 @@ func BenchmarkPPS(b *testing.B) { for i := 0; i < b.N; i++ { go func() { - srcSecConn.Write(testdata) + _, _ = srcSecConn.Write(testdata) }() out := make([]byte, 16384) n, err := dstSecConn.Read(out) - if err != nil { b.Errorf("got error %v", err) } @@ -169,5 +175,4 @@ func BenchmarkPPS(b *testing.B) { b.Errorf("%v not equal to %v", out[:n], testdata) } } - } diff --git a/server.go b/server.go index 055c9b5..ef238bf 100644 --- a/server.go +++ b/server.go @@ -34,8 +34,10 @@ type NGF struct { Timestamp time.Time } -var ngfs []NGF -var globalId uint32 +var ( + ngfs []NGF + globalId uint32 +) func (s *Server) Run() { log.Infof("starting server on :%d", s.port) @@ -58,7 +60,7 @@ func (s *Server) Run() { log.Printf("removing file: %s", ngf.StorePath) err := os.Remove(ngf.StorePath) if err != nil { - log.Printf("could not remove %s: %v", ngf.StorePath, err) + log.Errorf("could not remove %s: %v", ngf.StorePath, err) } } os.Exit(0) @@ -68,7 +70,6 @@ func (s *Server) Run() { for { conn, err := listener.AcceptTCP() - if err != nil { fmt.Print(err) } @@ -80,7 +81,7 @@ func (s *Server) Run() { func (s *Server) handleConnection(conn *net.TCPConn) { defer conn.Close() - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) sharedKey := secure.Handshake(conn) secureConnection := secure.SecureConnection{Conn: conn, SharedKey: sharedKey, Buffer: &bytes.Buffer{}} @@ -111,24 +112,25 @@ func (s *Server) handleConnection(conn *net.TCPConn) { if start.ProtocolVersion != ProtocolVersion { log.Errorf("bad protocol version") startResponse.Response = secure.PacketStartResponseEnumWrongProtocol - enc.Encode(startResponse) + _ = enc.Encode(startResponse) return } if start.AuthToken != s.authToken { log.Errorf("bad authtoken") startResponse.Response = secure.PacketStartResponseEnumBadAuthToken - enc.Encode(startResponse) + _ = enc.Encode(startResponse) return } // otherwise we are good to continue, tell the client that startResponse.Response = secure.PacketStartResponseEnumOK - enc.Encode(startResponse) + _ = enc.Encode(startResponse) - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) - if start.OperationType == secure.OperationTypeSend { + switch start.OperationType { + case secure.OperationTypeSend: log.Debugf("file incoming") sendStart := secure.PacketSendDataStart{} @@ -160,7 +162,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { sendData := secure.PacketSendDataNext{} determinedKind := false for { - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) err = dec.Decode(&sendData) if err == io.EOF { break @@ -190,7 +192,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { determinedKind = true } - file.Write(sendData.Data) + _, _ = file.Write(sendData.Data) } info, err := file.Stat() if err != nil { @@ -204,13 +206,13 @@ func (s *Server) handleConnection(conn *net.TCPConn) { log.Printf("done receiving file: %v", ngf) return - } else if start.OperationType == secure.OperationTypeReceive { + case secure.OperationTypeReceive: log.Printf("client requesting file receive") // wait for them to send the request req := secure.PacketReceiveDataStartRequest{} err := dec.Decode(&req) if err != nil { - log.Printf("error expecting PacketReceiveDataStartRequest: %v", err) + log.Errorf("error expecting PacketReceiveDataStartRequest: %v", err) return } @@ -242,7 +244,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { } err = enc.Encode(res) if err != nil { - log.Printf("could not send NotFound: %v", err) + log.Errorf("could not send NotFound: %v", err) } return @@ -297,8 +299,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { } log.Printf("sending done") return - - } else if start.OperationType == secure.OperationTypeList { + case secure.OperationTypeList: log.Debugf("client requesting file list") for _, ngf := range ngfs { @@ -308,15 +309,83 @@ func (s *Server) handleConnection(conn *net.TCPConn) { p.Id = ngf.Id p.Filename = ngf.Filename p.Timestamp = ngf.Timestamp - enc.Encode(p) + _ = enc.Encode(p) } log.Debugf("done sending list, closing connection") return + case secure.OperationTypeBurn: + log.Debugf("client requesting burn") + // wait for them to send the request + req := secure.PacketBurnRequest{} + err := dec.Decode(&req) + if err != nil { + log.Errorf("error expecting PacketBurnRequest: %v", err) + return + } + + log.Debugf("The client asked for %v to be burned", req) + + // do we have this ngf by id? + var requestedNGF NGF + + if len(ngfs) > 0 { + if req.Id == 0 { + // they want the most recent one + requestedNGF = ngfs[len(ngfs)-1] + } else { + for _, ngf := range ngfs { + if ngf.Id == req.Id { + requestedNGF = ngf + } + } + } + } + + log.Debugf("going to burn %v", requestedNGF) - } else { + if requestedNGF.Id == 0 { + // not found + log.Errorf("user requested burning %d, not found", req.Id) + res := secure.PacketBurnResponse{ + Status: secure.BurnResponseNotFound, + } + err = enc.Encode(res) + if err != nil { + log.Errorf("could not send NotFound: %v", err) + } + + return + } + + // remove the file + err = os.Remove(requestedNGF.StorePath) + if err != nil { + log.Errorf("could not remove file %s: %v", requestedNGF.StorePath, err) + return + } + + // remove the ngf from the list + for i, ngf := range ngfs { + if ngf.Id == requestedNGF.Id { + ngfs = append(ngfs[:i], ngfs[i+1:]...) + break + } + } + + res := secure.PacketBurnResponse{ + Status: secure.BurnResponseOK, + } + err = enc.Encode(res) + if err != nil { + log.Errorf("error sending PacketBurnResponse: %v", err) + return + } + + log.Printf("burn complete") + return + default: log.Errorf("bad operation") return } - }