From a50836fb74355898daee35fb5e9602453b5e3726 Mon Sep 17 00:00:00 2001 From: Jason Fowler Date: Sun, 29 Sep 2024 10:23:38 +0800 Subject: [PATCH] add burn operation to the client, server, and protocol this provides a method for removing files from the server remotely without needing to restart the server example use case for this is if your server is publicly accessible but you don't expose SSH publicly and you're transferring data between two cloud servers and don't want the data to be stored on the server any longer than it has to be --- client.go | 51 ++++++++++++++++---- main.go | 64 ++++++++++++++----------- secure/secure.go | 26 ++++++++-- secure/secure_test.go | 31 +++++++----- server.go | 109 ++++++++++++++++++++++++++++++++++-------- 5 files changed, 207 insertions(+), 74 deletions(-) diff --git a/client.go b/client.go index 349b5e9..2d42373 100644 --- a/client.go +++ b/client.go @@ -22,6 +22,7 @@ type Client struct { port int list bool send bool + burnNum int receiveNum int authToken string } @@ -50,7 +51,8 @@ func (c *Client) Connect() error { enc := gob.NewEncoder(&secureConnection) dec := gob.NewDecoder(&secureConnection) - if c.list { + switch { + case c.list: log.Debugf("requesting file list") err := c.connectToServer(secure.OperationTypeList, enc, dec) @@ -75,8 +77,7 @@ func (c *Client) Connect() error { fmt.Printf("total: %d files\n", numFiles) conn.Close() log.Debugf("done listing") - - } else if c.receiveNum >= 0 { + case c.receiveNum >= 0: log.Debugf("receiving file %d", c.receiveNum) err := c.connectToServer(secure.OperationTypeReceive, enc, dec) @@ -98,7 +99,8 @@ func (c *Client) Connect() error { panic(err) } - if res.Status == secure.ReceiveDataStartResponseOK { + switch res.Status { + case secure.ReceiveDataStartResponseOK: for { res := secure.PacketReceiveDataNext{} err = dec.Decode(&res) @@ -111,14 +113,14 @@ func (c *Client) Connect() error { } } log.Debugf("finished") - } else if res.Status == secure.ReceiveDataStartResponseNotFound { + case secure.ReceiveDataStartResponseNotFound: log.Error("ngf not found") - } else { + default: panic("unknown status") } conn.Close() - } else if c.send { + case c.send: // send mode err := c.connectToServer(secure.OperationTypeSend, enc, dec) @@ -169,16 +171,45 @@ func (c *Client) Connect() error { log.Debugf("Sent %s in %d chunks", humanize.Bytes(uint64(nBytes)), nChunks) conn.Close() + case c.burnNum >= 0: + log.Debugf("burning file %d", c.burnNum) + + err := c.connectToServer(secure.OperationTypeBurn, enc, dec) + if err != nil { + return fmt.Errorf("could not connect and auth: %v", err) + } + + req := secure.PacketBurnRequest{ + Id: uint32(c.burnNum), + } + err = enc.Encode(req) + if err != nil { + panic(err) + } + // expect a response telling us if we can go ahead + res := secure.PacketBurnResponse{} + err = dec.Decode(&res) + if err != nil { + panic(err) + } + + switch res.Status { + case secure.BurnResponseOK: + log.Debugf("finished") + case secure.BurnResponseNotFound: + log.Error("ngf not found") + default: + panic("unknown status") + } - } else { + conn.Close() + default: panic("no client mode set") } return nil - } func (c *Client) connectToServer(op secure.OperationTypeEnum, enc *gob.Encoder, dec *gob.Decoder) error { - // list mode startPacket := secure.PacketStartRequest{ OperationType: op, diff --git a/main.go b/main.go index e8d86fc..977b1a2 100644 --- a/main.go +++ b/main.go @@ -13,41 +13,39 @@ import ( "github.com/spf13/viper" ) -var CurrentVersion = "v0.0.4" +var CurrentVersion = "v0.0.5" -const ProtocolVersion = "1.1" +const ProtocolVersion = "1.2" -type PasteValue struct { - PasteRequired bool - PasteNumber uint +type ListValue struct { + Required bool + Number uint } -func (v *PasteValue) String() string { - if v.PasteRequired { - return fmt.Sprintf("YES: %d", v.PasteNumber) +func (v *ListValue) String() string { + if v.Required { + return fmt.Sprintf("YES: %d", v.Number) } return "0" } -func (v *PasteValue) Set(s string) error { - v.PasteRequired = true +func (v *ListValue) Set(s string) error { + v.Required = true num, err := strconv.ParseUint(s, 10, 64) if err != nil { return err } - v.PasteNumber = uint(num) + v.Number = uint(num) return nil } -func (v *PasteValue) Type() string { +func (v *ListValue) Type() string { return "int" - } func getAuthTokenFromTerminal() string { - tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0755) - + tty, err := os.OpenFile("/dev/tty", os.O_RDWR, 0o755) if err != nil { log.Printf("cannot open /dev/tty to read authtoken: %v", err) return "" @@ -59,7 +57,9 @@ func getAuthTokenFromTerminal() string { log.Printf("cannot set /dev/tty to raw mode: %v", err) return "" } - defer term.Restore(fd, oldState) + defer func() { + _ = term.Restore(fd, oldState) + }() t := term.NewTerminal(tty, "") pass, err := t.ReadPassword("Enter auth token: ") @@ -76,12 +76,16 @@ func main() { // client mode flags isList := flag.BoolP("list", "l", false, "Returns a list of current items on the server") - isSend := flag.BoolP("copy", "c", false, "sending stdin to netgiv server (copy)") + isSend := flag.BoolP("copy", "c", false, "send stdin to netgiv server (copy)") - pasteFlag := PasteValue{} - flag.VarP(&pasteFlag, "paste", "p", "receive from netgiv server to stdout (paste), with optional number (see --list)") + pasteFlag := ListValue{} + flag.VarP(&pasteFlag, "paste", "p", "receive from netgiv server to stdout (paste), with optional id (see --list)") flag.Lookup("paste").NoOptDefVal = "0" + burnFlag := ListValue{} + flag.VarP(&burnFlag, "burn", "b", "burn (remove/delete) the item on the netgiv server, with optional id (see --list)") + flag.Lookup("burn").NoOptDefVal = "0" + debug := flag.Bool("debug", false, "turn on debug logging") flag.String("address", "", "IP address/hostname of the netgiv server") @@ -93,12 +97,18 @@ func main() { flag.Parse() - receiveNum := int(pasteFlag.PasteNumber) - if !pasteFlag.PasteRequired { + receiveNum := int(pasteFlag.Number) + if !pasteFlag.Required { receiveNum = -1 } - viper.AddConfigPath("$HOME/.netgiv/") // call multiple times to add many search paths + burnNum := int(burnFlag.Number) + if !burnFlag.Required { + burnNum = -1 + } + + viper.AddConfigPath("$HOME/.netgiv/") + viper.AddConfigPath("$HOME/.config/netgiv/") // calling multiple times adds to search paths viper.SetConfigType("yaml") viper.SetDefault("port", 4512) @@ -112,11 +122,10 @@ func main() { } } - flag.Parse() - viper.BindPFlags(flag.CommandLine) + _ = viper.BindPFlags(flag.CommandLine) viper.SetEnvPrefix("NETGIV") - viper.BindEnv("authtoken") + _ = viper.BindEnv("authtoken") // pull the various things into local variables port := viper.GetInt("port") // retrieve value from viper @@ -170,11 +179,12 @@ environment variable. This may be preferable in some environments. log.Fatal("an address must be provided on the command line, or configuration") } + log.Debugf("protocol version: %s", ProtocolVersion) if *isServer { s := Server{port: port, authToken: authtoken} s.Run() } else { - if !*isList && !*isSend && receiveNum == -1 { + if !*isList && !*isSend && burnNum == -1 && receiveNum == -1 { // try to work out the intent based on whether or not stdin/stdout // are ttys stdinTTY := isatty.IsTerminal(os.Stdin.Fd()) @@ -193,7 +203,7 @@ environment variable. This may be preferable in some environments. } - c := Client{port: port, address: address, list: *isList, send: *isSend, receiveNum: receiveNum, authToken: authtoken} + c := Client{port: port, address: address, list: *isList, send: *isSend, burnNum: burnNum, receiveNum: receiveNum, authToken: authtoken} err := c.Connect() if err != nil { fmt.Print(err) diff --git a/secure/secure.go b/secure/secure.go index c7fbb15..a589910 100644 --- a/secure/secure.go +++ b/secure/secure.go @@ -129,7 +129,7 @@ func (s *SecureConnection) Write(p []byte) (int, error) { var nonce [24]byte // Create a new nonce for each message sent - rand.Read(nonce[:]) + _, _ = rand.Read(nonce[:]) encryptedMessage := box.SealAfterPrecomputation(nil, p, &nonce, s.SharedKey) sm := SecureMessage{Msg: encryptedMessage, Nonce: nonce} @@ -145,10 +145,10 @@ func Handshake(conn *net.TCPConn) *[32]byte { publicKey, privateKey, _ := box.GenerateKey(rand.Reader) - conn.Write(publicKey[:]) + _, _ = conn.Write(publicKey[:]) peerKeyArray := make([]byte, 32) - conn.Read(peerKeyArray) + _, _ = conn.Read(peerKeyArray) copy(peerKey[:], peerKeyArray) box.Precompute(&sharedKey, &peerKey, privateKey) @@ -162,10 +162,11 @@ const ( OperationTypeSend OperationTypeEnum = iota OperationTypeList OperationTypeReceive + OperationTypeBurn ) // PacketStartRequest is sent from the client to the server at the beginning -// to authenticate and annonce the requested particular operation +// to authenticate and announce the requested particular operation type PacketStartRequest struct { OperationType OperationTypeEnum ClientName string @@ -233,3 +234,20 @@ type PacketListData struct { Timestamp time.Time Kind string } + +type PacketBurnRequest struct { + Id uint32 +} + +type PacketBurnResponse struct { + Status PacketBurnResponseEnum +} + +type PacketBurnResponseEnum byte + +const ( + // File has been deleted + BurnResponseOK PacketBurnResponseEnum = iota + // No such file by index + BurnResponseNotFound +) diff --git a/secure/secure_test.go b/secure/secure_test.go index 0d8c620..825ea1c 100644 --- a/secure/secure_test.go +++ b/secure/secure_test.go @@ -13,7 +13,8 @@ func TestBasic(t *testing.T) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -23,7 +24,8 @@ func TestBasic(t *testing.T) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -45,7 +47,7 @@ func TestBasic(t *testing.T) { for _, b := range testData { go func() { - srcSecConn.Write(b) + _, _ = srcSecConn.Write(b) }() time.Sleep(time.Second) @@ -70,7 +72,8 @@ func TestPacketBasic(t *testing.T) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -80,7 +83,8 @@ func TestPacketBasic(t *testing.T) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -97,10 +101,12 @@ func TestPacketBasic(t *testing.T) { ProtocolVersion: "1.1", AuthToken: "abc123", } - go func() { enc.Encode(packet) }() + go func() { + _ = enc.Encode(packet) + }() recvPacket := PacketStartRequest{} - dec.Decode(&recvPacket) + _ = dec.Decode(&recvPacket) if recvPacket.OperationType != OperationTypeReceive { t.Error("bad OperationType") @@ -117,7 +123,6 @@ func TestPacketBasic(t *testing.T) { if recvPacket.ProtocolVersion != "1.1" { t.Error("bad ProtocolVersion") } - } func BenchmarkPPS(b *testing.B) { @@ -125,7 +130,8 @@ func BenchmarkPPS(b *testing.B) { srcSecConn := SecureConnection{ Conn: srcConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -135,7 +141,8 @@ func BenchmarkPPS(b *testing.B) { dstSecConn := SecureConnection{ Conn: dstConn, - SharedKey: &[32]byte{0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, + SharedKey: &[32]byte{ + 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x0, 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, @@ -153,12 +160,11 @@ func BenchmarkPPS(b *testing.B) { for i := 0; i < b.N; i++ { go func() { - srcSecConn.Write(testdata) + _, _ = srcSecConn.Write(testdata) }() out := make([]byte, 16384) n, err := dstSecConn.Read(out) - if err != nil { b.Errorf("got error %v", err) } @@ -169,5 +175,4 @@ func BenchmarkPPS(b *testing.B) { b.Errorf("%v not equal to %v", out[:n], testdata) } } - } diff --git a/server.go b/server.go index 055c9b5..ef238bf 100644 --- a/server.go +++ b/server.go @@ -34,8 +34,10 @@ type NGF struct { Timestamp time.Time } -var ngfs []NGF -var globalId uint32 +var ( + ngfs []NGF + globalId uint32 +) func (s *Server) Run() { log.Infof("starting server on :%d", s.port) @@ -58,7 +60,7 @@ func (s *Server) Run() { log.Printf("removing file: %s", ngf.StorePath) err := os.Remove(ngf.StorePath) if err != nil { - log.Printf("could not remove %s: %v", ngf.StorePath, err) + log.Errorf("could not remove %s: %v", ngf.StorePath, err) } } os.Exit(0) @@ -68,7 +70,6 @@ func (s *Server) Run() { for { conn, err := listener.AcceptTCP() - if err != nil { fmt.Print(err) } @@ -80,7 +81,7 @@ func (s *Server) Run() { func (s *Server) handleConnection(conn *net.TCPConn) { defer conn.Close() - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) sharedKey := secure.Handshake(conn) secureConnection := secure.SecureConnection{Conn: conn, SharedKey: sharedKey, Buffer: &bytes.Buffer{}} @@ -111,24 +112,25 @@ func (s *Server) handleConnection(conn *net.TCPConn) { if start.ProtocolVersion != ProtocolVersion { log.Errorf("bad protocol version") startResponse.Response = secure.PacketStartResponseEnumWrongProtocol - enc.Encode(startResponse) + _ = enc.Encode(startResponse) return } if start.AuthToken != s.authToken { log.Errorf("bad authtoken") startResponse.Response = secure.PacketStartResponseEnumBadAuthToken - enc.Encode(startResponse) + _ = enc.Encode(startResponse) return } // otherwise we are good to continue, tell the client that startResponse.Response = secure.PacketStartResponseEnumOK - enc.Encode(startResponse) + _ = enc.Encode(startResponse) - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) - if start.OperationType == secure.OperationTypeSend { + switch start.OperationType { + case secure.OperationTypeSend: log.Debugf("file incoming") sendStart := secure.PacketSendDataStart{} @@ -160,7 +162,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { sendData := secure.PacketSendDataNext{} determinedKind := false for { - conn.SetDeadline(time.Now().Add(time.Second * 5)) + _ = conn.SetDeadline(time.Now().Add(time.Second * 5)) err = dec.Decode(&sendData) if err == io.EOF { break @@ -190,7 +192,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { determinedKind = true } - file.Write(sendData.Data) + _, _ = file.Write(sendData.Data) } info, err := file.Stat() if err != nil { @@ -204,13 +206,13 @@ func (s *Server) handleConnection(conn *net.TCPConn) { log.Printf("done receiving file: %v", ngf) return - } else if start.OperationType == secure.OperationTypeReceive { + case secure.OperationTypeReceive: log.Printf("client requesting file receive") // wait for them to send the request req := secure.PacketReceiveDataStartRequest{} err := dec.Decode(&req) if err != nil { - log.Printf("error expecting PacketReceiveDataStartRequest: %v", err) + log.Errorf("error expecting PacketReceiveDataStartRequest: %v", err) return } @@ -242,7 +244,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { } err = enc.Encode(res) if err != nil { - log.Printf("could not send NotFound: %v", err) + log.Errorf("could not send NotFound: %v", err) } return @@ -297,8 +299,7 @@ func (s *Server) handleConnection(conn *net.TCPConn) { } log.Printf("sending done") return - - } else if start.OperationType == secure.OperationTypeList { + case secure.OperationTypeList: log.Debugf("client requesting file list") for _, ngf := range ngfs { @@ -308,15 +309,83 @@ func (s *Server) handleConnection(conn *net.TCPConn) { p.Id = ngf.Id p.Filename = ngf.Filename p.Timestamp = ngf.Timestamp - enc.Encode(p) + _ = enc.Encode(p) } log.Debugf("done sending list, closing connection") return + case secure.OperationTypeBurn: + log.Debugf("client requesting burn") + // wait for them to send the request + req := secure.PacketBurnRequest{} + err := dec.Decode(&req) + if err != nil { + log.Errorf("error expecting PacketBurnRequest: %v", err) + return + } + + log.Debugf("The client asked for %v to be burned", req) + + // do we have this ngf by id? + var requestedNGF NGF + + if len(ngfs) > 0 { + if req.Id == 0 { + // they want the most recent one + requestedNGF = ngfs[len(ngfs)-1] + } else { + for _, ngf := range ngfs { + if ngf.Id == req.Id { + requestedNGF = ngf + } + } + } + } + + log.Debugf("going to burn %v", requestedNGF) - } else { + if requestedNGF.Id == 0 { + // not found + log.Errorf("user requested burning %d, not found", req.Id) + res := secure.PacketBurnResponse{ + Status: secure.BurnResponseNotFound, + } + err = enc.Encode(res) + if err != nil { + log.Errorf("could not send NotFound: %v", err) + } + + return + } + + // remove the file + err = os.Remove(requestedNGF.StorePath) + if err != nil { + log.Errorf("could not remove file %s: %v", requestedNGF.StorePath, err) + return + } + + // remove the ngf from the list + for i, ngf := range ngfs { + if ngf.Id == requestedNGF.Id { + ngfs = append(ngfs[:i], ngfs[i+1:]...) + break + } + } + + res := secure.PacketBurnResponse{ + Status: secure.BurnResponseOK, + } + err = enc.Encode(res) + if err != nil { + log.Errorf("error sending PacketBurnResponse: %v", err) + return + } + + log.Printf("burn complete") + return + default: log.Errorf("bad operation") return } - }