Add Android client support and unify cross-platform builds
This commit is contained in:
@@ -6,14 +6,19 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// ClientConfig 客户端配置
|
||||
// ClientConfig defines client runtime configuration.
|
||||
type ClientConfig struct {
|
||||
Server string `yaml:"server"` // 服务器地址
|
||||
Token string `yaml:"token"` // 认证 Token
|
||||
NoTLS bool `yaml:"no_tls"` // 禁用 TLS
|
||||
Server string `yaml:"server"`
|
||||
Token string `yaml:"token"`
|
||||
NoTLS bool `yaml:"no_tls"`
|
||||
DataDir string `yaml:"data_dir"`
|
||||
Name string `yaml:"name"`
|
||||
ClientID string `yaml:"client_id"`
|
||||
ReconnectMinSec int `yaml:"reconnect_min_sec"`
|
||||
ReconnectMaxSec int `yaml:"reconnect_max_sec"`
|
||||
}
|
||||
|
||||
// LoadClientConfig 加载客户端配置
|
||||
// LoadClientConfig loads client configuration from YAML.
|
||||
func LoadClientConfig(path string) (*ClientConfig, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
|
||||
@@ -22,82 +22,90 @@ import (
|
||||
"github.com/hashicorp/yamux"
|
||||
)
|
||||
|
||||
// 客户端常量
|
||||
const (
|
||||
dialTimeout = 10 * time.Second
|
||||
localDialTimeout = 5 * time.Second
|
||||
udpTimeout = 10 * time.Second
|
||||
reconnectDelay = 5 * time.Second
|
||||
disconnectDelay = 3 * time.Second
|
||||
udpBufferSize = 65535
|
||||
dialTimeout = 10 * time.Second
|
||||
localDialTimeout = 5 * time.Second
|
||||
udpTimeout = 10 * time.Second
|
||||
reconnectDelay = 5 * time.Second
|
||||
maxReconnectDelay = 30 * time.Second
|
||||
disconnectDelay = 3 * time.Second
|
||||
tcpKeepAlive = 30 * time.Second
|
||||
udpBufferSize = 65535
|
||||
)
|
||||
|
||||
// Client 隧道客户端
|
||||
// Client is the tunnel client runtime.
|
||||
type Client struct {
|
||||
ServerAddr string
|
||||
Token string
|
||||
ID string
|
||||
Name string // 客户端名称(主机名)
|
||||
Name string
|
||||
TLSEnabled bool
|
||||
TLSConfig *tls.Config
|
||||
DataDir string // 数据目录
|
||||
session *yamux.Session
|
||||
rules []protocol.ProxyRule
|
||||
mu sync.RWMutex
|
||||
logger *Logger // 日志收集器
|
||||
DataDir string
|
||||
|
||||
features PlatformFeatures
|
||||
reconnectDelay time.Duration
|
||||
reconnectMaxDelay time.Duration
|
||||
|
||||
session *yamux.Session
|
||||
rules []protocol.ProxyRule
|
||||
mu sync.RWMutex
|
||||
logger *Logger
|
||||
}
|
||||
|
||||
// NewClient 创建客户端
|
||||
// NewClient creates a client with default desktop options.
|
||||
func NewClient(serverAddr, token string) *Client {
|
||||
// 默认数据目录:优先使用用户主目录,失败时回退到当前工作目录
|
||||
var dataDir string
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
dataDir = filepath.Join(home, ".gotunnel")
|
||||
} else {
|
||||
// UserHomeDir 失败(如 Android adb shell 环境),使用当前工作目录
|
||||
if cwd, err := os.Getwd(); err == nil {
|
||||
dataDir = filepath.Join(cwd, ".gotunnel")
|
||||
log.Printf("[Client] UserHomeDir unavailable, using current directory: %s", dataDir)
|
||||
} else {
|
||||
// 最后回退到相对路径
|
||||
dataDir = ".gotunnel"
|
||||
log.Printf("[Client] Warning: using relative path for data directory")
|
||||
}
|
||||
}
|
||||
return NewClientWithOptions(serverAddr, token, ClientOptions{})
|
||||
}
|
||||
|
||||
// 确保数据目录存在
|
||||
// NewClientWithOptions creates a client with explicit runtime options.
|
||||
func NewClientWithOptions(serverAddr, token string, opts ClientOptions) *Client {
|
||||
dataDir := resolveDataDir(opts.DataDir)
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
log.Printf("Failed to create data dir: %v", err)
|
||||
}
|
||||
|
||||
// ID 优先级:命令行参数 > 机器ID
|
||||
id := getMachineID()
|
||||
|
||||
// 获取主机名作为客户端名称
|
||||
hostname, _ := os.Hostname()
|
||||
|
||||
// 初始化日志收集器
|
||||
logger, err := NewLogger(dataDir)
|
||||
if err != nil {
|
||||
log.Printf("Failed to initialize logger: %v", err)
|
||||
}
|
||||
|
||||
features := DefaultPlatformFeatures()
|
||||
if opts.Features != nil {
|
||||
features = *opts.Features
|
||||
}
|
||||
|
||||
delay := opts.ReconnectDelay
|
||||
if delay <= 0 {
|
||||
delay = reconnectDelay
|
||||
}
|
||||
|
||||
maxDelay := opts.ReconnectMaxDelay
|
||||
if maxDelay <= 0 {
|
||||
maxDelay = maxReconnectDelay
|
||||
}
|
||||
if maxDelay < delay {
|
||||
maxDelay = delay
|
||||
}
|
||||
|
||||
return &Client{
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
ID: id,
|
||||
Name: hostname,
|
||||
DataDir: dataDir,
|
||||
logger: logger,
|
||||
ServerAddr: serverAddr,
|
||||
Token: token,
|
||||
ID: resolveClientID(dataDir, opts.ClientID),
|
||||
Name: resolveClientName(opts.ClientName),
|
||||
DataDir: dataDir,
|
||||
features: features,
|
||||
reconnectDelay: delay,
|
||||
reconnectMaxDelay: maxDelay,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// InitVersionStore 初始化版本存储
|
||||
// InitVersionStore is kept for compatibility with older callers.
|
||||
func (c *Client) InitVersionStore() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// logf 安全地记录日志(同时输出到标准日志和日志收集器)
|
||||
func (c *Client) logf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
log.Print(msg)
|
||||
@@ -106,7 +114,6 @@ func (c *Client) logf(format string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// logErrorf 安全地记录错误日志
|
||||
func (c *Client) logErrorf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
log.Print(msg)
|
||||
@@ -115,7 +122,6 @@ func (c *Client) logErrorf(format string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// logWarnf 安全地记录警告日志
|
||||
func (c *Client) logWarnf(format string, args ...interface{}) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
log.Print(msg)
|
||||
@@ -124,32 +130,82 @@ func (c *Client) logWarnf(format string, args ...interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
// Run 启动客户端(带断线重连)
|
||||
// Run starts the reconnect loop until the process exits.
|
||||
func (c *Client) Run() error {
|
||||
return c.RunContext(context.Background())
|
||||
}
|
||||
|
||||
// RunContext starts the reconnect loop and exits when ctx is cancelled.
|
||||
func (c *Client) RunContext(ctx context.Context) error {
|
||||
backoff := c.reconnectDelay
|
||||
|
||||
for {
|
||||
if err := c.connect(); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := c.connect(ctx); err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
c.logErrorf("Connect error: %v", err)
|
||||
c.logf("Reconnecting in %v...", reconnectDelay)
|
||||
time.Sleep(reconnectDelay)
|
||||
c.logf("Reconnecting in %v...", backoff)
|
||||
if !sleepWithContext(ctx, backoff) {
|
||||
return nil
|
||||
}
|
||||
backoff *= 2
|
||||
if backoff > c.reconnectMaxDelay {
|
||||
backoff = c.reconnectMaxDelay
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
c.handleSession()
|
||||
backoff = c.reconnectDelay
|
||||
c.handleSession(ctx)
|
||||
if ctx.Err() != nil {
|
||||
return nil
|
||||
}
|
||||
c.logWarnf("Disconnected, reconnecting...")
|
||||
time.Sleep(disconnectDelay)
|
||||
if !sleepWithContext(ctx, disconnectDelay) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// connect 连接到服务端并认证
|
||||
func (c *Client) connect() error {
|
||||
func sleepWithContext(ctx context.Context, wait time.Duration) bool {
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-timer.C:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connect(ctx context.Context) error {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: dialTimeout,
|
||||
KeepAlive: tcpKeepAlive,
|
||||
}
|
||||
|
||||
if c.TLSEnabled && c.TLSConfig != nil {
|
||||
dialer := &net.Dialer{Timeout: dialTimeout}
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
|
||||
rawConn, dialErr := dialer.DialContext(ctx, "tcp", c.ServerAddr)
|
||||
if dialErr != nil {
|
||||
return dialErr
|
||||
}
|
||||
tlsConn := tls.Client(rawConn, c.TLSConfig)
|
||||
if handshakeErr := tlsConn.HandshakeContext(ctx); handshakeErr != nil {
|
||||
rawConn.Close()
|
||||
return handshakeErr
|
||||
}
|
||||
conn = tlsConn
|
||||
} else {
|
||||
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
|
||||
conn, err = dialer.DialContext(ctx, "tcp", c.ServerAddr)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -184,8 +240,6 @@ func (c *Client) connect() error {
|
||||
conn.Close()
|
||||
return fmt.Errorf("auth failed: %s", authResp.Message)
|
||||
}
|
||||
|
||||
// 如果服务端分配了新 ID,则更新
|
||||
if authResp.ClientID != "" && authResp.ClientID != c.ID {
|
||||
conn.Close()
|
||||
return fmt.Errorf("server returned unexpected client id: %s", authResp.ClientID)
|
||||
@@ -206,12 +260,31 @@ func (c *Client) connect() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleSession 处理会话
|
||||
func (c *Client) handleSession() {
|
||||
defer c.session.Close()
|
||||
func (c *Client) currentSession() *yamux.Session {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.session
|
||||
}
|
||||
|
||||
func (c *Client) handleSession(ctx context.Context) {
|
||||
session := c.currentSession()
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
session.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
defer session.Close()
|
||||
|
||||
for {
|
||||
stream, err := c.session.Accept()
|
||||
stream, err := session.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -219,7 +292,6 @@ func (c *Client) handleSession() {
|
||||
}
|
||||
}
|
||||
|
||||
// handleStream 处理流
|
||||
func (c *Client) handleStream(stream net.Conn) {
|
||||
msg, err := protocol.ReadMessage(stream)
|
||||
if err != nil {
|
||||
@@ -254,10 +326,11 @@ func (c *Client) handleStream(stream net.Conn) {
|
||||
c.handleScreenshotRequest(stream, msg)
|
||||
case protocol.MsgTypeShellExecuteRequest:
|
||||
c.handleShellExecuteRequest(stream, msg)
|
||||
default:
|
||||
stream.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// handleProxyConfig 处理代理配置
|
||||
func (c *Client) handleProxyConfig(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
@@ -276,12 +349,10 @@ func (c *Client) handleProxyConfig(stream net.Conn, msg *protocol.Message) {
|
||||
c.logf(" %s: %s:%d", r.Name, r.LocalIP, r.LocalPort)
|
||||
}
|
||||
|
||||
// 发送配置确认
|
||||
ack := &protocol.Message{Type: protocol.MsgTypeProxyReady}
|
||||
protocol.WriteMessage(stream, ack)
|
||||
}
|
||||
|
||||
// handleNewProxy 处理新代理请求
|
||||
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
|
||||
var req protocol.NewProxyRequest
|
||||
if err := msg.ParsePayload(&req); err != nil {
|
||||
@@ -291,9 +362,9 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
|
||||
|
||||
var rule *protocol.ProxyRule
|
||||
c.mu.RLock()
|
||||
for _, r := range c.rules {
|
||||
if r.RemotePort == req.RemotePort {
|
||||
rule = &r
|
||||
for i := range c.rules {
|
||||
if c.rules[i].RemotePort == req.RemotePort {
|
||||
rule = &c.rules[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -314,13 +385,11 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
|
||||
relay.Relay(stream, localConn)
|
||||
}
|
||||
|
||||
// handleHeartbeat 处理心跳
|
||||
func (c *Client) handleHeartbeat(stream net.Conn) {
|
||||
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeatAck}
|
||||
protocol.WriteMessage(stream, msg)
|
||||
}
|
||||
|
||||
// handleProxyConnect 处理代理连接请求 (SOCKS5/HTTP)
|
||||
func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
@@ -330,7 +399,6 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
// 连接目标地址
|
||||
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
|
||||
if err != nil {
|
||||
c.sendProxyResult(stream, false, err.Error())
|
||||
@@ -338,23 +406,19 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// 发送成功响应
|
||||
if err := c.sendProxyResult(stream, true, ""); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 双向转发数据
|
||||
relay.Relay(stream, targetConn)
|
||||
}
|
||||
|
||||
// sendProxyResult 发送代理连接结果
|
||||
func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) error {
|
||||
result := protocol.ProxyConnectResult{Success: success, Message: message}
|
||||
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
|
||||
return protocol.WriteMessage(stream, msg)
|
||||
}
|
||||
|
||||
// handleUDPData 处理 UDP 数据
|
||||
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
@@ -363,13 +427,11 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的规则
|
||||
rule := c.findRuleByPort(packet.RemotePort)
|
||||
if rule == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 连接本地 UDP 服务
|
||||
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
|
||||
conn, err := net.DialTimeout("udp", target, localDialTimeout)
|
||||
if err != nil {
|
||||
@@ -377,20 +439,17 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// 发送数据到本地服务
|
||||
conn.SetDeadline(time.Now().Add(udpTimeout))
|
||||
if _, err := conn.Write(packet.Data); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 读取响应
|
||||
buf := make([]byte, udpBufferSize)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 发送响应回服务端
|
||||
respPacket := protocol.UDPPacket{
|
||||
RemotePort: packet.RemotePort,
|
||||
ClientAddr: packet.ClientAddr,
|
||||
@@ -400,7 +459,6 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
}
|
||||
|
||||
// findRuleByPort 根据端口查找规则
|
||||
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
@@ -413,7 +471,6 @@ func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleClientRestart 处理客户端重启请求
|
||||
func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
@@ -422,7 +479,6 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
|
||||
|
||||
c.logf("Restart requested: %s", req.Reason)
|
||||
|
||||
// 发送响应
|
||||
resp := protocol.ClientRestartResponse{
|
||||
Success: true,
|
||||
Message: "restarting",
|
||||
@@ -430,17 +486,19 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
|
||||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, resp)
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
|
||||
// 停止所有运行中的插件
|
||||
// 关闭会话(会触发重连)
|
||||
if c.session != nil {
|
||||
c.session.Close()
|
||||
if session := c.currentSession(); session != nil {
|
||||
session.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// handleUpdateDownload 处理更新下载请求
|
||||
func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
if !c.features.AllowSelfUpdate {
|
||||
c.sendUpdateResult(stream, false, "self-update not supported on this platform")
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.UpdateDownloadRequest
|
||||
if err := msg.ParsePayload(&req); err != nil {
|
||||
c.logErrorf("Parse update request error: %v", err)
|
||||
@@ -450,7 +508,6 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
|
||||
|
||||
c.logf("Update download requested: %s", req.DownloadURL)
|
||||
|
||||
// 异步执行更新
|
||||
go func() {
|
||||
if err := c.performSelfUpdate(req.DownloadURL); err != nil {
|
||||
c.logErrorf("Update failed: %v", err)
|
||||
@@ -460,7 +517,6 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
|
||||
c.sendUpdateResult(stream, true, "update started")
|
||||
}
|
||||
|
||||
// sendUpdateResult 发送更新结果
|
||||
func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) {
|
||||
result := protocol.UpdateResultResponse{
|
||||
Success: success,
|
||||
@@ -470,11 +526,13 @@ func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string)
|
||||
protocol.WriteMessage(stream, msg)
|
||||
}
|
||||
|
||||
// performSelfUpdate 执行自更新
|
||||
func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
if runtime.GOOS == "android" {
|
||||
return fmt.Errorf("self-update must be handled by the Android host app")
|
||||
}
|
||||
|
||||
c.logf("Starting self-update from: %s", downloadURL)
|
||||
|
||||
// 获取当前可执行文件路径
|
||||
currentPath, err := os.Executable()
|
||||
if err != nil {
|
||||
c.logErrorf("Update failed: cannot get executable path: %v", err)
|
||||
@@ -482,17 +540,12 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
}
|
||||
currentPath, _ = filepath.EvalSymlinks(currentPath)
|
||||
|
||||
// 预检查:验证是否有写权限(在下载前检查,避免浪费带宽)
|
||||
// Windows 跳过预检查,因为 Windows 更新通过 batch 脚本以提升权限执行
|
||||
// 非 Windows:原始路径 → DataDir → 临时目录,逐级回退
|
||||
fallbackDir := ""
|
||||
if runtime.GOOS != "windows" {
|
||||
if err := c.checkUpdatePermissions(currentPath); err != nil {
|
||||
// 尝试 DataDir
|
||||
fallbackDir = c.DataDir
|
||||
testFile := filepath.Join(fallbackDir, ".gotunnel_update_test")
|
||||
if f, err := os.Create(testFile); err != nil {
|
||||
// DataDir 也不可写,回退到临时目录
|
||||
fallbackDir = os.TempDir()
|
||||
c.logf("DataDir not writable, falling back to temp directory: %s", fallbackDir)
|
||||
} else {
|
||||
@@ -503,7 +556,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// 使用共享的下载和解压逻辑
|
||||
c.logf("Downloading update package...")
|
||||
binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client")
|
||||
if err != nil {
|
||||
@@ -512,12 +564,10 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
// Windows 需要特殊处理
|
||||
if runtime.GOOS == "windows" {
|
||||
return performWindowsClientUpdate(binaryPath, currentPath, c.ServerAddr, c.Token)
|
||||
}
|
||||
|
||||
// 确定目标路径
|
||||
targetPath := currentPath
|
||||
if fallbackDir != "" {
|
||||
targetPath = filepath.Join(fallbackDir, filepath.Base(currentPath))
|
||||
@@ -525,7 +575,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
}
|
||||
|
||||
if fallbackDir == "" {
|
||||
// 原地替换:备份 → 复制 → 清理
|
||||
backupPath := currentPath + ".bak"
|
||||
|
||||
c.logf("Backing up current binary...")
|
||||
@@ -549,7 +598,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
|
||||
os.Remove(backupPath)
|
||||
} else {
|
||||
// 回退路径:直接复制到回退目录
|
||||
c.logf("Installing new binary to data directory...")
|
||||
if err := update.CopyFile(binaryPath, targetPath); err != nil {
|
||||
c.logErrorf("Update failed: cannot install new binary: %v", err)
|
||||
@@ -563,15 +611,11 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
|
||||
}
|
||||
|
||||
c.logf("Update completed successfully, restarting...")
|
||||
|
||||
// 重启进程(从新路径启动)
|
||||
restartClientProcess(targetPath, c.ServerAddr, c.Token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkUpdatePermissions 检查是否有更新权限
|
||||
func (c *Client) checkUpdatePermissions(execPath string) error {
|
||||
// 检查可执行文件所在目录是否可写
|
||||
dir := filepath.Dir(execPath)
|
||||
testFile := filepath.Join(dir, ".gotunnel_update_test")
|
||||
|
||||
@@ -583,7 +627,6 @@ func (c *Client) checkUpdatePermissions(execPath string) error {
|
||||
f.Close()
|
||||
os.Remove(testFile)
|
||||
|
||||
// 检查可执行文件本身是否可写
|
||||
f, err = os.OpenFile(execPath, os.O_WRONLY, 0)
|
||||
if err != nil {
|
||||
c.logErrorf("No write permission to executable: %s", execPath)
|
||||
@@ -594,9 +637,7 @@ func (c *Client) checkUpdatePermissions(execPath string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// performWindowsClientUpdate Windows 平台更新
|
||||
func performWindowsClientUpdate(newFile, currentPath, serverAddr, token string) error {
|
||||
// 创建批处理脚本
|
||||
args := fmt.Sprintf(`-s "%s" -t "%s"`, serverAddr, token)
|
||||
batchScript := fmt.Sprintf(`@echo off
|
||||
:: Check for admin rights, request UAC elevation if needed
|
||||
@@ -622,12 +663,10 @@ del "%%~f0"
|
||||
return fmt.Errorf("start batch: %w", err)
|
||||
}
|
||||
|
||||
// 退出当前进程
|
||||
os.Exit(0)
|
||||
return nil
|
||||
}
|
||||
|
||||
// restartClientProcess 重启客户端进程
|
||||
func restartClientProcess(path, serverAddr, token string) {
|
||||
args := []string{"-s", serverAddr, "-t", token}
|
||||
|
||||
@@ -638,7 +677,6 @@ func restartClientProcess(path, serverAddr, token string) {
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// handleLogRequest 处理日志请求
|
||||
func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
|
||||
if c.logger == nil {
|
||||
stream.Close()
|
||||
@@ -653,7 +691,6 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
|
||||
|
||||
c.logger.Printf("Log request received: session=%s, follow=%v", req.SessionID, req.Follow)
|
||||
|
||||
// 发送历史日志
|
||||
entries := c.logger.GetRecentLogs(req.Lines, req.Level)
|
||||
if len(entries) > 0 {
|
||||
data := protocol.LogData{
|
||||
@@ -668,20 +705,16 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果不需要持续推送,关闭流
|
||||
if !req.Follow {
|
||||
stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// 订阅新日志
|
||||
ch := c.logger.Subscribe(req.SessionID)
|
||||
defer c.logger.Unsubscribe(req.SessionID)
|
||||
defer stream.Close()
|
||||
|
||||
// 持续推送新日志
|
||||
for entry := range ch {
|
||||
// 应用级别过滤
|
||||
if req.Level != "" && entry.Level != req.Level {
|
||||
continue
|
||||
}
|
||||
@@ -698,7 +731,6 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
// handleLogStop 处理停止日志流请求
|
||||
func (c *Client) handleLogStop(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
@@ -714,13 +746,20 @@ func (c *Client) handleLogStop(stream net.Conn, msg *protocol.Message) {
|
||||
c.logger.Unsubscribe(req.SessionID)
|
||||
}
|
||||
|
||||
// handleSystemStatsRequest 处理系统状态请求
|
||||
func (c *Client) handleSystemStatsRequest(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
if !c.features.AllowSystemStats {
|
||||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsResponse, protocol.SystemStatsResponse{})
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
return
|
||||
}
|
||||
|
||||
stats, err := utils.GetSystemStats()
|
||||
if err != nil {
|
||||
log.Printf("Failed to get system stats: %v", err)
|
||||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsResponse, protocol.SystemStatsResponse{})
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -738,14 +777,19 @@ func (c *Client) handleSystemStatsRequest(stream net.Conn, msg *protocol.Message
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
}
|
||||
|
||||
// handleScreenshotRequest 处理截图请求
|
||||
func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
var req protocol.ScreenshotRequest
|
||||
msg.ParsePayload(&req)
|
||||
|
||||
// 捕获截图
|
||||
if !c.features.AllowScreenshot {
|
||||
resp := protocol.ScreenshotResponse{Error: "screenshot not supported on this platform"}
|
||||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeScreenshotResponse, resp)
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
return
|
||||
}
|
||||
|
||||
data, width, height, err := utils.CaptureScreenshot(req.Quality)
|
||||
if err != nil {
|
||||
c.logErrorf("Screenshot capture failed: %v", err)
|
||||
@@ -755,9 +799,7 @@ func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message)
|
||||
return
|
||||
}
|
||||
|
||||
// 编码为 Base64
|
||||
base64Data := base64.StdEncoding.EncodeToString(data)
|
||||
|
||||
resp := protocol.ScreenshotResponse{
|
||||
Data: base64Data,
|
||||
Width: width,
|
||||
@@ -769,10 +811,16 @@ func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message)
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
}
|
||||
|
||||
// handleShellExecuteRequest 处理 Shell 执行请求
|
||||
func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Message) {
|
||||
defer stream.Close()
|
||||
|
||||
if !c.features.AllowShellExecute {
|
||||
resp := protocol.ShellExecuteResponse{ExitCode: -1, Error: "remote shell execution not supported on this platform"}
|
||||
respMsg, _ := protocol.NewMessage(protocol.MsgTypeShellExecuteResponse, resp)
|
||||
protocol.WriteMessage(stream, respMsg)
|
||||
return
|
||||
}
|
||||
|
||||
var req protocol.ShellExecuteRequest
|
||||
if err := msg.ParsePayload(&req); err != nil {
|
||||
resp := protocol.ShellExecuteResponse{Error: err.Error(), ExitCode: -1}
|
||||
@@ -781,7 +829,6 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认超时
|
||||
timeout := req.Timeout
|
||||
if timeout <= 0 {
|
||||
timeout = 30
|
||||
@@ -789,7 +836,6 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
|
||||
|
||||
c.logf("Executing shell command: %s", req.Command)
|
||||
|
||||
// 根据操作系统选择 shell
|
||||
var cmd *exec.Cmd
|
||||
if runtime.GOOS == "windows" {
|
||||
cmd = exec.Command("cmd", "/C", req.Command)
|
||||
@@ -797,12 +843,10 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
|
||||
cmd = exec.Command("sh", "-c", req.Command)
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
|
||||
defer cancel()
|
||||
cmd = exec.CommandContext(ctx, cmd.Path, cmd.Args[1:]...)
|
||||
|
||||
// 执行命令并获取输出
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
exitCode := 0
|
||||
|
||||
90
internal/client/tunnel/identity.go
Normal file
90
internal/client/tunnel/identity.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const clientIDFileName = "client.id"
|
||||
|
||||
func resolveDataDir(explicit string) string {
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
|
||||
if envDir := strings.TrimSpace(os.Getenv("GOTUNNEL_DATA_DIR")); envDir != "" {
|
||||
return envDir
|
||||
}
|
||||
|
||||
if configDir, err := os.UserConfigDir(); err == nil && configDir != "" {
|
||||
return filepath.Join(configDir, "gotunnel")
|
||||
}
|
||||
|
||||
if home, err := os.UserHomeDir(); err == nil && home != "" {
|
||||
return filepath.Join(home, ".gotunnel")
|
||||
}
|
||||
|
||||
if cwd, err := os.Getwd(); err == nil && cwd != "" {
|
||||
return filepath.Join(cwd, ".gotunnel")
|
||||
}
|
||||
|
||||
return ".gotunnel"
|
||||
}
|
||||
|
||||
func resolveClientName(explicit string) string {
|
||||
if explicit != "" {
|
||||
return explicit
|
||||
}
|
||||
|
||||
if hostname, err := os.Hostname(); err == nil && hostname != "" {
|
||||
return hostname
|
||||
}
|
||||
|
||||
if runtime.GOOS == "android" {
|
||||
return "android-device"
|
||||
}
|
||||
|
||||
return "gotunnel-client"
|
||||
}
|
||||
|
||||
func resolveClientID(dataDir, explicit string) string {
|
||||
if explicit != "" {
|
||||
_ = persistClientID(dataDir, explicit)
|
||||
return explicit
|
||||
}
|
||||
|
||||
if id := loadClientID(dataDir); id != "" {
|
||||
return id
|
||||
}
|
||||
|
||||
if id := getMachineID(); id != "" {
|
||||
_ = persistClientID(dataDir, id)
|
||||
return id
|
||||
}
|
||||
|
||||
id := strings.ReplaceAll(uuid.NewString(), "-", "")[:16]
|
||||
_ = persistClientID(dataDir, id)
|
||||
return id
|
||||
}
|
||||
|
||||
func loadClientID(dataDir string) string {
|
||||
data, err := os.ReadFile(filepath.Join(dataDir, clientIDFileName))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
func persistClientID(dataDir, id string) error {
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
if err := os.MkdirAll(dataDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(filepath.Join(dataDir, clientIDFileName), []byte(id+"\n"), 0600)
|
||||
}
|
||||
@@ -14,11 +14,15 @@ import (
|
||||
// getMachineID builds a stable fingerprint from multiple host identifiers
|
||||
// and hashes the combined result into the client ID we expose externally.
|
||||
func getMachineID() string {
|
||||
return hashID(strings.Join(collectMachineIDParts(), "|"))
|
||||
parts := collectMachineIDParts()
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
return hashID(strings.Join(parts, "|"))
|
||||
}
|
||||
|
||||
func collectMachineIDParts() []string {
|
||||
parts := []string{"os=" + runtime.GOOS, "arch=" + runtime.GOARCH}
|
||||
parts := make([]string, 0, 6)
|
||||
|
||||
if id := getSystemMachineID(); id != "" {
|
||||
parts = append(parts, "system="+id)
|
||||
@@ -36,6 +40,11 @@ func collectMachineIDParts() []string {
|
||||
parts = append(parts, "ifaces="+strings.Join(names, ","))
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts = append(parts, "os="+runtime.GOOS, "arch="+runtime.GOARCH)
|
||||
return parts
|
||||
}
|
||||
|
||||
@@ -47,6 +56,8 @@ func getSystemMachineID() string {
|
||||
return getDarwinMachineID()
|
||||
case "windows":
|
||||
return getWindowsMachineID()
|
||||
case "android":
|
||||
return ""
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
41
internal/client/tunnel/options.go
Normal file
41
internal/client/tunnel/options.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package tunnel
|
||||
|
||||
import "time"
|
||||
|
||||
// PlatformFeatures controls which platform-specific capabilities the client may use.
|
||||
type PlatformFeatures struct {
|
||||
AllowSelfUpdate bool
|
||||
AllowScreenshot bool
|
||||
AllowShellExecute bool
|
||||
AllowSystemStats bool
|
||||
}
|
||||
|
||||
// ClientOptions controls optional client runtime settings.
|
||||
type ClientOptions struct {
|
||||
DataDir string
|
||||
ClientID string
|
||||
ClientName string
|
||||
Features *PlatformFeatures
|
||||
ReconnectDelay time.Duration
|
||||
ReconnectMaxDelay time.Duration
|
||||
}
|
||||
|
||||
// DefaultPlatformFeatures enables the desktop feature set.
|
||||
func DefaultPlatformFeatures() PlatformFeatures {
|
||||
return PlatformFeatures{
|
||||
AllowSelfUpdate: true,
|
||||
AllowScreenshot: true,
|
||||
AllowShellExecute: true,
|
||||
AllowSystemStats: true,
|
||||
}
|
||||
}
|
||||
|
||||
// MobilePlatformFeatures disables capabilities that are unsuitable for a mobile sandbox.
|
||||
func MobilePlatformFeatures() PlatformFeatures {
|
||||
return PlatformFeatures{
|
||||
AllowSelfUpdate: false,
|
||||
AllowScreenshot: false,
|
||||
AllowShellExecute: false,
|
||||
AllowSystemStats: true,
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user