Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: 修复 pg ssl key 无法读取的问题 #39

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
import java.io.FileWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.security.KeyFactory;
import java.security.PrivateKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;

public class SSLCertManager {

@Setter
private String caCert; // CA 证书
@Setter
private String clientCertKey; // 客户端私钥
private String clientCertKey; // 客户端私钥 (PEM 格式)
@Setter
private String clientCert; // 客户端证书

Expand All @@ -22,23 +26,24 @@ public class SSLCertManager {
private File clientCertFile;

// 获取 CA 证书的路径
private String getCaCertPath() throws IOException {
public String getCaCertPath() throws IOException {
if (caCertFile == null) {
caCertFile = createTempFile("ca-cert", caCert);
}
return caCertFile.getAbsolutePath();
}

// 获取客户端私钥的路径
private String getClientCertKeyPath() throws IOException {
// 获取客户端私钥的路径,并将 PEM 格式的私钥转换为 DER 格式
public String getClientCertKeyPath() throws Exception {
if (clientCertKeyFile == null) {
clientCertKeyFile = createTempFile("client-cert-key", clientCertKey);
// 检查 clientCertKey 是否是 PEM 格式并转换为 DER
clientCertKeyFile = createTempFile("client-cert-key", convertPEMToDER(clientCertKey));
}
return clientCertKeyFile.getAbsolutePath();
}

// 获取客户端证书的路径
private String getClientCertPath() throws IOException {
public String getClientCertPath() throws IOException {
if (clientCertFile == null) {
clientCertFile = createTempFile("client-cert", clientCert);
}
Expand All @@ -53,6 +58,14 @@ public void Destroy() {
}

// 辅助方法:创建临时文件并写入内容
private File createTempFile(String prefix, byte[] content) throws IOException {
File tempFile = File.createTempFile(prefix, ".der");
Files.write(tempFile.toPath(), content); // 直接写入二进制数据
tempFile.deleteOnExit(); // JVM 退出时自动删除
return tempFile;
}

// 辅助方法:创建临时文件并写入内容(用于普通字符串内容)
private File createTempFile(String prefix, String content) throws IOException {
File tempFile = File.createTempFile(prefix, ".pem");
try (FileWriter writer = new FileWriter(tempFile)) {
Expand All @@ -73,4 +86,23 @@ private void deleteTempFile(File file) {
}
}
}

// 将 PEM 格式的私钥转换为 DER 格式
private byte[] convertPEMToDER(String pemContent) throws Exception {
// 去掉 PEM 格式的头尾标记,获取 Base64 编码内容
pemContent = pemContent.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replaceAll("\\s+", ""); // 去掉空格和换行符

// Base64 解码
byte[] keyBytes = Base64.getDecoder().decode(pemContent);

// 使用 PKCS8EncodedKeySpec 来生成 PrivateKey 对象
PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(keyBytes);
KeyFactory keyFactory = KeyFactory.getInstance("RSA"); // 假设是 RSA 私钥
PrivateKey privateKey = keyFactory.generatePrivate(keySpec);

// 返回 DER 格式的私钥字节数组
return privateKey.getEncoded();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import org.jumpserver.chen.framework.datasource.base.BaseConnectionManager;
import org.jumpserver.chen.framework.datasource.entity.DBConnectInfo;
import org.jumpserver.chen.framework.datasource.sql.SQL;
import org.jumpserver.chen.modules.base.ssl.SSLCertManager;

import java.sql.SQLException;
import java.util.Properties;

public class PostgresqlConnectionManager extends BaseConnectionManager {

private static final String jdbcUrlTemplate = "jdbc:postgresql://${host}:${port}/${db}?useUnicode=true&characterEncoding=UTF-8";
Expand Down Expand Up @@ -35,15 +35,25 @@ protected void setSSLProps(Properties props) {
if (this.getConnectInfo().getOptions().get("useSSL") != null
&& (boolean) this.getConnectInfo().getOptions().get("useSSL")) {

var caCertPath = (String) this.getConnectInfo().getOptions().get("caCert");
var clientCertPath = (String) this.getConnectInfo().getOptions().get("clientCert");
var clientKeyPath = (String) this.getConnectInfo().getOptions().get("clientKey");

props.setProperty("ssl", "true");
props.setProperty("sslmode", "verify-full");
props.setProperty("sslrootcert", caCertPath);
props.setProperty("sslcert", clientCertPath);
props.setProperty("sslkey", clientKeyPath);
var caCert = (String) this.getConnectInfo().getOptions().get("caCert");
var clientCert = (String) this.getConnectInfo().getOptions().get("clientCert");
var clientKey = (String) this.getConnectInfo().getOptions().get("clientKey");

var sslManager = new SSLCertManager();
sslManager.setCaCert(caCert);
sslManager.setClientCert(clientCert);
sslManager.setClientCertKey(clientKey);


try {
props.setProperty("ssl", "true");
props.setProperty("sslmode", "verify-full");
props.setProperty("sslrootcert", sslManager.getCaCertPath());
props.setProperty("sslcert", sslManager.getClientCertPath());
props.setProperty("sslkey", sslManager.getClientCertKeyPath());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

Expand Down