From c4e3a29c56fa4c920b1c526d1406b0ea73294eb9 Mon Sep 17 00:00:00 2001
From: Luke Wagner <mail@lukewagner.name>
Date: Fri, 24 Jan 2025 14:41:04 -0600
Subject: [PATCH] Allow stream.{read,write}s of length 0 to query/signal
 readiness

---
 design/mvp/Async.md                     |  6 ++-
 design/mvp/CanonicalABI.md              | 65 ++++++++++++++++---------
 design/mvp/canonical-abi/definitions.py | 39 +++++++++------
 design/mvp/canonical-abi/run_tests.py   | 27 ++++++++--
 4 files changed, 95 insertions(+), 42 deletions(-)

diff --git a/design/mvp/Async.md b/design/mvp/Async.md
index 2e98fce1..7686dd75 100644
--- a/design/mvp/Async.md
+++ b/design/mvp/Async.md
@@ -356,7 +356,11 @@ These built-ins can either return immediately if >0 elements were able to be
 written or read immediately (without blocking) or return a sentinel "blocked"
 value indicating that the read or write will execute concurrently. The readable
 and writable ends of streams and futures can then be [waited](#waiting) on to
-make progress.
+make progress. Notification of progress signals *completion* of a read or write
+(i.e., the bytes have already been copied into the buffer). Additionally,
+*readiness* (to perform a read or write in the future) can be queried and
+signalled by performing a `0`-length read or write (see the [Stream State]
+section in the Canonical ABI explainer for details).
 
 As a temporary limitation, if a `read` and `write` for a single stream or
 future occur from within the same component, there is a trap. In the future
diff --git a/design/mvp/CanonicalABI.md b/design/mvp/CanonicalABI.md
index 1c918589..ae1916e1 100644
--- a/design/mvp/CanonicalABI.md
+++ b/design/mvp/CanonicalABI.md
@@ -415,8 +415,8 @@ class BufferGuestImpl(Buffer):
   length: int
 
   def __init__(self, t, cx, ptr, length):
-    trap_if(length == 0 or length > Buffer.MAX_LENGTH)
-    if t:
+    trap_if(length > Buffer.MAX_LENGTH)
+    if t and length > 0:
       trap_if(ptr != align_to(ptr, alignment(t)))
       trap_if(ptr + length * elem_size(t) > len(cx.opts.memory))
     self.cx = cx
@@ -1299,10 +1299,13 @@ class ReadableStreamGuestImpl(ReadableStream):
     self.reset_pending()
 
   def reset_pending(self):
-    self.pending_inst = None
-    self.pending_buffer = None
-    self.pending_on_partial_copy = None
-    self.pending_on_copy_done = None
+    self.set_pending(None, None, None, None)
+
+  def set_pending(self, inst, buffer, on_partial_copy, on_copy_done):
+    self.pending_inst = inst
+    self.pending_buffer = buffer
+    self.pending_on_partial_copy = on_partial_copy
+    self.pending_on_copy_done = on_copy_done
 ```
 If set, the `pending_*` fields record the `Buffer` and `On*` callbacks of a
 `read` or `write` that is waiting to rendezvous with a complementary `write` or
@@ -1356,27 +1359,45 @@ but in the opposite direction. Both are implemented by a single underlying
     if self.closed_:
       return 'done'
     elif not self.pending_buffer:
-      self.pending_inst = inst
-      self.pending_buffer = buffer
-      self.pending_on_partial_copy = on_partial_copy
-      self.pending_on_copy_done = on_copy_done
+      self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
       return 'blocked'
     else:
       trap_if(inst is self.pending_inst) # temporary
-      ncopy = min(src.remain(), dst.remain())
-      assert(ncopy > 0)
-      dst.write(src.read(ncopy))
       if self.pending_buffer.remain() > 0:
-        self.pending_on_partial_copy(self.reset_pending)
+        if buffer.remain() > 0:
+          dst.write(src.read(min(src.remain(), dst.remain())))
+          if self.pending_buffer.remain() > 0:
+            self.pending_on_partial_copy(self.reset_pending)
+          else:
+            self.reset_and_notify_pending('completed')
+        return 'done'
       else:
-        self.reset_and_notify_pending('completed')
-      return 'done'
-```
-Currently, there is a trap when both the `read` and `write` come from the same
-component instance, but this trapping condition will be removed in a subsequent
-release. The reason for this trap is that when lifting and lowering can alias
-the same memory, interleaving must be handled carefully. Future improvements to
-the Canonical ABI ([lazy lowering]) can greatly simplify this interleaving.
+        if buffer.remain() > 0 or buffer is dst:
+          self.reset_and_notify_pending('completed')
+          self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
+          return 'blocked'
+        else:
+          return 'done'
+```
+The meaning of a `read` or `write` when the length is `0` is that the caller is
+querying the "readiness" of the other side. When a `0`-length read/write
+rendezvous with a non-`0`-length read/write, only the `0`-length read/write
+completes; the non-`0`-length read/write is kept pending (and ready for a
+subsequent rendezvous).
+
+In the corner case where a `0`-length read *and* write rendezvous, only the
+*writer* is notified of readiness. To avoid livelock, the Canonical ABI
+requires that a writer *must* (eventually) follow a completed `0`-length write
+with a non-`0`-length write that is allowed to block (allowing the reader end
+to run and rendezvous with its own non-`0`-length read). To implement a
+traditional `O_NONBLOCK` `write()` or `sendmsg()` API, a writer can use a
+buffering scheme in which, after `select()` (or a similar API) signals a file
+descriptor is ready to write, the next `O_NONBLOCK` `write()`/`sendmsg()` on
+that file descriptor copies to an internal buffer and suceeds, issuing an
+`async` `stream.write` in the background and waiting for completion before
+signalling readiness again. Note that buffering only occurs when streaming
+between two components using non-blocking I/O; if either side is the host or a
+component using blocking or completion-based I/O, no buffering is necessary.
 
 Given the above, we can define the `{Readable,Writable}StreamEnd` classes that
 are actually stored in the `waitables` table. The classes are almost entirely
diff --git a/design/mvp/canonical-abi/definitions.py b/design/mvp/canonical-abi/definitions.py
index 6f015aa1..680aaf99 100644
--- a/design/mvp/canonical-abi/definitions.py
+++ b/design/mvp/canonical-abi/definitions.py
@@ -327,8 +327,8 @@ class BufferGuestImpl(Buffer):
   length: int
 
   def __init__(self, t, cx, ptr, length):
-    trap_if(length == 0 or length > Buffer.MAX_LENGTH)
-    if t:
+    trap_if(length > Buffer.MAX_LENGTH)
+    if t and length > 0:
       trap_if(ptr != align_to(ptr, alignment(t)))
       trap_if(ptr + length * elem_size(t) > len(cx.opts.memory))
     self.cx = cx
@@ -772,10 +772,13 @@ def __init__(self, t):
     self.reset_pending()
 
   def reset_pending(self):
-    self.pending_inst = None
-    self.pending_buffer = None
-    self.pending_on_partial_copy = None
-    self.pending_on_copy_done = None
+    self.set_pending(None, None, None, None)
+
+  def set_pending(self, inst, buffer, on_partial_copy, on_copy_done):
+    self.pending_inst = inst
+    self.pending_buffer = buffer
+    self.pending_on_partial_copy = on_partial_copy
+    self.pending_on_copy_done = on_copy_done
 
   def reset_and_notify_pending(self, why):
     pending_on_copy_done = self.pending_on_copy_done
@@ -804,21 +807,25 @@ def copy(self, inst, buffer, on_partial_copy, on_copy_done, src, dst):
     if self.closed_:
       return 'done'
     elif not self.pending_buffer:
-      self.pending_inst = inst
-      self.pending_buffer = buffer
-      self.pending_on_partial_copy = on_partial_copy
-      self.pending_on_copy_done = on_copy_done
+      self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
       return 'blocked'
     else:
       trap_if(inst is self.pending_inst) # temporary
-      ncopy = min(src.remain(), dst.remain())
-      assert(ncopy > 0)
-      dst.write(src.read(ncopy))
       if self.pending_buffer.remain() > 0:
-        self.pending_on_partial_copy(self.reset_pending)
+        if buffer.remain() > 0:
+          dst.write(src.read(min(src.remain(), dst.remain())))
+          if self.pending_buffer.remain() > 0:
+            self.pending_on_partial_copy(self.reset_pending)
+          else:
+            self.reset_and_notify_pending('completed')
+        return 'done'
       else:
-        self.reset_and_notify_pending('completed')
-      return 'done'
+        if buffer.remain() > 0 or buffer is dst:
+          self.reset_and_notify_pending('completed')
+          self.set_pending(inst, buffer, on_partial_copy, on_copy_done)
+          return 'blocked'
+        else:
+          return 'done'
 
 class StreamEnd(Waitable):
   stream: ReadableStream
diff --git a/design/mvp/canonical-abi/run_tests.py b/design/mvp/canonical-abi/run_tests.py
index 25a63df3..6a6ae08c 100644
--- a/design/mvp/canonical-abi/run_tests.py
+++ b/design/mvp/canonical-abi/run_tests.py
@@ -1503,8 +1503,19 @@ async def core_func1(task, args):
     result,n = unpack_result(mem1[retp+4])
     assert(n == 4 and result == definitions.COMPLETED)
 
+    [ret] = await canon_stream_write(U8Type(), opts1, task, wsi, 12345, 0)
+    assert(ret == definitions.BLOCKED)
+
     fut4.set_result(None)
 
+    [event] = await canon_waitable_set_wait(False, mem1, task, seti, retp)
+    assert(event == EventCode.STREAM_WRITE)
+    assert(mem1[retp+0] == wsi)
+    assert(mem1[retp+4] == 0)
+
+    [ret] = await canon_stream_write(U8Type(), opts1, task, wsi, 12345, 0)
+    assert(ret == 0)
+
     [errctxi] = await canon_error_context_new(opts1, task, 0, 0)
     [] = await canon_stream_close_writable(U8Type(), task, wsi)
     [] = await canon_waitable_set_drop(task, seti)
@@ -1545,6 +1556,9 @@ async def core_func2(task, args):
     fut2.set_result(None)
     await task.on_block(fut3)
 
+    [ret] = await canon_stream_read(U8Type(), opts2, task, rsi, 12345, 0)
+    assert(ret == 0)
+
     mem2[0:8] = bytes(8)
     [ret] = await canon_stream_read(U8Type(), opts2, task, rsi, 0, 2)
     result,n = unpack_result(ret)
@@ -1557,9 +1571,16 @@ async def core_func2(task, args):
 
     await task.on_block(fut4)
 
-    [ret] = await canon_stream_read(U8Type(), opts2, task, rsi, 0, 2)
-    result,n = unpack_result(ret)
-    assert(n == 0 and result == definitions.CLOSED)
+    [ret] = await canon_stream_read(U8Type(), opts2, task, rsi, 12345, 0)
+    assert(ret == definitions.BLOCKED)
+
+    [event] = await canon_waitable_set_wait(False, mem2, task, seti, retp)
+    assert(event == EventCode.STREAM_READ)
+    assert(mem2[retp+0] == rsi)
+    p2 = int.from_bytes(mem2[retp+4 : retp+8], 'little', signed=False)
+    errctxi = 1
+    assert(p2 == (definitions.CLOSED | errctxi))
+
     [] = await canon_stream_close_readable(U8Type(), task, rsi)
     [] = await canon_waitable_set_drop(task, seti)
     return []