feat: authentication support fallback (#431)

This commit is contained in:
hstyi
2025-03-29 20:37:19 +08:00
committed by GitHub
parent 827d814c7b
commit 30fe047e5c
8 changed files with 245 additions and 127 deletions

View File

@@ -0,0 +1,69 @@
package app.termora;
import org.apache.sshd.common.keyprovider.KeyIdentityProvider;
import org.apache.sshd.common.session.SessionContext;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyPair;
import java.util.*;
public class CombinedKeyIdentityProvider implements KeyIdentityProvider {
private final List<KeyIdentityProvider> providers = new ArrayList<>();
@Override
public Iterable<KeyPair> loadKeys(SessionContext context) {
return () -> new Iterator<>() {
private final Iterator<KeyIdentityProvider> factories = providers
.iterator();
private Iterator<KeyPair> current;
private Boolean hasElement;
@Override
public boolean hasNext() {
if (hasElement != null) {
return hasElement;
}
while (current == null || !current.hasNext()) {
if (factories.hasNext()) {
try {
current = factories.next().loadKeys(context)
.iterator();
} catch (IOException | GeneralSecurityException e) {
throw new RuntimeException(e);
}
} else {
current = null;
hasElement = Boolean.FALSE;
return false;
}
}
hasElement = Boolean.TRUE;
return true;
}
@Override
public KeyPair next() {
if ((hasElement == null && !hasNext()) || !hasElement) {
throw new NoSuchElementException();
}
hasElement = null;
KeyPair result;
try {
result = current.next();
} catch (NoSuchElementException e) {
result = null;
}
return result;
}
};
}
public void addKeyKeyIdentityProvider(KeyIdentityProvider provider) {
providers.add(Objects.requireNonNull(provider));
}
}

View File

@@ -2,7 +2,6 @@ package app.termora
import app.termora.actions.AnAction import app.termora.actions.AnAction
import app.termora.actions.AnActionEvent import app.termora.actions.AnActionEvent
import app.termora.keyboardinteractive.TerminalUserInteraction
import kotlinx.coroutines.* import kotlinx.coroutines.*
import kotlinx.coroutines.swing.Swing import kotlinx.coroutines.swing.Swing
import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.commons.lang3.exception.ExceptionUtils
@@ -103,8 +102,7 @@ class HostDialog(owner: Window, host: Host? = null) : DialogWrapper(owner) {
var client: SshClient? = null var client: SshClient? = null
var session: ClientSession? = null var session: ClientSession? = null
try { try {
client = SshClients.openClient(host) client = SshClients.openClient(host, this)
client.userInteraction = TerminalUserInteraction(owner)
session = SshClients.openSession(host, client) session = SshClients.openSession(host, client)
} finally { } finally {
session?.close() session?.close()

View File

@@ -34,8 +34,8 @@ class RequestAuthenticationDialog(owner: Window, host: Host) : DialogWrapper(own
pack() pack()
size = Dimension(max(380, size.width), size.height) size = Dimension(max(380, size.width), size.height)
preferredSize = size
setLocationRelativeTo(null) minimumSize = size
publicKeyComboBox.renderer = object : DefaultListCellRenderer() { publicKeyComboBox.renderer = object : DefaultListCellRenderer() {
override fun getListCellRendererComponent( override fun getListCellRendererComponent(
@@ -65,6 +65,10 @@ class RequestAuthenticationDialog(owner: Window, host: Host) : DialogWrapper(own
} }
} }
if (host.authentication.type != AuthenticationType.No) {
authenticationTypeComboBox.selectedItem = host.authentication.type
}
usernameTextField.text = host.username usernameTextField.text = host.username
} }

View File

@@ -30,6 +30,7 @@ class SFTPPtyTerminalTab(windowScope: WindowScope, host: Host) : PtyHostTerminal
private var lastPasswordReporterDataListener: PasswordReporterDataListener? = null private var lastPasswordReporterDataListener: PasswordReporterDataListener? = null
private val sftpCommand get() = Database.getDatabase().sftp.sftpCommand private val sftpCommand get() = Database.getDatabase().sftp.sftpCommand
private val defaultDirectory get() = Database.getDatabase().sftp.defaultDirectory private val defaultDirectory get() = Database.getDatabase().sftp.defaultDirectory
private val owner get() = SwingUtilities.getWindowAncestor(terminalPanel)
init { init {
terminalPanel.dropFiles = true terminalPanel.dropFiles = true
@@ -67,7 +68,7 @@ class SFTPPtyTerminalTab(windowScope: WindowScope, host: Host) : PtyHostTerminal
) )
) )
val sshClient = SshClients.openClient(host).apply { sshClient = this } val sshClient = SshClients.openClient(host, owner).apply { sshClient = this }
val sshSession = SshClients.openSession(host, sshClient).apply { sshSession = this } val sshSession = SshClients.openSession(host, sshClient).apply { sshSession = this }
// 打开通道 // 打开通道

View File

@@ -4,7 +4,6 @@ import app.termora.actions.AnActionEvent
import app.termora.actions.DataProviders import app.termora.actions.DataProviders
import app.termora.actions.TabReconnectAction import app.termora.actions.TabReconnectAction
import app.termora.addons.zmodem.ZModemPtyConnectorAdaptor import app.termora.addons.zmodem.ZModemPtyConnectorAdaptor
import app.termora.keyboardinteractive.TerminalUserInteraction
import app.termora.keymap.KeyShortcut import app.termora.keymap.KeyShortcut
import app.termora.keymap.KeymapManager import app.termora.keymap.KeymapManager
import app.termora.terminal.ControlCharacters import app.termora.terminal.ControlCharacters
@@ -89,35 +88,8 @@ class SSHTerminalTab(windowScope: WindowScope, host: Host) :
terminal.write("SSH client is opening...\r\n") terminal.write("SSH client is opening...\r\n")
} }
var host =
this.host.copy(authentication = this.host.authentication.copy(), updateDate = System.currentTimeMillis())
val owner = SwingUtilities.getWindowAncestor(terminalPanel) val owner = SwingUtilities.getWindowAncestor(terminalPanel)
val client = SshClients.openClient(host).also { sshClient = it } val client = SshClients.openClient(host, owner).also { sshClient = it }
client.serverKeyVerifier = DialogServerKeyVerifier(owner)
// keyboard interactive
client.userInteraction = TerminalUserInteraction(owner)
if (host.authentication.type == AuthenticationType.No) {
withContext(Dispatchers.Swing) {
val dialog = RequestAuthenticationDialog(owner, host)
val authentication = dialog.getAuthentication()
host = host.copy(
authentication = authentication,
username = dialog.getUsername(),
updateDate = System.currentTimeMillis(),
)
// save
if (dialog.isRemembered()) {
HostManager.getInstance().addHost(
tab.host.copy(
authentication = authentication,
username = dialog.getUsername(), updateDate = System.currentTimeMillis(),
)
)
}
}
}
val sessionListener = MySessionListener() val sessionListener = MySessionListener()
val channelListener = MyChannelListener() val channelListener = MyChannelListener()

View File

@@ -1,15 +1,15 @@
package app.termora package app.termora
import app.termora.keyboardinteractive.TerminalUserInteraction import app.termora.keyboardinteractive.TerminalUserInteraction
import app.termora.keymgr.KeyManager
import app.termora.keymgr.OhKeyPairKeyPairProvider import app.termora.keymgr.OhKeyPairKeyPairProvider
import app.termora.terminal.TerminalSize import app.termora.terminal.TerminalSize
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.swing.Swing
import kotlinx.coroutines.withContext
import org.apache.commons.io.IOUtils import org.apache.commons.io.IOUtils
import org.apache.commons.lang3.StringUtils import org.apache.commons.lang3.StringUtils
import org.apache.sshd.client.ClientBuilder import org.apache.sshd.client.ClientBuilder
import org.apache.sshd.client.SshClient import org.apache.sshd.client.SshClient
import org.apache.sshd.client.auth.password.PasswordIdentityProvider
import org.apache.sshd.client.auth.password.UserAuthPasswordFactory
import org.apache.sshd.client.channel.ChannelShell import org.apache.sshd.client.channel.ChannelShell
import org.apache.sshd.client.channel.ClientChannelEvent import org.apache.sshd.client.channel.ClientChannelEvent
import org.apache.sshd.client.config.hosts.HostConfigEntry import org.apache.sshd.client.config.hosts.HostConfigEntry
@@ -27,6 +27,7 @@ import org.apache.sshd.common.config.keys.KeyUtils
import org.apache.sshd.common.global.KeepAliveHandler import org.apache.sshd.common.global.KeepAliveHandler
import org.apache.sshd.common.kex.BuiltinDHFactories import org.apache.sshd.common.kex.BuiltinDHFactories
import org.apache.sshd.common.keyprovider.KeyIdentityProvider import org.apache.sshd.common.keyprovider.KeyIdentityProvider
import org.apache.sshd.common.session.SessionContext
import org.apache.sshd.common.util.net.SshdSocketAddress import org.apache.sshd.common.util.net.SshdSocketAddress
import org.apache.sshd.core.CoreModuleProperties import org.apache.sshd.core.CoreModuleProperties
import org.apache.sshd.server.forward.AcceptAllForwardingFilter import org.apache.sshd.server.forward.AcceptAllForwardingFilter
@@ -44,6 +45,7 @@ import java.net.Proxy
import java.net.SocketAddress import java.net.SocketAddress
import java.nio.file.Path import java.nio.file.Path
import java.nio.file.Paths import java.nio.file.Paths
import java.security.KeyPair
import java.security.PublicKey import java.security.PublicKey
import java.time.Duration import java.time.Duration
import java.util.* import java.util.*
@@ -57,6 +59,7 @@ object SshClients {
val HOST_KEY = AttributeRepository.AttributeKey<Host>() val HOST_KEY = AttributeRepository.AttributeKey<Host>()
private val timeout = Duration.ofSeconds(30) private val timeout = Duration.ofSeconds(30)
private val hostManager get() = HostManager.getInstance()
private val log by lazy { LoggerFactory.getLogger(SshClients::class.java) } private val log by lazy { LoggerFactory.getLogger(SshClients::class.java) }
/** /**
@@ -119,16 +122,16 @@ object SshClients {
* 打开一个会话 * 打开一个会话
*/ */
fun openSession(host: Host, client: SshClient): ClientSession { fun openSession(host: Host, client: SshClient): ClientSession {
val h = hostManager.getHost(host.id) ?: host
// 如果没有跳板机直接连接 // 如果没有跳板机直接连接
if (host.options.jumpHosts.isEmpty()) { if (h.options.jumpHosts.isEmpty()) {
return doOpenSession(host, client) return doOpenSession(h, client)
} }
val jumpHosts = mutableListOf<Host>() val jumpHosts = mutableListOf<Host>()
val hosts = HostManager.getInstance().hosts().associateBy { it.id } val hosts = HostManager.getInstance().hosts().associateBy { it.id }
for (jumpHostId in host.options.jumpHosts) { for (jumpHostId in h.options.jumpHosts) {
val e = hosts[jumpHostId] val e = hosts[jumpHostId]
if (e == null) { if (e == null) {
if (log.isWarnEnabled) { if (log.isWarnEnabled) {
@@ -140,7 +143,7 @@ object SshClients {
} }
// 最后一跳是目标机器 // 最后一跳是目标机器
jumpHosts.add(host) jumpHosts.add(h)
val sessions = mutableListOf<ClientSession>() val sessions = mutableListOf<ClientSession>()
for (i in 0 until jumpHosts.size) { for (i in 0 until jumpHosts.size) {
@@ -187,14 +190,25 @@ object SshClients {
entry.hostName = host.host entry.hostName = host.host
entry.setProperty("Middleware", middleware.toString()) entry.setProperty("Middleware", middleware.toString())
val session = client.connect(entry) val session = client.connect(entry).verify(timeout).session
.verify(timeout).session
if (host.authentication.type == AuthenticationType.Password) { if (host.authentication.type == AuthenticationType.Password) {
session.addPasswordIdentity(host.authentication.password) session.addPasswordIdentity(host.authentication.password)
} else if (host.authentication.type == AuthenticationType.PublicKey) { } else if (host.authentication.type == AuthenticationType.PublicKey) {
session.keyIdentityProvider = OhKeyPairKeyPairProvider(host.authentication.password) session.keyIdentityProvider = OhKeyPairKeyPairProvider(host.authentication.password)
} }
val owner = client.properties["owner"] as Window?
if (owner != null) {
val identityProvider = IdentityProvider(host, owner)
session.passwordIdentityProvider = identityProvider
val combinedKeyIdentityProvider = CombinedKeyIdentityProvider()
if (session.keyIdentityProvider != null) {
combinedKeyIdentityProvider.addKeyKeyIdentityProvider(session.keyIdentityProvider)
}
combinedKeyIdentityProvider.addKeyKeyIdentityProvider(identityProvider)
session.keyIdentityProvider = combinedKeyIdentityProvider
}
val verifyTimeout = Duration.ofSeconds(timeout.seconds * 5) val verifyTimeout = Duration.ofSeconds(timeout.seconds * 5)
if (!session.auth().verify(verifyTimeout).await(verifyTimeout)) { if (!session.auth().verify(verifyTimeout).await(verifyTimeout)) {
throw SshException("Authentication failed") throw SshException("Authentication failed")
@@ -241,27 +255,13 @@ object SshClients {
return sshdSocketAddress return sshdSocketAddress
} }
suspend fun openClient(host: Host, owner: Window): Pair<SshClient, Host> { fun openClient(host: Host, owner: Window): SshClient {
val client = openClient(host) val h = hostManager.getHost(host.id) ?: host
var myHost = host val client = openClient(h)
withContext(Dispatchers.Swing) {
client.userInteraction = TerminalUserInteraction(owner) client.userInteraction = TerminalUserInteraction(owner)
client.serverKeyVerifier = DialogServerKeyVerifier(owner) client.serverKeyVerifier = DialogServerKeyVerifier(owner)
// 弹出授权框 client.properties["owner"] = owner
if (host.authentication.type == AuthenticationType.No) { return client
val dialog = RequestAuthenticationDialog(owner, host)
val authentication = dialog.getAuthentication()
myHost = myHost.copy(
authentication = authentication,
username = dialog.getUsername(), updateDate = System.currentTimeMillis(),
)
// save
if (dialog.isRemembered()) {
HostManager.getInstance().addHost(myHost)
}
}
}
return client to myHost
} }
/** /**
@@ -298,6 +298,28 @@ object SshClients {
// JGit 会尝试读取本地的私钥或缓存的私钥 // JGit 会尝试读取本地的私钥或缓存的私钥
sshClient.keyIdentityProvider = KeyIdentityProvider { mutableListOf() } sshClient.keyIdentityProvider = KeyIdentityProvider { mutableListOf() }
// 设置优先级
if (host.authentication.type == AuthenticationType.PublicKey) {
CoreModuleProperties.PREFERRED_AUTHS.set(
sshClient,
listOf(
UserAuthPasswordFactory.PUBLIC_KEY,
UserAuthPasswordFactory.PASSWORD,
UserAuthPasswordFactory.KB_INTERACTIVE
).joinToString(",")
)
} else {
CoreModuleProperties.PREFERRED_AUTHS.set(
sshClient,
listOf(
UserAuthPasswordFactory.PASSWORD,
UserAuthPasswordFactory.PUBLIC_KEY,
UserAuthPasswordFactory.KB_INTERACTIVE
).joinToString(",")
)
}
val heartbeatInterval = max(host.options.heartbeatInterval, 3) val heartbeatInterval = max(host.options.heartbeatInterval, 3)
CoreModuleProperties.HEARTBEAT_INTERVAL.set(sshClient, Duration.ofSeconds(heartbeatInterval.toLong())) CoreModuleProperties.HEARTBEAT_INTERVAL.set(sshClient, Duration.ofSeconds(heartbeatInterval.toLong()))
CoreModuleProperties.ALLOW_DHG1_KEX_FALLBACK.set(sshClient, true) CoreModuleProperties.ALLOW_DHG1_KEX_FALLBACK.set(sshClient, true)
@@ -327,10 +349,9 @@ object SshClients {
sshClient.start() sshClient.start()
return sshClient return sshClient
} }
}
private class MyDialogServerKeyVerifier(private val owner: Window) : ServerKeyVerifier, ModifiedServerKeyAcceptor { private class MyDialogServerKeyVerifier(private val owner: Window) : ServerKeyVerifier, ModifiedServerKeyAcceptor {
override fun verifyServerKey( override fun verifyServerKey(
clientSession: ClientSession, clientSession: ClientSession,
remoteAddress: SocketAddress, remoteAddress: SocketAddress,
@@ -368,14 +389,14 @@ private class MyDialogServerKeyVerifier(private val owner: Window) : ServerKeyVe
return result.get() return result.get()
} }
} }
class DialogServerKeyVerifier( private class DialogServerKeyVerifier(
owner: Window, owner: Window,
) : KnownHostsServerKeyVerifier( ) : KnownHostsServerKeyVerifier(
MyDialogServerKeyVerifier(owner), MyDialogServerKeyVerifier(owner),
Paths.get(Application.getBaseDataDir().absolutePath, "known_hosts") Paths.get(Application.getBaseDataDir().absolutePath, "known_hosts")
) { ) {
init { init {
modifiedServerKeyAcceptor = delegateVerifier as ModifiedServerKeyAcceptor modifiedServerKeyAcceptor = delegateVerifier as ModifiedServerKeyAcceptor
} }
@@ -388,10 +409,63 @@ class DialogServerKeyVerifier(
knownHosts: Collection<HostEntryPair?>? knownHosts: Collection<HostEntryPair?>?
): KnownHostEntry? { ): KnownHostEntry? {
if (clientSession is JGitClientSession) { if (clientSession is JGitClientSession) {
if (SshClients.isMiddleware(clientSession)) { if (isMiddleware(clientSession)) {
return null return null
} }
} }
return super.updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, knownHosts) return super.updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, knownHosts)
} }
}
private class IdentityProvider(private val host: Host, private val owner: Window) : PasswordIdentityProvider,
KeyIdentityProvider {
private val asked = AtomicBoolean(false)
private val hostManager get() = HostManager.getInstance()
private val keyManager get() = KeyManager.getInstance()
private var authentication = Authentication.No
override fun loadPasswords(session: SessionContext): MutableIterable<String> {
val authentication = ask()
if (authentication.type != AuthenticationType.Password) {
return mutableListOf()
}
return mutableListOf(authentication.password)
}
override fun loadKeys(session: SessionContext): MutableIterable<KeyPair> {
val authentication = ask()
if (authentication.type != AuthenticationType.PublicKey) {
return mutableListOf()
}
val ohKeyPair = keyManager.getOhKeyPair(authentication.password) ?: return mutableListOf()
return mutableListOf(OhKeyPairKeyPairProvider.generateKeyPair(ohKeyPair))
}
private fun ask(): Authentication {
if (asked.compareAndSet(false, true)) {
askNow()
}
return authentication
}
private fun askNow() {
if (SwingUtilities.isEventDispatchThread()) {
val dialog = RequestAuthenticationDialog(owner, host)
dialog.setLocationRelativeTo(owner)
authentication = dialog.getAuthentication()
// save
if (dialog.isRemembered()) {
val host = host.copy(
authentication = authentication,
username = dialog.getUsername(), updateDate = System.currentTimeMillis(),
)
hostManager.addHost(host)
}
} else {
SwingUtilities.invokeAndWait { askNow() }
}
}
}
} }

View File

@@ -144,7 +144,7 @@ class SSHCopyIdDialog(
} }
try { try {
val client = SshClients.openClient(host).apply { myClient = this } val client = SshClients.openClient(host, this).apply { myClient = this }
client.userInteraction = TerminalUserInteraction(owner) client.userInteraction = TerminalUserInteraction(owner)
val session = SshClients.openSession(host, client).apply { mySession = this } val session = SshClients.openSession(host, client).apply { mySession = this }
val channel = val channel =

View File

@@ -112,9 +112,9 @@ class SFTPFileSystemViewPanel(
closeIO() closeIO()
try { try {
val (client, host) = SshClients.openClient(thisHost, SwingUtilities.getWindowAncestor(that)) val owner = SwingUtilities.getWindowAncestor(that)
this.client = client val client = SshClients.openClient(thisHost, owner).apply { client = this }
val session = SshClients.openSession(host, client).apply { session = this } val session = SshClients.openSession(thisHost, client).apply { session = this }
fileSystem = SftpClientFactory.instance().createSftpFileSystem(session) fileSystem = SftpClientFactory.instance().createSftpFileSystem(session)
session.addCloseFutureListener { onClose() } session.addCloseFutureListener { onClose() }
} catch (e: Exception) { } catch (e: Exception) {