diff --git a/libraries/SE05X/src/WiFiSSLSE050Client.cpp b/libraries/SE05X/src/WiFiSSLSE050Client.cpp index 7a3b88555..b5c43852e 100644 --- a/libraries/SE05X/src/WiFiSSLSE050Client.cpp +++ b/libraries/SE05X/src/WiFiSSLSE050Client.cpp @@ -26,8 +26,8 @@ arduino::MbedSSLSE050Client::MbedSSLSE050Client() { void arduino::MbedSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) { _keySlot = KeySlot; - _client_cert_len = certLen; - _client_cert = cert; + _certLen = certLen; + _cert = cert; } void WiFiSSLSE050Client::setEccSlot(int KeySlot, const byte cert[], int certLen) { diff --git a/libraries/SE05X/src/WiFiSSLSE050Client.h b/libraries/SE05X/src/WiFiSSLSE050Client.h index 6d3409652..3aed4f4e3 100644 --- a/libraries/SE05X/src/WiFiSSLSE050Client.h +++ b/libraries/SE05X/src/WiFiSSLSE050Client.h @@ -37,37 +37,28 @@ class MbedSSLSE050Client : public arduino::MbedSSLClient { void setEccSlot(int KeySlot, const byte cert[], int certLen); private: - const byte* _client_cert; - const char* _ca_cert; - int _client_cert_len; + const byte* _cert; + int _certLen; int _keySlot; sss_object_t _keyObject; int setRootCAClientCertKey() { - if( NSAPI_ERROR_OK != ((TLSSocket*)sock)->set_root_ca_cert_path("/wlan/")) { - return 0; + int err = setRootCA(); + if (err != NSAPI_ERROR_OK) { + return err; } - if(_hostname && !_disableSNI) { - ((TLSSocket*)sock)->set_hostname(_hostname); + if(SE05X.getObjectHandle(_keySlot, &_keyObject) != NSAPI_ERROR_OK) { + return NSAPI_ERROR_DEVICE_ERROR; } - if( NSAPI_ERROR_OK != ((TLSSocket*)sock)->append_root_ca_cert(_ca_cert_custom)) { - return 0; + if(((TLSSocket*)sock)->set_client_cert_key((void*)_cert, + (size_t)_certLen, + &_keyObject, + SE05X.getDeviceCtx()) != NSAPI_ERROR_OK) { + return NSAPI_ERROR_DEVICE_ERROR; } - - if(!SE05X.getObjectHandle(_keySlot, &_keyObject)) { - return 0; - } - - if( NSAPI_ERROR_OK != ((TLSSocket*)sock)->set_client_cert_key((void*)_client_cert, - (size_t)_client_cert_len, - &_keyObject, - SE05X.getDeviceCtx())) { - return 0; - } - - return 1; + return NSAPI_ERROR_OK; } }; diff --git a/libraries/SocketWrapper/src/AClient.cpp b/libraries/SocketWrapper/src/AClient.cpp index 9ffa9137a..96a22585b 100644 --- a/libraries/SocketWrapper/src/AClient.cpp +++ b/libraries/SocketWrapper/src/AClient.cpp @@ -143,3 +143,24 @@ void arduino::ASslClient::appendCustomCACert(const char* ca_cert) { } static_cast(client.get())->appendCustomCACert(ca_cert); } + +void arduino::ASslClient::setCACert(const char* rootCA) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->setCACert(rootCA); +} + +void arduino::ASslClient::setCertificate(const char* clientCert) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->setCertificate(clientCert); +} + +void arduino::ASslClient::setPrivateKey(const char* privateKey) { + if (!client) { + newMbedClient(); + } + static_cast(client.get())->setPrivateKey(privateKey); +} diff --git a/libraries/SocketWrapper/src/AClient.h b/libraries/SocketWrapper/src/AClient.h index 4f72020ee..9671cc25b 100644 --- a/libraries/SocketWrapper/src/AClient.h +++ b/libraries/SocketWrapper/src/AClient.h @@ -74,6 +74,9 @@ class ASslClient : public AClient { void disableSNI(bool statusSNI); void appendCustomCACert(const char* ca_cert); + void setCACert(const char* rootCA); + void setCertificate(const char* clientCert); + void setPrivateKey(const char* privateKey); protected: virtual void newMbedClient(); diff --git a/libraries/SocketWrapper/src/MbedSSLClient.cpp b/libraries/SocketWrapper/src/MbedSSLClient.cpp index ce4cc9fca..0823bf781 100644 --- a/libraries/SocketWrapper/src/MbedSSLClient.cpp +++ b/libraries/SocketWrapper/src/MbedSSLClient.cpp @@ -1,9 +1,12 @@ #include "MbedSSLClient.h" arduino::MbedSSLClient::MbedSSLClient() - : _ca_cert_custom(nullptr), + : _rootCA(nullptr), _hostname(nullptr), - _disableSNI(false) { + _clientCert(nullptr), + _privateKey(nullptr), + _disableSNI(false), + _appendCA(true) { onBeforeConnect(mbed::callback(this, &MbedSSLClient::setRootCA)); }; diff --git a/libraries/SocketWrapper/src/MbedSSLClient.h b/libraries/SocketWrapper/src/MbedSSLClient.h index d4c48fc82..372dbf5bf 100644 --- a/libraries/SocketWrapper/src/MbedSSLClient.h +++ b/libraries/SocketWrapper/src/MbedSSLClient.h @@ -48,19 +48,48 @@ class MbedSSLClient : public arduino::MbedClient { _disableSNI = statusSNI; } - void appendCustomCACert(const char* ca_cert) { - _ca_cert_custom = ca_cert; + void appendCustomCACert(const char* rootCA) { + _rootCA = rootCA; + _appendCA = true; + } + void setCACert(const char* rootCA) { + _rootCA = rootCA; + _appendCA = false; + } + void setCertificate(const char* clientCert) { + _clientCert = clientCert; + } + void setPrivateKey(const char* privateKey) { + _privateKey = privateKey; } -protected: - const char* _ca_cert_custom; +private: + const char* _rootCA; const char* _hostname; + const char* _clientCert; + const char* _privateKey; bool _disableSNI; + bool _appendCA; -private: +protected: int setRootCA() { int err = 0; + if(_hostname && !_disableSNI) { + ((TLSSocket*)sock)->set_hostname(_hostname); + } + + if(_clientCert && _privateKey) { + err = ((TLSSocket*)sock)->set_client_cert_key(_clientCert, _privateKey); + if( err != NSAPI_ERROR_OK) { + return err; + } + } + + if(!_appendCA && _rootCA) { + return ((TLSSocket*)sock)->set_root_ca_cert(_rootCA); + } + #if defined(MBEDTLS_FS_IO) mbed::BlockDevice* root = mbed::BlockDevice::get_default_instance(); err = root->init(); @@ -82,12 +111,8 @@ class MbedSSLClient : public arduino::MbedClient { } #endif - if(_hostname && !_disableSNI) { - ((TLSSocket*)sock)->set_hostname(_hostname); - } - - if(_ca_cert_custom != NULL) { - err = ((TLSSocket*)sock)->append_root_ca_cert(_ca_cert_custom); + if(_rootCA != NULL) { + err = ((TLSSocket*)sock)->append_root_ca_cert(_rootCA); } return err; }