diff --git a/ext/openssl/ossl.c b/ext/openssl/ossl.c
index 60780790b..d148447dd 100644
--- a/ext/openssl/ossl.c
+++ b/ext/openssl/ossl.c
@@ -1033,6 +1033,7 @@ Init_openssl(void)
* Init components
*/
Init_ossl_asn1();
+ Init_ossl_bio();
Init_ossl_bn();
Init_ossl_cipher();
Init_ossl_config();
diff --git a/ext/openssl/ossl_bio.c b/ext/openssl/ossl_bio.c
index 2ef208050..aad7a4908 100644
--- a/ext/openssl/ossl_bio.c
+++ b/ext/openssl/ossl_bio.c
@@ -40,3 +40,324 @@ ossl_membio2str(BIO *bio)
return ret;
}
+
+static BIO_METHOD *ossl_bio_meth;
+static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable;
+
+struct ossl_bio_ctx {
+ VALUE io;
+ int state;
+ int eof;
+};
+
+static void
+bio_free(void *ptr)
+{
+ BIO_free(ptr);
+}
+
+static void
+bio_mark(void *ptr)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(ptr);
+ rb_gc_mark_movable(ctx->io);
+}
+
+static void
+bio_compact(void *ptr)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(ptr);
+ ctx->io = rb_gc_location(ctx->io);
+}
+
+static const rb_data_type_t ossl_bio_type = {
+ "OpenSSL/BIO",
+ {
+ .dmark = bio_mark,
+ .dfree = bio_free,
+ .dcompact = bio_compact,
+ },
+ 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
+};
+
+VALUE
+ossl_bio_new(VALUE io)
+{
+ VALUE obj = TypedData_Wrap_Struct(rb_cObject, &ossl_bio_type, NULL);
+ BIO *bio = BIO_new(ossl_bio_meth);
+ if (!bio)
+ ossl_raise(eOSSLError, "BIO_new");
+
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ ctx->io = io;
+ BIO_set_init(bio, 1);
+ RTYPEDDATA_DATA(obj) = bio;
+ return obj;
+}
+
+BIO *
+ossl_bio_get(VALUE obj)
+{
+ BIO *bio;
+ TypedData_Get_Struct(obj, BIO, &ossl_bio_type, bio);
+ return bio;
+}
+
+int
+ossl_bio_state(VALUE obj)
+{
+ BIO *bio;
+ TypedData_Get_Struct(obj, BIO, &ossl_bio_type, bio);
+
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ int state = ctx->state;
+ ctx->state = 0;
+ return state;
+}
+
+static int
+bio_create(BIO *bio)
+{
+ struct ossl_bio_ctx *ctx = OPENSSL_malloc(sizeof(*ctx));
+ if (!ctx)
+ return 0;
+ memset(ctx, 0, sizeof(*ctx));
+ BIO_set_data(bio, ctx);
+
+ return 1;
+}
+
+static int
+bio_destroy(BIO *bio)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ if (ctx) {
+ OPENSSL_free(ctx);
+ BIO_set_data(bio, NULL);
+ }
+
+ return 1;
+}
+
+struct bwrite_args {
+ BIO *bio;
+ struct ossl_bio_ctx *ctx;
+ const char *data;
+ int dlen;
+ int written;
+};
+
+static VALUE
+bio_bwrite0(VALUE args)
+{
+ struct bwrite_args *p = (void *)args;
+ BIO_clear_retry_flags(p->bio);
+
+ VALUE fargs[] = { rb_str_new_static(p->data, p->dlen), nonblock_kwargs };
+ VALUE ret = rb_funcallv_kw(p->ctx->io, rb_intern("write_nonblock"),
+ 2, fargs, RB_PASS_KEYWORDS);
+
+ if (RB_INTEGER_TYPE_P(ret)) {
+ p->written = NUM2INT(ret);
+ return Qtrue;
+ }
+ else if (ret == sym_wait_readable) {
+ BIO_set_retry_read(p->bio);
+ return Qfalse;
+ }
+ else if (ret == sym_wait_writable) {
+ BIO_set_retry_write(p->bio);
+ return Qfalse;
+ }
+ else {
+ rb_raise(rb_eTypeError, "write_nonblock must return an Integer, "
+ ":wait_readable, or :wait_writable");
+ }
+}
+
+struct call0_args {
+ VALUE (*func)(VALUE);
+ VALUE args;
+ VALUE ret;
+};
+
+static VALUE
+do_nothing(VALUE _)
+{
+ return Qnil;
+}
+
+static VALUE
+call_protect1(VALUE args_)
+{
+ struct call0_args *args = (void *)args_;
+ rb_set_errinfo(Qnil);
+ args->ret = args->func(args->args);
+ return Qnil;
+}
+
+static VALUE
+call_protect0(VALUE args_)
+{
+ /*
+ * At this point rb_errinfo() may be set by another callback called from
+ * the same OpenSSL function (e.g., SSL_accept()).
+ *
+ * Abusing rb_ensure() to temporarily save errinfo and restore it after
+ * the BIO callback successfully returns.
+ */
+ rb_ensure(do_nothing, Qnil, call_protect1, args_);
+ return Qnil;
+}
+
+static VALUE
+call_protect(VALUE (*func)(VALUE), VALUE args, int *state)
+{
+ /*
+ * FIXME: should check !NIL_P(rb_ivar_get(ssl_obj, ID_callback_state))
+ * instead to see if a tag jump is pending or not.
+ */
+ int pending = !NIL_P(rb_errinfo());
+ struct call0_args call0_args = { func, args, Qfalse };
+ rb_protect(call_protect0, (VALUE)&call0_args, state);
+ if (pending && *state)
+ rb_warn("exception ignored in BIO callback: pending=%d", pending);
+ return call0_args.ret;
+}
+
+static int
+bio_bwrite(BIO *bio, const char *data, int dlen)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ struct bwrite_args args = { bio, ctx, data, dlen, 0 };
+ int state;
+
+ if (ctx->state)
+ return -1;
+
+ VALUE ok = call_protect(bio_bwrite0, (VALUE)&args, &state);
+ if (state) {
+ ctx->state = state;
+ return -1;
+ }
+ if (RTEST(ok))
+ return args.written;
+ return -1;
+}
+
+struct bread_args {
+ BIO *bio;
+ struct ossl_bio_ctx *ctx;
+ char *data;
+ int dlen;
+ int readbytes;
+};
+
+static VALUE
+bio_bread0(VALUE args)
+{
+ struct bread_args *p = (void *)args;
+ BIO_clear_retry_flags(p->bio);
+
+ VALUE fargs[] = { INT2NUM(p->dlen), nonblock_kwargs };
+ VALUE ret = rb_funcallv_kw(p->ctx->io, rb_intern("read_nonblock"),
+ 2, fargs, RB_PASS_KEYWORDS);
+
+ if (RB_TYPE_P(ret, T_STRING)) {
+ int len = RSTRING_LENINT(ret);
+ if (len > p->dlen)
+ rb_raise(rb_eTypeError, "read_nonblock returned too much data");
+ memcpy(p->data, RSTRING_PTR(ret), len);
+ p->readbytes = len;
+ return Qtrue;
+ }
+ else if (NIL_P(ret)) {
+ // In OpenSSL 3.0 or later: BIO_set_flags(p->bio, BIO_FLAGS_IN_EOF);
+ p->ctx->eof = 1;
+ return Qtrue;
+ }
+ else if (ret == sym_wait_readable) {
+ BIO_set_retry_read(p->bio);
+ return Qfalse;
+ }
+ else if (ret == sym_wait_writable) {
+ BIO_set_retry_write(p->bio);
+ return Qfalse;
+ }
+ else {
+ rb_raise(rb_eTypeError, "write_nonblock must return an Integer, "
+ ":wait_readable, or :wait_writable");
+ }
+}
+
+static int
+bio_bread(BIO *bio, char *data, int dlen)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ struct bread_args args = { bio, ctx, data, dlen, 0 };
+ int state;
+
+ if (ctx->state)
+ return -1;
+
+ VALUE ok = call_protect(bio_bread0, (VALUE)&args, &state);
+ if (state) {
+ ctx->state = state;
+ return -1;
+ }
+ if (RTEST(ok))
+ return args.readbytes;
+ return -1;
+}
+
+static VALUE
+bio_flush0(VALUE vctx)
+{
+ struct ossl_bio_ctx *ctx = (void *)vctx;
+ return rb_funcallv(ctx->io, rb_intern("flush"), 0, NULL);
+}
+
+static long
+bio_ctrl(BIO *bio, int cmd, long larg, void *parg)
+{
+ struct ossl_bio_ctx *ctx = BIO_get_data(bio);
+ int state;
+
+ if (ctx->state)
+ return 0;
+
+ switch (cmd) {
+ case BIO_CTRL_EOF:
+ return ctx->eof;
+ case BIO_CTRL_FLUSH:
+ call_protect(bio_flush0, (VALUE)ctx, &state);
+ ctx->state = state;
+ return !state;
+ default:
+ return 0;
+ }
+}
+
+void
+Init_ossl_bio(void)
+{
+ ossl_bio_meth = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "Ruby IO-like object");
+ if (!ossl_bio_meth)
+ ossl_raise(eOSSLError, "BIO_meth_new");
+ if (!BIO_meth_set_create(ossl_bio_meth, bio_create) ||
+ !BIO_meth_set_destroy(ossl_bio_meth, bio_destroy) ||
+ !BIO_meth_set_write(ossl_bio_meth, bio_bwrite) ||
+ !BIO_meth_set_read(ossl_bio_meth, bio_bread) ||
+ !BIO_meth_set_ctrl(ossl_bio_meth, bio_ctrl)) {
+ BIO_meth_free(ossl_bio_meth);
+ ossl_bio_meth = NULL;
+ ossl_raise(eOSSLError, "BIO_meth_set_*");
+ }
+
+ nonblock_kwargs = rb_hash_new();
+ rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern_const("exception")), Qfalse);
+ rb_global_variable(&nonblock_kwargs);
+
+ sym_wait_readable = ID2SYM(rb_intern_const("wait_readable"));
+ sym_wait_writable = ID2SYM(rb_intern_const("wait_writable"));
+}
diff --git a/ext/openssl/ossl_bio.h b/ext/openssl/ossl_bio.h
index 1b871f1cd..634f99fae 100644
--- a/ext/openssl/ossl_bio.h
+++ b/ext/openssl/ossl_bio.h
@@ -13,4 +13,10 @@
BIO *ossl_obj2bio(volatile VALUE *);
VALUE ossl_membio2str(BIO*);
+VALUE ossl_bio_new(VALUE io);
+BIO *ossl_bio_get(VALUE obj);
+int ossl_bio_state(VALUE obj);
+
+void Init_ossl_bio(void);
+
#endif
diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c
index a5b25e14d..3a3fceef1 100644
--- a/ext/openssl/ossl_ssl.c
+++ b/ext/openssl/ossl_ssl.c
@@ -37,7 +37,7 @@ static VALUE eSSLErrorWaitReadable;
static VALUE eSSLErrorWaitWritable;
static ID id_call, ID_callback_state, id_tmp_dh_callback,
- id_npn_protocols_encoded, id_each;
+ id_npn_protocols_encoded, id_each, id_bio;
static VALUE sym_exception, sym_wait_readable, sym_wait_writable;
static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode,
@@ -1536,7 +1536,10 @@ ossl_ssl_s_alloc(VALUE klass)
static VALUE
peer_ip_address(VALUE self)
{
- VALUE remote_address = rb_funcall(rb_attr_get(self, id_i_io), rb_intern("remote_address"), 0);
+ VALUE io = rb_attr_get(self, id_i_io);
+ VALUE remote_address = rb_check_funcall(io, rb_intern("remote_address"), 0, NULL);
+ if (remote_address == Qundef)
+ return rb_str_new_cstr("(unsupported)");
return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0);
}
@@ -1550,10 +1553,16 @@ fallback_peer_ip_address(VALUE self, VALUE args)
static VALUE
peeraddr_ip_str(VALUE self)
{
- VALUE rb_mErrno = rb_const_get(rb_cObject, rb_intern("Errno"));
- VALUE rb_eSystemCallError = rb_const_get(rb_mErrno, rb_intern("SystemCallError"));
+ return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, Qnil,
+ rb_eSystemCallError, (VALUE)0);
+}
- return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL);
+static int
+is_real_socket(VALUE io)
+{
+ // FIXME: DO NOT MERGE
+ return 0;
+ return RB_TYPE_P(io, T_FILE);
}
/*
@@ -1561,8 +1570,10 @@ peeraddr_ip_str(VALUE self)
* SSLSocket.new(io) => aSSLSocket
* SSLSocket.new(io, ctx) => aSSLSocket
*
- * Creates a new SSL socket from _io_ which must be a real IO object (not an
- * IO-like object that responds to read/write).
+ * Creates a new SSL socket from the underlying socket _io_ and _ctx_.
+ *
+ * _io_ must be an IO object, typically a TCPSocket or Socket from the socket
+ * library, or an IO-like object that supports the typical IO methods.
*
* If _ctx_ is provided the SSL Sockets initial params will be taken from
* the context.
@@ -1571,6 +1582,22 @@ peeraddr_ip_str(VALUE self)
*
* This method will freeze the SSLContext if one is provided;
* however, session management is still allowed in the frozen SSLContext.
+ *
+ * == Support for IO-like objects
+ *
+ * Support for IO-like objects was added in version 3.3 and is experimental.
+ *
+ * As of version 3.3, SSLSocket uses the following methods:
+ *
+ * - write_nonblock with the exception: false option
+ * - read_nonblock with the exception: false option
+ * - wait_readable
+ * - wait_writable
+ * - flush
+ * - close
+ * - closed?
+ *
+ * Note that future versions may require additional methods to be implemented.
*/
static VALUE
ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
@@ -1590,9 +1617,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self)
rb_ivar_set(self, id_i_context, v_ctx);
ossl_sslctx_setup(v_ctx);
- if (rb_respond_to(io, rb_intern("nonblock=")))
- rb_funcall(io, rb_intern("nonblock="), 1, Qtrue);
- Check_Type(io, T_FILE);
+ if (is_real_socket(io)) {
+ rb_io_t *fptr;
+ GetOpenFile(io, fptr);
+ rb_io_set_nonblock(fptr);
+ }
+ else {
+ // Not meant to be a comprehensive check
+ if (!rb_respond_to(io, rb_intern("read_nonblock")) ||
+ !rb_respond_to(io, rb_intern("write_nonblock")))
+ rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like "
+ "object that responds to read_nonblock and write_nonblock");
+ }
rb_ivar_set(self, id_i_io, io);
ssl = SSL_new(ctx);
@@ -1624,27 +1660,57 @@ ossl_ssl_setup(VALUE self)
{
VALUE io;
SSL *ssl;
- rb_io_t *fptr;
GetSSL(self, ssl);
if (ssl_started(ssl))
return Qtrue;
io = rb_attr_get(self, id_i_io);
- GetOpenFile(io, fptr);
- rb_io_check_readable(fptr);
- rb_io_check_writable(fptr);
- if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
- ossl_raise(eSSLError, "SSL_set_fd");
+ if (is_real_socket(io)) {
+ rb_io_t *fptr;
+ GetOpenFile(io, fptr);
+ rb_io_check_readable(fptr);
+ rb_io_check_writable(fptr);
+ if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io))))
+ ossl_raise(eSSLError, "SSL_set_fd");
+ }
+ else {
+ VALUE bobj = ossl_bio_new(io);
+ rb_ivar_set(self, id_bio, bobj);
+
+ BIO *bio = ossl_bio_get(bobj);
+ if (!BIO_up_ref(bio))
+ ossl_raise(eSSLError, "BIO_up_ref");
+ SSL_set_bio(ssl, bio, bio);
+ }
return Qtrue;
}
+static void
+check_bio_error(VALUE self, SSL *ssl, VALUE bobj, int ret)
+{
+ if (NIL_P(bobj)) {
#ifdef _WIN32
-#define ssl_get_error(ssl, ret) (errno = rb_w32_map_errno(WSAGetLastError()), SSL_get_error((ssl), (ret)))
-#else
-#define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret))
+ errno = rb_w32_map_errno(WSAGetLastError());
#endif
+ }
+ else {
+ int state = ossl_bio_state(bobj);
+ if (state) {
+ ossl_clear_error();
+ rb_jump_tag(state);
+ }
+ errno = 0;
+ }
+
+ VALUE cb_state = rb_attr_get(self, ID_callback_state);
+ if (!NIL_P(cb_state)) {
+ /* must cleanup OpenSSL error stack before re-raising */
+ ossl_clear_error();
+ rb_jump_tag(NUM2INT(cb_state));
+ }
+}
static void
write_would_block(int nonblock)
@@ -1684,6 +1750,11 @@ no_exception_p(VALUE opts)
static void
io_wait_writable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!");
@@ -1698,6 +1769,11 @@ io_wait_writable(VALUE io)
static void
io_wait_readable(VALUE io)
{
+ if (!is_real_socket(io)) {
+ if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL)))
+ rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
+ return;
+ }
#ifdef HAVE_RB_IO_MAYBE_WAIT
if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) {
rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!");
@@ -1714,28 +1790,22 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts)
{
SSL *ssl;
int ret, ret2;
- VALUE cb_state;
int nonblock = opts != Qfalse;
- rb_ivar_set(self, ID_callback_state, Qnil);
-
GetSSL(self, ssl);
- VALUE io = rb_attr_get(self, id_i_io);
+ VALUE io = rb_attr_get(self, id_i_io),
+ bobj = rb_attr_get(self, id_bio);
+
+ rb_ivar_set(self, ID_callback_state, Qnil);
for (;;) {
ret = func(ssl);
-
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
- /* must cleanup OpenSSL error stack before re-raising */
- ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
- }
+ check_bio_error(self, ssl, bobj, ret);
if (ret > 0)
break;
- switch ((ret2 = ssl_get_error(ssl, ret))) {
+ switch ((ret2 = SSL_get_error(ssl, ret))) {
case SSL_ERROR_WANT_WRITE:
if (no_exception_p(opts)) { return sym_wait_writable; }
write_would_block(nonblock);
@@ -1885,7 +1955,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
{
SSL *ssl;
int ilen;
- VALUE len, str, cb_state;
+ VALUE len, str;
VALUE opts = Qnil;
if (nonblock) {
@@ -1913,21 +1983,17 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock)
return str;
}
- VALUE io = rb_attr_get(self, id_i_io);
+ VALUE io = rb_attr_get(self, id_i_io),
+ bobj = rb_attr_get(self, id_bio);
+ rb_ivar_set(self, ID_callback_state, Qnil);
for (;;) {
rb_str_locktmp(str);
int nread = SSL_read(ssl, RSTRING_PTR(str), ilen);
rb_str_unlocktmp(str);
+ check_bio_error(self, ssl, bobj, nread);
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
- rb_ivar_set(self, ID_callback_state, Qnil);
- ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
- }
-
- switch (ssl_get_error(ssl, nread)) {
+ switch (SSL_get_error(ssl, nread)) {
case SSL_ERROR_NONE:
rb_str_set_len(str, nread);
return str;
@@ -2018,33 +2084,26 @@ ossl_ssl_write_internal_safe(VALUE _args)
VALUE opts = args[2];
SSL *ssl;
- rb_io_t *fptr;
int num, nonblock = opts != Qfalse;
- VALUE cb_state;
GetSSL(self, ssl);
if (!ssl_started(ssl))
rb_raise(eSSLError, "SSL session is not started yet");
- VALUE io = rb_attr_get(self, id_i_io);
- GetOpenFile(io, fptr);
-
/* SSL_write(3ssl) manpage states num == 0 is undefined */
num = RSTRING_LENINT(str);
if (num == 0)
return INT2FIX(0);
+ VALUE io = rb_attr_get(self, id_i_io),
+ bobj = rb_attr_get(self, id_bio);
+
+ rb_ivar_set(self, ID_callback_state, Qnil);
for (;;) {
int nwritten = SSL_write(ssl, RSTRING_PTR(str), num);
+ check_bio_error(self, ssl, bobj, nwritten);
- cb_state = rb_attr_get(self, ID_callback_state);
- if (!NIL_P(cb_state)) {
- rb_ivar_set(self, ID_callback_state, Qnil);
- ossl_clear_error();
- rb_jump_tag(NUM2INT(cb_state));
- }
-
- switch (ssl_get_error(ssl, nwritten)) {
+ switch (SSL_get_error(ssl, nwritten)) {
case SSL_ERROR_NONE:
return INT2NUM(nwritten);
case SSL_ERROR_WANT_WRITE:
@@ -2143,7 +2202,14 @@ ossl_ssl_stop(VALUE self)
GetSSL(self, ssl);
if (!ssl_started(ssl))
return Qnil;
+
ret = SSL_shutdown(ssl);
+
+ /* XXX: Suppressing errors from the underlying socket */
+ VALUE bobj = rb_attr_get(self, id_bio);
+ if (!NIL_P(bobj) && ossl_bio_state(bobj))
+ rb_set_errinfo(Qnil);
+
if (ret == 1) /* Have already received close_notify */
return Qnil;
if (ret == 0) /* Sent close_notify, but we don't wait for reply */
@@ -3121,6 +3187,7 @@ Init_ossl_ssl(void)
id_tmp_dh_callback = rb_intern_const("tmp_dh_callback");
id_npn_protocols_encoded = rb_intern_const("npn_protocols_encoded");
id_each = rb_intern_const("each");
+ id_bio = rb_intern_const("bio");
#define DefIVarID(name) do \
id_i_##name = rb_intern_const("@"#name); while (0)
diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb
index 1464a4292..aa39105f9 100644
--- a/lib/openssl/buffering.rb
+++ b/lib/openssl/buffering.rb
@@ -60,7 +60,7 @@ def initialize(*)
super
@eof = false
@rbuffer = Buffer.new
- @sync = @io.sync
+ @sync = @io.respond_to?(:sync) ? @io.sync : true
end
#
diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb
index 10942191d..1664c00c8 100644
--- a/test/openssl/test_pair.rb
+++ b/test/openssl/test_pair.rb
@@ -67,6 +67,32 @@ def create_tcp_client(host, port)
end
end
+module OpenSSL::SSLPairIOish
+ include OpenSSL::SSLPairM
+
+ def create_tcp_server(host, port)
+ Addrinfo.tcp(host, port).listen
+ end
+
+ class TCPSocketWrapper
+ def initialize(io) @io = io end
+ def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end
+ def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end
+ def wait_readable() @io.wait_readable end
+ def wait_writable() @io.wait_writable end
+ def flush() @io.flush end
+ def close() @io.close end
+ def closed?() @io.closed? end
+
+ # Only used within test_pair.rb
+ def write(*args) @io.write(*args) end
+ end
+
+ def create_tcp_client(host, port)
+ TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect)
+ end
+end
+
module OpenSSL::TestEOF1M
def open_file(content)
ssl_pair { |s1, s2|
@@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF1M
end
+class OpenSSL::TestEOF1IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF1M
+end
+
class OpenSSL::TestEOF2 < OpenSSL::TestCase
include OpenSSL::TestEOF
include OpenSSL::SSLPair
@@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestEOF2M
end
+class OpenSSL::TestEOF2IOish < OpenSSL::TestCase
+ include OpenSSL::TestEOF
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestEOF2M
+end
+
class OpenSSL::TestPair < OpenSSL::TestCase
include OpenSSL::SSLPair
include OpenSSL::TestPairM
@@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase
include OpenSSL::TestPairM
end
+class OpenSSL::TestPairIOish < OpenSSL::TestCase
+ include OpenSSL::SSLPairIOish
+ include OpenSSL::TestPairM
+end
+
end
diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb
index 4642063f4..a5e0c0443 100644
--- a/test/openssl/test_ssl.rb
+++ b/test/openssl/test_ssl.rb
@@ -4,17 +4,6 @@
if defined?(OpenSSL::SSL)
class OpenSSL::TestSSL < OpenSSL::SSLTestCase
- def test_bad_socket
- bad_socket = Struct.new(:sync).new
- assert_raise TypeError do
- socket = OpenSSL::SSL::SSLSocket.new bad_socket
- # if the socket is not a T_FILE, `connect` will segv because it tries
- # to get the underlying file descriptor but the API it calls assumes
- # the object type is T_FILE
- socket.connect
- end
- end
-
def test_ctx_setup
ctx = OpenSSL::SSL::SSLContext.new
assert_equal true, ctx.setup
@@ -170,6 +159,155 @@ def test_socket_close_write
end
end
+ def test_synthetic_io_sanity_check
+ obj = Object.new
+ assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) }
+
+ obj = Object.new
+ obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| }
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| }
+ assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) }
+ end
+
+ def test_synthetic_io
+ start_server do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ obj.define_singleton_method(:read_nonblock) { |maxlen, exception:|
+ tcp.read_nonblock(maxlen, exception: exception) }
+ obj.define_singleton_method(:write_nonblock) { |str, exception:|
+ tcp.write_nonblock(str, exception: exception) }
+ obj.define_singleton_method(:wait_readable) { tcp.wait_readable }
+ obj.define_singleton_method(:wait_writable) { tcp.wait_writable }
+ obj.define_singleton_method(:flush) { tcp.flush }
+ obj.define_singleton_method(:closed?) { tcp.closed? }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_same obj, ssl.to_io
+
+ ssl.connect
+ ssl.puts "abc"; assert_equal "abc\n", ssl.gets
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
+ def test_synthetic_io_write_nonblock_exception
+ start_server(ignore_listener_error: true) do |port|
+ tcp = TCPSocket.new("127.0.0.1", port)
+ obj = Object.new
+ [:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name|
+ obj.define_singleton_method(name) { |*args, **kwargs|
+ tcp.__send__(name, *args, **kwargs) }
+ end
+
+ # SSLSocket#connect calls write_nonblock at least twice: to write
+ # ClientHello and Finished. Let's raise an exception in the 2nd call.
+ called = 0
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ raise "foo" if (called += 1) == 2
+ tcp.write_nonblock(*args, **kwargs)
+ }
+
+ ssl = OpenSSL::SSL::SSLSocket.new(obj)
+ assert_raise_with_message(RuntimeError, "foo") { ssl.connect }
+ ensure
+ ssl&.close
+ tcp&.close
+ end
+ end
+
+ def test_synthetic_io_errors_in_servername_cb
+ assert_separately(["-ropenssl"], <<~"end;")
+ begin
+ #sock1, sock2 = socketpair
+ sock1, sock2 = if defined? UNIXSocket
+ UNIXSocket.pair
+ else
+ Socket.pair(Socket::AF_INET, Socket::SOCK_STREAM, 0)
+ end
+
+ t = Thread.new {
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ s1.hostname = "localhost"
+ assert_raise_with_message(OpenSSL::SSL::SSLError, /unrecognized.name/i) {
+ s1.connect
+ }
+ }
+
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.servername_cb = lambda { |args| raise RuntimeError, "exc in servername_cb" }
+ obj = Object.new
+ obj.define_singleton_method(:method_missing) { |name, *args, **kwargs| sock2.__send__(name, *args, **kwargs) }
+ obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs| sock2.respond_to?(name, *args, **kwargs) }
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ begin
+ raise "exc in write_nonblock"
+ rescue
+ p $!
+ end
+ p $!
+ sock2.write_nonblock(*args, **kwargs)
+ }
+ s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2)
+ assert_raise_with_message(RuntimeError, "exc in servername_cb") { s2.accept }
+ assert t.join
+ ensure
+ sock1.close
+ sock2.close
+ end
+ end;
+ end
+
+ def test_synthetic_io_errors_in_callback_and_socket
+ assert_separately(["-ropenssl"], <<~"end;", ignore_stderr: true)
+ begin
+ #sock1, sock2 = socketpair
+ sock1, sock2 = if defined? UNIXSocket
+ UNIXSocket.pair
+ else
+ Socket.pair(Socket::AF_INET, Socket::SOCK_STREAM, 0)
+ end
+
+ t = Thread.new {
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ s1.hostname = "localhost"
+ begin
+ s1.connect
+ rescue
+ end
+ }
+
+ called = []
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.servername_cb = lambda { |args|
+ called << :servername_cb
+ raise "servername_cb"
+ }
+ obj = Object.new
+ obj.define_singleton_method(:method_missing) { |name, *args, **kwargs| sock2.__send__(name, *args, **kwargs) }
+ obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs| sock2.respond_to?(name, *args, **kwargs) }
+ obj.define_singleton_method(:write_nonblock) { |*args, **kwargs|
+ called << :write_nonblock
+ throw :throw_from, :write_nonblock
+ }
+ s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2)
+
+ ret = assert_warning(/exception ignored/) {
+ catch(:throw_from) { s2.accept }
+ }
+ assert_equal(:write_nonblock, ret)
+ assert_equal([:servername_cb, :write_nonblock], called)
+ sock2.close
+ assert t.join
+ ensure
+ sock1.close
+ sock2.close
+ end
+ end;
+ end
+
def test_add_certificate
ctx_proc = -> ctx {
# Unset values set by start_server
@@ -1062,36 +1200,44 @@ def test_tlsext_hostname
end
end
- def test_servername_cb_raises_an_exception_on_unknown_objects
- hostname = 'example.org'
-
- ctx2 = OpenSSL::SSL::SSLContext.new
- ctx2.cert = @svr_cert
- ctx2.key = @svr_key
- ctx2.servername_cb = lambda { |args| Object.new }
-
+ def test_servername_cb_exception
sock1, sock2 = socketpair
+ t = Thread.new {
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ s1.hostname = "localhost"
+ assert_raise_with_message(OpenSSL::SSL::SSLError, /unrecognized.name/i) {
+ s1.connect
+ }
+ }
+
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.servername_cb = lambda { |args| raise RuntimeError, "foo" }
s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2)
+ assert_raise_with_message(RuntimeError, "foo") { s2.accept }
+ assert t.join
+ ensure
+ sock1.close
+ sock2.close
+ end
- ctx1 = OpenSSL::SSL::SSLContext.new
+ def test_servername_cb_raises_an_exception_on_unknown_objects
+ sock1, sock2 = socketpair
- s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1)
- s1.hostname = hostname
t = Thread.new {
- assert_raise(OpenSSL::SSL::SSLError) do
- s1.connect
- end
+ s1 = OpenSSL::SSL::SSLSocket.new(sock1)
+ s1.hostname = "localhost"
+ assert_raise(OpenSSL::SSL::SSLError) { s1.connect }
}
- assert_raise(ArgumentError) do
- s2.accept
- end
-
+ ctx2 = OpenSSL::SSL::SSLContext.new
+ ctx2.servername_cb = lambda { |args| Object.new }
+ s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2)
+ assert_raise(ArgumentError) { s2.accept }
assert t.join
ensure
- sock1.close if sock1
- sock2.close if sock2
+ sock1.close
+ sock2.close
end
def test_accept_errors_include_peeraddr
@@ -1555,7 +1701,12 @@ def test_options_disable_versions
# Client only supports TLS 1.3
ctx2 = OpenSSL::SSL::SSLContext.new
ctx2.min_version = ctx2.max_version = OpenSSL::SSL::TLS1_3_VERSION
- assert_nothing_raised { server_connect(port, ctx2) { } }
+ assert_nothing_raised {
+ server_connect(port, ctx2) { |ssl|
+ # Ensure SSL_accept() finishes successfully
+ ssl.puts("abc"); ssl.gets
+ }
+ }
}
# Server only supports TLS 1.2