diff --git a/ext/openssl/ossl.c b/ext/openssl/ossl.c index 60780790b..d148447dd 100644 --- a/ext/openssl/ossl.c +++ b/ext/openssl/ossl.c @@ -1033,6 +1033,7 @@ Init_openssl(void) * Init components */ Init_ossl_asn1(); + Init_ossl_bio(); Init_ossl_bn(); Init_ossl_cipher(); Init_ossl_config(); diff --git a/ext/openssl/ossl_bio.c b/ext/openssl/ossl_bio.c index 2ef208050..aad7a4908 100644 --- a/ext/openssl/ossl_bio.c +++ b/ext/openssl/ossl_bio.c @@ -40,3 +40,324 @@ ossl_membio2str(BIO *bio) return ret; } + +static BIO_METHOD *ossl_bio_meth; +static VALUE nonblock_kwargs, sym_wait_readable, sym_wait_writable; + +struct ossl_bio_ctx { + VALUE io; + int state; + int eof; +}; + +static void +bio_free(void *ptr) +{ + BIO_free(ptr); +} + +static void +bio_mark(void *ptr) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(ptr); + rb_gc_mark_movable(ctx->io); +} + +static void +bio_compact(void *ptr) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(ptr); + ctx->io = rb_gc_location(ctx->io); +} + +static const rb_data_type_t ossl_bio_type = { + "OpenSSL/BIO", + { + .dmark = bio_mark, + .dfree = bio_free, + .dcompact = bio_compact, + }, + 0, 0, RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED, +}; + +VALUE +ossl_bio_new(VALUE io) +{ + VALUE obj = TypedData_Wrap_Struct(rb_cObject, &ossl_bio_type, NULL); + BIO *bio = BIO_new(ossl_bio_meth); + if (!bio) + ossl_raise(eOSSLError, "BIO_new"); + + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + ctx->io = io; + BIO_set_init(bio, 1); + RTYPEDDATA_DATA(obj) = bio; + return obj; +} + +BIO * +ossl_bio_get(VALUE obj) +{ + BIO *bio; + TypedData_Get_Struct(obj, BIO, &ossl_bio_type, bio); + return bio; +} + +int +ossl_bio_state(VALUE obj) +{ + BIO *bio; + TypedData_Get_Struct(obj, BIO, &ossl_bio_type, bio); + + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + int state = ctx->state; + ctx->state = 0; + return state; +} + +static int +bio_create(BIO *bio) +{ + struct ossl_bio_ctx *ctx = OPENSSL_malloc(sizeof(*ctx)); + if (!ctx) + return 0; + memset(ctx, 0, sizeof(*ctx)); + BIO_set_data(bio, ctx); + + return 1; +} + +static int +bio_destroy(BIO *bio) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + if (ctx) { + OPENSSL_free(ctx); + BIO_set_data(bio, NULL); + } + + return 1; +} + +struct bwrite_args { + BIO *bio; + struct ossl_bio_ctx *ctx; + const char *data; + int dlen; + int written; +}; + +static VALUE +bio_bwrite0(VALUE args) +{ + struct bwrite_args *p = (void *)args; + BIO_clear_retry_flags(p->bio); + + VALUE fargs[] = { rb_str_new_static(p->data, p->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_kw(p->ctx->io, rb_intern("write_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_INTEGER_TYPE_P(ret)) { + p->written = NUM2INT(ret); + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(p->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(p->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +struct call0_args { + VALUE (*func)(VALUE); + VALUE args; + VALUE ret; +}; + +static VALUE +do_nothing(VALUE _) +{ + return Qnil; +} + +static VALUE +call_protect1(VALUE args_) +{ + struct call0_args *args = (void *)args_; + rb_set_errinfo(Qnil); + args->ret = args->func(args->args); + return Qnil; +} + +static VALUE +call_protect0(VALUE args_) +{ + /* + * At this point rb_errinfo() may be set by another callback called from + * the same OpenSSL function (e.g., SSL_accept()). + * + * Abusing rb_ensure() to temporarily save errinfo and restore it after + * the BIO callback successfully returns. + */ + rb_ensure(do_nothing, Qnil, call_protect1, args_); + return Qnil; +} + +static VALUE +call_protect(VALUE (*func)(VALUE), VALUE args, int *state) +{ + /* + * FIXME: should check !NIL_P(rb_ivar_get(ssl_obj, ID_callback_state)) + * instead to see if a tag jump is pending or not. + */ + int pending = !NIL_P(rb_errinfo()); + struct call0_args call0_args = { func, args, Qfalse }; + rb_protect(call_protect0, (VALUE)&call0_args, state); + if (pending && *state) + rb_warn("exception ignored in BIO callback: pending=%d", pending); + return call0_args.ret; +} + +static int +bio_bwrite(BIO *bio, const char *data, int dlen) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + struct bwrite_args args = { bio, ctx, data, dlen, 0 }; + int state; + + if (ctx->state) + return -1; + + VALUE ok = call_protect(bio_bwrite0, (VALUE)&args, &state); + if (state) { + ctx->state = state; + return -1; + } + if (RTEST(ok)) + return args.written; + return -1; +} + +struct bread_args { + BIO *bio; + struct ossl_bio_ctx *ctx; + char *data; + int dlen; + int readbytes; +}; + +static VALUE +bio_bread0(VALUE args) +{ + struct bread_args *p = (void *)args; + BIO_clear_retry_flags(p->bio); + + VALUE fargs[] = { INT2NUM(p->dlen), nonblock_kwargs }; + VALUE ret = rb_funcallv_kw(p->ctx->io, rb_intern("read_nonblock"), + 2, fargs, RB_PASS_KEYWORDS); + + if (RB_TYPE_P(ret, T_STRING)) { + int len = RSTRING_LENINT(ret); + if (len > p->dlen) + rb_raise(rb_eTypeError, "read_nonblock returned too much data"); + memcpy(p->data, RSTRING_PTR(ret), len); + p->readbytes = len; + return Qtrue; + } + else if (NIL_P(ret)) { + // In OpenSSL 3.0 or later: BIO_set_flags(p->bio, BIO_FLAGS_IN_EOF); + p->ctx->eof = 1; + return Qtrue; + } + else if (ret == sym_wait_readable) { + BIO_set_retry_read(p->bio); + return Qfalse; + } + else if (ret == sym_wait_writable) { + BIO_set_retry_write(p->bio); + return Qfalse; + } + else { + rb_raise(rb_eTypeError, "write_nonblock must return an Integer, " + ":wait_readable, or :wait_writable"); + } +} + +static int +bio_bread(BIO *bio, char *data, int dlen) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + struct bread_args args = { bio, ctx, data, dlen, 0 }; + int state; + + if (ctx->state) + return -1; + + VALUE ok = call_protect(bio_bread0, (VALUE)&args, &state); + if (state) { + ctx->state = state; + return -1; + } + if (RTEST(ok)) + return args.readbytes; + return -1; +} + +static VALUE +bio_flush0(VALUE vctx) +{ + struct ossl_bio_ctx *ctx = (void *)vctx; + return rb_funcallv(ctx->io, rb_intern("flush"), 0, NULL); +} + +static long +bio_ctrl(BIO *bio, int cmd, long larg, void *parg) +{ + struct ossl_bio_ctx *ctx = BIO_get_data(bio); + int state; + + if (ctx->state) + return 0; + + switch (cmd) { + case BIO_CTRL_EOF: + return ctx->eof; + case BIO_CTRL_FLUSH: + call_protect(bio_flush0, (VALUE)ctx, &state); + ctx->state = state; + return !state; + default: + return 0; + } +} + +void +Init_ossl_bio(void) +{ + ossl_bio_meth = BIO_meth_new(BIO_TYPE_SOURCE_SINK, "Ruby IO-like object"); + if (!ossl_bio_meth) + ossl_raise(eOSSLError, "BIO_meth_new"); + if (!BIO_meth_set_create(ossl_bio_meth, bio_create) || + !BIO_meth_set_destroy(ossl_bio_meth, bio_destroy) || + !BIO_meth_set_write(ossl_bio_meth, bio_bwrite) || + !BIO_meth_set_read(ossl_bio_meth, bio_bread) || + !BIO_meth_set_ctrl(ossl_bio_meth, bio_ctrl)) { + BIO_meth_free(ossl_bio_meth); + ossl_bio_meth = NULL; + ossl_raise(eOSSLError, "BIO_meth_set_*"); + } + + nonblock_kwargs = rb_hash_new(); + rb_hash_aset(nonblock_kwargs, ID2SYM(rb_intern_const("exception")), Qfalse); + rb_global_variable(&nonblock_kwargs); + + sym_wait_readable = ID2SYM(rb_intern_const("wait_readable")); + sym_wait_writable = ID2SYM(rb_intern_const("wait_writable")); +} diff --git a/ext/openssl/ossl_bio.h b/ext/openssl/ossl_bio.h index 1b871f1cd..634f99fae 100644 --- a/ext/openssl/ossl_bio.h +++ b/ext/openssl/ossl_bio.h @@ -13,4 +13,10 @@ BIO *ossl_obj2bio(volatile VALUE *); VALUE ossl_membio2str(BIO*); +VALUE ossl_bio_new(VALUE io); +BIO *ossl_bio_get(VALUE obj); +int ossl_bio_state(VALUE obj); + +void Init_ossl_bio(void); + #endif diff --git a/ext/openssl/ossl_ssl.c b/ext/openssl/ossl_ssl.c index a5b25e14d..3a3fceef1 100644 --- a/ext/openssl/ossl_ssl.c +++ b/ext/openssl/ossl_ssl.c @@ -37,7 +37,7 @@ static VALUE eSSLErrorWaitReadable; static VALUE eSSLErrorWaitWritable; static ID id_call, ID_callback_state, id_tmp_dh_callback, - id_npn_protocols_encoded, id_each; + id_npn_protocols_encoded, id_each, id_bio; static VALUE sym_exception, sym_wait_readable, sym_wait_writable; static ID id_i_cert_store, id_i_ca_file, id_i_ca_path, id_i_verify_mode, @@ -1536,7 +1536,10 @@ ossl_ssl_s_alloc(VALUE klass) static VALUE peer_ip_address(VALUE self) { - VALUE remote_address = rb_funcall(rb_attr_get(self, id_i_io), rb_intern("remote_address"), 0); + VALUE io = rb_attr_get(self, id_i_io); + VALUE remote_address = rb_check_funcall(io, rb_intern("remote_address"), 0, NULL); + if (remote_address == Qundef) + return rb_str_new_cstr("(unsupported)"); return rb_funcall(remote_address, rb_intern("inspect_sockaddr"), 0); } @@ -1550,10 +1553,16 @@ fallback_peer_ip_address(VALUE self, VALUE args) static VALUE peeraddr_ip_str(VALUE self) { - VALUE rb_mErrno = rb_const_get(rb_cObject, rb_intern("Errno")); - VALUE rb_eSystemCallError = rb_const_get(rb_mErrno, rb_intern("SystemCallError")); + return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, Qnil, + rb_eSystemCallError, (VALUE)0); +} - return rb_rescue2(peer_ip_address, self, fallback_peer_ip_address, (VALUE)0, rb_eSystemCallError, NULL); +static int +is_real_socket(VALUE io) +{ + // FIXME: DO NOT MERGE + return 0; + return RB_TYPE_P(io, T_FILE); } /* @@ -1561,8 +1570,10 @@ peeraddr_ip_str(VALUE self) * SSLSocket.new(io) => aSSLSocket * SSLSocket.new(io, ctx) => aSSLSocket * - * Creates a new SSL socket from _io_ which must be a real IO object (not an - * IO-like object that responds to read/write). + * Creates a new SSL socket from the underlying socket _io_ and _ctx_. + * + * _io_ must be an IO object, typically a TCPSocket or Socket from the socket + * library, or an IO-like object that supports the typical IO methods. * * If _ctx_ is provided the SSL Sockets initial params will be taken from * the context. @@ -1571,6 +1582,22 @@ peeraddr_ip_str(VALUE self) * * This method will freeze the SSLContext if one is provided; * however, session management is still allowed in the frozen SSLContext. + * + * == Support for IO-like objects + * + * Support for IO-like objects was added in version 3.3 and is experimental. + * + * As of version 3.3, SSLSocket uses the following methods: + * + * - write_nonblock with the exception: false option + * - read_nonblock with the exception: false option + * - wait_readable + * - wait_writable + * - flush + * - close + * - closed? + * + * Note that future versions may require additional methods to be implemented. */ static VALUE ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) @@ -1590,9 +1617,18 @@ ossl_ssl_initialize(int argc, VALUE *argv, VALUE self) rb_ivar_set(self, id_i_context, v_ctx); ossl_sslctx_setup(v_ctx); - if (rb_respond_to(io, rb_intern("nonblock="))) - rb_funcall(io, rb_intern("nonblock="), 1, Qtrue); - Check_Type(io, T_FILE); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_set_nonblock(fptr); + } + else { + // Not meant to be a comprehensive check + if (!rb_respond_to(io, rb_intern("read_nonblock")) || + !rb_respond_to(io, rb_intern("write_nonblock"))) + rb_raise(rb_eTypeError, "io must be a real IO object or an IO-like " + "object that responds to read_nonblock and write_nonblock"); + } rb_ivar_set(self, id_i_io, io); ssl = SSL_new(ctx); @@ -1624,27 +1660,57 @@ ossl_ssl_setup(VALUE self) { VALUE io; SSL *ssl; - rb_io_t *fptr; GetSSL(self, ssl); if (ssl_started(ssl)) return Qtrue; io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - rb_io_check_readable(fptr); - rb_io_check_writable(fptr); - if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) - ossl_raise(eSSLError, "SSL_set_fd"); + if (is_real_socket(io)) { + rb_io_t *fptr; + GetOpenFile(io, fptr); + rb_io_check_readable(fptr); + rb_io_check_writable(fptr); + if (!SSL_set_fd(ssl, TO_SOCKET(rb_io_descriptor(io)))) + ossl_raise(eSSLError, "SSL_set_fd"); + } + else { + VALUE bobj = ossl_bio_new(io); + rb_ivar_set(self, id_bio, bobj); + + BIO *bio = ossl_bio_get(bobj); + if (!BIO_up_ref(bio)) + ossl_raise(eSSLError, "BIO_up_ref"); + SSL_set_bio(ssl, bio, bio); + } return Qtrue; } +static void +check_bio_error(VALUE self, SSL *ssl, VALUE bobj, int ret) +{ + if (NIL_P(bobj)) { #ifdef _WIN32 -#define ssl_get_error(ssl, ret) (errno = rb_w32_map_errno(WSAGetLastError()), SSL_get_error((ssl), (ret))) -#else -#define ssl_get_error(ssl, ret) SSL_get_error((ssl), (ret)) + errno = rb_w32_map_errno(WSAGetLastError()); #endif + } + else { + int state = ossl_bio_state(bobj); + if (state) { + ossl_clear_error(); + rb_jump_tag(state); + } + errno = 0; + } + + VALUE cb_state = rb_attr_get(self, ID_callback_state); + if (!NIL_P(cb_state)) { + /* must cleanup OpenSSL error stack before re-raising */ + ossl_clear_error(); + rb_jump_tag(NUM2INT(cb_state)); + } +} static void write_would_block(int nonblock) @@ -1684,6 +1750,11 @@ no_exception_p(VALUE opts) static void io_wait_writable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_writable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_writable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become writable!"); @@ -1698,6 +1769,11 @@ io_wait_writable(VALUE io) static void io_wait_readable(VALUE io) { + if (!is_real_socket(io)) { + if (!RTEST(rb_funcallv(io, rb_intern("wait_readable"), 0, NULL))) + rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); + return; + } #ifdef HAVE_RB_IO_MAYBE_WAIT if (!rb_io_maybe_wait_readable(errno, io, RUBY_IO_TIMEOUT_DEFAULT)) { rb_raise(IO_TIMEOUT_ERROR, "Timed out while waiting to become readable!"); @@ -1714,28 +1790,22 @@ ossl_start_ssl(VALUE self, int (*func)(SSL *), const char *funcname, VALUE opts) { SSL *ssl; int ret, ret2; - VALUE cb_state; int nonblock = opts != Qfalse; - rb_ivar_set(self, ID_callback_state, Qnil); - GetSSL(self, ssl); - VALUE io = rb_attr_get(self, id_i_io); + VALUE io = rb_attr_get(self, id_i_io), + bobj = rb_attr_get(self, id_bio); + + rb_ivar_set(self, ID_callback_state, Qnil); for (;;) { ret = func(ssl); - - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - /* must cleanup OpenSSL error stack before re-raising */ - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } + check_bio_error(self, ssl, bobj, ret); if (ret > 0) break; - switch ((ret2 = ssl_get_error(ssl, ret))) { + switch ((ret2 = SSL_get_error(ssl, ret))) { case SSL_ERROR_WANT_WRITE: if (no_exception_p(opts)) { return sym_wait_writable; } write_would_block(nonblock); @@ -1885,7 +1955,7 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) { SSL *ssl; int ilen; - VALUE len, str, cb_state; + VALUE len, str; VALUE opts = Qnil; if (nonblock) { @@ -1913,21 +1983,17 @@ ossl_ssl_read_internal(int argc, VALUE *argv, VALUE self, int nonblock) return str; } - VALUE io = rb_attr_get(self, id_i_io); + VALUE io = rb_attr_get(self, id_i_io), + bobj = rb_attr_get(self, id_bio); + rb_ivar_set(self, ID_callback_state, Qnil); for (;;) { rb_str_locktmp(str); int nread = SSL_read(ssl, RSTRING_PTR(str), ilen); rb_str_unlocktmp(str); + check_bio_error(self, ssl, bobj, nread); - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } - - switch (ssl_get_error(ssl, nread)) { + switch (SSL_get_error(ssl, nread)) { case SSL_ERROR_NONE: rb_str_set_len(str, nread); return str; @@ -2018,33 +2084,26 @@ ossl_ssl_write_internal_safe(VALUE _args) VALUE opts = args[2]; SSL *ssl; - rb_io_t *fptr; int num, nonblock = opts != Qfalse; - VALUE cb_state; GetSSL(self, ssl); if (!ssl_started(ssl)) rb_raise(eSSLError, "SSL session is not started yet"); - VALUE io = rb_attr_get(self, id_i_io); - GetOpenFile(io, fptr); - /* SSL_write(3ssl) manpage states num == 0 is undefined */ num = RSTRING_LENINT(str); if (num == 0) return INT2FIX(0); + VALUE io = rb_attr_get(self, id_i_io), + bobj = rb_attr_get(self, id_bio); + + rb_ivar_set(self, ID_callback_state, Qnil); for (;;) { int nwritten = SSL_write(ssl, RSTRING_PTR(str), num); + check_bio_error(self, ssl, bobj, nwritten); - cb_state = rb_attr_get(self, ID_callback_state); - if (!NIL_P(cb_state)) { - rb_ivar_set(self, ID_callback_state, Qnil); - ossl_clear_error(); - rb_jump_tag(NUM2INT(cb_state)); - } - - switch (ssl_get_error(ssl, nwritten)) { + switch (SSL_get_error(ssl, nwritten)) { case SSL_ERROR_NONE: return INT2NUM(nwritten); case SSL_ERROR_WANT_WRITE: @@ -2143,7 +2202,14 @@ ossl_ssl_stop(VALUE self) GetSSL(self, ssl); if (!ssl_started(ssl)) return Qnil; + ret = SSL_shutdown(ssl); + + /* XXX: Suppressing errors from the underlying socket */ + VALUE bobj = rb_attr_get(self, id_bio); + if (!NIL_P(bobj) && ossl_bio_state(bobj)) + rb_set_errinfo(Qnil); + if (ret == 1) /* Have already received close_notify */ return Qnil; if (ret == 0) /* Sent close_notify, but we don't wait for reply */ @@ -3121,6 +3187,7 @@ Init_ossl_ssl(void) id_tmp_dh_callback = rb_intern_const("tmp_dh_callback"); id_npn_protocols_encoded = rb_intern_const("npn_protocols_encoded"); id_each = rb_intern_const("each"); + id_bio = rb_intern_const("bio"); #define DefIVarID(name) do \ id_i_##name = rb_intern_const("@"#name); while (0) diff --git a/lib/openssl/buffering.rb b/lib/openssl/buffering.rb index 1464a4292..aa39105f9 100644 --- a/lib/openssl/buffering.rb +++ b/lib/openssl/buffering.rb @@ -60,7 +60,7 @@ def initialize(*) super @eof = false @rbuffer = Buffer.new - @sync = @io.sync + @sync = @io.respond_to?(:sync) ? @io.sync : true end # diff --git a/test/openssl/test_pair.rb b/test/openssl/test_pair.rb index 10942191d..1664c00c8 100644 --- a/test/openssl/test_pair.rb +++ b/test/openssl/test_pair.rb @@ -67,6 +67,32 @@ def create_tcp_client(host, port) end end +module OpenSSL::SSLPairIOish + include OpenSSL::SSLPairM + + def create_tcp_server(host, port) + Addrinfo.tcp(host, port).listen + end + + class TCPSocketWrapper + def initialize(io) @io = io end + def read_nonblock(*args, **kwargs) @io.read_nonblock(*args, **kwargs) end + def write_nonblock(*args, **kwargs) @io.write_nonblock(*args, **kwargs) end + def wait_readable() @io.wait_readable end + def wait_writable() @io.wait_writable end + def flush() @io.flush end + def close() @io.close end + def closed?() @io.closed? end + + # Only used within test_pair.rb + def write(*args) @io.write(*args) end + end + + def create_tcp_client(host, port) + TCPSocketWrapper.new(Addrinfo.tcp(host, port).connect) + end +end + module OpenSSL::TestEOF1M def open_file(content) ssl_pair { |s1, s2| @@ -518,6 +544,12 @@ class OpenSSL::TestEOF1LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF1M end +class OpenSSL::TestEOF1IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF1M +end + class OpenSSL::TestEOF2 < OpenSSL::TestCase include OpenSSL::TestEOF include OpenSSL::SSLPair @@ -530,6 +562,12 @@ class OpenSSL::TestEOF2LowlevelSocket < OpenSSL::TestCase include OpenSSL::TestEOF2M end +class OpenSSL::TestEOF2IOish < OpenSSL::TestCase + include OpenSSL::TestEOF + include OpenSSL::SSLPairIOish + include OpenSSL::TestEOF2M +end + class OpenSSL::TestPair < OpenSSL::TestCase include OpenSSL::SSLPair include OpenSSL::TestPairM @@ -540,4 +578,9 @@ class OpenSSL::TestPairLowlevelSocket < OpenSSL::TestCase include OpenSSL::TestPairM end +class OpenSSL::TestPairIOish < OpenSSL::TestCase + include OpenSSL::SSLPairIOish + include OpenSSL::TestPairM +end + end diff --git a/test/openssl/test_ssl.rb b/test/openssl/test_ssl.rb index 4642063f4..a5e0c0443 100644 --- a/test/openssl/test_ssl.rb +++ b/test/openssl/test_ssl.rb @@ -4,17 +4,6 @@ if defined?(OpenSSL::SSL) class OpenSSL::TestSSL < OpenSSL::SSLTestCase - def test_bad_socket - bad_socket = Struct.new(:sync).new - assert_raise TypeError do - socket = OpenSSL::SSL::SSLSocket.new bad_socket - # if the socket is not a T_FILE, `connect` will segv because it tries - # to get the underlying file descriptor but the API it calls assumes - # the object type is T_FILE - socket.connect - end - end - def test_ctx_setup ctx = OpenSSL::SSL::SSLContext.new assert_equal true, ctx.setup @@ -170,6 +159,155 @@ def test_socket_close_write end end + def test_synthetic_io_sanity_check + obj = Object.new + assert_raise_with_message(TypeError, /read_nonblock/) { OpenSSL::SSL::SSLSocket.new(obj) } + + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |*args, **kwargs| } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| } + assert_nothing_raised { OpenSSL::SSL::SSLSocket.new(obj) } + end + + def test_synthetic_io + start_server do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + obj.define_singleton_method(:read_nonblock) { |maxlen, exception:| + tcp.read_nonblock(maxlen, exception: exception) } + obj.define_singleton_method(:write_nonblock) { |str, exception:| + tcp.write_nonblock(str, exception: exception) } + obj.define_singleton_method(:wait_readable) { tcp.wait_readable } + obj.define_singleton_method(:wait_writable) { tcp.wait_writable } + obj.define_singleton_method(:flush) { tcp.flush } + obj.define_singleton_method(:closed?) { tcp.closed? } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_same obj, ssl.to_io + + ssl.connect + ssl.puts "abc"; assert_equal "abc\n", ssl.gets + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_write_nonblock_exception + start_server(ignore_listener_error: true) do |port| + tcp = TCPSocket.new("127.0.0.1", port) + obj = Object.new + [:read_nonblock, :wait_readable, :wait_writable, :flush, :closed?].each do |name| + obj.define_singleton_method(name) { |*args, **kwargs| + tcp.__send__(name, *args, **kwargs) } + end + + # SSLSocket#connect calls write_nonblock at least twice: to write + # ClientHello and Finished. Let's raise an exception in the 2nd call. + called = 0 + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + raise "foo" if (called += 1) == 2 + tcp.write_nonblock(*args, **kwargs) + } + + ssl = OpenSSL::SSL::SSLSocket.new(obj) + assert_raise_with_message(RuntimeError, "foo") { ssl.connect } + ensure + ssl&.close + tcp&.close + end + end + + def test_synthetic_io_errors_in_servername_cb + assert_separately(["-ropenssl"], <<~"end;") + begin + #sock1, sock2 = socketpair + sock1, sock2 = if defined? UNIXSocket + UNIXSocket.pair + else + Socket.pair(Socket::AF_INET, Socket::SOCK_STREAM, 0) + end + + t = Thread.new { + s1 = OpenSSL::SSL::SSLSocket.new(sock1) + s1.hostname = "localhost" + assert_raise_with_message(OpenSSL::SSL::SSLError, /unrecognized.name/i) { + s1.connect + } + } + + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.servername_cb = lambda { |args| raise RuntimeError, "exc in servername_cb" } + obj = Object.new + obj.define_singleton_method(:method_missing) { |name, *args, **kwargs| sock2.__send__(name, *args, **kwargs) } + obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs| sock2.respond_to?(name, *args, **kwargs) } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + begin + raise "exc in write_nonblock" + rescue + p $! + end + p $! + sock2.write_nonblock(*args, **kwargs) + } + s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2) + assert_raise_with_message(RuntimeError, "exc in servername_cb") { s2.accept } + assert t.join + ensure + sock1.close + sock2.close + end + end; + end + + def test_synthetic_io_errors_in_callback_and_socket + assert_separately(["-ropenssl"], <<~"end;", ignore_stderr: true) + begin + #sock1, sock2 = socketpair + sock1, sock2 = if defined? UNIXSocket + UNIXSocket.pair + else + Socket.pair(Socket::AF_INET, Socket::SOCK_STREAM, 0) + end + + t = Thread.new { + s1 = OpenSSL::SSL::SSLSocket.new(sock1) + s1.hostname = "localhost" + begin + s1.connect + rescue + end + } + + called = [] + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.servername_cb = lambda { |args| + called << :servername_cb + raise "servername_cb" + } + obj = Object.new + obj.define_singleton_method(:method_missing) { |name, *args, **kwargs| sock2.__send__(name, *args, **kwargs) } + obj.define_singleton_method(:respond_to_missing?) { |name, *args, **kwargs| sock2.respond_to?(name, *args, **kwargs) } + obj.define_singleton_method(:write_nonblock) { |*args, **kwargs| + called << :write_nonblock + throw :throw_from, :write_nonblock + } + s2 = OpenSSL::SSL::SSLSocket.new(obj, ctx2) + + ret = assert_warning(/exception ignored/) { + catch(:throw_from) { s2.accept } + } + assert_equal(:write_nonblock, ret) + assert_equal([:servername_cb, :write_nonblock], called) + sock2.close + assert t.join + ensure + sock1.close + sock2.close + end + end; + end + def test_add_certificate ctx_proc = -> ctx { # Unset values set by start_server @@ -1062,36 +1200,44 @@ def test_tlsext_hostname end end - def test_servername_cb_raises_an_exception_on_unknown_objects - hostname = 'example.org' - - ctx2 = OpenSSL::SSL::SSLContext.new - ctx2.cert = @svr_cert - ctx2.key = @svr_key - ctx2.servername_cb = lambda { |args| Object.new } - + def test_servername_cb_exception sock1, sock2 = socketpair + t = Thread.new { + s1 = OpenSSL::SSL::SSLSocket.new(sock1) + s1.hostname = "localhost" + assert_raise_with_message(OpenSSL::SSL::SSLError, /unrecognized.name/i) { + s1.connect + } + } + + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.servername_cb = lambda { |args| raise RuntimeError, "foo" } s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) + assert_raise_with_message(RuntimeError, "foo") { s2.accept } + assert t.join + ensure + sock1.close + sock2.close + end - ctx1 = OpenSSL::SSL::SSLContext.new + def test_servername_cb_raises_an_exception_on_unknown_objects + sock1, sock2 = socketpair - s1 = OpenSSL::SSL::SSLSocket.new(sock1, ctx1) - s1.hostname = hostname t = Thread.new { - assert_raise(OpenSSL::SSL::SSLError) do - s1.connect - end + s1 = OpenSSL::SSL::SSLSocket.new(sock1) + s1.hostname = "localhost" + assert_raise(OpenSSL::SSL::SSLError) { s1.connect } } - assert_raise(ArgumentError) do - s2.accept - end - + ctx2 = OpenSSL::SSL::SSLContext.new + ctx2.servername_cb = lambda { |args| Object.new } + s2 = OpenSSL::SSL::SSLSocket.new(sock2, ctx2) + assert_raise(ArgumentError) { s2.accept } assert t.join ensure - sock1.close if sock1 - sock2.close if sock2 + sock1.close + sock2.close end def test_accept_errors_include_peeraddr @@ -1555,7 +1701,12 @@ def test_options_disable_versions # Client only supports TLS 1.3 ctx2 = OpenSSL::SSL::SSLContext.new ctx2.min_version = ctx2.max_version = OpenSSL::SSL::TLS1_3_VERSION - assert_nothing_raised { server_connect(port, ctx2) { } } + assert_nothing_raised { + server_connect(port, ctx2) { |ssl| + # Ensure SSL_accept() finishes successfully + ssl.puts("abc"); ssl.gets + } + } } # Server only supports TLS 1.2