|  | 
|  | 1 | +#include <thread> | 
| 1 | 2 | #include <regex> | 
| 2 | 3 | 
 | 
| 3 | 4 | #include <nlohmann/json.hpp> | 
|  | 
| 6 | 7 | #include "serve-protocol.hh" | 
| 7 | 8 | #include "serve-protocol-impl.hh" | 
| 8 | 9 | #include "build-result.hh" | 
|  | 10 | +#include "file-descriptor.hh" | 
| 9 | 11 | #include "tests/protocol.hh" | 
| 10 | 12 | #include "tests/characterization.hh" | 
| 11 | 13 | 
 | 
| @@ -401,4 +403,112 @@ VERSIONED_CHARACTERIZATION_TEST( | 
| 401 | 403 |         }, | 
| 402 | 404 |     })) | 
| 403 | 405 | 
 | 
|  | 406 | +TEST_F(ServeProtoTest, handshake_log) | 
|  | 407 | +{ | 
|  | 408 | +    CharacterizationTest::writeTest("handshake-to-client", [&]() -> std::string { | 
|  | 409 | +        StringSink toClientLog; | 
|  | 410 | + | 
|  | 411 | +        Pipe toClient, toServer; | 
|  | 412 | +        toClient.create(); | 
|  | 413 | +        toServer.create(); | 
|  | 414 | + | 
|  | 415 | +        ServeProto::Version clientResult, serverResult; | 
|  | 416 | + | 
|  | 417 | +        auto thread = std::thread([&]() { | 
|  | 418 | +            FdSink out { toServer.writeSide.get() }; | 
|  | 419 | +            FdSource in0 { toClient.readSide.get() }; | 
|  | 420 | +            TeeSource in { in0, toClientLog }; | 
|  | 421 | +            clientResult = ServeProto::BasicClientConnection::handshake( | 
|  | 422 | +                out, in, defaultVersion, "blah"); | 
|  | 423 | +        }); | 
|  | 424 | + | 
|  | 425 | +        { | 
|  | 426 | +            FdSink out { toClient.writeSide.get() }; | 
|  | 427 | +            FdSource in { toServer.readSide.get() }; | 
|  | 428 | +            serverResult = ServeProto::BasicServerConnection::handshake( | 
|  | 429 | +                out, in, defaultVersion); | 
|  | 430 | +        }; | 
|  | 431 | + | 
|  | 432 | +        thread.join(); | 
|  | 433 | + | 
|  | 434 | +        return std::move(toClientLog.s); | 
|  | 435 | +    }); | 
|  | 436 | +} | 
|  | 437 | + | 
|  | 438 | +/// Has to be a `BufferedSink` for handshake. | 
|  | 439 | +struct NullBufferedSink : BufferedSink { | 
|  | 440 | +    void writeUnbuffered(std::string_view data) override { } | 
|  | 441 | +}; | 
|  | 442 | + | 
|  | 443 | +TEST_F(ServeProtoTest, handshake_client_replay) | 
|  | 444 | +{ | 
|  | 445 | +    CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { | 
|  | 446 | +        NullBufferedSink nullSink; | 
|  | 447 | + | 
|  | 448 | +        StringSource in { toClientLog }; | 
|  | 449 | +        auto clientResult = ServeProto::BasicClientConnection::handshake( | 
|  | 450 | +            nullSink, in, defaultVersion, "blah"); | 
|  | 451 | + | 
|  | 452 | +        EXPECT_EQ(clientResult, defaultVersion); | 
|  | 453 | +    }); | 
|  | 454 | +} | 
|  | 455 | + | 
|  | 456 | +TEST_F(ServeProtoTest, handshake_client_trunated_replay_throws) | 
|  | 457 | +{ | 
|  | 458 | +    CharacterizationTest::readTest("handshake-to-client", [&](std::string toClientLog) { | 
|  | 459 | +        for (size_t len = 0; len < toClientLog.size(); ++len) { | 
|  | 460 | +            NullBufferedSink nullSink; | 
|  | 461 | +            StringSource in { | 
|  | 462 | +                // truncate | 
|  | 463 | +                toClientLog.substr(0, len) | 
|  | 464 | +            }; | 
|  | 465 | +            if (len < 8) { | 
|  | 466 | +                EXPECT_THROW( | 
|  | 467 | +                    ServeProto::BasicClientConnection::handshake( | 
|  | 468 | +                        nullSink, in, defaultVersion, "blah"), | 
|  | 469 | +                    EndOfFile); | 
|  | 470 | +            } else { | 
|  | 471 | +                // Not sure why cannot keep on checking for `EndOfFile`. | 
|  | 472 | +                EXPECT_THROW( | 
|  | 473 | +                    ServeProto::BasicClientConnection::handshake( | 
|  | 474 | +                        nullSink, in, defaultVersion, "blah"), | 
|  | 475 | +                    Error); | 
|  | 476 | +            } | 
|  | 477 | +        } | 
|  | 478 | +    }); | 
|  | 479 | +} | 
|  | 480 | + | 
|  | 481 | +TEST_F(ServeProtoTest, handshake_client_corrupted_throws) | 
|  | 482 | +{ | 
|  | 483 | +    CharacterizationTest::readTest("handshake-to-client", [&](const std::string toClientLog) { | 
|  | 484 | +        for (size_t idx = 0; idx < toClientLog.size(); ++idx) { | 
|  | 485 | +            // corrupt a copy | 
|  | 486 | +            std::string toClientLogCorrupt = toClientLog; | 
|  | 487 | +            toClientLogCorrupt[idx] *= 4; | 
|  | 488 | +            ++toClientLogCorrupt[idx]; | 
|  | 489 | + | 
|  | 490 | +            NullBufferedSink nullSink; | 
|  | 491 | +            StringSource in { toClientLogCorrupt }; | 
|  | 492 | + | 
|  | 493 | +            if (idx < 4 || idx == 9) { | 
|  | 494 | +                // magic bytes don't match | 
|  | 495 | +                EXPECT_THROW( | 
|  | 496 | +                    ServeProto::BasicClientConnection::handshake( | 
|  | 497 | +                        nullSink, in, defaultVersion, "blah"), | 
|  | 498 | +                    Error); | 
|  | 499 | +            } else if (idx < 8 || idx >= 12) { | 
|  | 500 | +                // Number out of bounds | 
|  | 501 | +                EXPECT_THROW( | 
|  | 502 | +                    ServeProto::BasicClientConnection::handshake( | 
|  | 503 | +                        nullSink, in, defaultVersion, "blah"), | 
|  | 504 | +                    SerialisationError); | 
|  | 505 | +            } else { | 
|  | 506 | +                auto ver = ServeProto::BasicClientConnection::handshake( | 
|  | 507 | +                    nullSink, in, defaultVersion, "blah"); | 
|  | 508 | +                EXPECT_NE(ver, defaultVersion); | 
|  | 509 | +            } | 
|  | 510 | +        } | 
|  | 511 | +    }); | 
|  | 512 | +} | 
|  | 513 | + | 
| 404 | 514 | } | 
0 commit comments