Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,82 +1,84 @@
import XCTest
import Testing
import NIOCore
@testable import PostgresNIO

class AuthenticationStateMachineTests: XCTestCase {
func testAuthenticatePlaintext() {
@Suite struct AuthenticationStateMachineTests {

@Test func testAuthenticatePlaintext() {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")

var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.plaintext), .sendPasswordMessage(.cleartext, authContext))
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)

#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.plaintext) == .sendPasswordMessage(.cleartext, authContext))
#expect(state.authenticationMessageReceived(.ok) == .wait)
}

func testAuthenticateMD5() {
@Test func testAuthenticateMD5() {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext))
#expect(state.authenticationMessageReceived(.ok) == .wait)
}

func testAuthenticateMD5WithoutPassword() {
@Test func testAuthenticateMD5WithoutPassword() {
let authContext = AuthContext(username: "test", password: nil, database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)),
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.md5(salt: salt)) ==
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .authMechanismRequiresPassword, closePromise: nil)))
}

func testAuthenticateOkAfterStartUpWithoutAuthChallenge() {
@Test func testAuthenticateOkAfterStartUpWithoutAuthChallenge() {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.ok), .wait)
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.ok) == .wait)
}

func testAuthenticateSCRAMSHA256WithAtypicalEncoding() {
@Test func testAuthenticateSCRAMSHA256WithAtypicalEncoding() {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))

let saslResponse = state.authenticationMessageReceived(.sasl(names: ["SCRAM-SHA-256"]))
guard case .sendSaslInitialResponse(name: let name, initialResponse: let responseData) = saslResponse else {
return XCTFail("\(saslResponse) is not .sendSaslInitialResponse")
Issue.record("\(saslResponse) is not .sendSaslInitialResponse")
return
}
let responseString = String(decoding: responseData, as: UTF8.self)
XCTAssertEqual(name, "SCRAM-SHA-256")
XCTAssert(responseString.starts(with: "n,,n=test,r="))
#expect(name == "SCRAM-SHA-256")
#expect(responseString.starts(with: "n,,n=test,r="))

let saslContinueResponse = state.authenticationMessageReceived(.saslContinue(data: .init(bytes:
"r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,s=ijgUVaWgCDLRJyF963BKNA==,i=4096".utf8
)))
guard case .sendSaslResponse(let responseData2) = saslContinueResponse else {
return XCTFail("\(saslContinueResponse) is not .sendSaslResponse")
Issue.record("\(saslContinueResponse) is not .sendSaslResponse")
return
}
let response2String = String(decoding: responseData2, as: UTF8.self)
XCTAssertEqual(response2String.prefix(76), "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=")
#expect(response2String.prefix(76) == "c=biws,r=\(responseString.dropFirst(12))RUJSZHhkeUVFNzRLNERKMkxmU05ITU1NZWcxaQ==,p=")
}

func testAuthenticationFailure() {
@Test func testAuthenticationFailure() {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
let salt: UInt32 = 0x00_01_02_03

XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext))
let fields: [PostgresBackendMessage.Field: String] = [
.message: "password authentication failed for user \"postgres\"",
.severity: "FATAL",
Expand All @@ -86,13 +88,13 @@ class AuthenticationStateMachineTests: XCTestCase {
.line: "334",
.file: "auth.c"
]
XCTAssertEqual(state.errorReceived(.init(fields: fields)),
#expect(state.errorReceived(.init(fields: fields)) ==
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .server(.init(fields: fields)), closePromise: nil)))
}

// MARK: Test unsupported messages

func testUnsupportedAuthMechanism() {
@Test func testUnsupportedAuthMechanism() {
let unsupported: [(PostgresBackendMessage.Authentication, PSQLError.UnsupportedAuthScheme)] = [
(.kerberosV5, .kerberosV5),
(.scmCredential, .scmCredential),
Expand All @@ -104,14 +106,14 @@ class AuthenticationStateMachineTests: XCTestCase {
for (message, mechanism) in unsupported {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(message),
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(message) ==
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unsupportedAuthMechanism(mechanism), closePromise: nil)))
}
}

func testUnexpectedMessagesAfterStartUp() {
@Test func testUnexpectedMessagesAfterStartUp() {
var buffer = ByteBuffer()
buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8])
let unexpected: [PostgresBackendMessage.Authentication] = [
Expand All @@ -123,14 +125,14 @@ class AuthenticationStateMachineTests: XCTestCase {
for message in unexpected {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(message),
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(message) ==
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil)))
}
}

func testUnexpectedMessagesAfterPasswordSent() {
@Test func testUnexpectedMessagesAfterPasswordSent() {
let salt: UInt32 = 0x00_01_02_03
var buffer = ByteBuffer()
buffer.writeBytes([0, 1, 2, 3, 4, 5, 6, 7, 8])
Expand All @@ -150,10 +152,10 @@ class AuthenticationStateMachineTests: XCTestCase {
for message in unexpected {
let authContext = AuthContext(username: "test", password: "abc123", database: "test")
var state = ConnectionStateMachine(requireBackendKeyData: true)
XCTAssertEqual(state.connected(tls: .disable), .provideAuthenticationContext)
XCTAssertEqual(state.provideAuthenticationContext(authContext), .sendStartupMessage(authContext))
XCTAssertEqual(state.authenticationMessageReceived(.md5(salt: salt)), .sendPasswordMessage(.md5(salt: salt), authContext))
XCTAssertEqual(state.authenticationMessageReceived(message),
#expect(state.connected(tls: .disable) == .provideAuthenticationContext)
#expect(state.provideAuthenticationContext(authContext) == .sendStartupMessage(authContext))
#expect(state.authenticationMessageReceived(.md5(salt: salt)) == .sendPasswordMessage(.md5(salt: salt), authContext))
#expect(state.authenticationMessageReceived(message) ==
.closeConnectionAndCleanup(.init(action: .close, tasks: [], error: .unexpectedBackendMessage(.authentication(message)), closePromise: nil)))
}
}
Expand Down
Loading