diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..3eaaada --- /dev/null +++ b/AUTHORS @@ -0,0 +1,7 @@ +# This is the official list of zephyr-go authors for copyright purposes. + +# Names should be added to this file as: +# Name or Organization +# The email address is not required for organizations. + +David Benjamin diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..7a4a3ea --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/clock.go b/clock.go new file mode 100644 index 0000000..b26281e --- /dev/null +++ b/clock.go @@ -0,0 +1,38 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "time" +) + +// Clock is a mockable interface for components which involve time. +type Clock interface { + Now() time.Time + After(d time.Duration) <-chan time.Time +} + +type systemClock struct{} + +func (systemClock) Now() time.Time { + return time.Now() +} + +func (systemClock) After(d time.Duration) <-chan time.Time { + return time.After(d) +} + +// SystemClock is the real implementation of the Clock interface. +var SystemClock systemClock diff --git a/cmd/subscriber/subscriber.go b/cmd/subscriber/subscriber.go new file mode 100644 index 0000000..93d1d10 --- /dev/null +++ b/cmd/subscriber/subscriber.go @@ -0,0 +1,67 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "flag" + "log" + "os" + "os/signal" + "syscall" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/zephyr-go" +) + +func main() { + flag.Parse() + if flag.NArg() != 2 { + log.Fatal("Need 2 arguments") + } + subs := []zephyr.Subscription{ + {"", flag.Arg(0), flag.Arg(1)}, + } + + // Open a session. + session, err := zephyr.DialSystemDefault() + if err != nil { + log.Fatal(err) + } + defer session.Close() + go func() { + for r := range session.Messages() { + log.Printf("Received message %v %v", r.AuthStatus, r.Message) + } + }() + + log.Printf("Subscribing to %v", subs) + ctx, err := krb5.NewContext() + if err != nil { + log.Fatal(err) + } + defer ctx.Free() + ack, err := session.SendSubscribeNoDefaults(ctx, subs) + log.Printf(" -> %v %v", ack, err) + defer func() { + log.Printf("Canceling subscriptions") + ack, err := session.SendCancelSubscriptions(ctx) + log.Printf(" -> %v %v", ack, err) + }() + + // Keep listening until a SIGINT or SIGTERM. + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + <-c +} diff --git a/cmd/zwrite/zwrite.go b/cmd/zwrite/zwrite.go new file mode 100644 index 0000000..ab4cfd6 --- /dev/null +++ b/cmd/zwrite/zwrite.go @@ -0,0 +1,252 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "bufio" + "fmt" + "io/ioutil" + "log" + "os" + "strings" + "time" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/zephyr-go" +) + +var class = "message" +var instance = "personal" +var opcode = "" +var signature = "" +var message = "" +var sender = "" +var haveSender = false +var auth = true +var expandTabs = true +var eofOnly = false +var realm = "" +var haveRealm = false +var recipients = []string{} + +func printUsage() { + fmt.Fprintln(os.Stderr, "Usage: zwrite [-a] [-d] [-t] [-l] [-u]") + fmt.Fprintln(os.Stderr, "\t[-c class] [-i instance] [-O opcode] [-s signature] [-S sender]") + fmt.Fprintln(os.Stderr, "\t[user...] [-r realm] [-m message]") +} + +func parseFlagArg(flag, value string) { + switch flag { + case "-s": + signature = value + case "-c": + class = value + case "-i": + instance = value + case "-r": + realm = value + haveRealm = true + case "-S": + sender = value + haveSender = true + case "-O": + opcode = value + default: + panic(flag) + } +} + +func readMessage(eofOnly bool) (string, error) { + if eofOnly { + fmt.Fprintln(os.Stderr, "Type your message now. "+ + "End with the end-of-file character.") + message, err := ioutil.ReadAll(os.Stdin) + if err != nil { + return "", err + } + return string(message), nil + } + + fmt.Fprintln(os.Stderr, "Type your message now. "+ + "End with control-D or a dot on a line by itself.") + scanner := bufio.NewScanner(os.Stdin) + message := "" + for scanner.Scan() { + line := scanner.Text() + if line == "." { + break + } + message = message + line + "\n" + } + if err := scanner.Err(); err != nil { + return "", err + } + return message, nil +} + +func parseFlags() { + haveMessage := false + // TODO(davidben): To really be true to zwrite, check isatty + // and, if not, set eofOnly to true. + var i int +argLoop: + for i = 1; i < len(os.Args); i++ { + switch arg := os.Args[i]; arg { + case "-a": + auth = true + case "-d": + auth = false + case "-t": + expandTabs = false + case "-l": + eofOnly = true + case "-u": + instance = "URGENT" + case "-m": + haveMessage = true + message = strings.Join(os.Args[i+1:], " ") + break argLoop + case "-s", "-c", "-i", "-r", "-S", "-O": + if i+1 >= len(os.Args) { + printUsage() + os.Exit(1) + } + i++ + parseFlagArg(arg, os.Args[i]) + default: + if len(arg) >= 1 && arg[0] == '-' { + printUsage() + os.Exit(1) + } + recipients = append(recipients, arg) + } + } + + // Normalize receipients. + if len(recipients) == 0 { + if class == "message" && + (instance == "personal" || instance == "URGENT") { + fmt.Fprintln(os.Stderr, "No recipients specified.") + printUsage() + os.Exit(1) + } + recipients = []string{""} + } + if haveRealm { + for i := range recipients { + recipients[i] = recipients[i] + "@" + realm + } + } + + if !haveMessage { + // Read message from stdin. + var err error + message, err = readMessage(eofOnly) + if err != nil { + fmt.Fprintf(os.Stderr, "Error reading stdin: %s\n", err) + os.Exit(1) + } + } + + if expandTabs { + newMsg := []byte{} + spaces := [8]byte{ + ' ', ' ', ' ', ' ', + ' ', ' ', ' ', ' ', + } + off := 0 + for _, b := range []byte(message) { + if b == '\t' { + newMsg = append(newMsg, spaces[:8-off]...) + off = 0 + } else { + newMsg = append(newMsg, b) + if b == '\n' { + off = 0 + } else { + off = (off + 1) % 8 + } + } + } + message = string(newMsg) + } +} + +func main() { + parseFlags() + + // Open a session. + session, err := zephyr.DialSystemDefault() + if err != nil { + log.Fatal(err) + } + defer session.Close() + // Make sure the notice sink doesn't get stuck. + // TODO(davidben): This is silly. + go func() { + for _ = range session.Messages() { + } + }() + + // Further normalize receipients. + if !haveSender { + sender = session.Sender() + } + for i := range recipients { + if len(recipients[i]) != 0 && strings.Index(recipients[i], "@") < 0 { + recipients[i] = recipients[i] + "@" + session.Realm() + } + } + + // Get tickets. + ctx, err := krb5.NewContext() + if err != nil { + log.Fatal(err) + } + defer ctx.Free() + for _, recipient := range recipients { + // Construct the message. + uid := session.MakeUID(time.Now()) + msg := &zephyr.Message{ + Header: zephyr.Header{ + Kind: zephyr.ACKED, + UID: uid, + Port: session.Port(), + Class: class, Instance: instance, OpCode: opcode, + Sender: sender, + Recipient: recipient, + DefaultFormat: "http://mit.edu/df/", + SenderAddress: session.LocalAddr().IP, + Charset: zephyr.CharsetUTF8, + OtherFields: nil, + }, + Body: []string{signature, message}, + } + sendTime := time.Now() + var ack *zephyr.Notice + var err error + if auth { + ack, err = session.SendMessage(ctx, msg) + } else { + ack, err = session.SendMessageUnauth(msg) + } + if err != nil { + log.Printf("Send error: %v", err) + } else { + log.Printf("Received ack in %v: %v", + time.Now().Sub(sendTime), ack) + } + } +} diff --git a/connection.go b/connection.go new file mode 100644 index 0000000..2f35c6a --- /dev/null +++ b/connection.go @@ -0,0 +1,389 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + "log" + "net" + "sync" + "time" + + "github.com/zephyr-im/krb5-go" +) + +func localIPForUDPAddr(addr *net.UDPAddr) (net.IP, error) { + bogus, err := net.DialUDP("udp", nil, addr) + if err != nil { + return nil, err + } + defer bogus.Close() + return bogus.LocalAddr().(*net.UDPAddr).IP, nil +} + +func udpAddrsEqual(a, b *net.UDPAddr) bool { + return a.IP.Equal(b.IP) && a.Port == b.Port && a.Zone == b.Zone +} + +// How frequently we query for new servers. +const serverRefreshInterval = 10 * time.Minute + +// A Connection represents a low-level connection to the Zephyr +// servers. It handles server discovery and sending and receiving +// Notices. It does not provide high-level constructs like subscribing +// or message sharding. It also does not automatically send +// CLIENTACKs. +type Connection struct { + // Properties of the connection. + conn net.PacketConn + server ServerConfig + cred *krb5.Credential + clock Clock + localIP net.IP + + // Incoming notices from the connection. + allNotices <-chan NoticeReaderResult + + // Where non-ACK notices get dumped. + notices chan NoticeReaderResult + + // Table of pending ACKs. + ackTable map[UID]chan NoticeReaderResult + ackTableLock sync.Mutex + + // Current server send schedule. + sched []*net.UDPAddr + schedIdx int + schedLock sync.Mutex + + stopRefreshing chan int +} + +// NewConnection creates a new Connection wrapping a given +// net.PacketConn. The ServerConfig argument instructs the connection +// on how to locate the remote servers. The Credential is used to +// authenticate incoming and outgoing packets. The connection takes +// ownership of the PacketConn and will close it when Close is +// called. +func NewConnection( + conn net.PacketConn, + server ServerConfig, + cred *krb5.Credential, + logger *log.Logger, +) (*Connection, error) { + return NewConnectionFull(conn, server, cred, logger, SystemClock) +} + +// NewConnectionFull does the same as NewConnection but takes an +// additional Clock argument for testing. +func NewConnectionFull( + conn net.PacketConn, + server ServerConfig, + cred *krb5.Credential, + logger *log.Logger, + clock Clock, +) (*Connection, error) { + c := new(Connection) + c.conn = conn + c.server = server + c.cred = cred + c.clock = clock + var key *krb5.KeyBlock + if c.cred != nil { + key = c.cred.KeyBlock + } + c.allNotices = ReadNoticesFromServer(conn, key, logger) + c.notices = make(chan NoticeReaderResult) + c.ackTable = make(map[UID]chan NoticeReaderResult) + + c.stopRefreshing = make(chan int, 1) + + if _, err := c.RefreshServer(); err != nil { + return nil, err + } + localIP, err := localIPForUDPAddr(c.sched[0]) + if err != nil { + return nil, err + } + c.localIP = localIP + + go c.readLoop() + // This is kinda screwy. Purely for testing purposes, ensure + // the first query on the clock happens by the time + // NewConnectionFull returns. MockClock is a little messy. + go c.refreshLoop(c.clock.After(serverRefreshInterval)) + return c, nil +} + +// Notices returns the incoming notices from the connection. +func (c *Connection) Notices() <-chan NoticeReaderResult { + return c.notices +} + +// LocalAddr returns the local UDP address for the client when +// communicating with the Zephyr servers. +func (c *Connection) LocalAddr() *net.UDPAddr { + addr := c.conn.LocalAddr().(*net.UDPAddr) + addr.IP = c.localIP + return addr +} + +// Credential returns the credential for this connection. +func (c *Connection) Credential() *krb5.Credential { + return c.cred +} + +// Close closes the underlying connection. +func (c *Connection) Close() error { + c.stopRefreshing <- 0 + return c.conn.Close() +} + +func (c *Connection) readLoop() { + for r := range c.allNotices { + if r.Notice.Kind.IsServerACK() { + c.processServAck(r) + } else { + c.notices <- r + } + } + close(c.notices) +} + +func (c *Connection) refreshLoop(after <-chan time.Time) { + for { + select { + case <-after: + c.RefreshServer() + after = c.clock.After(serverRefreshInterval) + case <-c.stopRefreshing: + return + } + } +} + +func (c *Connection) findPendingSend(uid UID) chan NoticeReaderResult { + c.ackTableLock.Lock() + defer c.ackTableLock.Unlock() + if ps, ok := c.ackTable[uid]; ok { + delete(c.ackTable, uid) + return ps + } + return nil +} + +func (c *Connection) addPendingSend(uid UID) <-chan NoticeReaderResult { + // Buffer one entry; if the ACK and timeout race, the + // sending thread should not lock up. + ackChan := make(chan NoticeReaderResult, 1) + c.ackTableLock.Lock() + defer c.ackTableLock.Unlock() + c.ackTable[uid] = ackChan + return ackChan +} + +func (c *Connection) clearPendingSend(uid UID) { + c.ackTableLock.Lock() + defer c.ackTableLock.Unlock() + delete(c.ackTable, uid) +} + +func (c *Connection) processServAck(r NoticeReaderResult) { + ps := c.findPendingSend(r.Notice.UID) + if ps != nil { + ps <- r + } +} + +func (c *Connection) schedule() ([]*net.UDPAddr, int) { + c.schedLock.Lock() + defer c.schedLock.Unlock() + return c.sched, c.schedIdx +} + +func (c *Connection) setSchedule(sched []*net.UDPAddr, schedIdx int) { + c.schedLock.Lock() + defer c.schedLock.Unlock() + c.sched = sched + c.schedIdx = schedIdx +} + +func (c *Connection) goodServer(good *net.UDPAddr) { + c.schedLock.Lock() + defer c.schedLock.Unlock() + + // Find the good server in the schedule and use it + // preferentially next time. + for i, addr := range c.sched { + if udpAddrsEqual(addr, good) { + c.schedIdx = i + return + } + } +} + +// RefreshServer forces a manual refresh of the server schedule from +// the ServerConfig. This will be called periodically and when +// outgoing messages time out, so there should be little need to call +// this manually. +func (c *Connection) RefreshServer() ([]*net.UDPAddr, error) { + sched, err := c.server.ResolveServer() + if err != nil { + return nil, err + } + if len(sched) == 0 { + panic(sched) + } + c.setSchedule(sched, 0) + return sched, nil +} + +// SendNotice sends an authenticated notice to the servers. If the +// notice expects an acknowledgement, it returns the SERVACK or +// SERVNAK notice from the server on success. +func (c *Connection) SendNotice(ctx *krb5.Context, n *Notice) (*Notice, error) { + pkt, err := n.EncodePacketForServer(ctx, c.cred) + if err != nil { + return nil, err + } + return c.SendPacket(pkt, n.Kind, n.UID) +} + +// SendNoticeUnauth sends an unauthenticated notice to the servers. If +// the notice expects an acknowledgement, it returns the SERVACK or +// SERVNAK notice from the server on success. +func (c *Connection) SendNoticeUnauth(n *Notice) (*Notice, error) { + pkt := n.EncodePacketUnauth() + return c.SendPacket(pkt, n.Kind, n.UID) +} + +// SendNoticeUnackedTo sends an unauthenticated and unacked notice to +// a given destination. This is used to send a CLIENTACK to a received +// notice. +func (c *Connection) SendNoticeUnackedTo(n *Notice, addr net.Addr) error { + pkt := n.EncodePacketUnauth() + return c.SendPacketUnackedTo(pkt, addr) +} + +// ErrPacketTooLong is returned when a notice or packet exceeds the +// maximum Zephyr packet size. +var ErrPacketTooLong = errors.New("packet too long") + +// ErrSendTimeout is returned if a send times out without +// acknowledgement from the server. +var ErrSendTimeout = errors.New("send timeout") + +// SendPacketUnackedTo sends a raw packet to a given destination. +func (c *Connection) SendPacketUnackedTo(pkt []byte, addr net.Addr) error { + if len(pkt) > MaxPacketLength { + return ErrPacketTooLong + } + _, err := c.conn.WriteTo(pkt, addr) + return err +} + +// TODO(davidben): We probably want to be more cleverer later. For +// now, follow a similar strategy to the real zhm, but use a much more +// aggressive rexmit schedule. +// +// Empirically, it seems to take 15-20ms for the zephyrds to ACK a +// notice. +var retrySchedule = []time.Duration{ + 100 * time.Millisecond, + 100 * time.Millisecond, + 250 * time.Millisecond, + 500 * time.Millisecond, + 1 * time.Second, + 2 * time.Second, + 4 * time.Second, +} + +// If we've timed out 4 times, get a new server schedule. +const timeoutsBeforeRefresh = 4 + +// SendPacket sends a raw packet to the Zephyr servers. Based on kind +// and uid, it may wait for an acknowledgement. In that case, the +// SERVACK or SERVNAK notice will be returned. SendPacket rotates +// between the server instances and refreshes server list as necessary. +func (c *Connection) SendPacket(pkt []byte, kind Kind, uid UID) (*Notice, error) { + // TODO(davidben): Should we limit the number of packets + // in-flight as an ad-hoc congestion control? + if len(pkt) > MaxPacketLength { + return nil, ErrPacketTooLong + } + retryIdx := -1 + timeout := c.clock.After(0) + + // Listen for ACKs. + var ackChan <-chan NoticeReaderResult + var shouldClear bool + if kind.ExpectsServerACK() { + ackChan = c.addPendingSend(uid) + shouldClear = true + defer func() { + if shouldClear { + c.clearPendingSend(uid) + } + }() + } + + // Get the remote server schedule. + sched, schedIdx := c.schedule() + if len(sched) == 0 { + panic(sched) + } + + for { + select { + case ack := <-ackChan: + shouldClear = false // Already taken care of. + // Record the good server so next time we + // start at that one. + c.goodServer(ack.Addr.(*net.UDPAddr)) + return ack.Notice, nil + case <-timeout: + retryIdx++ + if retryIdx >= len(retrySchedule) { + return nil, ErrSendTimeout + } + + // Partway through the re-xmit schedule, if we + // still haven't heard back from any server, + // get a fresh set of remote addresses. + if retryIdx == timeoutsBeforeRefresh { + var err error + sched, err = c.RefreshServer() + if err != nil { + return nil, err + } + schedIdx = 0 + } + + addr := sched[schedIdx] + if err := c.SendPacketUnackedTo(pkt, addr); err != nil { + // TODO(davidben): Keep going on + // temporary errors? + return nil, err + } + if !kind.ExpectsServerACK() { + return nil, nil + } + // Schedule the next timeout and move on to + // the next server. + timeout = c.clock.After(retrySchedule[retryIdx]) + schedIdx = (schedIdx + 1) % len(sched) + } + } +} diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 0000000..9929990 --- /dev/null +++ b/connection_test.go @@ -0,0 +1,541 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + "net" + "reflect" + "sync" + "testing" + "time" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/krb5-go/krb5test" + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +var clientAddr = &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 1111} +var serverAddr1 = &net.UDPAddr{IP: net.IPv4(2, 2, 2, 2), Port: 2222} +var serverAddr2 = &net.UDPAddr{IP: net.IPv4(3, 3, 3, 3), Port: 3333} +var serverConfig = NewStaticServer([]*net.UDPAddr{serverAddr1}) +var serverConfigFull = NewStaticServer([]*net.UDPAddr{serverAddr1, serverAddr2}) + +func mockNetwork1() (net.PacketConn, net.PacketConn) { + n := zephyrtest.NewMockPacketNetwork( + []net.Addr{clientAddr, serverAddr1}) + return n[0], n[1] +} + +func mockNetwork2() (net.PacketConn, net.PacketConn, net.PacketConn) { + n := zephyrtest.NewMockPacketNetwork( + []net.Addr{clientAddr, serverAddr1, serverAddr2}) + return n[0], n[1], n[2] +} + +func mockServer(t *testing.T, + conn net.PacketConn, + expectedNotice *Notice, + expectedAuthStatus AuthStatus, + clock *zephyrtest.MockClock, + numDrop int) { + l, lc := expectNoLogs(t) + defer lc.Close() + + // Set some stuff up. + ctx, keytab := makeServerContextAndKeyTab(t) + defer ctx.Free() + defer keytab.Close() + for r := range ReadRawNotices(conn, l) { + authStatus, _, err := r.RawNotice.CheckAuthFromClient( + ctx, krb5test.Service(), keytab) + if err != nil { + t.Fatalf("CheckAuthFromClient failed: %v", err) + return + } + notice, err := DecodeRawNotice(r.RawNotice) + if err != nil { + t.Fatalf("DecodeRawNotice failed: %v", err) + return + } + + if notice != nil { + if authStatus != expectedAuthStatus { + t.Errorf("Bad authStatus %v; want %v", + authStatus, expectedAuthStatus) + } + expectNoticesEqual(t, notice, expectedNotice) + } + + // Drop the first few packets. + if numDrop > 0 { + clock.Advance(time.Minute) + numDrop-- + continue + } + if numDrop == -1 { + clock.Advance(time.Minute) + continue + } + + // Finally, ACK it. + conn.WriteTo(notice.MakeACK(SERVACK, "SENT").EncodePacketUnauth(), r.Addr) + } +} + +// Tests that a Connection forwards received packets out and doesn't +// send SERVACKs out. +func TestConnectionReceive(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + pkt, err := notice.EncodePacketForClient( + ctx, AuthYes, krb5test.SessionKey()) + if err != nil { + t.Fatal(err) + } + ack := notice.MakeACK(SERVACK, "SENT") + + readChan := make(chan zephyrtest.PacketRead, 2) + readChan <- zephyrtest.PacketRead{Packet: pkt} + readChan <- zephyrtest.PacketRead{Packet: ack.EncodePacketUnauth()} + close(readChan) + mock := zephyrtest.NewMockPacketConn(clientAddr, readChan) + clock := zephyrtest.NewMockClock() + conn, err := NewConnectionFull(mock, serverConfig, krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Check that the credential is the same. + if !reflect.DeepEqual(conn.Credential(), krb5test.Credential()) { + t.Errorf("conn.Credential() = %v; want %v", conn.Credential(), + krb5test.Credential()) + } + + // Read notices out of the connection. + result := <-conn.Notices() + expected := NoticeReaderResult{notice, AuthYes, nil} + if !reflect.DeepEqual(result, expected) { + t.Errorf("<-conn.Notices() = %v; want %v", result, expected) + } + + result, ok := <-conn.Notices() + if ok { + t.Errorf("conn.Notices() did not end: %v", result) + } +} + +func TestConnectionSendNotice(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVACK notices as they come in. + go mockServer(t, server, notice, AuthYes, clock, 0) + + ack, err := conn.SendNotice(ctx, notice) + if err != nil { + t.Fatalf("SendNotice failed: %v", err) + } + if ack.Kind != SERVACK { + t.Errorf("ack.Kind = %v; want %v", ack.Kind, SERVACK) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("ack.RawBody = %v; want %v", string(ack.RawBody), "SENT") + } +} + +func TestConnectionSendNoticeRetransmit(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVACK notices as they come in. + go mockServer(t, server, notice, AuthYes, clock, 2) + + ack, err := conn.SendNotice(ctx, notice) + if err != nil { + t.Fatalf("SendNotice failed: %v", err) + } + if ack.Kind != SERVACK { + t.Errorf("ack.Kind = %v; want %v", ack.Kind, SERVACK) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("ack.RawBody = %v; want %v", string(ack.RawBody), "SENT") + } +} + +func TestConnectionSendNoticeTimeout(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Server never responds to anything. + go mockServer(t, server, notice, AuthYes, clock, -1) + + _, err = conn.SendNotice(ctx, notice) + if err != ErrSendTimeout { + t.Fatalf("SendNoticeUnauth did not fail as expected: %v", err) + } +} + +func TestConnectionSendNoticeUnauth(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + notice := sampleNotice() + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVACK notices as they come in. + go mockServer(t, server, notice, AuthNo, clock, 0) + + ack, err := conn.SendNoticeUnauth(notice) + if err != nil { + t.Fatalf("SendNoticeUnauth failed: %v", err) + } + if ack.Kind != SERVACK { + t.Errorf("ack.Kind = %v; want %v", ack.Kind, SERVACK) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("ack.RawBody = %v; want %v", string(ack.RawBody), "SENT") + } +} + +func TestConnectionSendNoticeUnacked(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + notice.Kind = UNACKED + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to never ACK anything. + go mockServer(t, server, notice, AuthYes, clock, -1) + + ack, err := conn.SendNotice(ctx, notice) + if err != nil { + t.Fatalf("SendNotice failed: %v", err) + } + if ack != nil { + t.Errorf("ack = %v; want nil", ack) + } +} + +func TestConnectionSendNoticeRoundRobin(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server1, server2 := mockNetwork2() + defer server1.Close() + defer server2.Close() + conn, err := NewConnectionFull(client, serverConfigFull, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // server1 never responds to anything. server2 does after the first. + go mockServer(t, server1, notice, AuthYes, clock, -1) + go mockServer(t, server2, notice, AuthYes, clock, 1) + + ack, err := conn.SendNotice(ctx, notice) + if err != nil { + t.Fatalf("SendNotice failed: %v", err) + } + if ack.Kind != SERVACK { + t.Errorf("ack.Kind = %v; want %v", ack.Kind, SERVACK) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("ack.RawBody = %v; want %v", string(ack.RawBody), "SENT") + } + + // We should prefer server2 now. + if sched, idx := conn.schedule(); sched[idx].String() != serverAddr2.String() { + t.Errorf("Client prefers %v; want %v", + sched[idx].String(), serverAddr2.String()) + } +} + +type needRefresh bool + +func (nr *needRefresh) ResolveServer() ([]*net.UDPAddr, error) { + if !*nr { + *nr = true + return []*net.UDPAddr{serverAddr1}, nil + } + return []*net.UDPAddr{serverAddr2}, nil +} +func serverConfigNeedRefresh() ServerConfig { + nr := needRefresh(false) + return &nr +} + +func TestConnectionSendNoticeNeedRefresh(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server1, server2 := mockNetwork2() + defer server1.Close() + defer server2.Close() + conn, err := NewConnectionFull(client, serverConfigNeedRefresh(), + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // server1 never responds to anything. server2 is good. + go mockServer(t, server1, notice, AuthYes, clock, -1) + go mockServer(t, server2, notice, AuthYes, clock, 0) + + ack, err := conn.SendNotice(ctx, notice) + if err != nil { + t.Fatalf("SendNotice failed: %v", err) + } + if ack.Kind != SERVACK { + t.Errorf("ack.Kind = %v; want %v", ack.Kind, SERVACK) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("ack.RawBody = %v; want %v", string(ack.RawBody), "SENT") + } + + // We should prefer server2 now. + if sched, idx := conn.schedule(); sched[idx].String() != serverAddr2.String() { + t.Errorf("Client prefers %v; want %v", + sched[idx].String(), serverAddr2.String()) + } +} + +func TestConnectionSendNoticeWriteFailure(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + notice := sampleNotice() + clock := zephyrtest.NewMockClock() + readChan := make(chan zephyrtest.PacketRead) + close(readChan) + mock := zephyrtest.NewMockPacketConn(clientAddr, readChan) + conn, err := NewConnectionFull(mock, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to fail all writes. + expectedErr := errors.New("failed") + go func() { + for write := range mock.Writes() { + write.Result <- expectedErr + } + }() + + _, err = conn.SendNoticeUnauth(notice) + if err != expectedErr { + t.Fatalf("SendNoticeUnauth did not fail as expected: %v", err) + } +} + +func TestConnectionSendGiantNotices(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + notice := sampleNotice() + notice.RawBody = make([]byte, 99999) + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + _, err = conn.SendNoticeUnauth(notice) + if err != ErrPacketTooLong { + t.Fatalf("SendNoticeUnauth did not fail as expected: %v", err) + } +} + +type signalingConfig chan int + +func (sc signalingConfig) ResolveServer() ([]*net.UDPAddr, error) { + sc <- 0 + return []*net.UDPAddr{serverAddr1}, nil +} + +func TestConnectionPeriodicRefresh(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + config := signalingConfig(make(chan int, 2)) + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, config, krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Initial query. + <-config + + // Periodic refresh. + clock.Advance(serverRefreshInterval) + <-config +} + +var errFailingConfig = errors.New("hesiod fell over") + +type failingConfig struct { + lock sync.Mutex + goodRuns int + config ServerConfig +} + +func newFailingConfig(goodRuns int, config ServerConfig) ServerConfig { + return &failingConfig{goodRuns: goodRuns, config: config} +} + +func (f *failingConfig) ResolveServer() ([]*net.UDPAddr, error) { + f.lock.Lock() + if f.goodRuns == 0 { + f.lock.Unlock() + return nil, errFailingConfig + } + f.goodRuns-- + f.lock.Unlock() + return f.config.ResolveServer() +} + +func TestConnectionFailingConfigMidSend(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + notice := sampleNotice() + + clock := zephyrtest.NewMockClock() + client, server1, server2 := mockNetwork2() + defer server1.Close() + defer server2.Close() + conn, err := NewConnectionFull(client, + newFailingConfig(1, serverConfigNeedRefresh()), + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // server1 never responds to anything. server2 is good. + go mockServer(t, server1, notice, AuthYes, clock, -1) + go mockServer(t, server2, notice, AuthYes, clock, 0) + + ack, err := conn.SendNotice(ctx, notice) + if err != errFailingConfig { + t.Fatalf("SendNotice didn't fail as expected: %v %v", ack, err) + } +} diff --git a/data_for_test.go b/data_for_test.go new file mode 100644 index 0000000..6db6e9d --- /dev/null +++ b/data_for_test.go @@ -0,0 +1,169 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "encoding/base64" + "net" + "strings" + + "github.com/zephyr-im/krb5-go" +) + +func stringsToByteSlices(ss []string) [][]byte { + bs := make([][]byte, len(ss)) + for i := range ss { + bs[i] = []byte(ss[i]) + } + return bs +} + +// Authenticated packets taken from a libzephyr session. (Session was +// since canceled and the ticket associate with the key has expired.) + +func sampleKeyBlock() *krb5.KeyBlock { + data, err := base64.StdEncoding.DecodeString( + "2PgONWKpPuAyFwRRIe1Ex5bR4kLNkI9beX4NGl7mkIA=") + if err != nil { + panic(err) + } + return &krb5.KeyBlock{krb5.ENCTYPE_AES256_CTS_HMAC_SHA1_96, data} +} + +func stringToUID(s string) UID { + var uid UID + if len(s) != 12 { + panic(s) + } + copy(uid[:], []byte(s)) + return uid +} + +func sampleChecksum() []byte { + return []byte("\x39\x04\x48\x83\x3f\xa5\x59\xf2\x0f\x39\x88\x00") +} + +func sampleChecksumZcode() []byte { + return []byte("Z\x39\x04\x48\x83\x3f\xa5\x59\xf2\x0f\x39\x88\xff\xf0") +} + +func samplePacket() []byte { + return []byte("ZEPH0.2\x00" + + "0x00000013\x00" + + "0x00000002\x00" + + "0x1265189F 0x532DE3FC 0x0003AC0E\x00" + + "0xC0CA\x00" + + "0x00000001\x00" + + "0x00000000\x00" + + "\x00" + + "davidben-test-class\x00" + + "test\x00" + + "\x00" + + "davidben@ATHENA.MIT.EDU\x00" + + "\x00" + + "http://zephyr.1ts.org/wiki/df\x00" + + string(sampleChecksumZcode()) + "\x00" + + "0/23\x00" + + "0x1265189F 0x532DE3FC 0x0003AC0E\x00" + + "Z\x12\x65\x18\x9f\x00" + + "0x6A00\x00" + + "David Benjamin\x00" + + "Message\n") + +} + +func sampleFailPacket() []byte { + return makeTestPacket(14, + "Z\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c") +} + +func sampleMalformedChecksumPacket() []byte { + return makeTestPacket(14, "invalid checksum") +} + +func sampleMalformedPortPacket() []byte { + return makeTestPacket(4, "invalid port") +} + +func sampleRawNotice() *RawNotice { + return &RawNotice{ + stringsToByteSlices([]string{ + "ZEPH0.2", + "0x00000013", + "0x00000002", + "0x1265189F 0x532DE3FC 0x0003AC0E", + "0xC0CA", + "0x00000001", + "0x00000000", + "", + "davidben-test-class", + "test", + "", + "davidben@ATHENA.MIT.EDU", + "", + "http://zephyr.1ts.org/wiki/df", + string(sampleChecksumZcode()), + "0/23", + "0x1265189F 0x532DE3FC 0x0003AC0E", + "Z\x12\x65\x18\x9f", + "0x6A00", + }), + []byte("David Benjamin\x00Message\n")} + +} + +func sampleNotice() *Notice { + uid := stringToUID("\x12\x65\x18\x9F\x53\x2D\xE3\xFC\x00\x03\xAC\x0E") + return &Notice{ + Header: Header{ + Kind: ACKED, + UID: uid, + Port: 49354, + Class: "davidben-test-class", + Instance: "test", + OpCode: "", + Sender: "davidben@ATHENA.MIT.EDU", + Recipient: "", + DefaultFormat: "http://zephyr.1ts.org/wiki/df", + SenderAddress: net.ParseIP("18.101.24.159").To4(), + Charset: CharsetUTF8, + OtherFields: [][]byte{}, + }, + Multipart: "0/23", + MultiUID: uid, + RawBody: []byte("David Benjamin\x00Message\n")} +} + +func sampleNoticeWithUID(uid UID) *Notice { + notice := sampleNotice() + notice.UID = uid + return notice +} + +func sampleMessage(uid UID, rawBody []byte) *Message { + return &Message{ + sampleNoticeWithUID(uid).Header, + strings.Split(string(rawBody), "\x00")} +} + +func makeTestPacket(index int, replace string) []byte { + raw := sampleRawNotice() + fields := make([][]byte, len(raw.HeaderFields)+1) + copy(fields, raw.HeaderFields) + fields[len(fields)-1] = raw.Body + fields[index] = []byte(replace) + return bytes.Join(fields, []byte{0}) +} diff --git a/krb_util.go b/krb_util.go new file mode 100644 index 0000000..4861b66 --- /dev/null +++ b/krb_util.go @@ -0,0 +1,54 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + + "github.com/zephyr-im/krb5-go" +) + +const ( + keyUsageClientCksum = 1027 + keyUsageServerCksum = 1029 +) + +// ErrUnknownEncType is returned when attempting to use a key with an +// unknown enctype. +var ErrUnknownEncType = errors.New("unknown enctype") + +// Why is this not exported from MIT Kerberos? +func defaultSumTypeForEncType(enctype krb5.EncType) (krb5.SumType, error) { + switch enctype { + case krb5.ENCTYPE_DES_CBC_CRC: + return krb5.SUMTYPE_RSA_MD5_DES, nil + case krb5.ENCTYPE_DES_CBC_MD4: + return krb5.SUMTYPE_RSA_MD4_DES, nil + case krb5.ENCTYPE_DES_CBC_MD5: + return krb5.SUMTYPE_RSA_MD5_DES, nil + case krb5.ENCTYPE_DES3_CBC_SHA1: + return krb5.SUMTYPE_HMAC_SHA1_DES3, nil + case krb5.ENCTYPE_AES128_CTS_HMAC_SHA1_96: + return krb5.SUMTYPE_HMAC_SHA1_96_AES128, nil + case krb5.ENCTYPE_AES256_CTS_HMAC_SHA1_96: + return krb5.SUMTYPE_HMAC_SHA1_96_AES256, nil + case krb5.ENCTYPE_ARCFOUR_HMAC: + return krb5.SUMTYPE_HMAC_MD5_ARCFOUR, nil + case krb5.ENCTYPE_ARCFOUR_HMAC_EXP: + return krb5.SUMTYPE_HMAC_MD5_ARCFOUR, nil + default: + return krb5.SUMTYPE_DEFAULT, ErrUnknownEncType + } +} diff --git a/krb_util_test.go b/krb_util_test.go new file mode 100644 index 0000000..f32d2f9 --- /dev/null +++ b/krb_util_test.go @@ -0,0 +1,67 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "testing" + + "github.com/zephyr-im/krb5-go" +) + +func TestDefaultSumTypeForEncType(t *testing.T) { + ctx, err := krb5.NewContext() + if err != nil { + t.Fatalf("Could not create context: %v", err) + } + defer ctx.Free() + + enctypes := []krb5.EncType{ + krb5.ENCTYPE_DES_CBC_CRC, + krb5.ENCTYPE_DES_CBC_MD4, + krb5.ENCTYPE_DES_CBC_MD5, + krb5.ENCTYPE_DES3_CBC_SHA1, + krb5.ENCTYPE_AES128_CTS_HMAC_SHA1_96, + krb5.ENCTYPE_AES256_CTS_HMAC_SHA1_96, + krb5.ENCTYPE_ARCFOUR_HMAC, + krb5.ENCTYPE_ARCFOUR_HMAC_EXP, + } + usage := int32(0) + data := []byte("Hello") + for _, enctype := range enctypes { + key, err := ctx.MakeRandomKey(enctype) + if err != nil { + t.Errorf("ctx.MakeRandomKey(%v) failed: %v", enctype, err) + continue + } + + cksum, err := ctx.MakeChecksum(krb5.SUMTYPE_DEFAULT, key, usage, data) + if err != nil { + t.Errorf("ctx.MakeCheckum(%v, %v) failed: %v", key, data, err) + continue + } + + if sumtype, err := defaultSumTypeForEncType(enctype); err != nil { + t.Errorf("defaultSumTypeForEncType(%v) failed: %v", enctype, err) + } else if sumtype != cksum.SumType { + t.Errorf("defaultSumTypeForEncType(%v) = %v; want %v", + enctype, sumtype, cksum.SumType) + } + } + + // Error-handling for some random unknown checksum. + if _, err := defaultSumTypeForEncType(krb5.ENCTYPE_DES_CBC_RAW); err == nil { + t.Errorf("defaultSumTypeForEncType(krb5.ENCTYPE_DES_CBC_RAW) did not fail") + } +} diff --git a/log_util.go b/log_util.go new file mode 100644 index 0000000..a7bf68f --- /dev/null +++ b/log_util.go @@ -0,0 +1,46 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "log" +) + +// This is silly. Why doesn't the log package expose the standard +// logger? + +func logPrint(l *log.Logger, v ...interface{}) { + if l != nil { + l.Print(v...) + } else { + log.Print(v...) + } +} + +func logPrintf(l *log.Logger, format string, v ...interface{}) { + if l != nil { + l.Printf(format, v...) + } else { + log.Printf(format, v...) + } +} + +func logPrintln(l *log.Logger, v ...interface{}) { + if l != nil { + l.Println(v...) + } else { + log.Println(v...) + } +} diff --git a/message.go b/message.go new file mode 100644 index 0000000..426100d --- /dev/null +++ b/message.go @@ -0,0 +1,125 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "strings" + "time" + + "github.com/zephyr-im/krb5-go" +) + +// A Message is a high-level reassembled zepyr message. This is the +// final stage of the messaging pipeline. +type Message struct { + Header + Body []string +} + +func sendMessage(conn *Connection, msg *Message, slop int, + encodeFn func(*Notice) ([]byte, error)) (*Notice, error) { + // Determine the body to send. + rawBody := []byte(strings.Join(msg.Body, "\x00")) + rawBodyLen := len(rawBody) + + // Special-case: if the body is empty, send one packet. + if rawBodyLen == 0 { + notice := &Notice{ + Header: msg.Header, + MultiUID: msg.UID, + Multipart: "0/0", + } + pkt, err := encodeFn(notice) + if err != nil { + return nil, err + } + return conn.SendPacket(pkt, notice.Kind, notice.UID) + } + + // First, compute how much space we have for the body. + notice := &Notice{Header: msg.Header, MultiUID: msg.UID} + var headerLen int + pkt, err := encodeFn(notice) + if err != nil { + return nil, err + } + headerLen = len(pkt) + + var ack *Notice + uid := msg.UID + offset := 0 + for len(rawBody) != 0 { + // Compute multipart field. + multipart := EncodeMultipart(offset, rawBodyLen) + // Put as much of the body in as we can. + remaining := MaxPacketLength - headerLen - len(multipart) - slop + if len(rawBody) < remaining { + remaining = len(rawBody) + } + // The header was too long to include the body. + if remaining <= 0 { + return nil, ErrPacketTooLong + } + + // Prepare the next notice. + notice.UID = uid + notice.Multipart = multipart + notice.RawBody = rawBody[:remaining] + pkt, err := encodeFn(notice) + if err != nil { + return nil, err + } + + // Send the notice. Stop on error or SERVNAK. (The + // notice might not be ACKED, so it's possible for ack + // to be nil.) + ack, err = conn.SendPacket(pkt, notice.Kind, notice.UID) + if err != nil { + return nil, err + } else if ack != nil && ack.Kind != SERVACK { + return ack, nil + } + + // Next packet gets a new uid. + uid = MakeUID(conn.LocalAddr().IP, time.Now()) + rawBody = rawBody[remaining:] + offset += remaining + } + + // Return the last ACK we saw. + return ack, nil +} + +// SendMessage sends an authenticated message across a connection, +// sharding into multiple notices as needed. It returns the ACK from +// the server if the message is ACKED. +func SendMessage(ctx *krb5.Context, conn *Connection, msg *Message) (*Notice, error) { + // Leave some 13 bytes of slop because, if we're unlucky, + // zcode may blow up the input. 13 was chosen because it's + // what libzephyr uses and is above 1024 / 128. + return sendMessage(conn, msg, 13, func(n *Notice) ([]byte, error) { + return n.EncodePacketForServer(ctx, conn.Credential()) + }) +} + +// SendMessageUnauth sends an unauthenticated message across a +// connection, sharding into multiple notices as needed. It returns +// the ACK from the server if the message is ACKED. +func SendMessageUnauth(conn *Connection, msg *Message) (*Notice, error) { + // No slop needed because there isn't a checksum in this notice. + return sendMessage(conn, msg, 0, func(n *Notice) ([]byte, error) { + return n.EncodePacketUnauth(), nil + }) +} diff --git a/message_test.go b/message_test.go new file mode 100644 index 0000000..755b764 --- /dev/null +++ b/message_test.go @@ -0,0 +1,204 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/krb5-go/krb5test" + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +func testSendMessage(t *testing.T, l int, auth AuthStatus) { + hdr := fmt.Sprintf("(%d, %v)", l, auth) + + b := []byte{} + for i := 0; i < l; i++ { + b = append(b, byte(i)) + } + body := strings.Split(string(b), "\x00") + + msg := &Message{sampleNotice().Header, body} + + logger, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Error(err) + return + } + defer ctx.Free() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), logger, clock) + if err != nil { + t.Error(err) + return + } + defer conn.Close() + + // Set up a "server" to SERVACK notices as they come in. + notices := make(chan *Notice, l+1) + go ackAndDumpNotices(t, server, auth, notices) + + // Send the message. + var ack *Notice + if auth == AuthYes { + ack, err = SendMessage(ctx, conn, msg) + } else { + ack, err = SendMessageUnauth(conn, msg) + } + server.Close() + + // Check the ACK and whatnot. + if err != nil { + t.Errorf("%s Error sending message: %v", hdr, err) + return + } + if ack.Kind != SERVACK { + t.Errorf("%s Received %v; want SERVACK", hdr, ack) + } + + r := NewReassembler(l) + for n := range notices { + if !n.MultiUID.Equal(msg.UID) { + t.Errorf("%s n.MultiUID = %v; want %v", + hdr, n.MultiUID, msg.UID) + } + if r.Done() { + t.Errorf("%s r.Done() = true; want false", hdr) + } + if err := r.AddNotice(n, AuthYes); err != nil { + t.Errorf("%s r.AddNotice(n) failed: %v", hdr, err) + } + } + if !r.Done() { + t.Errorf("%s r.Done() = false; want true", hdr) + return + } + m, _ := r.Message() + expectHeadersEqual(t, &m.Header, &msg.Header) + if !reflect.DeepEqual(m.Body, msg.Body) { + t.Errorf("%s m.Body = %v; want %v", hdr, m.Body, msg.Body) + } +} + +func TestSendMessage(t *testing.T) { + // Test 0 and all powers of 2. + ls := []int{ + 0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, + 2048, 4096, 8192, 16384, + } + as := []AuthStatus{AuthYes, AuthNo} + for _, l := range ls { + for _, a := range as { + testSendMessage(t, l, a) + } + } +} + +func TestSendMessageLongHeader(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + msg := &Message{sampleNotice().Header, []string{"moo"}} + msg.Class = "a" + for i := 0; i < 1000; i++ { + msg.Class += "-really" + } + msg.Class += "-long-class" + + _, err = SendMessageUnauth(conn, msg) + if err != ErrPacketTooLong { + t.Errorf("SendMessageUnauth(conn, msg) did not fail as expected: %v", err) + } +} + +func TestSendMessageNack(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVNAK notices as they come in. + go nackNotices(t, server) + + // Send the message. + msg := &Message{sampleNotice().Header, []string{"moo"}} + ack, err := SendMessageUnauth(conn, msg) + + // Check the ACK and whatnot. + if err != nil { + t.Fatalf("Error sending message: %v", err) + } + if ack.Kind != SERVNAK { + t.Errorf("ack.Kind = %v; want SERVNAK", ack.Kind) + } + if string(ack.RawBody) != "LOST" { + t.Errorf("ack.RawBody = %v; want 'LOST'", string(ack.RawBody)) + } +} + +func TestSendMessageSendError(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + clock := zephyrtest.NewMockClock() + readChan := make(chan zephyrtest.PacketRead) + close(readChan) + mock := zephyrtest.NewMockPacketConn(clientAddr, readChan) + conn, err := NewConnectionFull(mock, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to fail all writes. + expectedErr := errors.New("failed") + go func() { + for write := range mock.Writes() { + write.Result <- expectedErr + } + }() + + // Send the message. + msg := &Message{sampleNotice().Header, []string{"moo"}} + if _, err := SendMessageUnauth(conn, msg); err != expectedErr { + t.Errorf("SendMessageUnauth didn't fail as expected: %v", err) + } +} diff --git a/notice.go b/notice.go new file mode 100644 index 0000000..af1c032 --- /dev/null +++ b/notice.go @@ -0,0 +1,441 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "encoding/binary" + "errors" + "net" + "strconv" + "time" + + "github.com/zephyr-im/krb5-go" +) + +// A Kind is the first field a zephyr notice. +type Kind uint32 + +const ( + // UNSAFE notices are acknowledged by neither the server nor + // the zhm (in implementations which use one). + UNSAFE Kind = 0 + // UNACKED notices are not acknowledged by the server. + UNACKED Kind = 1 + // ACKED notices are acknowledged by the server with a + // SERVACK. + ACKED Kind = 2 + // HMACK notices are zhm acknowledgements. + HMACK Kind = 3 + // HMCTL notices are used in communications between the server + // and the zhm. + HMCTL Kind = 4 + // SERVACK notices are acknowledgements from the server that a + // notice was accepted for delivery. + SERVACK Kind = 5 + // SERVNAK notices are acknowledgements from the server that a + // notice was received but not delivered for some reason. + SERVNAK Kind = 6 + // CLIENTACK notices are sent by the client to acknowledgement + // a notice from the server. + CLIENTACK Kind = 7 + // STAT notices are used to request statistics from the zhm. + STAT Kind = 8 +) + +func (k Kind) String() string { + switch k { + case UNSAFE: + return "UNSAFE" + case UNACKED: + return "UNACKED" + case ACKED: + return "ACKED" + case HMACK: + return "HMACK" + case HMCTL: + return "HMCTL" + case SERVACK: + return "SERVACK" + case SERVNAK: + return "SERVNAK" + case CLIENTACK: + return "CLIENTACK" + case STAT: + return "STAT" + default: + return strconv.FormatUint(uint64(k), 10) + } +} + +// IsACK returns whether this is a server or client +// acknowledgement. +func (k Kind) IsACK() bool { + return k.IsServerACK() || k == CLIENTACK +} + +// IsServerACK returns whether this is a SERVACK or SERVNAK. +func (k Kind) IsServerACK() bool { + return k == SERVACK || k == SERVNAK +} + +// ExpectsServerACK returns whether the client expects a server +// acknowledgement in response to this packet. +func (k Kind) ExpectsServerACK() bool { + return k == ACKED +} + +// ExpectsClientACK returns whether the server expects a client +// acknowledgement in response to this packet. +func (k Kind) ExpectsClientACK() bool { + return k != HMACK && k != SERVACK && k != SERVNAK && k != CLIENTACK +} + +// A Charset is the value of the "charset" field of a zephyr +// notice. Note that this field has a number of historical +// quirks. Only CharsetUTF8 or the byte-swapped version of it are +// really meaningful. +type Charset uint16 + +// The set of Charset values defined as well as one for byte-swapped +// UTF-8 for compatibility with a zephyr bug. +const ( + CharsetUnknown Charset = 0x0000 + CharsetISO8859_1 Charset = 0x0004 + CharsetUTF8 Charset = 0x006a + CharsetUTF8Swapped Charset = 0x6a00 +) + +func (cs Charset) String() string { + switch cs { + case CharsetUnknown: + return "Unknown charset" + case CharsetISO8859_1: + return "ISO 8859-1" + case CharsetUTF8: + return "UTF-8" + case CharsetUTF8Swapped: + return "UTF-8 (byte-swapped)" + default: + return strconv.FormatUint(uint64(cs), 10) + } +} + +func byteSwap16(n uint16) uint16 { + return ((n & 0xff) << 8) | (n >> 8) +} + +// ErrBadField is returned if a field is in the wrong format. +var ErrBadField = errors.New("bad field") + +// A UID is the identifier of a zephyr notice. It includes the +// sender's IP address and the send time. +type UID [12]byte + +func decodeUID(inp []byte) (*UID, error) { + if DecodedZAsciiLength(len(inp)) != 12 { + return nil, ErrBadField + } + var out UID + if l, err := DecodeZAsciiInto(out[:], inp); err != nil { + return nil, err + } else if l != 12 { + panic(l) + } + return &out, nil +} + +// MakeUID creates a uid out of an IP address and time. If ip is not +// an IPv4 address, the last four bytes are taken. +func MakeUID(ip net.IP, time time.Time) UID { + var uid UID + copy(uid[:4], ip[len(ip)-4:]) + binary.BigEndian.PutUint32(uid[4:8], uint32(time.Unix())) + binary.BigEndian.PutUint32(uid[8:12], uint32(time.Nanosecond()/1000)) + return uid +} + +// IP returns the IP address portion of the uid. +func (uid UID) IP() net.IP { + return net.IP(uid[0:4]) +} + +// Time returns the time portion of the uid. +func (uid UID) Time() time.Time { + seconds := int64(binary.BigEndian.Uint32(uid[4:8])) + useconds := int64(binary.BigEndian.Uint32(uid[8:12])) + return time.Unix(seconds, useconds*1000) +} + +// Equal returns true iff two uids are equal. +func (uid UID) Equal(x UID) bool { + return bytes.Equal(uid[:], x[:]) +} + +// A Header is the set of common metadata between a Notice and a +// reassembled Message. +type Header struct { + Kind Kind + UID UID + Port uint16 + Class string + Instance string + OpCode string + Sender string + Recipient string + DefaultFormat string + SenderAddress net.IP + Charset Charset + OtherFields [][]byte +} + +// A Notice is a raw message sent and received over zephyr. +type Notice struct { + Header + Multipart string + MultiUID UID + RawBody []byte +} + +// DecodeRawNotice decodes a RawNotice into a Notice. The underlying +// byte slices on the result do not share backing stores with the +// RawNotice. +func DecodeRawNotice(r *RawNotice) (*Notice, error) { + kind, err := DecodeZAscii32(r.HeaderFields[kindIndex]) + if err != nil { + return nil, err + } + + uid, err := decodeUID(r.HeaderFields[uidIndex]) + if err != nil { + return nil, err + } + + port, err := DecodeZAscii16(r.HeaderFields[portIndex]) + if err != nil { + return nil, err + } + + class := string(r.HeaderFields[classIndex]) + instance := string(r.HeaderFields[instanceIndex]) + opcode := string(r.HeaderFields[opcodeIndex]) + sender := string(r.HeaderFields[senderIndex]) + recipient := string(r.HeaderFields[recipientIndex]) + defaultFormat := string(r.HeaderFields[defaultformatIndex]) + + multipart := "" + multiuid := uid + if len(r.HeaderFields) > multiuidIndex { + multipart = string(r.HeaderFields[multipartIndex]) + + multiuid, err = decodeUID(r.HeaderFields[multiuidIndex]) + if err != nil { + return nil, err + } + } + + var senderAddress net.IP + if len(r.HeaderFields) > senderSockaddrIndex { + ipBytes, err := DecodeZcode(r.HeaderFields[senderSockaddrIndex]) + if err != nil { + return nil, err + } + if len(ipBytes) != net.IPv4len && len(ipBytes) != net.IPv6len { + return nil, ErrBadField + } + senderAddress = net.IP(ipBytes) + } else { + senderAddress = uid.IP() + } + + charset := CharsetUnknown + otherfields := [][]byte{} + if len(r.HeaderFields) > charsetIndex { + charsetRaw, err := DecodeZAscii16(r.HeaderFields[charsetIndex]) + if err != nil { + return nil, err + } + // The charset gets byte-swapped. Why? Good question. + charset = Charset(byteSwap16(charsetRaw)) + // This really should be a static assert... + if numKnownFields != charsetIndex+1 { + panic(numKnownFields) + } + otherfields = copyByteSlices(r.HeaderFields[numKnownFields:]) + } + + return &Notice{ + Header: Header{ + Kind: Kind(kind), + UID: *uid, + Port: port, + Class: class, + Instance: instance, + OpCode: opcode, + Sender: sender, + Recipient: recipient, + DefaultFormat: defaultFormat, + SenderAddress: senderAddress, + Charset: charset, + OtherFields: otherfields, + }, + Multipart: multipart, + MultiUID: *multiuid, + RawBody: copyByteSlice(r.Body)}, nil +} + +func (n *Notice) encodeRawNotice( + authstatus AuthStatus, authenticator []byte) *RawNotice { + fields := make([][]byte, numKnownFields, numKnownFields+len(n.OtherFields)) + + fields[versionIndex] = []byte(formatZephyrVersion( + ProtocolVersionMajor, ProtocolVersionMinor)) + fields[numfieldsIndex] = EncodeZAscii32(uint32(len(fields))) + fields[kindIndex] = EncodeZAscii32(uint32(n.Kind)) + fields[uidIndex] = EncodeZAscii(n.UID[:]) + fields[portIndex] = EncodeZAscii16(n.Port) + fields[authstatusIndex] = EncodeZAscii32(uint32(authstatus)) + fields[authlenIndex] = EncodeZAscii32(uint32(len(authenticator))) + if authenticator != nil { + fields[authenticatorIndex] = EncodeZcode(authenticator) + } else { + fields[authenticatorIndex] = []byte{} + } + fields[classIndex] = []byte(n.Class) + fields[instanceIndex] = []byte(n.Instance) + fields[opcodeIndex] = []byte(n.OpCode) + fields[senderIndex] = []byte(n.Sender) + fields[recipientIndex] = []byte(n.Recipient) + fields[defaultformatIndex] = []byte(n.DefaultFormat) + // Checksum gets filled in by the caller. + fields[checksumIndex] = nil + fields[multipartIndex] = []byte(n.Multipart) + fields[multiuidIndex] = EncodeZAscii(n.MultiUID[:]) + if ipv4 := n.SenderAddress.To4(); ipv4 != nil { + fields[senderSockaddrIndex] = EncodeZcode(ipv4) + } else { + fields[senderSockaddrIndex] = EncodeZcode(n.SenderAddress) + } + fields[charsetIndex] = EncodeZAscii16(byteSwap16(uint16(n.Charset))) + + fields = append(fields, n.OtherFields...) + + return &RawNotice{fields, n.RawBody} +} + +func (n *Notice) encodeRawNoticeWithKey( + ctx *krb5.Context, + authstatus AuthStatus, + authent []byte, + key *krb5.KeyBlock, + usage int32) (*RawNotice, error) { + + raw := n.encodeRawNotice(authstatus, authent) + cksum, err := ctx.MakeChecksum(krb5.SUMTYPE_DEFAULT, key, usage, + raw.ChecksumPayload()) + if err != nil { + return nil, err + } + raw.HeaderFields[checksumIndex] = EncodeZcode(cksum.Contents) + return raw, nil +} + +// EncodeRawNoticeForServer encodes a Notice into an authenticated +// RawNotice to send to the server. Returns the RawNotice. The +// authenticator always negotiates the key in the credential. +// +// TODO(davidben): Go back to returning the krb5.KeyBlock output from the API? +// Unless we negotiate a subkey, it's guaranteed to be the key in the +// credential anyway, and shared subscriptions require this to be true. +func (n *Notice) EncodeRawNoticeForServer( + ctx *krb5.Context, cred *krb5.Credential) (*RawNotice, error) { + authcon, err := ctx.NewAuthContext() + if err != nil { + return nil, err + } + defer authcon.Free() + + authent, err := authcon.MakeRequest(cred, 0, nil) + if err != nil { + return nil, err + } + + raw, err := n.encodeRawNoticeWithKey( + ctx, AuthYes, authent, cred.KeyBlock, keyUsageClientCksum) + if err != nil { + return nil, err + } + return raw, nil +} + +// EncodeRawNoticeForClient encodes and authenticates a Notice to send +// to a client using a pre-negotiated session key. Returns a RawNotice. +func (n *Notice) EncodeRawNoticeForClient( + ctx *krb5.Context, authstatus AuthStatus, key *krb5.KeyBlock) (*RawNotice, error) { + return n.encodeRawNoticeWithKey( + ctx, authstatus, nil, key, keyUsageServerCksum) +} + +// EncodeRawNoticeUnauth encodes a Notice into an unauthenticated +// RawNotice. +func (n *Notice) EncodeRawNoticeUnauth() *RawNotice { + // No authenticator or checksum. + raw := n.encodeRawNotice(AuthNo, nil) + raw.HeaderFields[checksumIndex] = []byte{} + return raw +} + +// EncodePacketForServer encodes and authenticates a Notice to be sent +// to a server. Returns a packet. +func (n *Notice) EncodePacketForServer( + ctx *krb5.Context, cred *krb5.Credential) ([]byte, error) { + raw, err := n.EncodeRawNoticeForServer(ctx, cred) + if err != nil { + return nil, err + } + return raw.EncodePacket(), nil +} + +// EncodePacketForClient encodes and authenticates a Notice to be sent +// to a client. Returns a packet. +func (n *Notice) EncodePacketForClient( + ctx *krb5.Context, authstatus AuthStatus, key *krb5.KeyBlock) ([]byte, error) { + raw, err := n.EncodeRawNoticeForClient(ctx, authstatus, key) + if err != nil { + return nil, err + } + return raw.EncodePacket(), nil +} + +// EncodePacketUnauth encodes a Notice into an unauthenticated packet. +func (n *Notice) EncodePacketUnauth() []byte { + return n.EncodeRawNoticeUnauth().EncodePacket() +} + +// MakeACK creates an acknowledgement notice for this message. Like +// libzephyr, this preserves the non-body header fields in the +// notice. This is not necessary for a CLIENTACK but is necessary for +// a SERVACK; clients like BarnOwl don't bother associating SERVACKs +// with outgoing messages and just trust the values in the SERVACK. +// +// TODO(davidben): Unlike libzephyr, this doesn't preserve the checksum and +// authenticator. This should be fine. In fact, we can even authenticate ACKs +// without breaking anything. +func (n *Notice) MakeACK(kind Kind, body string) *Notice { + ack := *n + ack.Kind = kind + ack.RawBody = []byte(body) + return &ack +} diff --git a/notice_test.go b/notice_test.go new file mode 100644 index 0000000..29f574e --- /dev/null +++ b/notice_test.go @@ -0,0 +1,236 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "net" + "reflect" + "testing" + "time" +) + +func expectHeadersEqual(t *testing.T, a *Header, b *Header) { + if a.Kind != b.Kind { + t.Errorf("Kind = %v; want %v", a.Kind, b.Kind) + } + if !a.UID.Equal(b.UID) { + t.Errorf("UID = %v; want %v", a.UID, b.UID) + } + if a.Port != b.Port { + t.Errorf("Port = %v; want %v", a.Port, b.Port) + } + if a.Class != b.Class { + t.Errorf("Class = %v; want %v", a.Class, b.Class) + } + if a.Instance != b.Instance { + t.Errorf("Instance = %v; want %v", a.Instance, b.Instance) + } + if a.OpCode != b.OpCode { + t.Errorf("OpCode = %v; want %v", a.OpCode, b.OpCode) + } + if a.Sender != b.Sender { + t.Errorf("Sender = %v; want %v", a.Sender, b.Sender) + } + if a.Recipient != b.Recipient { + t.Errorf("Recipient = %v; want %v", a.Recipient, b.Recipient) + } + if a.DefaultFormat != b.DefaultFormat { + t.Errorf("DefaultFormat = %v; want %v", a.DefaultFormat, b.DefaultFormat) + } + if !a.SenderAddress.Equal(b.SenderAddress) { + t.Errorf("SenderAddress = %v; want %v", a.SenderAddress, b.SenderAddress) + } + if a.Charset != b.Charset { + t.Errorf("Charset = %v; want %v", a.Charset, b.Charset) + } + if !reflect.DeepEqual(a.OtherFields, b.OtherFields) { + t.Errorf("OtherFields = %v; want %v", a.OtherFields, b.OtherFields) + } +} + +func expectNoticesEqual(t *testing.T, a *Notice, b *Notice) { + expectHeadersEqual(t, &a.Header, &b.Header) + if a.Multipart != b.Multipart { + t.Errorf("Multipart = %v; want %v", a.Multipart, b.Multipart) + } + if !a.MultiUID.Equal(b.MultiUID) { + t.Errorf("MultiUID = %v; want %v", a.MultiUID, b.MultiUID) + } + if !bytes.Equal(a.RawBody, b.RawBody) { + t.Errorf("RawBody = %q; want %q", a.RawBody, b.RawBody) + } +} + +func TestDecodeUID(t *testing.T) { + field := "0x1209400D 0x532A6AC4 0x0005385B" + uid := stringToUID("\x12\x09\x40\x0D\x53\x2A\x6A\xC4\x00\x05\x38\x5B") + if out, err := decodeUID([]byte(field)); err != nil { + t.Errorf("decodeUID(%q) failed: %v", field, err) + } else if string(out[:]) != string(uid[:]) { + t.Errorf("decodeUID(%q) = %v; want %v", field, &uid, out) + } + + if _, err := decodeUID([]byte("0x1209400D")); err != ErrBadField { + t.Errorf("decodeUID(%q) gave bad error: %v", "0x1209400D", err) + } + + if _, err := decodeUID([]byte("?" + field[1:])); err == nil { + t.Errorf("decodeUID(%q) unexpected succeeded", "bogus") + } +} + +func TestUID(t *testing.T) { + uid := stringToUID("\x12\x09\x40\x0D\x53\x2A\x6A\xC4\x00\x05\x38\x5B") + + if ip := uid.IP(); !ip.Equal(net.ParseIP("18.9.64.13")) { + t.Errorf("uid.IP() = %v; want 18.9.64.13", ip) + } + + expectedTime := time.Unix(0x532A6AC4, 0x0005385B*1000) + if time := uid.Time(); !time.Equal(expectedTime) { + t.Errorf("uid.Time() = %v; want %v", time, expectedTime) + } + + if uid2 := MakeUID(uid.IP(), uid.Time()); uid2 != uid { + t.Errorf("MakeUID() = %v; want %v", uid2, uid) + } +} + +func TestDecodeNotice(t *testing.T) { + // Test that the raw notice decodes as expected. + raw := sampleRawNotice() + expected := sampleNotice() + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // Value in sender address takes precedence over UID value. + raw.HeaderFields[17] = []byte("Z\x08\x08\x08\x08") + expected.SenderAddress = net.ParseIP("8.8.8.8").To4() + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + raw = sampleRawNotice() + expected = sampleNotice() + + // No charset. + raw.HeaderFields = raw.HeaderFields[0:18] + expected.Charset = CharsetUnknown + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // No sender address specified. Still get an IP address from the UID. + raw.HeaderFields = raw.HeaderFields[0:17] + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // No multiuid. + raw.HeaderFields = raw.HeaderFields[0:16] + expected.MultiUID = expected.UID + expected.Multipart = "" + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // No multipart. + raw.HeaderFields = raw.HeaderFields[0:15] + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // Extra fields. + raw = sampleRawNotice() + expected = sampleNotice() + raw.HeaderFields = append(raw.HeaderFields, []byte("extra")) + expected.OtherFields = [][]byte{[]byte("extra")} + if notice, err := DecodeRawNotice(raw); err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + // Test some bad packets. + indices := []int{2, 3, 4, 16, 17, 18} + for _, idx := range indices { + raw := sampleRawNotice() + raw.HeaderFields[idx] = []byte("bogus") + if _, err := DecodeRawNotice(raw); err == nil { + t.Errorf("Bad header %d unexpectedly succeeded", idx) + } + } + + // IP parses but has a bad length. + raw = sampleRawNotice() + raw.HeaderFields[17] = []byte("Zabc") + if _, err := DecodeRawNotice(raw); err == nil { + t.Errorf("Short IP unexpectedly succeeded") + } +} + +func zeroByteSlice(b []byte) { + for i := range b { + b[i] = 0 + } +} + +func TestDecodeNoticeAliasing(t *testing.T) { + // Test that DecodeRawNotice's result doesn't alias the input. + raw := sampleRawNotice() + expected := sampleNotice() + notice, err := DecodeRawNotice(raw) + if err != nil { + t.Errorf("DecodeRawNotice(%v) failed: %v", raw, err) + } else { + expectNoticesEqual(t, notice, expected) + } + + for _, h := range raw.HeaderFields { + zeroByteSlice(h) + } + zeroByteSlice(raw.Body) + + expectNoticesEqual(t, notice, expected) +} + +func TestEncRawNoticeUnauth(t *testing.T) { + raw := sampleRawNotice() + // AuthNo = 0 + raw.HeaderFields[5] = []byte("0x00000000") + // No authenticator. + raw.HeaderFields[6] = []byte("0x00000000") + raw.HeaderFields[7] = []byte("") + // No checksum. + raw.HeaderFields[14] = []byte("") + + if enc := sampleNotice().EncodeRawNoticeUnauth(); !reflect.DeepEqual(enc, raw) { + t.Errorf("EncodeRawNoticeUnauth()\n = %v\nwant %v", enc, raw) + } +} diff --git a/raw_notice.go b/raw_notice.go new file mode 100644 index 0000000..7c37603 --- /dev/null +++ b/raw_notice.go @@ -0,0 +1,386 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "errors" + "strconv" + "strings" + + "github.com/zephyr-im/krb5-go" +) + +// AuthStatus is the result of authenticating a notice. +type AuthStatus uint32 + +const ( + // AuthFailed describes a notice which failed authentication + // for some reason. + AuthFailed AuthStatus = 0xffffffff + // AuthYes describes an authenticated notice. + AuthYes AuthStatus = 1 + // AuthNo describes a notice which did not claim to be + // authenticated. + AuthNo AuthStatus = 0 +) + +func (as AuthStatus) String() string { + switch as { + case AuthFailed: + return "AuthFailed" + case AuthYes: + return "AuthYes" + case AuthNo: + return "AuthNo" + default: + return strconv.FormatUint(uint64(as), 10) + } +} + +// ErrBadVersionFormat is returned when a zephyr version field cannot +// be parsed. +var ErrBadVersionFormat = errors.New("bad version format") + +const zephyrVersionHeader = "ZEPH" + +// ProtocolVersionMajor and ProtocolVersionMinor are the version of +// the zephyr protocol implemented by this library. +const ( + ProtocolVersionMajor = 0 + ProtocolVersionMinor = 2 +) + +func parseZephyrVersion(version string) (uint, uint, error) { + if !strings.HasPrefix(version, zephyrVersionHeader) { + return 0, 0, ErrBadVersionFormat + } + split := strings.SplitN(version[len(zephyrVersionHeader):], ".", 2) + if len(split) != 2 { + return 0, 0, ErrBadVersionFormat + } + major, err := strconv.ParseUint(split[0], 10, 0) + if err != nil { + return 0, 0, err + } + minor, err := strconv.ParseUint(split[1], 10, 0) + if err != nil { + return 0, 0, err + } + return uint(major), uint(minor), nil +} + +func formatZephyrVersion(major, minor uint) string { + return zephyrVersionHeader + + strconv.FormatUint(uint64(major), 10) + + "." + + strconv.FormatUint(uint64(minor), 10) +} + +// Processing a notice is done in three stages: +// +// - First, we split it up into raw fields and do only basic validation. Just +// enough to extract the checksum, authenticator, and validate things. This stage +// gives the RawNotice type. +// +// - Second, we decode the various fields and give back a logical notice. That +// gives a Notice. +// +// - Third, we process MultiUID and reassemble sharded notices. Tentatively, this +// will reuse the Notice struct as the only real difference is MultiUID, but +// we'll see. +// +// Serializing a notice goes in reverse. +// +// TODO(davidben): When serializing, who does UID allocation, the +// library or the user? If the library, it's awkward that the type is +// in there. Perhaps we want a couple more types. The reassembly logic +// could return a tuple of uid, message or so. + +// Field layout +const ( + versionIndex = iota // string + numfieldsIndex // zascii32 + kindIndex // zascii32 + uidIndex // 12-byte zascii + portIndex // zascii16 + authstatusIndex // zascii32 + authlenIndex // zascii32 + authenticatorIndex // zcode + classIndex // string + instanceIndex // string + opcodeIndex // string + senderIndex // string + recipientIndex // string + defaultformatIndex // string + checksumIndex // zcode + // Added in 1988; ZEPHYR0.2 + multipartIndex // string + multiuidIndex // 12-byte zascii + // Added in 2009; no version bump + senderSockaddrIndex // zcode + charsetIndex // zascii16 little-endian + // Other fields + numKnownFields +) + +const numRequiredFields = checksumIndex + 1 + +// libzephyr does this awkward thing where it, for purposes of +// authentication checking, it assumes that everything is pointers +// into the z_packet field and does C-style pointer dancing. I'd kinda +// like the intermediate formats to not make assumptions like that, so +// instead we'll use Split/Join being reversible. Zephyr isn't +// terribly well-layered. +// +// Note that this does have one subtlety: we do NOT allow a missing +// body. libzephyr never produces this, but when parsing, it doesn't +// distinguish between +// +// ZEPH0.2 NUL 0x0000003 NUL blahotherfield +// ZEPH0.2 NUL 0x0000003 NUL blahotherfield NUL +// +// (Ignore that this notice doesn't pass our minimum field count +// rules.) In the former, we have three header fields and a missing +// body. In the latter, we have no body. This is relevant because we +// need to be able to reconstruct the concatenation of 0-13 and 15-end +// for checksumming. If this becomes an issue, do something inane like +// treat a nil Body as different. +// +// Delimiter-based serializations. They're the worst. + +// A RawNotice is the first stage of processing a packet. The +// individual header fields are parsed out to extract a checksum and +// authenticator. The other fields are uninterpreted. +type RawNotice struct { + HeaderFields [][]byte + Body []byte +} + +// ErrBadPacketFormat is returned when parsing a malformed packet. +var ErrBadPacketFormat = errors.New("bad packet format") + +// ErrBadPacketFieldCount is returned when parsing a packet with a +// field count that does not match the content. +var ErrBadPacketFieldCount = errors.New("bad field count") + +// ErrBadPacketVersion is returned when parsing a packet with an +// incompatible version field. +var ErrBadPacketVersion = errors.New("incompatible packet version") + +// DecodePacket records a packet into a RawNotice. +func DecodePacket(packet []byte) (*RawNotice, error) { + // First, split out the version and field count. + fs := bytes.SplitN(packet, []byte{0}, 3) + + // We better have at least those fields... + if len(fs) < 3 { + return nil, ErrBadPacketFormat + } + vers, numFieldsRaw, rest := fs[0], fs[1], fs[2] + + // Like libzephyr, the minor version is ignored in parsing. + if major, _, err := parseZephyrVersion(string(vers)); err != nil { + return nil, err + } else if major != ProtocolVersionMajor { + return nil, ErrBadPacketVersion + } + + // Decode the field count. + numFields, err := DecodeZAscii32(numFieldsRaw) + if err != nil { + return nil, err + } + // Pfft. + numFieldsInt := int(numFields) + + // Sanity check; just so we can't be made to allocate giant things or + // something? Meh. Also require at least 15 fields (ZEPH0.1) so there's + // a checksum. + if numFieldsInt > len(packet) || numFieldsInt < numRequiredFields { + return nil, ErrBadPacketFieldCount + } + + fields := make([][]byte, 0, numFields) + fields = append(fields, vers) + fields = append(fields, numFieldsRaw) + + // Parse the remaining fields. Subtract 2 for version and numfields. Add + // 1 for the remainder (the body). + rs := bytes.SplitN(rest, []byte{0}, numFieldsInt-2+1) + if len(rs) != numFieldsInt-2+1 { + return nil, ErrBadPacketFieldCount + } + + // And assemble the RawNotice. + fields = append(fields, rs[0:len(rs)-1]...) + if len(fields) != numFieldsInt { + panic(len(fields)) + } + body := rs[len(rs)-1] + return &RawNotice{fields, body}, nil +} + +// ErrAuthenticatorLengthMismatch is returned when processing the +// authenticator on a RawNotice where the authlen field does not match +// the length of the decoded authenticator. +var ErrAuthenticatorLengthMismatch = errors.New("authenticator length mismatch") + +// DecodeAuthenticator decodes the authenticator field of a RawNotice. +func (r *RawNotice) DecodeAuthenticator() ([]byte, error) { + // There's this length field. It's completely bogus, but may + // as well assert that it's right? Be lenient if this causes + // trouble. + authlen, err := DecodeZAscii32(r.HeaderFields[authlenIndex]) + if err != nil { + return nil, err + } + + // This used to be zephyrascii, but krb4 zephyr stopped + // working ages ago. + auth, err := DecodeZcode(r.HeaderFields[authenticatorIndex]) + if err != nil { + return nil, err + } + if len(auth) != int(authlen) { + return nil, ErrAuthenticatorLengthMismatch + } + return auth, nil +} + +// DecodeChecksum decodes the checksum field of a RawNotice. +func (r *RawNotice) DecodeChecksum() ([]byte, error) { + return DecodeZcode(r.HeaderFields[checksumIndex]) +} + +// DecodeAuthStatus decodes the authstate field of a RawNotice. +func (r *RawNotice) DecodeAuthStatus() (AuthStatus, error) { + authStatus, err := DecodeZAscii32(r.HeaderFields[authstatusIndex]) + if err != nil { + return AuthFailed, err + } + return AuthStatus(authStatus), nil +} + +// ChecksumPayload returns the portion of the packet that is +// checksumed. (The checksum itself is removed and the remainder is +// concatenated.) +func (r *RawNotice) ChecksumPayload() []byte { + // The part of the packet that's checksummed is really quite + // absurd, but here we go. + parts := make([][]byte, 0, len(r.HeaderFields)) + // Fields before the checkum. + parts = append(parts, r.HeaderFields[0:checksumIndex]...) + // Fields after the checksum. + parts = append(parts, r.HeaderFields[checksumIndex+1:]...) + // Body. + parts = append(parts, r.Body) + return bytes.Join(parts, []byte{0}) +} + +func (r *RawNotice) checkAuth( + ctx *krb5.Context, + key *krb5.KeyBlock, + usage int32, +) (AuthStatus, error) { + sumtype, err := defaultSumTypeForEncType(key.EncType) + if err != nil { + return AuthFailed, err + } + + checksumData, err := r.DecodeChecksum() + if err != nil { + return AuthFailed, err + } + + checksum := &krb5.Checksum{sumtype, checksumData} + + result, err := ctx.VerifyChecksum(key, usage, r.ChecksumPayload(), checksum) + if err != nil { + return AuthFailed, err + } else if !result { + return AuthFailed, nil + } else { + return AuthYes, nil + } +} + +// CheckAuthFromServer is called by a client to check a packet from +// the server using a previously negotiated key. +func (r *RawNotice) CheckAuthFromServer( + ctx *krb5.Context, + key *krb5.KeyBlock, +) (AuthStatus, error) { + if authStatus, err := r.DecodeAuthStatus(); err != nil { + return AuthFailed, err + } else if authStatus != AuthYes { + return authStatus, nil + } + + return r.checkAuth(ctx, key, keyUsageServerCksum) +} + +// CheckAuthFromClient is called by a server to check a packet from a +// client using a server KeyTab. If successful, the session key from +// the client's authenticator is returned. +func (r *RawNotice) CheckAuthFromClient( + ctx *krb5.Context, + service *krb5.Principal, + keytab *krb5.KeyTab, +) (AuthStatus, *krb5.KeyBlock, error) { + if authStatus, err := r.DecodeAuthStatus(); err != nil { + return AuthFailed, nil, err + } else if authStatus != AuthYes { + return authStatus, nil, nil + } + + authcon, err := ctx.NewAuthContext() + if err != nil { + return AuthFailed, nil, err + } + defer authcon.Free() + authcon.SetUseTimestamps(false) + + authent, err := r.DecodeAuthenticator() + if err != nil { + return AuthFailed, nil, err + } + + if err := authcon.ReadRequest(authent, service, keytab); err != nil { + return AuthFailed, nil, err + } + key, err := authcon.SessionKey() + if err != nil { + return AuthFailed, nil, err + } + + auth, err := r.checkAuth(ctx, key, keyUsageClientCksum) + if err != nil { + return AuthFailed, nil, err + } + return auth, key, nil +} + +// EncodePacket encodes a RawNotice as a packet. If it is to be +// authenticated, the checksum and (if a client) authenticator fields +// must be already populated. +func (r *RawNotice) EncodePacket() []byte { + // This function does not check that r.HeaderFields[1] is correct or + // anything. The caller is expected to provide a legal RawNotice. + parts := make([][]byte, 0, len(r.HeaderFields)+1) + parts = append(parts, r.HeaderFields...) + parts = append(parts, r.Body) + return bytes.Join(parts, []byte{0}) +} diff --git a/raw_notice_test.go b/raw_notice_test.go new file mode 100644 index 0000000..c2d6b97 --- /dev/null +++ b/raw_notice_test.go @@ -0,0 +1,344 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "reflect" + "strings" + "testing" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/krb5-go/krb5test" +) + +func makeServerContextAndKeyTab(t *testing.T) (*krb5.Context, *krb5.KeyTab) { + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + keytab, err := krb5test.MakeServerKeyTab(ctx) + if err != nil { + t.Fatal(err) + } + return ctx, keytab +} + +func TestParseZephyrVersion(t *testing.T) { + if major, minor, err := parseZephyrVersion("ZEPH0.2"); err != nil { + t.Errorf("parseZephyrVersion(\"ZEPH0.2\") failed: %v", err) + } else if major != 0 || minor != 2 { + t.Errorf("parseZephyrVersion(\"ZEPH0.2\") = %v, %v; want %v, %v", + major, minor, 0, 2) + } + + if vers := formatZephyrVersion(0, 2); vers != "ZEPH0.2" { + t.Errorf("formatZephyrVersion(0, 2) = %v; want \"ZEPH0.2\"", vers) + } + + badVersion := []string{ + "BOGUS4.2", + "ZEPHYR0.2", + "ZEPH0.1.2", + "zeph0.2", + "ZEPHa.b", + "ZEPH", + "ZEPH-1.-2", + } + for _, version := range badVersion { + if major, minor, err := parseZephyrVersion(version); err == nil { + t.Errorf("parseZephyrVersion(%q) unexpectedly succeeded: %v, %v", + version, major, minor) + } + } +} + +// Test some completely bogus packets. +func TestDecodeBadPackets(t *testing.T) { + type bad struct { + pkt []byte + err error + } + badPackets := []bad{ + // Bogus. + bad{[]byte(""), ErrBadPacketFormat}, + // Bad version. + bad{[]byte("Blah\x00Blah\x00Blah"), ErrBadVersionFormat}, + // Major version mismatch. + bad{[]byte("ZEPH1.2\x000x00000002\x00hi"), ErrBadPacketVersion}, + // Bad field count zephyrascii. + bad{[]byte("ZEPH0.2\x00two\x00hi"), nil}, + // Too few fields. + bad{[]byte("ZEPH0.2\x000x00000002\x00hi"), ErrBadPacketFieldCount}, + // Too few fields. + bad{makeTestPacket(1, "0x0000000E"), ErrBadPacketFieldCount}, + // Field count too high. + bad{makeTestPacket(1, "0x00000020"), ErrBadPacketFieldCount}, + // Giant field count. + bad{makeTestPacket(1, "0xFFFFFFFF"), ErrBadPacketFieldCount}, + } + for _, test := range badPackets { + if raw, err := DecodePacket([]byte(test.pkt)); err == nil { + t.Errorf("DecodePacket(%q) = %v; want error", test.pkt, raw) + } else if test.err != nil && err != test.err { + t.Errorf("DecodePacket(%q) failed with %v; want %v", + test.pkt, err, test.err) + } + } +} + +func TestDecodePacket(t *testing.T) { + pkt := samplePacket() + expectedRaw := sampleRawNotice() + if raw, err := DecodePacket(pkt); err != nil { + t.Errorf("DecodePacket(%q) failed: %v", string(pkt), err) + } else if !reflect.DeepEqual(raw, expectedRaw) { + t.Errorf("DecodePacket(%q) = %v\nwant %v", string(pkt), raw, expectedRaw) + } + + // Packets with 15 fields are okay. + pkt = makeTestPacket(1, "0x0000000F") + if _, err := DecodePacket(pkt); err != nil { + t.Errorf("DecodePacket(%q) failed: %v", string(pkt), err) + } + + // Packets with extra fields are okay. + pkt = makeTestPacket(1, "0x00000014") + if _, err := DecodePacket(pkt); err != nil { + t.Errorf("DecodePacket(%q) failed: %v", string(pkt), err) + } +} + +func TestDecodeAuthenticator(t *testing.T) { + raw := sampleRawNotice() + raw.HeaderFields[6] = []byte("0x00000004") + raw.HeaderFields[7] = []byte("Ztest") + + if auth, err := raw.DecodeAuthenticator(); err != nil { + t.Errorf("raw.DecodeAuthenticator() failed: %v", err) + } else if string(auth) != "test" { + t.Errorf("raw.DecodeAuthenticator() = %q; want %q", auth, "test") + } + + // Bogus length. + raw.HeaderFields[6] = []byte("bogus") + if _, err := raw.DecodeAuthenticator(); err == nil { + t.Errorf("raw.DecodeAuthenticator() unexpected succeeded") + } + + // Mismatch length. + raw.HeaderFields[6] = []byte("0xDEADBEEF") + if _, err := raw.DecodeAuthenticator(); err == nil { + t.Errorf("raw.DecodeAuthenticator() unexpected succeeded") + } else if err != ErrAuthenticatorLengthMismatch { + t.Errorf("raw.DecodeAuthenticator() gave the wrong error: %v", err) + } + + // Bad zcode. + raw.HeaderFields[7] = []byte("notvalidzcode") + if _, err := raw.DecodeAuthenticator(); err == nil { + t.Errorf("raw.DecodeAuthenticator() unexpected succeeded") + } +} + +func TestDecodeChecksum(t *testing.T) { + raw := sampleRawNotice() + + if cksum, err := raw.DecodeChecksum(); err != nil { + t.Errorf("raw.DecodeChecksum() failed: %v", err) + } else if !bytes.Equal(cksum, sampleChecksum()) { + t.Errorf("raw.DecodeChecksum() = %q; want %q", cksum, sampleChecksum()) + } + + // Bad zcode. + raw.HeaderFields[14] = []byte("notvalidzcode") + if _, err := raw.DecodeChecksum(); err == nil { + t.Errorf("raw.DecodeChecksum() unexpected succeeded") + } +} + +func TestDecodeAuthStatus(t *testing.T) { + raw := sampleRawNotice() + + if auth, err := raw.DecodeAuthStatus(); err != nil { + t.Errorf("raw.DecodeAuthStatus() failed: %v", err) + } else if auth != AuthYes { + t.Errorf("raw.DecodeAuthStatus() = %q; want %q", auth, AuthYes) + } + + // Bad authstatus. + raw.HeaderFields[5] = []byte("notvalidzascii") + if _, err := raw.DecodeAuthStatus(); err == nil { + t.Errorf("raw.DecodeAuthStatus() unexpected succeeded") + } +} + +func TestEncodePacket(t *testing.T) { + raw := sampleRawNotice() + expected := string(samplePacket()) + if enc := string(raw.EncodePacket()); enc != expected { + t.Errorf("raw.EncodePacket() = %q; want %q", enc, expected) + } +} + +func TestChecksumPayload(t *testing.T) { + raw := sampleRawNotice() + expected := strings.Replace(string(samplePacket()), + string(sampleChecksumZcode())+"\x00", "", 1) + if enc := string(raw.ChecksumPayload()); enc != expected { + t.Errorf("raw.ChecksumPayload() = %q; want %q", enc, expected) + } +} + +func TestCheckAuthFromServer(t *testing.T) { + raw := sampleRawNotice() + key := sampleKeyBlock() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + if auth, err := raw.CheckAuthFromServer(ctx, key); err != nil { + t.Errorf("raw.CheckAuthFromServer() failed: %v", err) + } else if auth != AuthYes { + t.Errorf("raw.CheckAuthFromServer() = %v; want %v", auth, AuthYes) + } + + // Break some random thing. + raw.HeaderFields[0] = []byte("moooo") + if auth, err := raw.CheckAuthFromServer(ctx, key); err != nil { + t.Errorf("raw.CheckAuthFromServer() failed: %v", err) + } else if auth != AuthFailed { + t.Errorf("raw.CheckAuthFromServer() = %v; want %v", auth, AuthFailed) + } + + // Doesn't claim authentication. + raw = sampleRawNotice() + raw.HeaderFields[5] = []byte("0x00000000") + if auth, err := raw.CheckAuthFromServer(ctx, key); err != nil { + t.Errorf("raw.CheckAuthFromServer() failed: %v", err) + } else if auth != AuthNo { + t.Errorf("raw.CheckAuthFromServer() = %v; want %v", auth, AuthNo) + } + + // Bogus authstatus. + raw.HeaderFields[5] = []byte("bogus") + if _, err := raw.CheckAuthFromServer(ctx, key); err == nil { + t.Errorf("raw.CheckAuthFromServer() unexpectedly ran") + } + + // Checksum length error. + raw = sampleRawNotice() + raw.HeaderFields[14] = []byte("Zasdf") + if _, err := raw.CheckAuthFromServer(ctx, key); err == nil { + t.Errorf("raw.CheckAuthFromServer() unexpectedly ran") + } + + // Bogus checksum. + raw.HeaderFields[14] = []byte("bogus") + if _, err := raw.CheckAuthFromServer(ctx, key); err == nil { + t.Errorf("raw.CheckAuthFromServer() unexpectedly ran") + } +} + +func TestClientToServerAuth(t *testing.T) { + // The client makes the notice. + clientCtx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer clientCtx.Free() + notice := sampleNotice() + raw, err := notice.EncodeRawNoticeForServer( + clientCtx, krb5test.Credential()) + if err != nil { + t.Fatalf("notice.EncodeRawNoticeForServer failed: %v", err) + } + + // Server checks the notice. + serverCtx, keytab := makeServerContextAndKeyTab(t) + defer serverCtx.Free() + defer keytab.Close() + auth, key, err := raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if err != nil { + t.Errorf("CheckAuthFromClient failed: %v", err) + } else { + if auth != AuthYes { + t.Errorf("CheckAuthFromClient returned %v; want AuthYes", auth) + } + if !reflect.DeepEqual(key, krb5test.SessionKey()) { + t.Errorf("CheckAuthFromClient return %v; want %v", + key, krb5test.SessionKey()) + } + } + + // Perturb the checksum a little. It should fail now. + cksum, err := raw.DecodeChecksum() + if err != nil { + t.Fatal(err) + } + cksum[0]++ + raw.HeaderFields[14] = EncodeZcode(cksum) + auth, _, err = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if err != nil { + t.Errorf("CheckAuthFromClient failed: %v", err) + } else if auth != AuthFailed { + t.Errorf("CheckAuthFromClient returned %v; want AuthFailed", auth) + } + + // Malformed checksums also fail. + raw.HeaderFields[14] = []byte("bogus") + auth, _, _ = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if auth != AuthFailed { + t.Errorf("CheckAuthFromClient returned %v; want AuthFailed", auth) + } + + // An unauthenticated packet should return unauthenticated. + raw = notice.EncodeRawNoticeUnauth() + auth, _, err = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if err != nil { + t.Errorf("CheckAuthFromClient failed: %v", err) + } else if auth != AuthNo { + t.Errorf("CheckAuthFromClient returned %v; want AuthNo", auth) + } + + // Malformed authstatus fields error. + raw.HeaderFields[5] = []byte("bogus") + auth, _, _ = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if auth != AuthFailed { + t.Errorf("CheckAuthFromClient returned %v; want AuthFailed", auth) + } + + // Bad authenticator fails. + raw, err = notice.EncodeRawNoticeForServer( + clientCtx, krb5test.Credential()) + if err != nil { + t.Fatalf("notice.EncodeRawNoticeForServer failed: %v", err) + } + raw.HeaderFields[6] = []byte("0x00000005") + raw.HeaderFields[7] = []byte("Z12345") + auth, _, _ = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if auth != AuthFailed { + t.Errorf("CheckAuthFromClient returned %v; want AuthFailed", auth) + } + + // Malformed authenticator fails. + raw.HeaderFields[7] = []byte("Bogus") + auth, _, _ = raw.CheckAuthFromClient(serverCtx, krb5test.Service(), keytab) + if auth != AuthFailed { + t.Errorf("CheckAuthFromClient returned %v; want AuthFailed", auth) + } +} diff --git a/reader.go b/reader.go new file mode 100644 index 0000000..700f2e1 --- /dev/null +++ b/reader.go @@ -0,0 +1,144 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "io" + "log" + "net" + "time" + + "github.com/zephyr-im/krb5-go" +) + +// MaxPacketLength is the maximum size of a zephyr notice on the wire. +const MaxPacketLength = 1024 + +// A RawReaderResult is an output of a ReadRawNotices call. It either +// contains a RawNotice and a source address or an error. +type RawReaderResult struct { + RawNotice *RawNotice + Addr net.Addr +} + +// ReadRawNotices decodes packets from a PacketConn into RawNotices +// and returns a stream of them. Non-fatal errors are returned through +// the stream. On a fatal error or EOF, the channel is closed. +func ReadRawNotices(conn net.PacketConn, logger *log.Logger) <-chan RawReaderResult { + sink := make(chan RawReaderResult) + go readRawNoticeLoop(conn, logger, sink) + return sink +} + +func readRawNoticeLoop( + conn net.PacketConn, + logger *log.Logger, + sink chan<- RawReaderResult, +) { + defer close(sink) + var buf [MaxPacketLength]byte + var tempDelay time.Duration + for { + n, addr, err := conn.ReadFrom(buf[:]) + if err != nil { + // Send the error out to the consumer. + if err != io.EOF { + logPrintf(logger, "Error reading packet: %v\n", err) + } + if ne, ok := err.(net.Error); ok && ne.Temporary() { + // Delay logic from net/http.Serve. + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + time.Sleep(tempDelay) + continue + } + break + } + tempDelay = 0 + + // Copy the packet so we can reuse the buffer. + raw, err := DecodePacket(copyByteSlice(buf[0:n])) + if err != nil { + logPrintf(logger, "Error decoding notice: %v\n", err) + continue + } + sink <- RawReaderResult{raw, addr} + } +} + +// A NoticeReaderResult is an output of a ReadNoticesFromServer +// call. It either contains a notice with authentication status and +// source address or an error. +type NoticeReaderResult struct { + Notice *Notice + AuthStatus AuthStatus + Addr net.Addr +} + +// ReadNoticesFromServer decodes and authenticates notices sent from +// the server. Returns a channel containing authenticated notices and +// errors. The channel is closed on fatal errors. If key is nil, all +// notices appear as AuthFailed. +func ReadNoticesFromServer( + conn net.PacketConn, + key *krb5.KeyBlock, + logger *log.Logger, +) <-chan NoticeReaderResult { + // TODO(davidben): Should this channel be buffered a little? + sink := make(chan NoticeReaderResult) + go readNoticeLoop(ReadRawNotices(conn, logger), key, logger, sink) + return sink +} + +func readNoticeLoop( + rawReader <-chan RawReaderResult, + key *krb5.KeyBlock, + logger *log.Logger, + sink chan<- NoticeReaderResult, +) { + defer close(sink) + ctx, err := krb5.NewContext() + if err != nil { + logPrintf(logger, "Error creating krb5 context: %v", err) + return + } + defer ctx.Free() + for r := range rawReader { + notice, err := DecodeRawNotice(r.RawNotice) + if err != nil { + logPrintf(logger, "Error parsing notice: %v", err) + continue + } + + authStatus := AuthFailed + if notice.Kind.IsACK() { + // Don't bother; ACKs' auth bits are always lies. + authStatus = AuthNo + } else if key != nil { + authStatus, err = r.RawNotice.CheckAuthFromServer(ctx, key) + if err != nil { + logPrintf(logger, "Error authenticating notice: %v", err) + authStatus = AuthFailed + } + } + sink <- NoticeReaderResult{notice, authStatus, r.Addr} + } +} diff --git a/reader_test.go b/reader_test.go new file mode 100644 index 0000000..89152e6 --- /dev/null +++ b/reader_test.go @@ -0,0 +1,199 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bufio" + "errors" + "io" + "log" + "net" + "reflect" + "testing" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +func newTestLogger() (*log.Logger, io.Closer, <-chan string) { + pr, pw := io.Pipe() + l := log.New(pw, "", 0) + c := make(chan string) + go func() { + s := bufio.NewScanner(pr) + for s.Scan() { + c <- s.Text() + } + // Meh. + if err := s.Err(); err != nil { + c <- "Error scanning: " + err.Error() + } + close(c) + }() + return l, pw, c +} + +func expectNoLogs(t *testing.T) (*log.Logger, io.Closer) { + l, closer, c := newTestLogger() + go func() { + for line := range c { + t.Error(line) + } + }() + return l, closer +} + +func TestReadNoticesFromServer(t *testing.T) { + addr1 := &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 1111} + addr2 := &net.UDPAddr{IP: net.IPv4(2, 2, 2, 2), Port: 2222} + clientAddr := &net.UDPAddr{IP: net.IPv4(3, 3, 3, 3), Port: 3333} + fatalErr := errors.New("polarity insufficiently reversed") + + type resultOrErr struct { + r *NoticeReaderResult + line string + } + result := func(n *Notice, as AuthStatus, addr net.Addr) resultOrErr { + return resultOrErr{r: &NoticeReaderResult{n, as, addr}} + } + err := func(line string) resultOrErr { + return resultOrErr{line: line} + } + + type test struct { + keyblock *krb5.KeyBlock + reads []zephyrtest.PacketRead + expected []resultOrErr + } + tests := []test{ + // Basic case + { + sampleKeyBlock(), + []zephyrtest.PacketRead{ + {samplePacket(), addr1, nil}, + }, []resultOrErr{ + result(sampleNotice(), AuthYes, addr1), + }, + }, + // Various non-fatal errors. + { + sampleKeyBlock(), + []zephyrtest.PacketRead{ + {nil, nil, zephyrtest.TemporaryError}, + {samplePacket(), addr1, nil}, + {nil, nil, zephyrtest.TemporaryError}, + {nil, nil, zephyrtest.TemporaryError}, + {samplePacket(), addr2, nil}, + {sampleFailPacket(), addr1, nil}, + {sampleMalformedChecksumPacket(), addr2, nil}, + {sampleMalformedPortPacket(), addr1, nil}, + {[]byte("bogus"), addr2, nil}, + {samplePacket(), addr1, nil}, + }, + []resultOrErr{ + err("Error reading packet: Temporary error"), + result(sampleNotice(), AuthYes, addr1), + err("Error reading packet: Temporary error"), + err("Error reading packet: Temporary error"), + // samplePacket + result(sampleNotice(), AuthYes, addr2), + // sampleFailPacket + result(sampleNotice(), AuthFailed, addr1), + // sampleMalformedChecksumPacket + err("Error authenticating notice: invalid zcode"), + result(sampleNotice(), AuthFailed, addr2), + // sampleMalformedPortPacket + err("Error parsing notice: bad length for " + + "uint16 zephyrascii"), + // bogus + err("Error decoding notice: bad packet format"), + // samplePacket + result(sampleNotice(), AuthYes, addr1), + }, + }, + // Stop after fatal error. + { + sampleKeyBlock(), + []zephyrtest.PacketRead{ + {nil, nil, fatalErr}, + {samplePacket(), addr1, nil}, + }, + []resultOrErr{ + err("Error reading packet: polarity " + + "insufficiently reversed"), + }, + }, + // nil key. + { + nil, + []zephyrtest.PacketRead{ + {samplePacket(), addr1, nil}, + }, []resultOrErr{ + result(sampleNotice(), AuthFailed, addr1), + }, + }, + } + for ti, test := range tests { + // Buffer of 1 because one of the tests intentionally + // has an extra read. + readChan := make(chan zephyrtest.PacketRead, 1) + go func() { + for _, read := range test.reads { + readChan <- read + } + close(readChan) + }() + mock := zephyrtest.NewMockPacketConn(clientAddr, readChan) + l, closer, lines := newTestLogger() + out := ReadNoticesFromServer(mock, test.keyblock, l) + for ei, expect := range test.expected { + if expect.r != nil { + if r, ok := <-out; !ok { + t.Errorf("%d.%d. Expected notice: %v", + ti, ei, expect.r) + } else { + expectNoticesEqual(t, r.Notice, expect.r.Notice) + if r.AuthStatus != expect.r.AuthStatus { + t.Errorf("%d.%d. AuthStatus = %v; want %v", + ti, ei, + r.AuthStatus, + expect.r.AuthStatus) + } + if !reflect.DeepEqual(r.Addr, expect.r.Addr) { + t.Errorf("%d.%d. Addr = %v; want %v", + ti, ei, + r.Addr, + expect.r.Addr) + } + } + } else { + if line, ok := <-lines; !ok { + t.Errorf("%d.%d. Expected error: %v", + ti, ei, expect.line) + } else if line != expect.line { + t.Errorf("%d.%d. line = %v; wanted %v", + ti, ei, line, expect.line) + } + } + } + closer.Close() + for line := range lines { + t.Errorf("%d. unexpected line: %v", ti, line) + } + for r := range out { + t.Errorf("%d. unexpected notice: %v", ti, r) + } + } +} diff --git a/reassembly.go b/reassembly.go new file mode 100644 index 0000000..b43892b --- /dev/null +++ b/reassembly.go @@ -0,0 +1,199 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + "fmt" + "strconv" + "strings" +) + +// MaxMessageBodyLength is the maximum length of a reassembled message +// body. +const MaxMessageBodyLength = 400000 + +// ErrBodyTooLong is returned if the reassembled body is too long. +var ErrBodyTooLong = errors.New("reassembled body too long") + +// ErrBodyLengthMismatch is returned if a notice's body length field +// is incompatible the reassembler being used. +var ErrBodyLengthMismatch = errors.New("reassembled body length mismatch") + +// ErrBodyFragmentOutOfBounds is returned if a notice's body is out of +// bounds of the reassembled message body. +var ErrBodyFragmentOutOfBounds = errors.New("message fragment out of bounds") + +// ParseMultipart parses the multipart field of a notice. It returns +// the part and partof parts of the field. On parse error, it returns +// 0 and the length of the body. +func ParseMultipart(n *Notice) (int, int) { + ss := strings.Split(n.Multipart, "/") + if len(ss) != 2 { + return 0, len(n.RawBody) + } + part, err := strconv.ParseInt(ss[0], 10, 0) + if err != nil || part < 0 { + return 0, len(n.RawBody) + } + partof, err := strconv.ParseInt(ss[1], 10, 0) + if err != nil || partof < 0 { + return 0, len(n.RawBody) + } + if part >= partof { + return 0, len(n.RawBody) + } + return int(part), int(partof) +} + +// EncodeMultipart encodes a pair of integers for the multipart field. +func EncodeMultipart(part, partof int) string { + return fmt.Sprintf("%d/%d", part, partof) +} + +type chunk struct { + offset int + buf []byte +} + +func (c chunk) end() int { + return c.offset + len(c.buf) +} + +// Reassembler maintains state for a reassembled notice. +type Reassembler struct { + length int + // We maintain a list of chunks that are ordered and separated + // by gaps. When completed, there is exactly one chunk. This + // differs from the libzephyr strategy of allocating a buffer + // ahead of time to be slightly less of a DoS vector. + chunks []chunk + header Header + haveHeader bool + // TODO(davidben): This is using libzephyr's behavior. After + // this is working, experiment with just including the + // AuthStatus into the key. The main concern is problems with + // the zhm retransmit bug. + authStatus AuthStatus +} + +// NewReassembler creates a Reassembler for a message with a given +// body length. +func NewReassembler(length int) *Reassembler { + return &Reassembler{length, []chunk{}, Header{}, false, AuthYes} +} + +// NewReassemblerFromMultipartField creates a Reassembler for a given +// notice's multipart field. Note that this does not call AddNotice. +func NewReassemblerFromMultipartField(n *Notice) (*Reassembler, error) { + _, partof := ParseMultipart(n) + if partof > MaxMessageBodyLength { + return nil, ErrBodyTooLong + } + return NewReassembler(partof), nil +} + +// TODO(davidben): Add serialization/deserialization methods for crazy +// fault-tolerant Roost version. + +// Done returns true when the message body has been reassembled. +func (r *Reassembler) Done() bool { + if !r.haveHeader { + return false + } + if r.length == 0 { + return true + } + return len(r.chunks) == 1 && r.chunks[0].offset == 0 && + len(r.chunks[0].buf) == r.length +} + +// Message returns the reassembled message once it is done. +func (r *Reassembler) Message() (*Message, AuthStatus) { + if !r.Done() { + return nil, AuthFailed + } + if r.length == 0 { + return &Message{r.header, []string{""}}, r.authStatus + } + return &Message{r.header, strings.Split(string(r.chunks[0].buf), "\x00")}, r.authStatus +} + +// AddNotice adds a notice into the reassembler state. If the notice +// is incompatible and discarded, it returns an error. +func (r *Reassembler) AddNotice(n *Notice, authStatus AuthStatus) error { + if r.Done() { + return nil + } + + // Check if this notice is compatible. + part, partof := ParseMultipart(n) + if partof != r.length { + return ErrBodyLengthMismatch + } + if part+len(n.RawBody) > r.length { + return ErrBodyFragmentOutOfBounds + } + + // Incorporate the AuthStatus. + if authStatus == AuthFailed { + r.authStatus = AuthFailed + } else if authStatus == AuthNo && r.authStatus != AuthFailed { + r.authStatus = AuthNo + } + + // Copy the header over. + if part == 0 { + r.header = n.Header + r.haveHeader = true + } + + if len(n.RawBody) == 0 { + return nil + } + + // Fill in the new chunks. First we insert our new chunk in order. + ordered := []chunk{} + added := false + for _, c := range r.chunks { + if c.offset > part && !added { + added = true + ordered = append(ordered, chunk{part, n.RawBody}) + } + ordered = append(ordered, c) + } + if !added { + ordered = append(ordered, chunk{part, n.RawBody}) + } + + // Now collapse chunks that touch. + dedup := []chunk{} + for _, c := range ordered { + if len(dedup) == 0 || dedup[len(dedup)-1].end() < c.offset { + dedup = append(dedup, c) + } else { + // Merge c into last by appending the last n + // bytes of c. + last := dedup[len(dedup)-1] + if n := c.end() - last.end(); n > 0 { + dedup[len(dedup)-1] = chunk{ + last.offset, + append(last.buf, c.buf[len(c.buf)-n:]...)} + } + } + } + r.chunks = dedup + return nil +} diff --git a/reassembly_test.go b/reassembly_test.go new file mode 100644 index 0000000..5e35d2d --- /dev/null +++ b/reassembly_test.go @@ -0,0 +1,195 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "fmt" + "reflect" + "testing" +) + +var longMessage []byte + +func init() { + longMessage = []byte{} + for i := 0; i < 256; i++ { + longMessage = append(longMessage, byte(i)) + } +} + +func longMessageChunk(uid, multiuid UID, off, length int) *Notice { + notice := sampleNoticeWithUID(uid) + notice.MultiUID = multiuid + notice.Multipart = fmt.Sprintf("%d/%d", off, len(longMessage)) + notice.RawBody = longMessage[off : off+length] + return notice +} + +func TestParseMultipart(t *testing.T) { + length := len(sampleNotice().RawBody) + tests := []struct { + in string + part int + partof int + }{ + {"0/10", 0, 10}, + {"4/5", 4, 5}, + {"-1/5", 0, length}, + {"bogus", 0, length}, + {"5", 0, length}, + {"10/5", 0, length}, + {"5/5", 0, length}, + {"bogus/5", 0, length}, + {"5/bogus", 0, length}, + {"5/-1", 0, length}, + } + for i, tt := range tests { + notice := sampleNotice() + notice.Multipart = tt.in + part, partof := ParseMultipart(notice) + if part != tt.part || partof != tt.partof { + t.Errorf("%d. ParseMultipart(%q) => %d %d, want %d %d", + i, tt.in, part, partof, tt.part, tt.partof) + } + } +} + +func TestReassembler(t *testing.T) { + expected := sampleMessage(sampleNotice().UID, longMessage) + + type chunkTest struct { + off int + length int + auth AuthStatus + } + tests := []struct { + chunks []chunkTest + auth AuthStatus + }{ + {[]chunkTest{{0, 256, AuthYes}}, AuthYes}, + {[]chunkTest{{0, 128, AuthYes}, {128, 128, AuthYes}}, AuthYes}, + {[]chunkTest{{128, 128, AuthYes}, {0, 128, AuthYes}}, AuthYes}, + {[]chunkTest{{100, 156, AuthYes}, {0, 128, AuthYes}}, AuthYes}, + {[]chunkTest{{0, 128, AuthYes}, {100, 156, AuthYes}}, AuthYes}, + { + []chunkTest{ + {0, 1, AuthYes}, {7, 1, AuthYes}, + {5, 1, AuthYes}, {3, 1, AuthYes}, + {0, 256, AuthYes}, + }, + AuthYes, + }, + { + []chunkTest{ + {5, 0, AuthYes}, {0, 128, AuthYes}, + {128, 0, AuthYes}, {128, 128, AuthYes}, + }, + AuthYes, + }, + {[]chunkTest{{0, 128, AuthYes}, {128, 128, AuthNo}}, AuthNo}, + {[]chunkTest{{0, 128, AuthNo}, {128, 128, AuthYes}}, AuthNo}, + {[]chunkTest{{0, 128, AuthNo}, {128, 128, AuthNo}}, AuthNo}, + {[]chunkTest{{0, 128, AuthYes}, {128, 128, AuthFailed}}, AuthFailed}, + {[]chunkTest{{0, 128, AuthFailed}, {128, 128, AuthYes}}, AuthFailed}, + {[]chunkTest{{0, 128, AuthNo}, {128, 128, AuthFailed}}, AuthFailed}, + {[]chunkTest{{0, 128, AuthFailed}, {128, 128, AuthNo}}, AuthFailed}, + {[]chunkTest{{0, 128, AuthFailed}, {128, 128, AuthFailed}}, AuthFailed}, + } +TestLoop: + for i, tt := range tests { + r := NewReassembler(len(longMessage)) + for _, c := range tt.chunks { + if r.Done() { + t.Errorf("%d. r.Done() was true; want false", i) + continue TestLoop + } + + // Make sure that the header comes from the + // first packet. + var uid UID + multiuid := sampleNotice().UID + if c.off == 0 { + uid = multiuid + } + notice := longMessageChunk(uid, multiuid, c.off, c.length) + if err := r.AddNotice(notice, c.auth); err != nil { + t.Errorf("%d. r.AddNotice(chunk %d, %d) failed: %v", + i, c.off, c.off+c.length, err) + continue TestLoop + } + } + if !r.Done() { + t.Errorf("%d. r.Done() was false; want true", i) + continue TestLoop + } + m, auth := r.Message() + expectHeadersEqual(t, &m.Header, &expected.Header) + if !reflect.DeepEqual(m.Body, expected.Body) { + t.Errorf("%d. m.Body = %v; want %v", i, m.Body, expected.Body) + } + if auth != tt.auth { + t.Errorf("auth = %v; want %v", auth, tt.auth) + } + } +} + +func TestReassemblerLengthMismatch(t *testing.T) { + r := NewReassembler(5) + if err := r.AddNotice(sampleNotice(), AuthYes); err != ErrBodyLengthMismatch { + t.Errorf("r.AddNotice did not fail as expected: %v", err) + } +} + +func TestReassemblerOutOfBounds(t *testing.T) { + r := NewReassembler(5) + notice := sampleNotice() + notice.RawBody = []byte("1234567890") + notice.Multipart = "0/5" + if err := r.AddNotice(notice, AuthYes); err != ErrBodyFragmentOutOfBounds { + t.Errorf("r.AddNotice did not fail as expected: %v", err) + } +} + +func TestReassemblerMaxBodyLength(t *testing.T) { + notice := sampleNotice() + notice.Multipart = "0/500000" + if _, err := NewReassemblerFromMultipartField(notice); err != ErrBodyTooLong { + t.Errorf("NewReassemblerFromMultipartField did not fail as expected: %v", err) + } +} + +func TestReassemblerZeroLength(t *testing.T) { + r := NewReassembler(0) + if r.Done() { + t.Errorf("r.Done() = true; want false") + } + notice := sampleNotice() + notice.RawBody = []byte{} + notice.Multipart = "0/0" + if err := r.AddNotice(notice, AuthYes); err != nil { + t.Errorf("r.AddNotice failed: %v", err) + } + if !r.Done() { + t.Errorf("r.Done() = false; want true") + } + m, auth := r.Message() + expectHeadersEqual(t, &m.Header, &sampleNotice().Header) + if len(m.Body) != 1 || m.Body[0] != "" { + t.Errorf("m.Body = %v; want []", m.Body) + } + if auth != AuthYes { + t.Errorf("auth = %v; want AuthYes", auth) + } +} diff --git a/server_config.go b/server_config.go new file mode 100644 index 0000000..9e7740e --- /dev/null +++ b/server_config.go @@ -0,0 +1,83 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" + "net" + + "github.com/zephyr-im/hesiod-go" +) + +// A ServerConfig describes how to connect to a given zephyrd +// instance. +type ServerConfig interface { + // ResolveServer returns a list of zephyrd addresses to + // round-robin connect to. This may be repeatedly to refresh + // this list. ResolveServer cannot return a list of length 0. + ResolveServer() ([]*net.UDPAddr, error) +} + +type staticConfig []*net.UDPAddr + +// NewStaticServer returns a ServerConfig which returns a static list +// of zephyr server addresses. +func NewStaticServer(addrs []*net.UDPAddr) ServerConfig { + if len(addrs) == 0 { + panic("no addresses supplied") + } + return staticConfig(addrs) +} + +func (s staticConfig) ResolveServer() ([]*net.UDPAddr, error) { + return []*net.UDPAddr(s), nil +} + +type hesiodConfig struct { + hs *hesiod.Hesiod + port int +} + +// NewServerFromHesiod creates the server configuration for a zephyr +// installation at a Hesiod realm. +func NewServerFromHesiod(hs *hesiod.Hesiod) (ServerConfig, error) { + // Get the port. This is shared among all instances. + svc, err := hs.GetServiceByName("zephyr-clt", "udp") + if err != nil { + return nil, err + } + return &hesiodConfig{hs, svc.Port}, nil +} + +func (hc *hesiodConfig) ResolveServer() ([]*net.UDPAddr, error) { + zephyrds, err := hc.hs.Resolve("zephyr", "sloc") + if err != nil { + return nil, err + } + if len(zephyrds) == 0 { + return nil, errors.New("no zephyrds found") + } + // Go ahead resolve them all now. Not much use in being lazy + // here. + addrs := make([]*net.UDPAddr, len(zephyrds)) + for i, zephyrd := range zephyrds { + addr, err := net.ResolveIPAddr("ip", zephyrd) + if err != nil { + return nil, err + } + addrs[i] = &net.UDPAddr{IP: addr.IP, Port: hc.port, Zone: addr.Zone} + } + return addrs, nil +} diff --git a/session.go b/session.go new file mode 100644 index 0000000..78c76c1 --- /dev/null +++ b/session.go @@ -0,0 +1,210 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "log" + "net" + "time" + + "github.com/zephyr-im/krb5-go" +) + +const dedupLifetime = 900 * time.Second +const fragmentLifetime = 30 * time.Second + +// A MessageReaderResult is an output of a Session's incoming Message +// channel. It either contains a Message with accompanying AuthStatus +// or an error. +type MessageReaderResult struct { + Message *Message + AuthStatus AuthStatus +} + +// A Session is a high-level connection to the zephyr servers. It +// handles reassembly of sharded messages, ACKs, and deduplicating +// notices by UID. This API should be used by most zephyr clients. +type Session struct { + conn *Connection + logger *log.Logger + clock Clock + messages chan MessageReaderResult +} + +// NewSession creates a new Session attached to a given connection +// with server configuration and credential. +func NewSession( + conn net.PacketConn, + server ServerConfig, + cred *krb5.Credential, + logger *log.Logger, +) (*Session, error) { + return NewSessionFull(conn, server, cred, logger, SystemClock) +} + +// NewSessionFull creates a new Session attached to a given connection +// with server configuration and credential. This variant allows +// passing a custom clock for testing. +func NewSessionFull( + conn net.PacketConn, + server ServerConfig, + cred *krb5.Credential, + logger *log.Logger, + clock Clock, +) (*Session, error) { + zconn, err := NewConnectionFull(conn, server, cred, logger, clock) + if err != nil { + return nil, err + } + + s := &Session{zconn, logger, clock, make(chan MessageReaderResult)} + go s.noticeLoop() + return s, nil +} + +func (s *Session) noticeLoop() { + dedup := NewWindowedMapFull(dedupLifetime, s.clock) + reassemble := NewWindowedMapFull(fragmentLifetime, s.clock) + + for r := range s.conn.Notices() { + // ACK if appropriate. + if r.Notice.Kind.ExpectsClientACK() { + s.conn.SendNoticeUnackedTo( + r.Notice.MakeACK(CLIENTACK, ""), r.Addr) + } + + // Deduplicate. + if _, ok := dedup.Lookup(r.Notice.UID); ok { + continue + } + dedup.Put(r.Notice.UID, nil) + + // Reassemble. + var msg *Reassembler + if t, ok := reassemble.Lookup(r.Notice.MultiUID); ok { + msg = t.(*Reassembler) + } else { + var err error + msg, err = NewReassemblerFromMultipartField(r.Notice) + if err != nil { + logPrintf(s.logger, "Error parsing multipart: %v", err) + continue + } + reassemble.Put(r.Notice.MultiUID, msg) + } + if err := msg.AddNotice(r.Notice, r.AuthStatus); err != nil { + logPrintf(s.logger, "Error reassembling notice: %v", err) + continue + } + if msg.Done() { + reassemble.Remove(r.Notice.MultiUID) + m, a := msg.Message() + s.messages <- MessageReaderResult{m, a} + } + } + close(s.messages) +} + +// Messages returns the incoming messages from the session. +func (s *Session) Messages() <-chan MessageReaderResult { + return s.messages +} + +// LocalAddr returns the local UDP address for the client when +// communicating with the Zephyr servers. +func (s *Session) LocalAddr() *net.UDPAddr { + return s.conn.LocalAddr() +} + +// Port returns the local UDP port for the client. +func (s *Session) Port() uint16 { + return uint16(s.conn.LocalAddr().Port) +} + +// Credential returns the credential for this session. +func (s *Session) Credential() *krb5.Credential { + return s.conn.Credential() +} + +// Sender returns the user the session is authenticated as. +func (s *Session) Sender() string { + return s.Credential().Client.String() +} + +// Realm returns the realm of the server being connected to. +func (s *Session) Realm() string { + // TODO(davidben): This really should come from the server + // config or something. + return s.Credential().Server.Realm +} + +// Close closes the session. +func (s *Session) Close() error { + return s.conn.Close() +} + +// MakeUID creates a new UID for a given time and this sessions local +// IP address. +func (s *Session) MakeUID(t time.Time) UID { + return MakeUID(s.LocalAddr().IP, t) +} + +// SendSubscribe sends a subscription notice for a list of triples, +// along with any defaults configured on the server. +func (s *Session) SendSubscribe( + ctx *krb5.Context, + subs []Subscription, +) (*Notice, error) { + return SendSubscribe(ctx, s.conn, 0, subs) +} + +// SendSubscribeNoDefaults sends a subscription notice for a list of +// triples. It returns the ACK from the server. +func (s *Session) SendSubscribeNoDefaults( + ctx *krb5.Context, + subs []Subscription, +) (*Notice, error) { + return SendSubscribeNoDefaults(ctx, s.conn, 0, subs) +} + +// SendUnsubscribe unsubscribes from a list of triples. It returns the +// ACK from the server. +func (s *Session) SendUnsubscribe( + ctx *krb5.Context, + subs []Subscription, +) (*Notice, error) { + return SendUnsubscribe(ctx, s.conn, 0, subs) +} + +// SendCancelSubscriptions closes the session. This should be called +// before exit to release the port on the zephyrd. It returns the ACK +// from the server. +func (s *Session) SendCancelSubscriptions(ctx *krb5.Context) (*Notice, error) { + return SendCancelSubscriptions(ctx, s.conn, 0) +} + +// SendMessage sends an authenticated message over the session, +// sharding into multiple notices as needed. It returns the ACK from +// the server if the message is ACKED. +func (s *Session) SendMessage(ctx *krb5.Context, msg *Message) (*Notice, error) { + return SendMessage(ctx, s.conn, msg) +} + +// SendMessageUnauth sends an unauthenticated message over the +// session, sharding into multiple notices as needed. It returns the +// ACK from the server if the message is ACKED. +func (s *Session) SendMessageUnauth(msg *Message) (*Notice, error) { + return SendMessageUnauth(s.conn, msg) +} diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..2a988c2 --- /dev/null +++ b/session_test.go @@ -0,0 +1,167 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "net" + "reflect" + "testing" + "time" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/krb5-go/krb5test" + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +type noticeWithAuth struct { + notice *Notice + authStatus AuthStatus +} + +func mockSendingServer( + t *testing.T, + client net.Addr, + key *krb5.KeyBlock, + conn net.PacketConn, + notices []noticeWithAuth, +) { + // Set some stuff up. + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, keytab := makeServerContextAndKeyTab(t) + defer ctx.Free() + defer keytab.Close() + + rawNotices := ReadRawNotices(conn, l) + for _, n := range notices { + // Assemble the notice to send out. + pkt, err := n.notice.EncodePacketForClient(ctx, n.authStatus, key) + if err != nil { + t.Fatalf("Error encoding notice: %v", err) + } + + // Send to the client. + _, err = conn.WriteTo(pkt, client) + if err != nil { + t.Fatalf("Failed to send packet") + } + + // Expect a CLIENTACK. + reply, ok := <-rawNotices + if !ok { + t.Errorf("Did not receive ACK from client") + } + ack, err := DecodeRawNotice(reply.RawNotice) + if err != nil { + t.Errorf("Failed to record ACK: %v", err) + } + if ack.Kind != CLIENTACK || !ack.UID.Equal(n.notice.UID) { + t.Errorf("Expected CLIENTACK; got %v", ack) + } + } + conn.Close() +} + +func TestSession(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + // Set up out network. + clock := zephyrtest.NewMockClock() + client, server1, server2 := mockNetwork2() + session, err := NewSessionFull(client, serverConfigFull, krb5test.Credential(), + l, clock) + if err != nil { + t.Fatal(err) + } + + uid1 := MakeUID(clientAddr.IP, time.Unix(1, 0)) + uid2 := MakeUID(clientAddr.IP, time.Unix(2, 0)) + uid3 := MakeUID(clientAddr.IP, time.Unix(3, 0)) + uid4 := MakeUID(clientAddr.IP, time.Unix(4, 0)) + uid5 := MakeUID(clientAddr.IP, time.Unix(5, 0)) + uid6 := MakeUID(clientAddr.IP, time.Unix(6, 0)) + uid7 := MakeUID(clientAddr.IP, time.Unix(7, 0)) + uid8 := MakeUID(clientAddr.IP, time.Unix(8, 0)) + + // Put together two servers. Each server sends some set of + // messages. Use two to assert that ACKs go to the server that + // expects them. This test will test reassembly and + // deduplicating and ACK behavior. + notices1 := []noticeWithAuth{ + // Two copies of a notice. Only receive one, but ACK both. + {sampleNoticeWithUID(uid1), AuthYes}, + {sampleNoticeWithUID(uid1), AuthYes}, + // Dedup works even when auth changes; this is needed + // for now because of a historical zhm bug on legacy + // clients. + {sampleNoticeWithUID(uid1), AuthFailed}, + // A sharded notice. + {longMessageChunk(uid2, uid2, 0, 128), AuthYes}, + {longMessageChunk(uid3, uid2, 128, 128), AuthYes}, + // Another sharded notice. Half comes from the other server. + {longMessageChunk(uid4, uid4, 0, 128), AuthYes}, + // Another copy of the first; still only receive one. + {longMessageChunk(uid2, uid2, 0, 128), AuthYes}, + {longMessageChunk(uid3, uid2, 128, 128), AuthYes}, + // To ensure that everything else ACKed properly, add + // a redundant test at the end. + {longMessageChunk(uid7, uid7, 0, 128), AuthYes}, + } + notices2 := []noticeWithAuth{ + // Sharded notice that never completes. + {longMessageChunk(uid5, uid5, 0, 128), AuthYes}, + // Other half of the last sharded notice. This comes + // from the other server, so ACKs should go there. + {longMessageChunk(uid6, uid4, 128, 128), AuthFailed}, + // Dedup works across servers. + {sampleNoticeWithUID(uid1), AuthYes}, + // To ensure that everything else ACKed properly, add + // a redundant test at the end. + {longMessageChunk(uid8, uid7, 128, 128), AuthYes}, + } + + // Spin up our two servers. + go mockSendingServer(t, clientAddr, krb5test.SessionKey(), server1, notices1) + go mockSendingServer(t, clientAddr, krb5test.SessionKey(), server2, notices2) + + messages := []MessageReaderResult{ + {sampleMessage(uid1, sampleNotice().RawBody), AuthYes}, + {sampleMessage(uid2, longMessage), AuthYes}, + {sampleMessage(uid4, longMessage), AuthFailed}, + {sampleMessage(uid7, longMessage), AuthYes}, + } + for _, expected := range messages { + m := <-session.Messages() + if m.AuthStatus != expected.AuthStatus { + t.Errorf("AuthStatus = %v; want %v", + m.AuthStatus, expected.AuthStatus) + } + expectHeadersEqual(t, &m.Message.Header, &expected.Message.Header) + if !reflect.DeepEqual(m.Message.Body, expected.Message.Body) { + t.Errorf("m.Body = %v; want %v", + m.Message.Body, expected.Message.Body) + } + } + session.Close() + for m := range session.Messages() { + t.Errorf("Unexpected message: %v", m) + } +} diff --git a/subscription.go b/subscription.go new file mode 100644 index 0000000..0ad1da4 --- /dev/null +++ b/subscription.go @@ -0,0 +1,185 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "time" + + "github.com/zephyr-im/krb5-go" +) + +// WildcardInstance is the instance value used to subscribe to all +// instances for a given recipient, class pair. +const WildcardInstance = "*" + +// A Subscription represents a triple to subscribe to. +type Subscription struct { + Recipient string + Class string + Instance string +} + +func tripleLength(s *Subscription) int { + return len(s.Recipient) + len(s.Class) + len(s.Instance) + 3 +} + +const ( + clientSubscribe = "SUBSCRIBE" + clientSubscribeNoDefs = "SUBSCRIBE_NODEFS" + clientUnsubscribe = "UNSUBSCRIBE" + clientCancelSub = "CLEARSUB" + clientGimmeSubs = "GIMME" + clientGimmeDefs = "GIMMEDEFS" + clientFlushSubs = "FLUSHSUBS" +) + +func sendControlNotice(ctx *krb5.Context, conn *Connection, + port uint16, opcode string, subs []Subscription) (*Notice, error) { + if port == 0 { + port = uint16(conn.LocalAddr().Port) + } + + uid := MakeUID(conn.LocalAddr().IP, time.Now()) + notice := &Notice{ + Header: Header{ + Kind: ACKED, + UID: uid, + Port: port, + Class: "ZEPHYR_CTL", + Instance: "CLIENT", + OpCode: opcode, + Sender: conn.Credential().Client.String(), + SenderAddress: conn.LocalAddr().IP, + }, + } + // Edge case: no subs to send. Still send at least one packet. + if len(subs) == 0 { + return conn.SendNotice(ctx, notice) + } + + // Otherwise, this needs to be sharded. First, compute how + // much space we have in the body. + pkt, err := notice.EncodePacketForServer(ctx, conn.Credential()) + if err != nil { + return nil, err + } + headerLen := len(pkt) + + var ack *Notice + for len(subs) > 0 { + // Compute how many triples will fit in one + // packet. Leave some 13 bytes of slop because, if + // we're unlucky, zcode may blow up the input. 13 + // chosen because it's what libzephyr uses and is + // above 1024 / 128. See also: message.go. + remaining := MaxPacketLength - headerLen - 13 + i := 0 + for ; i < len(subs); i++ { + l := tripleLength(&subs[i]) + if l > remaining { + break + } else { + remaining -= l + } + } + + if i == 0 { + return nil, ErrPacketTooLong + } + + shard := subs[0:i] + subs = subs[i:] + + // Send this shard. Should end with a trailing NUL, so + // add an empty field at the end. + fields := make([][]byte, len(shard)*3+1) + for j, sub := range shard { + fields[3*j] = []byte(sub.Class) + fields[3*j+1] = []byte(sub.Instance) + fields[3*j+2] = []byte(sub.Recipient) + } + uid := MakeUID(conn.LocalAddr().IP, time.Now()) + notice.UID = uid + notice.MultiUID = uid + notice.RawBody = bytes.Join(fields, []byte{0}) + pkt, err := notice.EncodePacketForServer(ctx, conn.Credential()) + if err != nil { + return nil, err + } + + ack, err = conn.SendPacket(pkt, ACKED, uid) + if err != nil { + return nil, err + } else if ack.Kind != SERVACK { + return ack, nil + } + } + + // Return the last ACK we saw. + return ack, nil +} + +// SendSubscribe uses a connection to subscribe to a list of +// subscriptions, along with any defaults configured on the +// server. The passed krb5 context is used to authenticate the +// request. The port number is the port to subscribe, or 0 to use the +// port of the connection. +func SendSubscribe( + ctx *krb5.Context, + conn *Connection, + port uint16, + subs []Subscription, +) (*Notice, error) { + return sendControlNotice(ctx, conn, port, clientSubscribe, subs) +} + +// SendSubscribeNoDefaults uses a connection to subscribe to a list of +// subscriptions. The passed krb5 context is used to authenticate the +// request. The port number is the port to subscribe, or 0 to use the +// port of the connection. +func SendSubscribeNoDefaults( + ctx *krb5.Context, + conn *Connection, + port uint16, + subs []Subscription, +) (*Notice, error) { + return sendControlNotice(ctx, conn, port, clientSubscribeNoDefs, subs) +} + +// SendUnsubscribe uses a connection to unsubscribe from a list of +// subscriptions. The passed krb5 context is used to authenticate the +// request. The port number is the port to subscribe, or 0 to use the +// port of the connection. +func SendUnsubscribe( + ctx *krb5.Context, + conn *Connection, + port uint16, + subs []Subscription, +) (*Notice, error) { + return sendControlNotice(ctx, conn, port, clientUnsubscribe, subs) +} + +// SendCancelSubscriptions uses a connection to close a session. The +// passed krb5 context is used to authenticate the request. The port +// number is the port to subscribe, or 0 to use the port of the +// connection. +func SendCancelSubscriptions( + ctx *krb5.Context, + conn *Connection, + port uint16, +) (*Notice, error) { + return sendControlNotice(ctx, conn, port, clientCancelSub, nil) +} diff --git a/subscription_test.go b/subscription_test.go new file mode 100644 index 0000000..e00b101 --- /dev/null +++ b/subscription_test.go @@ -0,0 +1,297 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "errors" + "fmt" + "net" + "reflect" + "testing" + + "github.com/zephyr-im/krb5-go" + "github.com/zephyr-im/krb5-go/krb5test" + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +func ackAndDumpNotices(t *testing.T, conn net.PacketConn, auth AuthStatus, sink chan<- *Notice) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, keytab := makeServerContextAndKeyTab(t) + defer ctx.Free() + defer keytab.Close() + for r := range ReadRawNotices(conn, l) { + authStatus, _, err := r.RawNotice.CheckAuthFromClient( + ctx, krb5test.Service(), keytab) + if err != nil { + t.Fatalf("CheckAuthFromClient failed: %v", err) + return + } + notice, err := DecodeRawNotice(r.RawNotice) + if err != nil { + t.Fatalf("DecodeRawNotice failed: %v", err) + return + } + if authStatus != auth { + t.Errorf("Bad authStatus %v; want %v", authStatus, auth) + } + if notice.Kind.ExpectsServerACK() { + conn.WriteTo(notice.MakeACK(SERVACK, "SENT").EncodePacketUnauth(), + r.Addr) + } + sink <- notice + } + close(sink) +} + +func nackNotices(t *testing.T, conn net.PacketConn) { + l, lc := expectNoLogs(t) + defer lc.Close() + for r := range ReadRawNotices(conn, l) { + notice, err := DecodeRawNotice(r.RawNotice) + if err != nil { + t.Fatalf("DecodeRawNotice failed: %v", err) + return + } + if notice.Kind.ExpectsServerACK() { + conn.WriteTo(notice.MakeACK(SERVNAK, "LOST").EncodePacketUnauth(), + r.Addr) + } + } +} + +func checkSubscriptionNotice( + t *testing.T, n *Notice, port uint16, opcode string) []Subscription { + if n.Kind != ACKED { + t.Errorf("n.Kind = %v; want ACKED", n.Kind) + } + if n.Port != port { + t.Errorf("n.Port = %v; want %v", n.Port, port) + } + if n.Class != "ZEPHYR_CTL" { + t.Errorf("n.Class = %q; want 'ZEPHYR_CTL'", n.Class) + } + if n.Instance != "CLIENT" { + t.Errorf("n.Instance = %q; want 'CLIENT'", n.Instance) + } + if n.OpCode != opcode { + t.Errorf("n.OpCode = %q; want %q", n.OpCode, opcode) + } + chunks := bytes.Split(n.RawBody, []byte{0}) + if len(chunks) == 0 { + t.Errorf("chunks is empty.") + return nil + } + if len(chunks[len(chunks)-1]) != 0 { + t.Errorf("End of chunks should be empty.") + return nil + } + chunks = chunks[0 : len(chunks)-1] + if len(chunks)%3 != 0 { + t.Errorf("Body length not a multiple of 3: %d", len(chunks)) + return nil + } + out := []Subscription{} + for i := 0; i+2 < len(chunks); i += 3 { + out = append(out, Subscription{ + Recipient: string(chunks[i+2]), + Class: string(chunks[i]), + Instance: string(chunks[i+1]), + }) + } + return out +} + +func TestSendSubscribe(t *testing.T) { + tests := [][]Subscription{ + // Basic case. + []Subscription{ + {"", "davidben", WildcardInstance}, + {"davidben@ATHENA.MIT.EDU", "message", "personal"}, + }, + // Empty subs list. + []Subscription{}, + } + // Too many subs. + long := []Subscription{} + for i := 0; i < 1000; i++ { + long = append(long, Subscription{ + "", fmt.Sprintf("class%d", i), WildcardInstance}) + } + tests = append(tests, long) + + for i, subs := range tests { + l, lc := expectNoLogs(t) + defer lc.Close() + + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVACK notices as they come in. + notices := make(chan *Notice, len(subs)) + go ackAndDumpNotices(t, server, AuthYes, notices) + + ack, err := SendSubscribeNoDefaults(ctx, conn, 0, subs) + if err != nil { + t.Fatalf("[%d] SendSubscribeNoDefaults failed: %v", i, err) + } + if ack.Kind != SERVACK { + t.Errorf("[%d] ack.Kind = %v; want SERVACK", i, ack.Kind) + } + if string(ack.RawBody) != "SENT" { + t.Errorf("[%d] ack.RawBody = %v; want 'SENT'", + i, string(ack.RawBody)) + } + server.Close() + + // Find out what was sent. + sentSubs := []Subscription{} + gotNotice := false + for notice := range notices { + gotNotice = true + s := checkSubscriptionNotice(t, notice, + uint16(conn.LocalAddr().Port), + clientSubscribeNoDefs) + sentSubs = append(sentSubs, s...) + } + if !reflect.DeepEqual(sentSubs, subs) { + t.Errorf("[%d] Sent %v; want %v", i, sentSubs, subs) + } + if !gotNotice { + t.Errorf("Received no notices") + } + } +} + +func TestSendSubscribeTooLong(t *testing.T) { + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + class := "" + for i := 0; i < MaxPacketLength; i++ { + class += "z" + } + subs := []Subscription{{"", class, WildcardInstance}} + _, err = SendSubscribeNoDefaults(ctx, conn, 0, subs) + if err != ErrPacketTooLong { + t.Fatalf("SendSubscribeNoDefaults didn't fail as expected: %v", err) + } +} + +func TestSendSubscribeNack(t *testing.T) { + subs := []Subscription{ + {"", "davidben", WildcardInstance}, + {"davidben@ATHENA.MIT.EDU", "message", "personal"}, + } + + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + clock := zephyrtest.NewMockClock() + client, server := mockNetwork1() + defer server.Close() + conn, err := NewConnectionFull(client, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to SERVNAK notices as they come in. + go nackNotices(t, server) + + ack, err := SendSubscribeNoDefaults(ctx, conn, 0, subs) + if err != nil { + t.Fatalf("SendSubscribeNoDefaults failed: %v", err) + } + if ack.Kind != SERVNAK { + t.Errorf("ack.Kind = %v; want SERVNAK", ack.Kind) + } + if string(ack.RawBody) != "LOST" { + t.Errorf("ack.RawBody = %v; want 'LOST'", string(ack.RawBody)) + } +} + +func TestSendSubscribeSendError(t *testing.T) { + subs := []Subscription{ + {"", "davidben", WildcardInstance}, + {"davidben@ATHENA.MIT.EDU", "message", "personal"}, + } + + l, lc := expectNoLogs(t) + defer lc.Close() + ctx, err := krb5.NewContext() + if err != nil { + t.Fatal(err) + } + defer ctx.Free() + + clock := zephyrtest.NewMockClock() + readChan := make(chan zephyrtest.PacketRead) + close(readChan) + mock := zephyrtest.NewMockPacketConn(clientAddr, readChan) + conn, err := NewConnectionFull(mock, serverConfig, + krb5test.Credential(), l, clock) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + // Set up a "server" to fail all writes. + expectedErr := errors.New("failed") + go func() { + for write := range mock.Writes() { + write.Result <- expectedErr + } + }() + + // Send the message. + if _, err := SendSubscribeNoDefaults(ctx, conn, 0, subs); err != expectedErr { + t.Errorf("SendMessageUnauth didn't fail as expected: %v", err) + } +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..49068b9 --- /dev/null +++ b/util.go @@ -0,0 +1,31 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +// Creates a copy of a byte slice so it doesn't share an underlying +// backing store. +func copyByteSlice(in []byte) []byte { + out := make([]byte, len(in)) + copy(out, in) + return out +} + +func copyByteSlices(in [][]byte) [][]byte { + out := make([][]byte, len(in)) + for i := range in { + out[i] = copyByteSlice(in[i]) + } + return out +} diff --git a/windowed_map.go b/windowed_map.go new file mode 100644 index 0000000..be9b886 --- /dev/null +++ b/windowed_map.go @@ -0,0 +1,142 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "container/heap" + "time" +) + +// TODO(davidben): The key should include the sender. Probably also +// the AuthStatus. + +type uidItem struct { + uid UID + expire time.Time + index int + value interface{} +} + +type uidList []*uidItem + +func (ul uidList) Len() int { return len(ul) } + +func (ul uidList) Less(i, j int) bool { + return ul[i].expire.Before(ul[j].expire) +} + +func (ul uidList) Swap(i, j int) { + ul[i], ul[j] = ul[j], ul[i] + ul[i].index = i + ul[j].index = j +} + +func (ul *uidList) Push(x interface{}) { + n := len(*ul) + item := x.(*uidItem) + item.index = n + *ul = append(*ul, item) +} + +func (ul *uidList) Pop() interface{} { + old := *ul + n := len(old) + item := old[n-1] + item.index = -1 // for safety + *ul = old[0 : n-1] + return item +} + +// A WindowedMap is a map from UID to some value type where keys +// expire if not accessed for some time. +type WindowedMap struct { + clock Clock + lifetime time.Duration + uidMap map[UID]*uidItem + uidList uidList +} + +// NewWindowedMap creates a new WindowedMap with the specified +// lifetime for key entries. +func NewWindowedMap(lifetime time.Duration) *WindowedMap { + return NewWindowedMapFull(lifetime, SystemClock) +} + +// NewWindowedMapFull creates a new WindowedMap with the specified +// lifetime for key entries. It allows passing a custom Clock +// implementation for testing purposes. +func NewWindowedMapFull(lifetime time.Duration, clock Clock) *WindowedMap { + w := &WindowedMap{clock, lifetime, map[UID]*uidItem{}, uidList{}} + heap.Init(&w.uidList) + return w +} + +// Len returns the number of entries in the windowed map. Note that +// this does not get updated for any entries that map have expired +// since the last lookup or modification. +func (w *WindowedMap) Len() int { + return len(w.uidList) +} + +// ExpireOldEntries removes entries from the windowed map which have +// since expired. +func (w *WindowedMap) ExpireOldEntries() { + if len(w.uidList) != len(w.uidMap) { + panic(*w) + } + now := w.clock.Now() + for len(w.uidList) > 0 && !w.uidList[0].expire.After(now) { + delete(w.uidMap, w.uidList[0].uid) + heap.Remove(&w.uidList, 0) + } +} + +// Lookup looks up the value for a given UID. The second return value +// is false if the key is not in the map. Looking up a UID updates the +// access time for that key. +func (w *WindowedMap) Lookup(uid UID) (interface{}, bool) { + w.ExpireOldEntries() + item, ok := w.uidMap[uid] + if !ok { + return nil, false + } + w.Put(uid, item.value) + return item.value, true +} + +// Remove removes the entry for a given UID and returns the value +// associated with it. The second return value is false if the key is +// not in the map. +func (w *WindowedMap) Remove(uid UID) (interface{}, bool) { + w.ExpireOldEntries() + item, ok := w.uidMap[uid] + if !ok { + return nil, false + } + heap.Remove(&w.uidList, item.index) + delete(w.uidMap, uid) + return item.value, true +} + +// Put inserts a new value into the windowed map, overriding the +// existing value for UID if it already exists. +func (w *WindowedMap) Put(uid UID, value interface{}) { + if _, ok := w.uidMap[uid]; ok { + w.Remove(uid) + } + item := &uidItem{uid, w.clock.Now().Add(w.lifetime), -1, value} + heap.Push(&w.uidList, item) + w.uidMap[uid] = item +} diff --git a/windowed_map_test.go b/windowed_map_test.go new file mode 100644 index 0000000..a1e273d --- /dev/null +++ b/windowed_map_test.go @@ -0,0 +1,96 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "testing" + "time" + + "github.com/zephyr-im/zephyr-go/zephyrtest" +) + +func TestWindowedMap(t *testing.T) { + clock := zephyrtest.NewMockClock() + uid1 := MakeUID(clientAddr.IP, time.Unix(1, 0)) + uid2 := MakeUID(clientAddr.IP, time.Unix(2, 0)) + + w := NewWindowedMapFull(5*time.Second, clock) + + // Initially empty. + if w.Len() != 0 { + t.Errorf("w.Len() = %d; want 0", w.Len()) + } + if val, ok := w.Lookup(uid1); ok { + t.Errorf("w.Lookup(uid1) returned %v, true; want false", val) + } + if val, ok := w.Remove(uid1); ok { + t.Errorf("w.Remove(uid1) returned %v, true; want false", val) + } + + // Insert something and remove it. + w.Put(uid1, 1) + if val, ok := w.Lookup(uid1); !ok || val.(int) != 1 { + t.Errorf("w.Lookup(uid1) returned %v, %v; want 1", val, ok) + } + if w.Len() != 1 { + t.Errorf("w.Len() = %d; want 1", w.Len()) + } + if val, ok := w.Remove(uid1); !ok || val.(int) != 1 { + t.Errorf("w.Remove(uid1) returned %v, %v; want 1", val, ok) + } + + // Insert something and let it expire. + w.Put(uid1, 1) + clock.Advance(5 * time.Second) + if val, ok := w.Lookup(uid1); ok { + t.Errorf("w.Lookup(uid1) returned %v, true; want false", val) + } + if w.Len() != 0 { + t.Errorf("w.Len() = %d; want 0", w.Len()) + } + + // Looking up a value updates the expiration time. + w.Put(uid1, 1) + w.Put(uid2, 2) + clock.Advance(2 * time.Second) + w.Lookup(uid1) + clock.Advance(3 * time.Second) + if val, ok := w.Lookup(uid2); ok { + t.Errorf("w.Lookup(uid2) returned %v, true; want false", val) + } + if val, ok := w.Lookup(uid1); !ok || val.(int) != 1 { + t.Errorf("w.Lookup(uid1) returned %v, %v; want 1", val, ok) + } + if w.Len() != 1 { + t.Errorf("w.Len() = %d; want 1", w.Len()) + } + clock.Advance(5 * time.Second) + if val, ok := w.Lookup(uid1); ok { + t.Errorf("w.Lookup(uid1) returned %v, true; want false", val) + } + if w.Len() != 0 { + t.Errorf("w.Len() = %d; want 0", w.Len()) + } + + // Overriding values works. + w.Put(uid1, 1) + w.Put(uid1, 2) + if val, ok := w.Lookup(uid1); !ok || val.(int) != 2 { + t.Errorf("w.Lookup(uid1) returned %v, %v; want 2", val, ok) + } + if w.Len() != 1 { + t.Errorf("w.Len() = %d; want 1", w.Len()) + } +} diff --git a/zascii.go b/zascii.go new file mode 100644 index 0000000..c25fdb7 --- /dev/null +++ b/zascii.go @@ -0,0 +1,197 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "encoding/binary" + "errors" +) + +// The zephyrascii implementation is going to intentionally be +// stricter than libzephyr's for now. Byte sequences MUST be bracketed +// at every word and MUST be uppercase. + +const upperHexTable = "0123456789ABCDEF" + +// The encoded length of a zephyrascii word. +const zephyrasciiWordLength = 2 + 2*4 + +// DecodedZAsciiLength returns the decoded length of a zephyrascii +// input of length n. +func DecodedZAsciiLength(n int) int { + if n == 0 { + return 0 + } + result := 4 * (n / (zephyrasciiWordLength + 1)) + result += ((n % (zephyrasciiWordLength + 1)) - 2) / 2 + return result +} + +func decodeUpperHexChar(c byte) (byte, error) { + if '0' <= c && c <= '9' { + return c - '0', nil + } else if 'A' <= c && c <= 'F' { + return c - 'A' + 10, nil + } else { + return 0, errors.New("bad hex character") + } +} + +func decodeUpperHex(dst, src []byte) (int, error) { + if len(src)%2 != 0 { + return 0, errors.New("bad hex length") + } + for i := 0; i < len(src)/2; i++ { + hi, err := decodeUpperHexChar(src[2*i]) + if err != nil { + return 0, err + } + lo, err := decodeUpperHexChar(src[2*i+1]) + if err != nil { + return 0, err + } + dst[i] = (hi<<4 | lo) + } + return len(src) / 2, nil +} + +// DecodeZAsciiInto decodes zephyrascii from src and writes the output +// into dst. It returns the number of bytes written. +func DecodeZAsciiInto(dst, src []byte) (int, error) { + if len(src) == 0 { + return 0, nil + } + + // Check the lengths ahead of time. + lastLen := len(src) % (zephyrasciiWordLength + 1) + if lastLen < 2+2 || lastLen%2 != 0 { + return 0, errors.New("bad zephyrascii field length") + } + + j := 0 + for i := 0; i < len(src); i += zephyrasciiWordLength + 1 { + if i > 0 && src[i-1] != ' ' { + return 0, errors.New("expected ' '") + } + if src[i] != '0' || src[i+1] != 'x' { + return 0, errors.New("expected '0x'") + } + limit := i + zephyrasciiWordLength + if limit > len(src) { + limit = len(src) + } + decoded, err := decodeUpperHex(dst[j:], src[i+2:limit]) + if err != nil { + return 0, err + } + j += decoded + } + return j, nil +} + +// DecodeZAscii decodes zephyrascii and returns the decoded result as +// a byte slice. +func DecodeZAscii(src []byte) ([]byte, error) { + dst := make([]byte, DecodedZAsciiLength(len(src))) + if l, err := DecodeZAsciiInto(dst, src); err != nil { + return nil, err + } else if l != len(dst) { + panic(l) + } + return dst, nil +} + +// DecodeZAscii16 decodes the zephyrascii encoding of a 16-bit integer +// and returns the result as a uint16. +func DecodeZAscii16(src []byte) (uint16, error) { + if len(src) != 2+2*2 { + return 0, errors.New("bad length for uint16 zephyrascii") + } + dst, err := DecodeZAscii(src) + if err != nil { + return 0, err + } + if len(dst) != 2 { + panic(dst) + } + return binary.BigEndian.Uint16(dst), nil +} + +// DecodeZAscii32 decodes the zephyrascii encoding of a 32-bit integer +// and returns the result as a uint32. +func DecodeZAscii32(src []byte) (uint32, error) { + if len(src) != 2+2*4 { + return 0, errors.New("bad length for uint32 zephyrascii") + } + dst, err := DecodeZAscii(src) + if err != nil { + return 0, err + } + if len(dst) != 4 { + panic(dst) + } + return binary.BigEndian.Uint32(dst), nil +} + +// EncodedZAsciiLength returns the length of the zephyrascii encoding +// of a byte slice of length n. +func EncodedZAsciiLength(n int) int { + if n == 0 { + return 0 + } + fullWords := (n / 4) * (zephyrasciiWordLength + 1) + rest := n % 4 + if rest == 0 { + // Remove trailing space. + return fullWords - 1 + } + // Account for 0x and remainder. + return fullWords + 2 + rest*2 +} + +// EncodeZAscii encodes a byte slice as zephyrascii. +func EncodeZAscii(src []byte) []byte { + dst := make([]byte, EncodedZAsciiLength(len(src))) + j := 0 + for i, v := range src { + if i%4 == 0 { + if i > 0 { + dst[j] = ' ' + j++ + } + dst[j] = '0' + dst[j+1] = 'x' + j += 2 + } + dst[j] = upperHexTable[v>>4] + dst[j+1] = upperHexTable[v&0xf] + j += 2 + } + return dst +} + +// EncodeZAscii16 encodes a 16-bit integer as zephyrascii. +func EncodeZAscii16(val uint16) []byte { + src := make([]byte, 2) + binary.BigEndian.PutUint16(src, val) + return EncodeZAscii(src) +} + +// EncodeZAscii32 encodes a 32-bit integer as zephyrascii. +func EncodeZAscii32(val uint32) []byte { + src := make([]byte, 4) + binary.BigEndian.PutUint32(src, val) + return EncodeZAscii(src) +} diff --git a/zascii_test.go b/zascii_test.go new file mode 100644 index 0000000..0ff82ff --- /dev/null +++ b/zascii_test.go @@ -0,0 +1,167 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "testing" +) + +func testZAsciiPair(t *testing.T, encodedStr, decodedStr string) { + encoded, decoded := []byte(encodedStr), []byte(decodedStr) + + if l := DecodedZAsciiLength(len(encoded)); l != len(decoded) { + t.Errorf("DecodedZAsciiLength(%v) = %v, want %v", + len(encoded), l, len(decoded)) + } + if l := EncodedZAsciiLength(len(decoded)); l != len(encoded) { + t.Errorf("EncodedZAsciiLength(%v) = %v, want %v", + len(decoded), l, len(encoded)) + } + + dst := make([]byte, len(decoded)) + l, err := DecodeZAsciiInto(dst, encoded) + if err != nil { + t.Errorf("DecodeZAsciiInto(%v) failed unexpectedly: %v", + encodedStr, err) + } else if l != len(decoded) { + t.Errorf("DecodeZAsciiInto(%v) wrote %v bytes, want %v", + encodedStr, l, len(decoded)) + } else if !bytes.Equal(dst, decoded) { + t.Errorf("DecodeZAsciiInto(%v) = %v, want %v", + encodedStr, dst, decoded) + } + + if dst, err := DecodeZAscii(encoded); err != nil { + t.Errorf("DecodeZAscii(%v) failed unexpectedly: %v", + encodedStr, err) + } else if !bytes.Equal(dst, decoded) { + t.Errorf("DecodeZAscii(%v) = %v, want %v", + encodedStr, dst, decoded) + } + + if dst := EncodeZAscii(decoded); !bytes.Equal(dst, encoded) { + t.Errorf("EncodeZAscii(%v) = %v, want %v", + decoded, string(dst), encodedStr) + } +} + +func testZAscii16Pair(t *testing.T, encodedStr string, val uint16) { + encoded := []byte(encodedStr) + + v, err := DecodeZAscii16(encoded) + if err != nil { + t.Errorf("DecodeZAscii16(%v) failed unexpectedly: %v", + encodedStr, err) + } else if v != val { + t.Errorf("DecodeZAscii16(%v) = %v, want %v", + encodedStr, v, val) + } + + if dst := EncodeZAscii16(val); !bytes.Equal(dst, encoded) { + t.Errorf("EncodeZAscii16(%v) = %v, want %v", + val, string(dst), encodedStr) + } +} + +func testZAscii32Pair(t *testing.T, encodedStr string, val uint32) { + encoded := []byte(encodedStr) + + v, err := DecodeZAscii32(encoded) + if err != nil { + t.Errorf("DecodeZAscii32(%v) failed unexpectedly: %v", + encodedStr, err) + } else if v != val { + t.Errorf("DecodeZAscii32(%v) = %v, want %v", + encodedStr, v, val) + } + + if dst := EncodeZAscii32(val); !bytes.Equal(dst, encoded) { + t.Errorf("EncodeZAscii32(%v) = %v, want %v", + val, string(dst), encodedStr) + } +} + +func TestZAscii(t *testing.T) { + testZAsciiPair(t, "", "") + testZAsciiPair(t, "0xA5", "\xa5") + testZAsciiPair(t, "0x0102", "\x01\x02") + testZAsciiPair(t, "0x1A2B3C", "\x1a\x2b\x3c") + testZAsciiPair(t, "0xDEADBEEF", "\xde\xad\xbe\xef") + testZAsciiPair(t, "0xDEADBEEF 0xABAD1DEA", + "\xde\xad\xbe\xef\xab\xad\x1d\xea") + testZAsciiPair(t, "0xDEADBEEF 0xABAD1DEA 0x1234", + "\xde\xad\xbe\xef\xab\xad\x1d\xea\x12\x34") +} + +func TestZAscii16(t *testing.T) { + testZAscii16Pair(t, "0x0001", 1) + testZAscii16Pair(t, "0xFACE", 0xface) +} + +func TestZAscii32(t *testing.T) { + testZAscii32Pair(t, "0x00000001", 1) + testZAscii32Pair(t, "0xDEADBEEF", 0xdeadbeef) +} + +func TestZDecodeZAsciiErrors(t *testing.T) { + bad := []string{ + "0x", + "0102", + "0xdeadbeef", + "0x0102 0x0304", + "0xGG", + "0x123", + "0xAABBCCDD.0x11223344", + "0xDEADBEEF ", + "0xDEADBEEF 0", + "0xDEADBEEF 0x", + } + for _, v := range bad { + if _, err := DecodeZAscii([]byte(v)); err == nil { + t.Errorf("DecodeZAscii(%v) unexpectedly succeeded", v) + } + } +} + +func TestZDecodeZAscii16Errors(t *testing.T) { + bad := []string{ + "", + "0x12", + "0x12 0x34", + "0x12345678", + "0xface", + } + for _, v := range bad { + if _, err := DecodeZAscii16([]byte(v)); err == nil { + t.Errorf("DecodeZAscii16(%v) unexpectedly succeeded", v) + } + } +} + +func TestZDecodeZAscii32Errors(t *testing.T) { + bad := []string{ + "", + "0x1234", + "0x12 0x34 0x56 0x78", + "0x123456", + "0xDEADBEEF 0x00", + } + for _, v := range bad { + if _, err := DecodeZAscii32([]byte(v)); err == nil { + t.Errorf("DecodeZAscii32(%v) unexpectedly succeeded", v) + } + } +} diff --git a/zcode.go b/zcode.go new file mode 100644 index 0000000..064bc1d --- /dev/null +++ b/zcode.go @@ -0,0 +1,108 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "errors" +) + +// Zcode is used to encode large arbitrary byte strings (namely the +// authenticator and checksum) in zephyr when zascii would be too +// wasteful. It is a simple escaping scheme to remove NUL bytes: 00 +// is replaced with FF F0, FF is replaced with FF F1. + +// ErrInvalidZcode is returned when decoding invalid zcode input. +var ErrInvalidZcode = errors.New("invalid zcode") + +// DecodeZcode decodes an input byte slice as zcode. +func DecodeZcode(in []byte) ([]byte, error) { + if len(in) == 0 { + return nil, ErrInvalidZcode + } + if in[0] != 'Z' { + return nil, ErrInvalidZcode + } + + // Compute the length. + l := len(in) - 1 + for _, v := range in { + if v == '\xff' { + l-- + } + } + + // Decode + out := make([]byte, l) + j := 0 + for i := 1; i < len(in); i++ { + if in[i] == '\x00' { + return nil, ErrInvalidZcode + } else if in[i] == '\xff' { + if i+1 >= len(in) { + return nil, ErrInvalidZcode + } + switch in[i+1] { + case '\xf0': + out[j] = '\x00' + case '\xf1': + out[j] = '\xff' + default: + return nil, ErrInvalidZcode + } + i++ + } else { + out[j] = in[i] + } + j++ + } + if j != l { + panic(j) + } + return out, nil +} + +// EncodeZcode encodes an input byte slice as zcode. +func EncodeZcode(in []byte) []byte { + // Compute the length. + l := len(in) + 1 + for _, v := range in { + if v == '\xff' || v == '\x00' { + l++ + } + } + + // Encode + out := make([]byte, l) + out[0] = 'Z' + j := 1 + for _, v := range in { + if v == '\x00' { + out[j] = '\xff' + out[j+1] = '\xf0' + j += 2 + } else if v == '\xff' { + out[j] = '\xff' + out[j+1] = '\xf1' + j += 2 + } else { + out[j] = v + j++ + } + } + if j != l { + panic(j) + } + return out +} diff --git a/zcode_test.go b/zcode_test.go new file mode 100644 index 0000000..b042021 --- /dev/null +++ b/zcode_test.go @@ -0,0 +1,64 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyr + +import ( + "bytes" + "testing" +) + +func testZcodePair(t *testing.T, encodedStr, decodedStr string) { + encoded, decoded := []byte(encodedStr), []byte(decodedStr) + + if ret := EncodeZcode(decoded); !bytes.Equal(ret, encoded) { + t.Errorf("EncodeZcode(%v) = %v, want %v", + decoded, ret, encoded) + } + + ret, err := DecodeZcode(encoded) + if err != nil { + t.Errorf("DecodeZcode(%v) failed unexpectedly: %v", + encoded, err) + } else if !bytes.Equal(ret, decoded) { + t.Errorf("DecodeZcode(%v) = %v, want %v", + encoded, ret, decoded) + } +} + +func TestZcode(t *testing.T) { + testZcodePair(t, "Z", "") + testZcodePair(t, "Zabcdef", "abcdef") + testZcodePair(t, "Z\xff\xf0", "\x00") + testZcodePair(t, "Z\xff\xf1", "\xff") + testZcodePair(t, "Z\xff\xf1abc\xff\xf0def", "\xffabc\x00def") +} + +func TestDecodeZcodeErrors(t *testing.T) { + bad := []string{ + "", + "abc", + "zabc", + "Zab\x00cd", + "Z\xff", + "Z\xff\x80", + "Z\xff\xf2", + } + for _, v := range bad { + if _, err := DecodeZcode([]byte(v)); err == nil { + t.Errorf("DecodeZcode(%v) unexpectedly succeeded", v) + } + } + +} diff --git a/zephyr.go b/zephyr.go new file mode 100644 index 0000000..49db643 --- /dev/null +++ b/zephyr.go @@ -0,0 +1,86 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package zephyr provides a Go zephyr protocol implementation. +// +// The Session API provides a high-level interface to the protocol and +// is one most users will likely use. The Dial functions create +// Sessions with default non-mock parameters and system-default +// parameters. +// +// More specialized clients can use the lower-level Connection API +// which allows for more specialized handling of ACKs and message +// reassembly while implementing server discovery and retransmit +// schedule. There are also even lower-level notice-parsing functions +// for even lower-level clients. +package zephyr + +import ( + "log" + "net" + + "github.com/zephyr-im/hesiod-go" + "github.com/zephyr-im/krb5-go" +) + +// DialSystemDefault opens a new Session using credentials from the +// default ccache and using the system-wide Hesiod config to find the +// zephyrds. +func DialSystemDefault() (*Session, error) { + ctx, err := krb5.NewContext() + if err != nil { + return nil, err + } + defer ctx.Free() + ccache, err := ctx.DefaultCCache() + if err != nil { + return nil, err + } + defer ccache.Close() + client, err := ccache.Principal() + if err != nil { + return nil, err + } + service, err := ctx.ParseName("zephyr/zephyr") + if err != nil { + return nil, err + } + cred, err := ctx.GetCredential(ccache, client, service) + if err != nil { + return nil, err + } + return Dial(hesiod.NewHesiod(), cred, nil) +} + +// Dial creates a new Session using the given Hesiod object and +// credential. +func Dial( + hesiod *hesiod.Hesiod, + cred *krb5.Credential, + logger *log.Logger, +) (*Session, error) { + // Create a server config from Hesiod. + server, err := NewServerFromHesiod(hesiod) + if err != nil { + return nil, err + } + + // Listen on a socket. + udp, err := net.ListenUDP("udp4", nil) + if err != nil { + return nil, err + } + + return NewSession(udp, server, cred, logger) +} diff --git a/zephyrtest/error.go b/zephyrtest/error.go new file mode 100644 index 0000000..9363f10 --- /dev/null +++ b/zephyrtest/error.go @@ -0,0 +1,32 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyrtest + +type temporaryError struct{} + +func (temporaryError) Error() string { + return "Temporary error" +} + +func (temporaryError) Timeout() bool { + return false +} + +func (temporaryError) Temporary() bool { + return true +} + +// TemporaryError is a mock temporary net.Error. +var TemporaryError temporaryError diff --git a/zephyrtest/mock_clock.go b/zephyrtest/mock_clock.go new file mode 100644 index 0000000..5c02e2a --- /dev/null +++ b/zephyrtest/mock_clock.go @@ -0,0 +1,106 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyrtest + +import ( + "container/heap" + "sync" + "time" +) + +type timer struct { + when time.Time + signal chan<- time.Time +} + +// Implements heap.Interface +type timerList []*timer + +func (tl timerList) Len() int { + return len(tl) +} + +func (tl timerList) Less(i, j int) bool { + return tl[i].when.Before(tl[j].when) +} + +func (tl timerList) Swap(i, j int) { + tl[i], tl[j] = tl[j], tl[i] +} + +func (tl *timerList) Push(x interface{}) { + *tl = append(*tl, x.(*timer)) +} + +func (tl *timerList) Pop() interface{} { + old := *tl + n := len(old) + timer := old[n-1] + *tl = old[0 : n-1] + return timer +} + +// A MockClock is a mocked Clock implementation for use with zephyr. +type MockClock struct { + lock sync.Mutex + now time.Time + timerList timerList +} + +// NewMockClockAt creates a new MockClock at a specific current time. +func NewMockClockAt(now time.Time) *MockClock { + mc := &MockClock{now: now} + heap.Init(&mc.timerList) + return mc +} + +// NewMockClock creates a new MockClock. +func NewMockClock() *MockClock { + return NewMockClockAt(time.Date(1990, 8, 3, 12, 0, 0, 0, time.UTC)) +} + +// Now is a mocked implementation of time.Now. +func (mc *MockClock) Now() time.Time { + mc.lock.Lock() + defer mc.lock.Unlock() + return mc.now +} + +// After is a mocked implementation of time.After. +func (mc *MockClock) After(d time.Duration) <-chan time.Time { + signal := make(chan time.Time, 1) + + mc.lock.Lock() + heap.Push(&mc.timerList, &timer{mc.now.Add(d), signal}) + mc.lock.Unlock() + + mc.Advance(0) + return signal +} + +// Advance advances the MockClock by some duration, resolving any +// channels returned by After. +func (mc *MockClock) Advance(d time.Duration) { + mc.lock.Lock() + defer mc.lock.Unlock() + // This is a little wonky. Really it'd be nice to be able to + // loop and advance + RunUntilIdle or something. Maybe put in + // a real sleep. + mc.now = mc.now.Add(d) + for len(mc.timerList) > 0 && !mc.timerList[0].when.After(mc.now) { + mc.timerList[0].signal <- mc.timerList[0].when + heap.Pop(&mc.timerList) + } +} diff --git a/zephyrtest/mock_packet_conn.go b/zephyrtest/mock_packet_conn.go new file mode 100644 index 0000000..5ef3206 --- /dev/null +++ b/zephyrtest/mock_packet_conn.go @@ -0,0 +1,170 @@ +// Copyright 2014 The zephyr-go authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package zephyrtest + +import ( + "io" + "net" + "time" +) + +// A PacketRead is the result of a ReadFrom call on a MockPacketConn. +type PacketRead struct { + Packet []byte + Addr net.Addr + Err error +} + +// A PacketWrite is WriteTo call on a MockPacketConn. The call blocks +// until an error or nil is written into Result. +type PacketWrite struct { + Packet []byte + Addr net.Addr + Result chan error +} + +// A MockPacketConn is a mocked PacketConn implementation for testing +// purposes. +type MockPacketConn struct { + localAddr net.Addr + reads <-chan PacketRead + writes chan PacketWrite + closed chan int +} + +// NewMockPacketConn creates a new mock packet connection for a given +// local address and read channel. +func NewMockPacketConn(localAddr net.Addr, reads <-chan PacketRead) *MockPacketConn { + c := new(MockPacketConn) + c.localAddr = localAddr + c.reads = reads + c.writes = make(chan PacketWrite) + c.closed = make(chan int, 1) + return c +} + +// Writes returns a channel of PacketWrites for each write. +func (c *MockPacketConn) Writes() <-chan PacketWrite { + return c.writes +} + +// ReadFrom consumes the next read in the mock connection's reads +// channel. +func (c *MockPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { + select { + case read, ok := <-c.reads: + if !ok { + return 0, nil, io.EOF + } + if read.Err != nil { + return 0, nil, read.Err + } + n := copy(b, read.Packet) + return n, read.Addr, nil + case <-c.closed: + c.closed <- 0 + return 0, nil, io.EOF + } +} + +// WriteTo sends a PacketWrite to the mock connection's writes +// channel. +func (c *MockPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { + bcopy := make([]byte, len(b)) + copy(bcopy, b) + ret := make(chan error) + c.writes <- PacketWrite{bcopy, addr, ret} + err := <-ret + if err != nil { + return 0, err + } + return len(b), nil +} + +// Close closes this mock connection, interrupting any ReadFrom in +// progress. +func (c *MockPacketConn) Close() error { + c.closed <- 0 + close(c.writes) + return nil +} + +// LocalAddr returns the mocked local address of this connection. +func (c *MockPacketConn) LocalAddr() net.Addr { + return c.localAddr +} + +// SetDeadline is not implemented. +func (c *MockPacketConn) SetDeadline(t time.Time) error { + panic("Not implemented") +} + +// SetReadDeadline is not implemented. +func (c *MockPacketConn) SetReadDeadline(t time.Time) error { + panic("Not implemented") +} + +// SetWriteDeadline is not implemented. +func (c *MockPacketConn) SetWriteDeadline(t time.Time) error { + panic("Not implemented") +} + +func indexOfAddr(haystack []net.Addr, needle net.Addr) int { + for i, straw := range haystack { + // Bah. + if straw.String() == needle.String() { + return i + } + } + return -1 +} + +// NewMockPacketNetwork returns a list of mock PacketConn +// implementations, one for each input address. A WriteTo call in one +// PacketConn will be routed to the appropriate destination PacketConn +// and read back. +func NewMockPacketNetwork(addrs []net.Addr) []net.PacketConn { + sinks := make([]chan<- PacketRead, len(addrs)) + conns := make([]*MockPacketConn, len(addrs)) + ret := make([]net.PacketConn, len(addrs)) + + // Create connections and link them all together. + for i := range addrs { + readChan := make(chan PacketRead) + conn := NewMockPacketConn(addrs[i], readChan) + sinks[i] = readChan + conns[i] = conn + ret[i] = conn + + go func(i int) { + for write := range conns[i].Writes() { + // Writes in a mock network succeed. + write.Result <- nil + + // Direct the write to the appropriate node. + dst := indexOfAddr(addrs, write.Addr) + if dst < 0 { + // Don't allow this. + panic(write) + } + pkt := make([]byte, len(write.Packet)) + copy(pkt, write.Packet) + sinks[dst] <- PacketRead{write.Packet, addrs[i], nil} + } + }(i) + } + + return ret +}