Files
GoTunnel/internal/client/tunnel/client.go
Flik 46f912423e
All checks were successful
Build Multi-Platform Binaries / build-frontend (push) Successful in 27s
Build Multi-Platform Binaries / build-binaries (amd64, darwin, server, false) (push) Successful in 1m1s
Build Multi-Platform Binaries / build-binaries (amd64, linux, client, true) (push) Successful in 44s
Build Multi-Platform Binaries / build-binaries (amd64, linux, server, true) (push) Successful in 1m27s
Build Multi-Platform Binaries / build-binaries (amd64, windows, client, true) (push) Successful in 44s
Build Multi-Platform Binaries / build-binaries (amd64, windows, server, true) (push) Successful in 1m26s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, client, true) (push) Successful in 50s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, server, true) (push) Successful in 1m43s
Build Multi-Platform Binaries / build-binaries (arm64, darwin, server, false) (push) Successful in 1m2s
Build Multi-Platform Binaries / build-binaries (arm64, linux, client, true) (push) Successful in 44s
Build Multi-Platform Binaries / build-binaries (arm64, linux, server, true) (push) Successful in 1m25s
Build Multi-Platform Binaries / build-binaries (arm64, windows, server, false) (push) Successful in 1m0s
11
2026-01-03 03:07:59 +08:00

903 lines
23 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tunnel
import (
"crypto/tls"
"fmt"
"log"
"net"
"os"
"os/exec"
"path/filepath"
"runtime"
"sync"
"time"
"github.com/gotunnel/pkg/plugin"
"github.com/gotunnel/pkg/plugin/script"
"github.com/gotunnel/pkg/plugin/sign"
"github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/relay"
"github.com/gotunnel/pkg/update"
"github.com/hashicorp/yamux"
)
// 客户端常量
const (
dialTimeout = 10 * time.Second
localDialTimeout = 5 * time.Second
udpTimeout = 10 * time.Second
reconnectDelay = 5 * time.Second
disconnectDelay = 3 * time.Second
udpBufferSize = 65535
idFileName = ".gotunnel_id"
)
// Client 隧道客户端
type Client struct {
ServerAddr string
Token string
ID string
TLSEnabled bool
TLSConfig *tls.Config
DataDir string // 数据目录
session *yamux.Session
rules []protocol.ProxyRule
mu sync.RWMutex
pluginRegistry *plugin.Registry
runningPlugins map[string]plugin.ClientPlugin
versionStore *PluginVersionStore
pluginMu sync.RWMutex
}
// NewClient 创建客户端
func NewClient(serverAddr, token, id string) *Client {
if id == "" {
id = loadClientID()
}
// 默认数据目录
home, _ := os.UserHomeDir()
dataDir := filepath.Join(home, ".gotunnel")
return &Client{
ServerAddr: serverAddr,
Token: token,
ID: id,
DataDir: dataDir,
runningPlugins: make(map[string]plugin.ClientPlugin),
}
}
// InitVersionStore 初始化版本存储
func (c *Client) InitVersionStore() error {
store, err := NewPluginVersionStore(c.DataDir)
if err != nil {
return err
}
c.versionStore = store
return nil
}
// getIDFilePath 获取 ID 文件路径
func getIDFilePath() string {
home, err := os.UserHomeDir()
if err != nil {
return idFileName
}
return filepath.Join(home, idFileName)
}
// loadClientID 从本地文件加载客户端 ID
func loadClientID() string {
data, err := os.ReadFile(getIDFilePath())
if err != nil {
return ""
}
return string(data)
}
// saveClientID 保存客户端 ID 到本地文件
func saveClientID(id string) {
if err := os.WriteFile(getIDFilePath(), []byte(id), 0600); err != nil {
log.Printf("[Client] Failed to save client ID: %v", err)
}
}
// SetPluginRegistry 设置插件注册表
func (c *Client) SetPluginRegistry(registry *plugin.Registry) {
c.pluginRegistry = registry
}
// Run 启动客户端(带断线重连)
func (c *Client) Run() error {
for {
if err := c.connect(); err != nil {
log.Printf("[Client] Connect error: %v", err)
log.Printf("[Client] Reconnecting in %v...", reconnectDelay)
time.Sleep(reconnectDelay)
continue
}
c.handleSession()
log.Printf("[Client] Disconnected, reconnecting...")
time.Sleep(disconnectDelay)
}
}
// connect 连接到服务端并认证
func (c *Client) connect() error {
var conn net.Conn
var err error
if c.TLSEnabled && c.TLSConfig != nil {
dialer := &net.Dialer{Timeout: dialTimeout}
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
} else {
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
}
if err != nil {
return err
}
authReq := protocol.AuthRequest{ClientID: c.ID, Token: c.Token}
msg, _ := protocol.NewMessage(protocol.MsgTypeAuth, authReq)
if err := protocol.WriteMessage(conn, msg); err != nil {
conn.Close()
return err
}
resp, err := protocol.ReadMessage(conn)
if err != nil {
conn.Close()
return err
}
var authResp protocol.AuthResponse
if err := resp.ParsePayload(&authResp); err != nil {
conn.Close()
return fmt.Errorf("parse auth response: %w", err)
}
if !authResp.Success {
conn.Close()
return fmt.Errorf("auth failed: %s", authResp.Message)
}
// 如果服务端分配了新 ID则更新并保存
if authResp.ClientID != "" && authResp.ClientID != c.ID {
c.ID = authResp.ClientID
saveClientID(c.ID)
log.Printf("[Client] New ID assigned and saved: %s", c.ID)
}
log.Printf("[Client] Authenticated as %s", c.ID)
session, err := yamux.Client(conn, nil)
if err != nil {
conn.Close()
return err
}
c.mu.Lock()
c.session = session
c.mu.Unlock()
return nil
}
// handleSession 处理会话
func (c *Client) handleSession() {
defer c.session.Close()
for {
stream, err := c.session.Accept()
if err != nil {
return
}
go c.handleStream(stream)
}
}
// handleStream 处理流
func (c *Client) handleStream(stream net.Conn) {
msg, err := protocol.ReadMessage(stream)
if err != nil {
stream.Close()
return
}
switch msg.Type {
case protocol.MsgTypeProxyConfig:
defer stream.Close()
c.handleProxyConfig(msg)
case protocol.MsgTypeNewProxy:
defer stream.Close()
c.handleNewProxy(stream, msg)
case protocol.MsgTypeHeartbeat:
defer stream.Close()
c.handleHeartbeat(stream)
case protocol.MsgTypeProxyConnect:
c.handleProxyConnect(stream, msg)
case protocol.MsgTypeUDPData:
c.handleUDPData(stream, msg)
case protocol.MsgTypePluginConfig:
defer stream.Close()
c.handlePluginConfig(msg)
case protocol.MsgTypeClientPluginStart:
c.handleClientPluginStart(stream, msg)
case protocol.MsgTypeClientPluginStop:
c.handleClientPluginStop(stream, msg)
case protocol.MsgTypeClientPluginConn:
c.handleClientPluginConn(stream, msg)
case protocol.MsgTypeJSPluginInstall:
c.handleJSPluginInstall(stream, msg)
case protocol.MsgTypeClientRestart:
c.handleClientRestart(stream, msg)
case protocol.MsgTypePluginConfigUpdate:
c.handlePluginConfigUpdate(stream, msg)
case protocol.MsgTypeUpdateDownload:
c.handleUpdateDownload(stream, msg)
}
}
// handleProxyConfig 处理代理配置
func (c *Client) handleProxyConfig(msg *protocol.Message) {
var cfg protocol.ProxyConfig
if err := msg.ParsePayload(&cfg); err != nil {
log.Printf("[Client] Parse proxy config error: %v", err)
return
}
c.mu.Lock()
c.rules = cfg.Rules
c.mu.Unlock()
log.Printf("[Client] Received %d proxy rules", len(cfg.Rules))
for _, r := range cfg.Rules {
log.Printf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort)
}
}
// handleNewProxy 处理新代理请求
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
var req protocol.NewProxyRequest
if err := msg.ParsePayload(&req); err != nil {
log.Printf("[Client] Parse new proxy request error: %v", err)
return
}
var rule *protocol.ProxyRule
c.mu.RLock()
for _, r := range c.rules {
if r.RemotePort == req.RemotePort {
rule = &r
break
}
}
c.mu.RUnlock()
if rule == nil {
log.Printf("[Client] Unknown port %d", req.RemotePort)
return
}
localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout)
if err != nil {
log.Printf("[Client] Connect %s error: %v", localAddr, err)
return
}
relay.Relay(stream, localConn)
}
// handleHeartbeat 处理心跳
func (c *Client) handleHeartbeat(stream net.Conn) {
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeatAck}
protocol.WriteMessage(stream, msg)
}
// handleProxyConnect 处理代理连接请求 (SOCKS5/HTTP)
func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ProxyConnectRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendProxyResult(stream, false, "invalid request")
return
}
// 连接目标地址
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
if err != nil {
c.sendProxyResult(stream, false, err.Error())
return
}
defer targetConn.Close()
// 发送成功响应
if err := c.sendProxyResult(stream, true, ""); err != nil {
return
}
// 双向转发数据
relay.Relay(stream, targetConn)
}
// sendProxyResult 发送代理连接结果
func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) error {
result := protocol.ProxyConnectResult{Success: success, Message: message}
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
return protocol.WriteMessage(stream, msg)
}
// handleUDPData 处理 UDP 数据
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var packet protocol.UDPPacket
if err := msg.ParsePayload(&packet); err != nil {
return
}
// 查找对应的规则
rule := c.findRuleByPort(packet.RemotePort)
if rule == nil {
return
}
// 连接本地 UDP 服务
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
conn, err := net.DialTimeout("udp", target, localDialTimeout)
if err != nil {
return
}
defer conn.Close()
// 发送数据到本地服务
conn.SetDeadline(time.Now().Add(udpTimeout))
if _, err := conn.Write(packet.Data); err != nil {
return
}
// 读取响应
buf := make([]byte, udpBufferSize)
n, err := conn.Read(buf)
if err != nil {
return
}
// 发送响应回服务端
respPacket := protocol.UDPPacket{
RemotePort: packet.RemotePort,
ClientAddr: packet.ClientAddr,
Data: buf[:n],
}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeUDPData, respPacket)
protocol.WriteMessage(stream, respMsg)
}
// findRuleByPort 根据端口查找规则
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
c.mu.RLock()
defer c.mu.RUnlock()
for i := range c.rules {
if c.rules[i].RemotePort == port {
return &c.rules[i]
}
}
return nil
}
// handlePluginConfig 处理插件配置同步
func (c *Client) handlePluginConfig(msg *protocol.Message) {
var cfg protocol.PluginConfigSync
if err := msg.ParsePayload(&cfg); err != nil {
log.Printf("[Client] Parse plugin config error: %v", err)
return
}
log.Printf("[Client] Received config for plugin: %s", cfg.PluginName)
// 应用配置到插件
if c.pluginRegistry != nil {
handler, err := c.pluginRegistry.GetClient(cfg.PluginName)
if err != nil {
log.Printf("[Client] Plugin %s not found: %v", cfg.PluginName, err)
return
}
if err := handler.Init(cfg.Config); err != nil {
log.Printf("[Client] Plugin %s init error: %v", cfg.PluginName, err)
return
}
log.Printf("[Client] Plugin %s config applied", cfg.PluginName)
}
}
// handleClientPluginStart 处理客户端插件启动请求
func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ClientPluginStartRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
log.Printf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName)
// 获取插件
if c.pluginRegistry == nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "plugin registry not set")
return
}
handler, err := c.pluginRegistry.GetClient(req.PluginName)
if err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
// 初始化并启动
if err := handler.Init(req.Config); err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
localAddr, err := handler.Start()
if err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
// 保存运行中的插件
key := req.PluginName + ":" + req.RuleName
c.pluginMu.Lock()
c.runningPlugins[key] = handler
c.pluginMu.Unlock()
log.Printf("[Client] Plugin %s started at %s", req.PluginName, localAddr)
c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, localAddr, "")
}
// sendPluginStatus 发送插件状态响应
func (c *Client) sendPluginStatus(stream net.Conn, pluginName, ruleName string, running bool, localAddr, errMsg string) {
resp := protocol.ClientPluginStatusResponse{
PluginName: pluginName,
RuleName: ruleName,
Running: running,
LocalAddr: localAddr,
Error: errMsg,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeClientPluginStatus, resp)
protocol.WriteMessage(stream, msg)
}
// handleClientPluginConn 处理客户端插件连接
func (c *Client) handleClientPluginConn(stream net.Conn, msg *protocol.Message) {
var req protocol.ClientPluginConnRequest
if err := msg.ParsePayload(&req); err != nil {
stream.Close()
return
}
key := req.PluginName + ":" + req.RuleName
c.pluginMu.RLock()
handler, ok := c.runningPlugins[key]
c.pluginMu.RUnlock()
if !ok {
log.Printf("[Client] Plugin %s not running", key)
stream.Close()
return
}
// 让插件处理连接
handler.HandleConn(stream)
}
// handleJSPluginInstall 处理 JS 插件安装请求
func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.JSPluginInstallRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendJSPluginResult(stream, "", false, err.Error())
return
}
log.Printf("[Client] Installing JS plugin: %s", req.PluginName)
// 如果插件已经在运行,先停止它
key := req.PluginName + ":" + req.RuleName
c.pluginMu.Lock()
if existingHandler, ok := c.runningPlugins[key]; ok {
log.Printf("[Client] Stopping existing plugin %s before reinstall", key)
if err := existingHandler.Stop(); err != nil {
log.Printf("[Client] Stop existing plugin error: %v", err)
}
delete(c.runningPlugins, key)
}
c.pluginMu.Unlock()
// 验证官方签名
if err := c.verifyJSPluginSignature(req.PluginName, req.Source, req.Signature); err != nil {
log.Printf("[Client] JS plugin %s signature verification failed: %v", req.PluginName, err)
c.sendJSPluginResult(stream, req.PluginName, false, "signature verification failed: "+err.Error())
return
}
log.Printf("[Client] JS plugin %s signature verified", req.PluginName)
// 创建 JS 插件
jsPlugin, err := script.NewJSPlugin(req.PluginName, req.Source)
if err != nil {
c.sendJSPluginResult(stream, req.PluginName, false, err.Error())
return
}
// 注册到 registry
if c.pluginRegistry != nil {
c.pluginRegistry.RegisterClient(jsPlugin)
}
log.Printf("[Client] JS plugin %s installed", req.PluginName)
c.sendJSPluginResult(stream, req.PluginName, true, "")
// 保存版本信息(防止降级攻击)
if c.versionStore != nil {
signed, _ := sign.DecodeSignedPlugin(req.Signature)
if signed != nil {
c.versionStore.SetVersion(req.PluginName, signed.Payload.Version)
}
}
// 自动启动
if req.AutoStart {
c.startJSPlugin(jsPlugin, req)
}
}
// sendJSPluginResult 发送 JS 插件安装结果
func (c *Client) sendJSPluginResult(stream net.Conn, name string, success bool, errMsg string) {
result := protocol.JSPluginInstallResult{
PluginName: name,
Success: success,
Error: errMsg,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeJSPluginResult, result)
protocol.WriteMessage(stream, msg)
}
// startJSPlugin 启动 JS 插件
func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPluginInstallRequest) {
if err := handler.Init(req.Config); err != nil {
log.Printf("[Client] JS plugin %s init error: %v", req.PluginName, err)
return
}
localAddr, err := handler.Start()
if err != nil {
log.Printf("[Client] JS plugin %s start error: %v", req.PluginName, err)
return
}
key := req.PluginName + ":" + req.RuleName
c.pluginMu.Lock()
c.runningPlugins[key] = handler
c.pluginMu.Unlock()
log.Printf("[Client] JS plugin %s started at %s", req.PluginName, localAddr)
}
// verifyJSPluginSignature 验证 JS 插件签名
func (c *Client) verifyJSPluginSignature(pluginName, source, signature string) error {
if signature == "" {
return fmt.Errorf("missing signature")
}
// 解码签名
signed, err := sign.DecodeSignedPlugin(signature)
if err != nil {
return fmt.Errorf("decode signature: %w", err)
}
// 根据 KeyID 获取对应公钥
pubKey, err := sign.GetPublicKeyByID(signed.Payload.KeyID)
if err != nil {
return fmt.Errorf("get public key: %w", err)
}
// 验证插件名称匹配
if signed.Payload.Name != pluginName {
return fmt.Errorf("plugin name mismatch: expected %s, got %s",
pluginName, signed.Payload.Name)
}
// 验证签名和源码哈希
if err := sign.VerifyPlugin(pubKey, signed, source); err != nil {
return err
}
// 检查版本降级攻击
if c.versionStore != nil {
currentVer := c.versionStore.GetVersion(pluginName)
if currentVer != "" {
cmp := sign.CompareVersions(signed.Payload.Version, currentVer)
if cmp < 0 {
return fmt.Errorf("version downgrade rejected: %s < %s",
signed.Payload.Version, currentVer)
}
}
}
return nil
}
// handleClientPluginStop 处理客户端插件停止请求
func (c *Client) handleClientPluginStop(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ClientPluginStopRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, "", err.Error())
return
}
key := req.PluginName + ":" + req.RuleName
c.pluginMu.Lock()
handler, ok := c.runningPlugins[key]
if ok {
if err := handler.Stop(); err != nil {
log.Printf("[Client] Plugin %s stop error: %v", key, err)
}
delete(c.runningPlugins, key)
}
c.pluginMu.Unlock()
log.Printf("[Client] Plugin %s stopped", key)
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "")
}
// handleClientRestart 处理客户端重启请求
func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ClientRestartRequest
msg.ParsePayload(&req)
log.Printf("[Client] Restart requested: %s", req.Reason)
// 发送响应
resp := protocol.ClientRestartResponse{
Success: true,
Message: "restarting",
}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, resp)
protocol.WriteMessage(stream, respMsg)
// 停止所有运行中的插件
c.pluginMu.Lock()
for key, handler := range c.runningPlugins {
log.Printf("[Client] Stopping plugin %s for restart", key)
handler.Stop()
}
c.runningPlugins = make(map[string]plugin.ClientPlugin)
c.pluginMu.Unlock()
// 关闭会话(会触发重连)
if c.session != nil {
c.session.Close()
}
}
// handlePluginConfigUpdate 处理插件配置更新请求
func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.PluginConfigUpdateRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
return
}
key := req.PluginName + ":" + req.RuleName
log.Printf("[Client] Config update for plugin %s", key)
c.pluginMu.RLock()
handler, ok := c.runningPlugins[key]
c.pluginMu.RUnlock()
if !ok {
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, "plugin not running")
return
}
if req.Restart {
// 停止并重启插件
c.pluginMu.Lock()
if err := handler.Stop(); err != nil {
log.Printf("[Client] Plugin %s stop error: %v", key, err)
}
delete(c.runningPlugins, key)
c.pluginMu.Unlock()
// 重新初始化和启动
if err := handler.Init(req.Config); err != nil {
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
return
}
localAddr, err := handler.Start()
if err != nil {
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
return
}
c.pluginMu.Lock()
c.runningPlugins[key] = handler
c.pluginMu.Unlock()
log.Printf("[Client] Plugin %s restarted at %s with new config", key, localAddr)
}
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, true, "")
}
// sendPluginConfigUpdateResult 发送插件配置更新结果
func (c *Client) sendPluginConfigUpdateResult(stream net.Conn, pluginName, ruleName string, success bool, errMsg string) {
result := protocol.PluginConfigUpdateResponse{
PluginName: pluginName,
RuleName: ruleName,
Success: success,
Error: errMsg,
}
msg, _ := protocol.NewMessage(protocol.MsgTypePluginConfigUpdate, result)
protocol.WriteMessage(stream, msg)
}
// handleUpdateDownload 处理更新下载请求
func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.UpdateDownloadRequest
if err := msg.ParsePayload(&req); err != nil {
log.Printf("[Client] Parse update request error: %v", err)
c.sendUpdateResult(stream, false, "invalid request")
return
}
log.Printf("[Client] Update download requested: %s", req.DownloadURL)
// 异步执行更新
go func() {
if err := c.performSelfUpdate(req.DownloadURL); err != nil {
log.Printf("[Client] Update failed: %v", err)
}
}()
c.sendUpdateResult(stream, true, "update started")
}
// sendUpdateResult 发送更新结果
func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) {
result := protocol.UpdateResultResponse{
Success: success,
Message: message,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeUpdateResult, result)
protocol.WriteMessage(stream, msg)
}
// performSelfUpdate 执行自更新
func (c *Client) performSelfUpdate(downloadURL string) error {
log.Printf("[Client] Starting self-update from: %s", downloadURL)
// 使用共享的下载和解压逻辑
binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client")
if err != nil {
return err
}
defer cleanup()
// 获取当前可执行文件路径
currentPath, err := os.Executable()
if err != nil {
return fmt.Errorf("get executable: %w", err)
}
currentPath, _ = filepath.EvalSymlinks(currentPath)
// Windows 需要特殊处理
if runtime.GOOS == "windows" {
return performWindowsClientUpdate(binaryPath, currentPath, c.ServerAddr, c.Token, c.ID)
}
// Linux/Mac: 直接替换
backupPath := currentPath + ".bak"
// 停止所有插件
c.stopAllPlugins()
// 备份当前文件
if err := os.Rename(currentPath, backupPath); err != nil {
return fmt.Errorf("backup current: %w", err)
}
// 复制新文件(不能用 rename可能跨文件系统
if err := update.CopyFile(binaryPath, currentPath); err != nil {
os.Rename(backupPath, currentPath)
return fmt.Errorf("replace binary: %w", err)
}
// 设置执行权限
if err := os.Chmod(currentPath, 0755); err != nil {
os.Rename(backupPath, currentPath)
return fmt.Errorf("chmod: %w", err)
}
// 删除备份
os.Remove(backupPath)
log.Printf("[Client] Update completed, restarting...")
// 重启进程
restartClientProcess(currentPath, c.ServerAddr, c.Token, c.ID)
return nil
}
// stopAllPlugins 停止所有运行中的插件
func (c *Client) stopAllPlugins() {
c.pluginMu.Lock()
for key, handler := range c.runningPlugins {
log.Printf("[Client] Stopping plugin %s for update", key)
handler.Stop()
}
c.runningPlugins = make(map[string]plugin.ClientPlugin)
c.pluginMu.Unlock()
}
// performWindowsClientUpdate Windows 平台更新
func performWindowsClientUpdate(newFile, currentPath, serverAddr, token, id string) error {
// 创建批处理脚本
args := fmt.Sprintf(`-s "%s" -t "%s"`, serverAddr, token)
if id != "" {
args += fmt.Sprintf(` -id "%s"`, id)
}
batchScript := fmt.Sprintf(`@echo off
ping 127.0.0.1 -n 2 > nul
del "%s"
move "%s" "%s"
start "" "%s" %s
del "%%~f0"
`, currentPath, newFile, currentPath, currentPath, args)
batchPath := filepath.Join(os.TempDir(), "gotunnel_client_update.bat")
if err := os.WriteFile(batchPath, []byte(batchScript), 0755); err != nil {
return fmt.Errorf("write batch: %w", err)
}
cmd := exec.Command("cmd", "/C", "start", "/MIN", batchPath)
if err := cmd.Start(); err != nil {
return fmt.Errorf("start batch: %w", err)
}
// 退出当前进程
os.Exit(0)
return nil
}
// restartClientProcess 重启客户端进程
func restartClientProcess(path, serverAddr, token, id string) {
args := []string{"-s", serverAddr, "-t", token}
if id != "" {
args = append(args, "-id", id)
}
cmd := exec.Command(path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Start()
os.Exit(0)
}