All checks were successful
Build Multi-Platform Binaries / build-frontend (push) Successful in 29s
Build Multi-Platform Binaries / build-binaries (amd64, darwin, server, false) (push) Successful in 1m3s
Build Multi-Platform Binaries / build-binaries (amd64, linux, client, true) (push) Successful in 49s
Build Multi-Platform Binaries / build-binaries (amd64, linux, server, true) (push) Successful in 1m30s
Build Multi-Platform Binaries / build-binaries (amd64, windows, client, true) (push) Successful in 46s
Build Multi-Platform Binaries / build-binaries (amd64, windows, server, true) (push) Successful in 1m29s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, client, true) (push) Successful in 51s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, server, true) (push) Successful in 1m44s
Build Multi-Platform Binaries / build-binaries (arm64, darwin, server, false) (push) Successful in 1m5s
Build Multi-Platform Binaries / build-binaries (arm64, linux, client, true) (push) Successful in 1m15s
Build Multi-Platform Binaries / build-binaries (arm64, linux, server, true) (push) Successful in 1m35s
Build Multi-Platform Binaries / build-binaries (arm64, windows, server, false) (push) Successful in 1m3s
937 lines
24 KiB
Go
937 lines
24 KiB
Go
package tunnel
|
||
|
||
import (
|
||
"crypto/tls"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"os/exec"
|
||
"path/filepath"
|
||
"runtime"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gotunnel/pkg/plugin"
|
||
"github.com/gotunnel/pkg/plugin/script"
|
||
"github.com/gotunnel/pkg/plugin/sign"
|
||
"github.com/gotunnel/pkg/protocol"
|
||
"github.com/gotunnel/pkg/relay"
|
||
"github.com/hashicorp/yamux"
|
||
)
|
||
|
||
// 客户端常量
|
||
const (
|
||
dialTimeout = 10 * time.Second
|
||
localDialTimeout = 5 * time.Second
|
||
udpTimeout = 10 * time.Second
|
||
reconnectDelay = 5 * time.Second
|
||
disconnectDelay = 3 * time.Second
|
||
udpBufferSize = 65535
|
||
idFileName = ".gotunnel_id"
|
||
)
|
||
|
||
// Client 隧道客户端
|
||
type Client struct {
|
||
ServerAddr string
|
||
Token string
|
||
ID string
|
||
TLSEnabled bool
|
||
TLSConfig *tls.Config
|
||
DataDir string // 数据目录
|
||
session *yamux.Session
|
||
rules []protocol.ProxyRule
|
||
mu sync.RWMutex
|
||
pluginRegistry *plugin.Registry
|
||
runningPlugins map[string]plugin.ClientPlugin
|
||
versionStore *PluginVersionStore
|
||
pluginMu sync.RWMutex
|
||
}
|
||
|
||
// NewClient 创建客户端
|
||
func NewClient(serverAddr, token, id string) *Client {
|
||
if id == "" {
|
||
id = loadClientID()
|
||
}
|
||
|
||
// 默认数据目录
|
||
home, _ := os.UserHomeDir()
|
||
dataDir := filepath.Join(home, ".gotunnel")
|
||
|
||
return &Client{
|
||
ServerAddr: serverAddr,
|
||
Token: token,
|
||
ID: id,
|
||
DataDir: dataDir,
|
||
runningPlugins: make(map[string]plugin.ClientPlugin),
|
||
}
|
||
}
|
||
|
||
// InitVersionStore 初始化版本存储
|
||
func (c *Client) InitVersionStore() error {
|
||
store, err := NewPluginVersionStore(c.DataDir)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
c.versionStore = store
|
||
return nil
|
||
}
|
||
|
||
// getIDFilePath 获取 ID 文件路径
|
||
func getIDFilePath() string {
|
||
home, err := os.UserHomeDir()
|
||
if err != nil {
|
||
return idFileName
|
||
}
|
||
return filepath.Join(home, idFileName)
|
||
}
|
||
|
||
// loadClientID 从本地文件加载客户端 ID
|
||
func loadClientID() string {
|
||
data, err := os.ReadFile(getIDFilePath())
|
||
if err != nil {
|
||
return ""
|
||
}
|
||
return string(data)
|
||
}
|
||
|
||
// saveClientID 保存客户端 ID 到本地文件
|
||
func saveClientID(id string) {
|
||
if err := os.WriteFile(getIDFilePath(), []byte(id), 0600); err != nil {
|
||
log.Printf("[Client] Failed to save client ID: %v", err)
|
||
}
|
||
}
|
||
|
||
// SetPluginRegistry 设置插件注册表
|
||
func (c *Client) SetPluginRegistry(registry *plugin.Registry) {
|
||
c.pluginRegistry = registry
|
||
}
|
||
|
||
// Run 启动客户端(带断线重连)
|
||
func (c *Client) Run() error {
|
||
for {
|
||
if err := c.connect(); err != nil {
|
||
log.Printf("[Client] Connect error: %v", err)
|
||
log.Printf("[Client] Reconnecting in %v...", reconnectDelay)
|
||
time.Sleep(reconnectDelay)
|
||
continue
|
||
}
|
||
|
||
c.handleSession()
|
||
log.Printf("[Client] Disconnected, reconnecting...")
|
||
time.Sleep(disconnectDelay)
|
||
}
|
||
}
|
||
|
||
// connect 连接到服务端并认证
|
||
func (c *Client) connect() error {
|
||
var conn net.Conn
|
||
var err error
|
||
|
||
if c.TLSEnabled && c.TLSConfig != nil {
|
||
dialer := &net.Dialer{Timeout: dialTimeout}
|
||
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
|
||
} else {
|
||
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
|
||
}
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
authReq := protocol.AuthRequest{ClientID: c.ID, Token: c.Token}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeAuth, authReq)
|
||
if err := protocol.WriteMessage(conn, msg); err != nil {
|
||
conn.Close()
|
||
return err
|
||
}
|
||
|
||
resp, err := protocol.ReadMessage(conn)
|
||
if err != nil {
|
||
conn.Close()
|
||
return err
|
||
}
|
||
|
||
var authResp protocol.AuthResponse
|
||
if err := resp.ParsePayload(&authResp); err != nil {
|
||
conn.Close()
|
||
return fmt.Errorf("parse auth response: %w", err)
|
||
}
|
||
if !authResp.Success {
|
||
conn.Close()
|
||
return fmt.Errorf("auth failed: %s", authResp.Message)
|
||
}
|
||
|
||
// 如果服务端分配了新 ID,则更新并保存
|
||
if authResp.ClientID != "" && authResp.ClientID != c.ID {
|
||
c.ID = authResp.ClientID
|
||
saveClientID(c.ID)
|
||
log.Printf("[Client] New ID assigned and saved: %s", c.ID)
|
||
}
|
||
|
||
log.Printf("[Client] Authenticated as %s", c.ID)
|
||
|
||
session, err := yamux.Client(conn, nil)
|
||
if err != nil {
|
||
conn.Close()
|
||
return err
|
||
}
|
||
|
||
c.mu.Lock()
|
||
c.session = session
|
||
c.mu.Unlock()
|
||
|
||
return nil
|
||
}
|
||
|
||
// handleSession 处理会话
|
||
func (c *Client) handleSession() {
|
||
defer c.session.Close()
|
||
|
||
for {
|
||
stream, err := c.session.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
go c.handleStream(stream)
|
||
}
|
||
}
|
||
|
||
// handleStream 处理流
|
||
func (c *Client) handleStream(stream net.Conn) {
|
||
msg, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
stream.Close()
|
||
return
|
||
}
|
||
|
||
switch msg.Type {
|
||
case protocol.MsgTypeProxyConfig:
|
||
defer stream.Close()
|
||
c.handleProxyConfig(msg)
|
||
case protocol.MsgTypeNewProxy:
|
||
defer stream.Close()
|
||
c.handleNewProxy(stream, msg)
|
||
case protocol.MsgTypeHeartbeat:
|
||
defer stream.Close()
|
||
c.handleHeartbeat(stream)
|
||
case protocol.MsgTypeProxyConnect:
|
||
c.handleProxyConnect(stream, msg)
|
||
case protocol.MsgTypeUDPData:
|
||
c.handleUDPData(stream, msg)
|
||
case protocol.MsgTypePluginConfig:
|
||
defer stream.Close()
|
||
c.handlePluginConfig(msg)
|
||
case protocol.MsgTypeClientPluginStart:
|
||
c.handleClientPluginStart(stream, msg)
|
||
case protocol.MsgTypeClientPluginStop:
|
||
c.handleClientPluginStop(stream, msg)
|
||
case protocol.MsgTypeClientPluginConn:
|
||
c.handleClientPluginConn(stream, msg)
|
||
case protocol.MsgTypeJSPluginInstall:
|
||
c.handleJSPluginInstall(stream, msg)
|
||
case protocol.MsgTypeClientRestart:
|
||
c.handleClientRestart(stream, msg)
|
||
case protocol.MsgTypePluginConfigUpdate:
|
||
c.handlePluginConfigUpdate(stream, msg)
|
||
case protocol.MsgTypeUpdateDownload:
|
||
c.handleUpdateDownload(stream, msg)
|
||
}
|
||
}
|
||
|
||
// handleProxyConfig 处理代理配置
|
||
func (c *Client) handleProxyConfig(msg *protocol.Message) {
|
||
var cfg protocol.ProxyConfig
|
||
if err := msg.ParsePayload(&cfg); err != nil {
|
||
log.Printf("[Client] Parse proxy config error: %v", err)
|
||
return
|
||
}
|
||
|
||
c.mu.Lock()
|
||
c.rules = cfg.Rules
|
||
c.mu.Unlock()
|
||
|
||
log.Printf("[Client] Received %d proxy rules", len(cfg.Rules))
|
||
for _, r := range cfg.Rules {
|
||
log.Printf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort)
|
||
}
|
||
}
|
||
|
||
// handleNewProxy 处理新代理请求
|
||
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
|
||
var req protocol.NewProxyRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
log.Printf("[Client] Parse new proxy request error: %v", err)
|
||
return
|
||
}
|
||
|
||
var rule *protocol.ProxyRule
|
||
c.mu.RLock()
|
||
for _, r := range c.rules {
|
||
if r.RemotePort == req.RemotePort {
|
||
rule = &r
|
||
break
|
||
}
|
||
}
|
||
c.mu.RUnlock()
|
||
|
||
if rule == nil {
|
||
log.Printf("[Client] Unknown port %d", req.RemotePort)
|
||
return
|
||
}
|
||
|
||
localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
|
||
localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout)
|
||
if err != nil {
|
||
log.Printf("[Client] Connect %s error: %v", localAddr, err)
|
||
return
|
||
}
|
||
|
||
relay.Relay(stream, localConn)
|
||
}
|
||
|
||
// handleHeartbeat 处理心跳
|
||
func (c *Client) handleHeartbeat(stream net.Conn) {
|
||
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeatAck}
|
||
protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// handleProxyConnect 处理代理连接请求 (SOCKS5/HTTP)
|
||
func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.ProxyConnectRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
c.sendProxyResult(stream, false, "invalid request")
|
||
return
|
||
}
|
||
|
||
// 连接目标地址
|
||
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
|
||
if err != nil {
|
||
c.sendProxyResult(stream, false, err.Error())
|
||
return
|
||
}
|
||
defer targetConn.Close()
|
||
|
||
// 发送成功响应
|
||
if err := c.sendProxyResult(stream, true, ""); err != nil {
|
||
return
|
||
}
|
||
|
||
// 双向转发数据
|
||
relay.Relay(stream, targetConn)
|
||
}
|
||
|
||
// sendProxyResult 发送代理连接结果
|
||
func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) error {
|
||
result := protocol.ProxyConnectResult{Success: success, Message: message}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
|
||
return protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// handleUDPData 处理 UDP 数据
|
||
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var packet protocol.UDPPacket
|
||
if err := msg.ParsePayload(&packet); err != nil {
|
||
return
|
||
}
|
||
|
||
// 查找对应的规则
|
||
rule := c.findRuleByPort(packet.RemotePort)
|
||
if rule == nil {
|
||
return
|
||
}
|
||
|
||
// 连接本地 UDP 服务
|
||
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
|
||
conn, err := net.DialTimeout("udp", target, localDialTimeout)
|
||
if err != nil {
|
||
return
|
||
}
|
||
defer conn.Close()
|
||
|
||
// 发送数据到本地服务
|
||
conn.SetDeadline(time.Now().Add(udpTimeout))
|
||
if _, err := conn.Write(packet.Data); err != nil {
|
||
return
|
||
}
|
||
|
||
// 读取响应
|
||
buf := make([]byte, udpBufferSize)
|
||
n, err := conn.Read(buf)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 发送响应回服务端
|
||
respPacket := protocol.UDPPacket{
|
||
RemotePort: packet.RemotePort,
|
||
ClientAddr: packet.ClientAddr,
|
||
Data: buf[:n],
|
||
}
|
||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeUDPData, respPacket)
|
||
protocol.WriteMessage(stream, respMsg)
|
||
}
|
||
|
||
// findRuleByPort 根据端口查找规则
|
||
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
|
||
c.mu.RLock()
|
||
defer c.mu.RUnlock()
|
||
|
||
for i := range c.rules {
|
||
if c.rules[i].RemotePort == port {
|
||
return &c.rules[i]
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// handlePluginConfig 处理插件配置同步
|
||
func (c *Client) handlePluginConfig(msg *protocol.Message) {
|
||
var cfg protocol.PluginConfigSync
|
||
if err := msg.ParsePayload(&cfg); err != nil {
|
||
log.Printf("[Client] Parse plugin config error: %v", err)
|
||
return
|
||
}
|
||
|
||
log.Printf("[Client] Received config for plugin: %s", cfg.PluginName)
|
||
|
||
// 应用配置到插件
|
||
if c.pluginRegistry != nil {
|
||
handler, err := c.pluginRegistry.GetClient(cfg.PluginName)
|
||
if err != nil {
|
||
log.Printf("[Client] Plugin %s not found: %v", cfg.PluginName, err)
|
||
return
|
||
}
|
||
if err := handler.Init(cfg.Config); err != nil {
|
||
log.Printf("[Client] Plugin %s init error: %v", cfg.PluginName, err)
|
||
return
|
||
}
|
||
log.Printf("[Client] Plugin %s config applied", cfg.PluginName)
|
||
}
|
||
}
|
||
|
||
// handleClientPluginStart 处理客户端插件启动请求
|
||
func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.ClientPluginStartRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
|
||
return
|
||
}
|
||
|
||
log.Printf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName)
|
||
|
||
// 获取插件
|
||
if c.pluginRegistry == nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "plugin registry not set")
|
||
return
|
||
}
|
||
|
||
handler, err := c.pluginRegistry.GetClient(req.PluginName)
|
||
if err != nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
|
||
return
|
||
}
|
||
|
||
// 初始化并启动
|
||
if err := handler.Init(req.Config); err != nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
|
||
return
|
||
}
|
||
|
||
localAddr, err := handler.Start()
|
||
if err != nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
|
||
return
|
||
}
|
||
|
||
// 保存运行中的插件
|
||
key := req.PluginName + ":" + req.RuleName
|
||
c.pluginMu.Lock()
|
||
c.runningPlugins[key] = handler
|
||
c.pluginMu.Unlock()
|
||
|
||
log.Printf("[Client] Plugin %s started at %s", req.PluginName, localAddr)
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, localAddr, "")
|
||
}
|
||
|
||
// sendPluginStatus 发送插件状态响应
|
||
func (c *Client) sendPluginStatus(stream net.Conn, pluginName, ruleName string, running bool, localAddr, errMsg string) {
|
||
resp := protocol.ClientPluginStatusResponse{
|
||
PluginName: pluginName,
|
||
RuleName: ruleName,
|
||
Running: running,
|
||
LocalAddr: localAddr,
|
||
Error: errMsg,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeClientPluginStatus, resp)
|
||
protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// handleClientPluginConn 处理客户端插件连接
|
||
func (c *Client) handleClientPluginConn(stream net.Conn, msg *protocol.Message) {
|
||
var req protocol.ClientPluginConnRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
stream.Close()
|
||
return
|
||
}
|
||
|
||
key := req.PluginName + ":" + req.RuleName
|
||
c.pluginMu.RLock()
|
||
handler, ok := c.runningPlugins[key]
|
||
c.pluginMu.RUnlock()
|
||
|
||
if !ok {
|
||
log.Printf("[Client] Plugin %s not running", key)
|
||
stream.Close()
|
||
return
|
||
}
|
||
|
||
// 让插件处理连接
|
||
handler.HandleConn(stream)
|
||
}
|
||
|
||
// handleJSPluginInstall 处理 JS 插件安装请求
|
||
func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.JSPluginInstallRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
c.sendJSPluginResult(stream, "", false, err.Error())
|
||
return
|
||
}
|
||
|
||
log.Printf("[Client] Installing JS plugin: %s", req.PluginName)
|
||
|
||
// 如果插件已经在运行,先停止它
|
||
key := req.PluginName + ":" + req.RuleName
|
||
c.pluginMu.Lock()
|
||
if existingHandler, ok := c.runningPlugins[key]; ok {
|
||
log.Printf("[Client] Stopping existing plugin %s before reinstall", key)
|
||
if err := existingHandler.Stop(); err != nil {
|
||
log.Printf("[Client] Stop existing plugin error: %v", err)
|
||
}
|
||
delete(c.runningPlugins, key)
|
||
}
|
||
c.pluginMu.Unlock()
|
||
|
||
// 验证官方签名
|
||
if err := c.verifyJSPluginSignature(req.PluginName, req.Source, req.Signature); err != nil {
|
||
log.Printf("[Client] JS plugin %s signature verification failed: %v", req.PluginName, err)
|
||
c.sendJSPluginResult(stream, req.PluginName, false, "signature verification failed: "+err.Error())
|
||
return
|
||
}
|
||
log.Printf("[Client] JS plugin %s signature verified", req.PluginName)
|
||
|
||
// 创建 JS 插件
|
||
jsPlugin, err := script.NewJSPlugin(req.PluginName, req.Source)
|
||
if err != nil {
|
||
c.sendJSPluginResult(stream, req.PluginName, false, err.Error())
|
||
return
|
||
}
|
||
|
||
// 注册到 registry
|
||
if c.pluginRegistry != nil {
|
||
c.pluginRegistry.RegisterClient(jsPlugin)
|
||
}
|
||
|
||
log.Printf("[Client] JS plugin %s installed", req.PluginName)
|
||
c.sendJSPluginResult(stream, req.PluginName, true, "")
|
||
|
||
// 保存版本信息(防止降级攻击)
|
||
if c.versionStore != nil {
|
||
signed, _ := sign.DecodeSignedPlugin(req.Signature)
|
||
if signed != nil {
|
||
c.versionStore.SetVersion(req.PluginName, signed.Payload.Version)
|
||
}
|
||
}
|
||
|
||
// 自动启动
|
||
if req.AutoStart {
|
||
c.startJSPlugin(jsPlugin, req)
|
||
}
|
||
}
|
||
|
||
// sendJSPluginResult 发送 JS 插件安装结果
|
||
func (c *Client) sendJSPluginResult(stream net.Conn, name string, success bool, errMsg string) {
|
||
result := protocol.JSPluginInstallResult{
|
||
PluginName: name,
|
||
Success: success,
|
||
Error: errMsg,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeJSPluginResult, result)
|
||
protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// startJSPlugin 启动 JS 插件
|
||
func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPluginInstallRequest) {
|
||
if err := handler.Init(req.Config); err != nil {
|
||
log.Printf("[Client] JS plugin %s init error: %v", req.PluginName, err)
|
||
return
|
||
}
|
||
|
||
localAddr, err := handler.Start()
|
||
if err != nil {
|
||
log.Printf("[Client] JS plugin %s start error: %v", req.PluginName, err)
|
||
return
|
||
}
|
||
|
||
key := req.PluginName + ":" + req.RuleName
|
||
c.pluginMu.Lock()
|
||
c.runningPlugins[key] = handler
|
||
c.pluginMu.Unlock()
|
||
|
||
log.Printf("[Client] JS plugin %s started at %s", req.PluginName, localAddr)
|
||
}
|
||
|
||
// verifyJSPluginSignature 验证 JS 插件签名
|
||
func (c *Client) verifyJSPluginSignature(pluginName, source, signature string) error {
|
||
if signature == "" {
|
||
return fmt.Errorf("missing signature")
|
||
}
|
||
|
||
// 解码签名
|
||
signed, err := sign.DecodeSignedPlugin(signature)
|
||
if err != nil {
|
||
return fmt.Errorf("decode signature: %w", err)
|
||
}
|
||
|
||
// 根据 KeyID 获取对应公钥
|
||
pubKey, err := sign.GetPublicKeyByID(signed.Payload.KeyID)
|
||
if err != nil {
|
||
return fmt.Errorf("get public key: %w", err)
|
||
}
|
||
|
||
// 验证插件名称匹配
|
||
if signed.Payload.Name != pluginName {
|
||
return fmt.Errorf("plugin name mismatch: expected %s, got %s",
|
||
pluginName, signed.Payload.Name)
|
||
}
|
||
|
||
// 验证签名和源码哈希
|
||
if err := sign.VerifyPlugin(pubKey, signed, source); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 检查版本降级攻击
|
||
if c.versionStore != nil {
|
||
currentVer := c.versionStore.GetVersion(pluginName)
|
||
if currentVer != "" {
|
||
cmp := sign.CompareVersions(signed.Payload.Version, currentVer)
|
||
if cmp < 0 {
|
||
return fmt.Errorf("version downgrade rejected: %s < %s",
|
||
signed.Payload.Version, currentVer)
|
||
}
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// handleClientPluginStop 处理客户端插件停止请求
|
||
func (c *Client) handleClientPluginStop(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.ClientPluginStopRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, "", err.Error())
|
||
return
|
||
}
|
||
|
||
key := req.PluginName + ":" + req.RuleName
|
||
|
||
c.pluginMu.Lock()
|
||
handler, ok := c.runningPlugins[key]
|
||
if ok {
|
||
if err := handler.Stop(); err != nil {
|
||
log.Printf("[Client] Plugin %s stop error: %v", key, err)
|
||
}
|
||
delete(c.runningPlugins, key)
|
||
}
|
||
c.pluginMu.Unlock()
|
||
|
||
log.Printf("[Client] Plugin %s stopped", key)
|
||
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "")
|
||
}
|
||
|
||
// handleClientRestart 处理客户端重启请求
|
||
func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.ClientRestartRequest
|
||
msg.ParsePayload(&req)
|
||
|
||
log.Printf("[Client] Restart requested: %s", req.Reason)
|
||
|
||
// 发送响应
|
||
resp := protocol.ClientRestartResponse{
|
||
Success: true,
|
||
Message: "restarting",
|
||
}
|
||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, resp)
|
||
protocol.WriteMessage(stream, respMsg)
|
||
|
||
// 停止所有运行中的插件
|
||
c.pluginMu.Lock()
|
||
for key, handler := range c.runningPlugins {
|
||
log.Printf("[Client] Stopping plugin %s for restart", key)
|
||
handler.Stop()
|
||
}
|
||
c.runningPlugins = make(map[string]plugin.ClientPlugin)
|
||
c.pluginMu.Unlock()
|
||
|
||
// 关闭会话(会触发重连)
|
||
if c.session != nil {
|
||
c.session.Close()
|
||
}
|
||
}
|
||
|
||
// handlePluginConfigUpdate 处理插件配置更新请求
|
||
func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.PluginConfigUpdateRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
|
||
return
|
||
}
|
||
|
||
key := req.PluginName + ":" + req.RuleName
|
||
log.Printf("[Client] Config update for plugin %s", key)
|
||
|
||
c.pluginMu.RLock()
|
||
handler, ok := c.runningPlugins[key]
|
||
c.pluginMu.RUnlock()
|
||
|
||
if !ok {
|
||
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, "plugin not running")
|
||
return
|
||
}
|
||
|
||
if req.Restart {
|
||
// 停止并重启插件
|
||
c.pluginMu.Lock()
|
||
if err := handler.Stop(); err != nil {
|
||
log.Printf("[Client] Plugin %s stop error: %v", key, err)
|
||
}
|
||
delete(c.runningPlugins, key)
|
||
c.pluginMu.Unlock()
|
||
|
||
// 重新初始化和启动
|
||
if err := handler.Init(req.Config); err != nil {
|
||
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
|
||
return
|
||
}
|
||
|
||
localAddr, err := handler.Start()
|
||
if err != nil {
|
||
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, err.Error())
|
||
return
|
||
}
|
||
|
||
c.pluginMu.Lock()
|
||
c.runningPlugins[key] = handler
|
||
c.pluginMu.Unlock()
|
||
|
||
log.Printf("[Client] Plugin %s restarted at %s with new config", key, localAddr)
|
||
}
|
||
|
||
c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, true, "")
|
||
}
|
||
|
||
// sendPluginConfigUpdateResult 发送插件配置更新结果
|
||
func (c *Client) sendPluginConfigUpdateResult(stream net.Conn, pluginName, ruleName string, success bool, errMsg string) {
|
||
result := protocol.PluginConfigUpdateResponse{
|
||
PluginName: pluginName,
|
||
RuleName: ruleName,
|
||
Success: success,
|
||
Error: errMsg,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypePluginConfigUpdate, result)
|
||
protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// handleUpdateDownload 处理更新下载请求
|
||
func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
|
||
defer stream.Close()
|
||
|
||
var req protocol.UpdateDownloadRequest
|
||
if err := msg.ParsePayload(&req); err != nil {
|
||
log.Printf("[Client] Parse update request error: %v", err)
|
||
c.sendUpdateResult(stream, false, "invalid request")
|
||
return
|
||
}
|
||
|
||
log.Printf("[Client] Update download requested: %s", req.DownloadURL)
|
||
|
||
// 异步执行更新
|
||
go func() {
|
||
if err := c.performSelfUpdate(req.DownloadURL); err != nil {
|
||
log.Printf("[Client] Update failed: %v", err)
|
||
}
|
||
}()
|
||
|
||
c.sendUpdateResult(stream, true, "update started")
|
||
}
|
||
|
||
// sendUpdateResult 发送更新结果
|
||
func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) {
|
||
result := protocol.UpdateResultResponse{
|
||
Success: success,
|
||
Message: message,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeUpdateResult, result)
|
||
protocol.WriteMessage(stream, msg)
|
||
}
|
||
|
||
// performSelfUpdate 执行自更新
|
||
func (c *Client) performSelfUpdate(downloadURL string) error {
|
||
log.Printf("[Client] Starting self-update from: %s", downloadURL)
|
||
|
||
// 创建临时文件
|
||
tempDir := os.TempDir()
|
||
tempFile := filepath.Join(tempDir, "gotunnel_client_update")
|
||
|
||
if runtime.GOOS == "windows" {
|
||
tempFile += ".exe"
|
||
}
|
||
|
||
// 下载新版本
|
||
if err := downloadUpdateFile(downloadURL, tempFile); err != nil {
|
||
return fmt.Errorf("download update: %w", err)
|
||
}
|
||
|
||
// 设置执行权限
|
||
if runtime.GOOS != "windows" {
|
||
if err := os.Chmod(tempFile, 0755); err != nil {
|
||
os.Remove(tempFile)
|
||
return fmt.Errorf("chmod: %w", err)
|
||
}
|
||
}
|
||
|
||
// 获取当前可执行文件路径
|
||
currentPath, err := os.Executable()
|
||
if err != nil {
|
||
os.Remove(tempFile)
|
||
return fmt.Errorf("get executable: %w", err)
|
||
}
|
||
currentPath, _ = filepath.EvalSymlinks(currentPath)
|
||
|
||
// Windows 需要特殊处理
|
||
if runtime.GOOS == "windows" {
|
||
return performWindowsClientUpdate(tempFile, currentPath, c.ServerAddr, c.Token, c.ID)
|
||
}
|
||
|
||
// Linux/Mac: 直接替换
|
||
backupPath := currentPath + ".bak"
|
||
|
||
// 停止所有插件
|
||
c.stopAllPlugins()
|
||
|
||
// 备份当前文件
|
||
if err := os.Rename(currentPath, backupPath); err != nil {
|
||
os.Remove(tempFile)
|
||
return fmt.Errorf("backup current: %w", err)
|
||
}
|
||
|
||
// 移动新文件
|
||
if err := os.Rename(tempFile, currentPath); err != nil {
|
||
os.Rename(backupPath, currentPath)
|
||
return fmt.Errorf("replace binary: %w", err)
|
||
}
|
||
|
||
// 删除备份
|
||
os.Remove(backupPath)
|
||
|
||
log.Printf("[Client] Update completed, restarting...")
|
||
|
||
// 重启进程
|
||
restartClientProcess(currentPath, c.ServerAddr, c.Token, c.ID)
|
||
return nil
|
||
}
|
||
|
||
// stopAllPlugins 停止所有运行中的插件
|
||
func (c *Client) stopAllPlugins() {
|
||
c.pluginMu.Lock()
|
||
for key, handler := range c.runningPlugins {
|
||
log.Printf("[Client] Stopping plugin %s for update", key)
|
||
handler.Stop()
|
||
}
|
||
c.runningPlugins = make(map[string]plugin.ClientPlugin)
|
||
c.pluginMu.Unlock()
|
||
}
|
||
|
||
// downloadUpdateFile 下载更新文件
|
||
func downloadUpdateFile(url, dest string) error {
|
||
client := &http.Client{Timeout: 10 * time.Minute}
|
||
resp, err := client.Get(url)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return fmt.Errorf("download failed: %s", resp.Status)
|
||
}
|
||
|
||
out, err := os.Create(dest)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer out.Close()
|
||
|
||
_, err = io.Copy(out, resp.Body)
|
||
return err
|
||
}
|
||
|
||
// performWindowsClientUpdate Windows 平台更新
|
||
func performWindowsClientUpdate(newFile, currentPath, serverAddr, token, id string) error {
|
||
// 创建批处理脚本
|
||
args := fmt.Sprintf(`-s "%s" -t "%s"`, serverAddr, token)
|
||
if id != "" {
|
||
args += fmt.Sprintf(` -id "%s"`, id)
|
||
}
|
||
|
||
batchScript := fmt.Sprintf(`@echo off
|
||
ping 127.0.0.1 -n 2 > nul
|
||
del "%s"
|
||
move "%s" "%s"
|
||
start "" "%s" %s
|
||
del "%%~f0"
|
||
`, currentPath, newFile, currentPath, currentPath, args)
|
||
|
||
batchPath := filepath.Join(os.TempDir(), "gotunnel_client_update.bat")
|
||
if err := os.WriteFile(batchPath, []byte(batchScript), 0755); err != nil {
|
||
return fmt.Errorf("write batch: %w", err)
|
||
}
|
||
|
||
cmd := exec.Command("cmd", "/C", "start", "/MIN", batchPath)
|
||
if err := cmd.Start(); err != nil {
|
||
return fmt.Errorf("start batch: %w", err)
|
||
}
|
||
|
||
// 退出当前进程
|
||
os.Exit(0)
|
||
return nil
|
||
}
|
||
|
||
// restartClientProcess 重启客户端进程
|
||
func restartClientProcess(path, serverAddr, token, id string) {
|
||
args := []string{"-s", serverAddr, "-t", token}
|
||
if id != "" {
|
||
args = append(args, "-id", id)
|
||
}
|
||
|
||
cmd := exec.Command(path, args...)
|
||
cmd.Stdout = os.Stdout
|
||
cmd.Stderr = os.Stderr
|
||
cmd.Start()
|
||
os.Exit(0)
|
||
}
|