feat: SFTP command support for Jump Hosts and Proxy (#236)

This commit is contained in:
hstyi
2025-02-15 13:15:02 +08:00
committed by GitHub
parent 7f8573ec4c
commit 1e8c617a85
3 changed files with 156 additions and 88 deletions

View File

@@ -3,29 +3,144 @@ package app.termora
import app.termora.keymgr.KeyManager import app.termora.keymgr.KeyManager
import app.termora.keymgr.OhKeyPairKeyPairProvider import app.termora.keymgr.OhKeyPairKeyPairProvider
import app.termora.terminal.* import app.termora.terminal.*
import kotlinx.coroutines.Dispatchers import com.formdev.flatlaf.util.SystemInfo
import kotlinx.coroutines.swing.Swing
import kotlinx.coroutines.withContext
import org.apache.commons.io.Charsets import org.apache.commons.io.Charsets
import org.apache.commons.io.FileUtils import org.apache.commons.io.FileUtils
import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
import org.apache.sshd.client.SshClient
import org.apache.sshd.client.session.ClientSession
import org.apache.sshd.common.config.keys.writer.openssh.OpenSSHKeyPairResourceWriter import org.apache.sshd.common.config.keys.writer.openssh.OpenSSHKeyPairResourceWriter
import org.apache.sshd.common.util.net.SshdSocketAddress
import java.awt.event.KeyEvent import java.awt.event.KeyEvent
import java.io.File import java.io.File
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.nio.file.Files import java.nio.file.Files
import java.nio.file.Path import java.nio.file.Path
import javax.swing.SwingUtilities
class SFTPPtyTerminalTab(windowScope: WindowScope, host: Host) : PtyHostTerminalTab(windowScope, host) { class SFTPPtyTerminalTab(windowScope: WindowScope, host: Host) : PtyHostTerminalTab(windowScope, host) {
private val keyManager by lazy { KeyManager.getInstance() } private val keyManager by lazy { KeyManager.getInstance() }
private val tempFiles = mutableListOf<Path>() private val tempFiles = mutableListOf<Path>()
private val passwordDataListener = object : DataListener { private var sshClient: SshClient? = null
private var sshSession: ClientSession? = null
private var lastPasswordReporterDataListener: PasswordReporterDataListener? = null
override suspend fun openPtyConnector(): PtyConnector {
val useJumpHosts = host.options.jumpHosts.isNotEmpty() || host.proxy.type != ProxyType.No
val commands = mutableListOf("sftp")
var host = this.host
// 如果配置了跳板机或者代理,那么通过 SSH 的端口转发到本地
if (useJumpHosts) {
host = host.copy(
tunnelings = listOf(
Tunneling(
type = TunnelingType.Local,
sourceHost = SshdSocketAddress.LOCALHOST_NAME,
destinationHost = SshdSocketAddress.LOCALHOST_NAME,
destinationPort = host.port,
)
)
)
val sshClient = SshClients.openClient(host).apply { sshClient = this }
val sshSession = SshClients.openSession(host, sshClient).apply { sshSession = this }
// 打开通道
for (tunneling in host.tunnelings) {
val address = SshClients.openTunneling(sshSession, host, tunneling)
host = host.copy(host = address.hostName, port = address.port)
}
}
if (useJumpHosts) {
// 打开通道后忽略 key 检查
commands.add("-o")
commands.add("StrictHostKeyChecking=no")
// 不保存 known_hosts
commands.add("-o")
commands.add("UserKnownHostsFile=${if (SystemInfo.isWindows) "NUL" else "/dev/null"}")
} else {
// known_hosts
commands.add("-o")
commands.add("UserKnownHostsFile=${File(Application.getBaseDataDir(), "known_hosts").absolutePath}")
}
// Compression
commands.add("-o")
commands.add("Compression=yes")
// port
commands.add("-P")
commands.add(host.port.toString())
// 设置认证信息
setAuthentication(commands, host)
commands.add("${host.username}@${host.host}")
val winSize = terminalPanel.winSize()
val ptyConnector = ptyConnectorFactory.createPtyConnector(
commands.toTypedArray(),
winSize.rows, winSize.cols,
host.options.envs(),
Charsets.toCharset(host.options.encoding, StandardCharsets.UTF_8),
)
return ptyConnector
}
private fun setAuthentication(commands: MutableList<String>, host: Host) {
// 如果通过公钥连接
if (host.authentication.type == AuthenticationType.PublicKey) {
val keyPair = keyManager.getOhKeyPair(host.authentication.password)
if (keyPair != null) {
val keyPair = OhKeyPairKeyPairProvider.generateKeyPair(keyPair)
val privateKeyPath = Application.createSubTemporaryDir()
val privateKeyFile = Files.createTempFile(privateKeyPath, Application.getName(), StringUtils.EMPTY)
Files.newOutputStream(privateKeyFile)
.use { OpenSSHKeyPairResourceWriter.INSTANCE.writePrivateKey(keyPair, null, null, it) }
commands.add("-i")
commands.add(privateKeyFile.toFile().absolutePath)
tempFiles.add(privateKeyPath)
}
} else if (host.authentication.type == AuthenticationType.Password) {
terminal.getTerminalModel().addDataListener(PasswordReporterDataListener(host).apply {
lastPasswordReporterDataListener = this
})
}
}
override fun stop() {
// 删除密码监听
lastPasswordReporterDataListener?.let { listener ->
SwingUtilities.invokeLater { terminal.getTerminalModel().removeDataListener(listener) }
}
IOUtils.closeQuietly(sshSession)
IOUtils.closeQuietly(sshClient)
tempFiles.removeIf {
FileUtils.deleteQuietly(it.toFile())
true
}
super.stop()
}
private inner class PasswordReporterDataListener(private val host: Host) : DataListener {
override fun onChanged(key: DataKey<*>, data: Any) { override fun onChanged(key: DataKey<*>, data: Any) {
if (key == VisualTerminal.Written && data is String) { if (key == VisualTerminal.Written && data is String) {
// 要求输入密码 // 要求输入密码
val line = terminal.getDocument().getScreenLine(terminal.getCursorModel().getPosition().y) val line = terminal.getDocument().getScreenLine(terminal.getCursorModel().getPosition().y)
if (line.getText().startsWith("${host.username}@${host.host}'s password:")) { if (line.getText().trim().trimIndent().startsWith("${host.username}@${host.host}'s password:")) {
// 删除密码监听 // 删除密码监听
terminal.getTerminalModel().removeDataListener(this) terminal.getTerminalModel().removeDataListener(this)
@@ -45,65 +160,4 @@ class SFTPPtyTerminalTab(windowScope: WindowScope, host: Host) : PtyHostTerminal
} }
} }
} }
override suspend fun openPtyConnector(): PtyConnector {
// 删除密码监听
withContext(Dispatchers.Swing) { terminal.getTerminalModel().removeDataListener(passwordDataListener) }
val winSize = terminalPanel.winSize()
val commands = mutableListOf("sftp")
// known_hosts
commands.add("-o")
commands.add("UserKnownHostsFile=${File(Application.getBaseDataDir(), "known_hosts").absolutePath}")
// Compression
commands.add("-o")
commands.add("Compression=yes")
// port
commands.add("-P")
commands.add(host.port.toString())
// 设置认证信息
setAuthentication(commands)
commands.add("${host.username}@${host.host}")
val ptyConnector = ptyConnectorFactory.createPtyConnector(
commands.toTypedArray(),
winSize.rows, winSize.cols,
host.options.envs(),
Charsets.toCharset(host.options.encoding, StandardCharsets.UTF_8),
)
return ptyConnector
}
private fun setAuthentication(commands: MutableList<String>) {
// 如果通过公钥连接
if (host.authentication.type == AuthenticationType.PublicKey) {
val keyPair = keyManager.getOhKeyPair(host.authentication.password)
if (keyPair != null) {
val keyPair = OhKeyPairKeyPairProvider.generateKeyPair(keyPair)
val privateKeyPath = Application.createSubTemporaryDir()
val privateKeyFile = Files.createTempFile(privateKeyPath, Application.getName(), StringUtils.EMPTY)
Files.newOutputStream(privateKeyFile)
.use { OpenSSHKeyPairResourceWriter.INSTANCE.writePrivateKey(keyPair, null, null, it) }
commands.add("-i")
commands.add(privateKeyFile.toFile().absolutePath)
tempFiles.add(privateKeyPath)
}
} else if (host.authentication.type == AuthenticationType.Password) {
terminal.getTerminalModel().addDataListener(passwordDataListener)
}
}
override fun stop() {
for (path in tempFiles) {
FileUtils.deleteQuietly(path.toFile())
}
tempFiles.clear()
super.stop()
}
} }

View File

@@ -26,7 +26,6 @@ import org.apache.sshd.common.channel.ChannelListener
import org.apache.sshd.common.session.Session import org.apache.sshd.common.session.Session
import org.apache.sshd.common.session.SessionListener import org.apache.sshd.common.session.SessionListener
import org.apache.sshd.common.session.SessionListener.Event import org.apache.sshd.common.session.SessionListener.Event
import org.apache.sshd.common.util.net.SshdSocketAddress
import org.slf4j.LoggerFactory import org.slf4j.LoggerFactory
import java.nio.charset.StandardCharsets import java.nio.charset.StandardCharsets
import java.util.* import java.util.*
@@ -193,28 +192,8 @@ class SSHTerminalTab(windowScope: WindowScope, host: Host) :
} }
for (tunneling in host.tunnelings) { for (tunneling in host.tunnelings) {
if (tunneling.type == TunnelingType.Local) {
session.startLocalPortForwarding(
SshdSocketAddress(tunneling.sourceHost, tunneling.sourcePort),
SshdSocketAddress(tunneling.destinationHost, tunneling.destinationPort)
)
} else if (tunneling.type == TunnelingType.Remote) {
session.startRemotePortForwarding(
SshdSocketAddress(tunneling.sourceHost, tunneling.sourcePort),
SshdSocketAddress(tunneling.destinationHost, tunneling.destinationPort),
)
} else if (tunneling.type == TunnelingType.Dynamic) {
session.startDynamicPortForwarding(
SshdSocketAddress(
tunneling.sourceHost,
tunneling.sourcePort
)
)
}
if (log.isInfoEnabled) { SshClients.openTunneling(session, host, tunneling)
log.info("SSH [{}] started {} port forwarding.", host.name, tunneling.name)
}
withContext(Dispatchers.Swing) { withContext(Dispatchers.Swing) {
terminal.write("Start [${tunneling.name}] port forwarding successfully.\r\n") terminal.write("Start [${tunneling.name}] port forwarding successfully.\r\n")

View File

@@ -162,6 +162,41 @@ object SshClients {
return session return session
} }
fun openTunneling(session: ClientSession, host: Host, tunneling: Tunneling): SshdSocketAddress {
val sshdSocketAddress = if (tunneling.type == TunnelingType.Local) {
session.startLocalPortForwarding(
SshdSocketAddress(tunneling.sourceHost, tunneling.sourcePort),
SshdSocketAddress(tunneling.destinationHost, tunneling.destinationPort)
)
} else if (tunneling.type == TunnelingType.Remote) {
session.startRemotePortForwarding(
SshdSocketAddress(tunneling.sourceHost, tunneling.sourcePort),
SshdSocketAddress(tunneling.destinationHost, tunneling.destinationPort),
)
} else if (tunneling.type == TunnelingType.Dynamic) {
session.startDynamicPortForwarding(
SshdSocketAddress(
tunneling.sourceHost,
tunneling.sourcePort
)
)
} else {
SshdSocketAddress.LOCALHOST_ADDRESS
}
if (log.isInfoEnabled) {
log.info(
"SSH [{}] started {} port forwarding. host: {} , port: {}",
host.name,
tunneling.name,
sshdSocketAddress.hostName,
sshdSocketAddress.port
)
}
return sshdSocketAddress
}
/** /**
* 打开一个客户端 * 打开一个客户端