Add Android client support and unify cross-platform builds

This commit is contained in:
2026-03-22 21:25:09 +08:00
parent 6558d1acdb
commit 4210ab7675
44 changed files with 2241 additions and 328 deletions

View File

@@ -6,14 +6,19 @@ import (
"gopkg.in/yaml.v3"
)
// ClientConfig 客户端配置
// ClientConfig defines client runtime configuration.
type ClientConfig struct {
Server string `yaml:"server"` // 服务器地址
Token string `yaml:"token"` // 认证 Token
NoTLS bool `yaml:"no_tls"` // 禁用 TLS
Server string `yaml:"server"`
Token string `yaml:"token"`
NoTLS bool `yaml:"no_tls"`
DataDir string `yaml:"data_dir"`
Name string `yaml:"name"`
ClientID string `yaml:"client_id"`
ReconnectMinSec int `yaml:"reconnect_min_sec"`
ReconnectMaxSec int `yaml:"reconnect_max_sec"`
}
// LoadClientConfig 加载客户端配置
// LoadClientConfig loads client configuration from YAML.
func LoadClientConfig(path string) (*ClientConfig, error) {
data, err := os.ReadFile(path)
if err != nil {

View File

@@ -22,82 +22,90 @@ import (
"github.com/hashicorp/yamux"
)
// 客户端常量
const (
dialTimeout = 10 * time.Second
localDialTimeout = 5 * time.Second
udpTimeout = 10 * time.Second
reconnectDelay = 5 * time.Second
disconnectDelay = 3 * time.Second
udpBufferSize = 65535
dialTimeout = 10 * time.Second
localDialTimeout = 5 * time.Second
udpTimeout = 10 * time.Second
reconnectDelay = 5 * time.Second
maxReconnectDelay = 30 * time.Second
disconnectDelay = 3 * time.Second
tcpKeepAlive = 30 * time.Second
udpBufferSize = 65535
)
// Client 隧道客户端
// Client is the tunnel client runtime.
type Client struct {
ServerAddr string
Token string
ID string
Name string // 客户端名称(主机名)
Name string
TLSEnabled bool
TLSConfig *tls.Config
DataDir string // 数据目录
session *yamux.Session
rules []protocol.ProxyRule
mu sync.RWMutex
logger *Logger // 日志收集器
DataDir string
features PlatformFeatures
reconnectDelay time.Duration
reconnectMaxDelay time.Duration
session *yamux.Session
rules []protocol.ProxyRule
mu sync.RWMutex
logger *Logger
}
// NewClient 创建客户端
// NewClient creates a client with default desktop options.
func NewClient(serverAddr, token string) *Client {
// 默认数据目录:优先使用用户主目录,失败时回退到当前工作目录
var dataDir string
if home, err := os.UserHomeDir(); err == nil && home != "" {
dataDir = filepath.Join(home, ".gotunnel")
} else {
// UserHomeDir 失败(如 Android adb shell 环境),使用当前工作目录
if cwd, err := os.Getwd(); err == nil {
dataDir = filepath.Join(cwd, ".gotunnel")
log.Printf("[Client] UserHomeDir unavailable, using current directory: %s", dataDir)
} else {
// 最后回退到相对路径
dataDir = ".gotunnel"
log.Printf("[Client] Warning: using relative path for data directory")
}
}
return NewClientWithOptions(serverAddr, token, ClientOptions{})
}
// 确保数据目录存在
// NewClientWithOptions creates a client with explicit runtime options.
func NewClientWithOptions(serverAddr, token string, opts ClientOptions) *Client {
dataDir := resolveDataDir(opts.DataDir)
if err := os.MkdirAll(dataDir, 0755); err != nil {
log.Printf("Failed to create data dir: %v", err)
}
// ID 优先级:命令行参数 > 机器ID
id := getMachineID()
// 获取主机名作为客户端名称
hostname, _ := os.Hostname()
// 初始化日志收集器
logger, err := NewLogger(dataDir)
if err != nil {
log.Printf("Failed to initialize logger: %v", err)
}
features := DefaultPlatformFeatures()
if opts.Features != nil {
features = *opts.Features
}
delay := opts.ReconnectDelay
if delay <= 0 {
delay = reconnectDelay
}
maxDelay := opts.ReconnectMaxDelay
if maxDelay <= 0 {
maxDelay = maxReconnectDelay
}
if maxDelay < delay {
maxDelay = delay
}
return &Client{
ServerAddr: serverAddr,
Token: token,
ID: id,
Name: hostname,
DataDir: dataDir,
logger: logger,
ServerAddr: serverAddr,
Token: token,
ID: resolveClientID(dataDir, opts.ClientID),
Name: resolveClientName(opts.ClientName),
DataDir: dataDir,
features: features,
reconnectDelay: delay,
reconnectMaxDelay: maxDelay,
logger: logger,
}
}
// InitVersionStore 初始化版本存储
// InitVersionStore is kept for compatibility with older callers.
func (c *Client) InitVersionStore() error {
return nil
}
// logf 安全地记录日志(同时输出到标准日志和日志收集器)
func (c *Client) logf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
log.Print(msg)
@@ -106,7 +114,6 @@ func (c *Client) logf(format string, args ...interface{}) {
}
}
// logErrorf 安全地记录错误日志
func (c *Client) logErrorf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
log.Print(msg)
@@ -115,7 +122,6 @@ func (c *Client) logErrorf(format string, args ...interface{}) {
}
}
// logWarnf 安全地记录警告日志
func (c *Client) logWarnf(format string, args ...interface{}) {
msg := fmt.Sprintf(format, args...)
log.Print(msg)
@@ -124,32 +130,82 @@ func (c *Client) logWarnf(format string, args ...interface{}) {
}
}
// Run 启动客户端(带断线重连)
// Run starts the reconnect loop until the process exits.
func (c *Client) Run() error {
return c.RunContext(context.Background())
}
// RunContext starts the reconnect loop and exits when ctx is cancelled.
func (c *Client) RunContext(ctx context.Context) error {
backoff := c.reconnectDelay
for {
if err := c.connect(); err != nil {
if ctx.Err() != nil {
return nil
}
if err := c.connect(ctx); err != nil {
if ctx.Err() != nil {
return nil
}
c.logErrorf("Connect error: %v", err)
c.logf("Reconnecting in %v...", reconnectDelay)
time.Sleep(reconnectDelay)
c.logf("Reconnecting in %v...", backoff)
if !sleepWithContext(ctx, backoff) {
return nil
}
backoff *= 2
if backoff > c.reconnectMaxDelay {
backoff = c.reconnectMaxDelay
}
continue
}
c.handleSession()
backoff = c.reconnectDelay
c.handleSession(ctx)
if ctx.Err() != nil {
return nil
}
c.logWarnf("Disconnected, reconnecting...")
time.Sleep(disconnectDelay)
if !sleepWithContext(ctx, disconnectDelay) {
return nil
}
}
}
// connect 连接到服务端并认证
func (c *Client) connect() error {
func sleepWithContext(ctx context.Context, wait time.Duration) bool {
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return false
case <-timer.C:
return true
}
}
func (c *Client) connect(ctx context.Context) error {
var conn net.Conn
var err error
dialer := &net.Dialer{
Timeout: dialTimeout,
KeepAlive: tcpKeepAlive,
}
if c.TLSEnabled && c.TLSConfig != nil {
dialer := &net.Dialer{Timeout: dialTimeout}
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
rawConn, dialErr := dialer.DialContext(ctx, "tcp", c.ServerAddr)
if dialErr != nil {
return dialErr
}
tlsConn := tls.Client(rawConn, c.TLSConfig)
if handshakeErr := tlsConn.HandshakeContext(ctx); handshakeErr != nil {
rawConn.Close()
return handshakeErr
}
conn = tlsConn
} else {
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
conn, err = dialer.DialContext(ctx, "tcp", c.ServerAddr)
}
if err != nil {
return err
@@ -184,8 +240,6 @@ func (c *Client) connect() error {
conn.Close()
return fmt.Errorf("auth failed: %s", authResp.Message)
}
// 如果服务端分配了新 ID则更新
if authResp.ClientID != "" && authResp.ClientID != c.ID {
conn.Close()
return fmt.Errorf("server returned unexpected client id: %s", authResp.ClientID)
@@ -206,12 +260,31 @@ func (c *Client) connect() error {
return nil
}
// handleSession 处理会话
func (c *Client) handleSession() {
defer c.session.Close()
func (c *Client) currentSession() *yamux.Session {
c.mu.RLock()
defer c.mu.RUnlock()
return c.session
}
func (c *Client) handleSession(ctx context.Context) {
session := c.currentSession()
if session == nil {
return
}
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
session.Close()
case <-done:
}
}()
defer close(done)
defer session.Close()
for {
stream, err := c.session.Accept()
stream, err := session.Accept()
if err != nil {
return
}
@@ -219,7 +292,6 @@ func (c *Client) handleSession() {
}
}
// handleStream 处理流
func (c *Client) handleStream(stream net.Conn) {
msg, err := protocol.ReadMessage(stream)
if err != nil {
@@ -254,10 +326,11 @@ func (c *Client) handleStream(stream net.Conn) {
c.handleScreenshotRequest(stream, msg)
case protocol.MsgTypeShellExecuteRequest:
c.handleShellExecuteRequest(stream, msg)
default:
stream.Close()
}
}
// handleProxyConfig 处理代理配置
func (c *Client) handleProxyConfig(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
@@ -276,12 +349,10 @@ func (c *Client) handleProxyConfig(stream net.Conn, msg *protocol.Message) {
c.logf(" %s: %s:%d", r.Name, r.LocalIP, r.LocalPort)
}
// 发送配置确认
ack := &protocol.Message{Type: protocol.MsgTypeProxyReady}
protocol.WriteMessage(stream, ack)
}
// handleNewProxy 处理新代理请求
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
var req protocol.NewProxyRequest
if err := msg.ParsePayload(&req); err != nil {
@@ -291,9 +362,9 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
var rule *protocol.ProxyRule
c.mu.RLock()
for _, r := range c.rules {
if r.RemotePort == req.RemotePort {
rule = &r
for i := range c.rules {
if c.rules[i].RemotePort == req.RemotePort {
rule = &c.rules[i]
break
}
}
@@ -314,13 +385,11 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
relay.Relay(stream, localConn)
}
// handleHeartbeat 处理心跳
func (c *Client) handleHeartbeat(stream net.Conn) {
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeatAck}
protocol.WriteMessage(stream, msg)
}
// handleProxyConnect 处理代理连接请求 (SOCKS5/HTTP)
func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
@@ -330,7 +399,6 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
return
}
// 连接目标地址
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
if err != nil {
c.sendProxyResult(stream, false, err.Error())
@@ -338,23 +406,19 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
}
defer targetConn.Close()
// 发送成功响应
if err := c.sendProxyResult(stream, true, ""); err != nil {
return
}
// 双向转发数据
relay.Relay(stream, targetConn)
}
// sendProxyResult 发送代理连接结果
func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) error {
result := protocol.ProxyConnectResult{Success: success, Message: message}
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
return protocol.WriteMessage(stream, msg)
}
// handleUDPData 处理 UDP 数据
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
@@ -363,13 +427,11 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
return
}
// 查找对应的规则
rule := c.findRuleByPort(packet.RemotePort)
if rule == nil {
return
}
// 连接本地 UDP 服务
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
conn, err := net.DialTimeout("udp", target, localDialTimeout)
if err != nil {
@@ -377,20 +439,17 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
}
defer conn.Close()
// 发送数据到本地服务
conn.SetDeadline(time.Now().Add(udpTimeout))
if _, err := conn.Write(packet.Data); err != nil {
return
}
// 读取响应
buf := make([]byte, udpBufferSize)
n, err := conn.Read(buf)
if err != nil {
return
}
// 发送响应回服务端
respPacket := protocol.UDPPacket{
RemotePort: packet.RemotePort,
ClientAddr: packet.ClientAddr,
@@ -400,7 +459,6 @@ func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
protocol.WriteMessage(stream, respMsg)
}
// findRuleByPort 根据端口查找规则
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
c.mu.RLock()
defer c.mu.RUnlock()
@@ -413,7 +471,6 @@ func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
return nil
}
// handleClientRestart 处理客户端重启请求
func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
@@ -422,7 +479,6 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
c.logf("Restart requested: %s", req.Reason)
// 发送响应
resp := protocol.ClientRestartResponse{
Success: true,
Message: "restarting",
@@ -430,17 +486,19 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) {
respMsg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, resp)
protocol.WriteMessage(stream, respMsg)
// 停止所有运行中的插件
// 关闭会话(会触发重连)
if c.session != nil {
c.session.Close()
if session := c.currentSession(); session != nil {
session.Close()
}
}
// handleUpdateDownload 处理更新下载请求
func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
if !c.features.AllowSelfUpdate {
c.sendUpdateResult(stream, false, "self-update not supported on this platform")
return
}
var req protocol.UpdateDownloadRequest
if err := msg.ParsePayload(&req); err != nil {
c.logErrorf("Parse update request error: %v", err)
@@ -450,7 +508,6 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
c.logf("Update download requested: %s", req.DownloadURL)
// 异步执行更新
go func() {
if err := c.performSelfUpdate(req.DownloadURL); err != nil {
c.logErrorf("Update failed: %v", err)
@@ -460,7 +517,6 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) {
c.sendUpdateResult(stream, true, "update started")
}
// sendUpdateResult 发送更新结果
func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) {
result := protocol.UpdateResultResponse{
Success: success,
@@ -470,11 +526,13 @@ func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string)
protocol.WriteMessage(stream, msg)
}
// performSelfUpdate 执行自更新
func (c *Client) performSelfUpdate(downloadURL string) error {
if runtime.GOOS == "android" {
return fmt.Errorf("self-update must be handled by the Android host app")
}
c.logf("Starting self-update from: %s", downloadURL)
// 获取当前可执行文件路径
currentPath, err := os.Executable()
if err != nil {
c.logErrorf("Update failed: cannot get executable path: %v", err)
@@ -482,17 +540,12 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
}
currentPath, _ = filepath.EvalSymlinks(currentPath)
// 预检查:验证是否有写权限(在下载前检查,避免浪费带宽)
// Windows 跳过预检查,因为 Windows 更新通过 batch 脚本以提升权限执行
// 非 Windows原始路径 → DataDir → 临时目录,逐级回退
fallbackDir := ""
if runtime.GOOS != "windows" {
if err := c.checkUpdatePermissions(currentPath); err != nil {
// 尝试 DataDir
fallbackDir = c.DataDir
testFile := filepath.Join(fallbackDir, ".gotunnel_update_test")
if f, err := os.Create(testFile); err != nil {
// DataDir 也不可写,回退到临时目录
fallbackDir = os.TempDir()
c.logf("DataDir not writable, falling back to temp directory: %s", fallbackDir)
} else {
@@ -503,7 +556,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
}
}
// 使用共享的下载和解压逻辑
c.logf("Downloading update package...")
binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client")
if err != nil {
@@ -512,12 +564,10 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
}
defer cleanup()
// Windows 需要特殊处理
if runtime.GOOS == "windows" {
return performWindowsClientUpdate(binaryPath, currentPath, c.ServerAddr, c.Token)
}
// 确定目标路径
targetPath := currentPath
if fallbackDir != "" {
targetPath = filepath.Join(fallbackDir, filepath.Base(currentPath))
@@ -525,7 +575,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
}
if fallbackDir == "" {
// 原地替换:备份 → 复制 → 清理
backupPath := currentPath + ".bak"
c.logf("Backing up current binary...")
@@ -549,7 +598,6 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
os.Remove(backupPath)
} else {
// 回退路径:直接复制到回退目录
c.logf("Installing new binary to data directory...")
if err := update.CopyFile(binaryPath, targetPath); err != nil {
c.logErrorf("Update failed: cannot install new binary: %v", err)
@@ -563,15 +611,11 @@ func (c *Client) performSelfUpdate(downloadURL string) error {
}
c.logf("Update completed successfully, restarting...")
// 重启进程(从新路径启动)
restartClientProcess(targetPath, c.ServerAddr, c.Token)
return nil
}
// checkUpdatePermissions 检查是否有更新权限
func (c *Client) checkUpdatePermissions(execPath string) error {
// 检查可执行文件所在目录是否可写
dir := filepath.Dir(execPath)
testFile := filepath.Join(dir, ".gotunnel_update_test")
@@ -583,7 +627,6 @@ func (c *Client) checkUpdatePermissions(execPath string) error {
f.Close()
os.Remove(testFile)
// 检查可执行文件本身是否可写
f, err = os.OpenFile(execPath, os.O_WRONLY, 0)
if err != nil {
c.logErrorf("No write permission to executable: %s", execPath)
@@ -594,9 +637,7 @@ func (c *Client) checkUpdatePermissions(execPath string) error {
return nil
}
// performWindowsClientUpdate Windows 平台更新
func performWindowsClientUpdate(newFile, currentPath, serverAddr, token string) error {
// 创建批处理脚本
args := fmt.Sprintf(`-s "%s" -t "%s"`, serverAddr, token)
batchScript := fmt.Sprintf(`@echo off
:: Check for admin rights, request UAC elevation if needed
@@ -622,12 +663,10 @@ del "%%~f0"
return fmt.Errorf("start batch: %w", err)
}
// 退出当前进程
os.Exit(0)
return nil
}
// restartClientProcess 重启客户端进程
func restartClientProcess(path, serverAddr, token string) {
args := []string{"-s", serverAddr, "-t", token}
@@ -638,7 +677,6 @@ func restartClientProcess(path, serverAddr, token string) {
os.Exit(0)
}
// handleLogRequest 处理日志请求
func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
if c.logger == nil {
stream.Close()
@@ -653,7 +691,6 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
c.logger.Printf("Log request received: session=%s, follow=%v", req.SessionID, req.Follow)
// 发送历史日志
entries := c.logger.GetRecentLogs(req.Lines, req.Level)
if len(entries) > 0 {
data := protocol.LogData{
@@ -668,20 +705,16 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
}
}
// 如果不需要持续推送,关闭流
if !req.Follow {
stream.Close()
return
}
// 订阅新日志
ch := c.logger.Subscribe(req.SessionID)
defer c.logger.Unsubscribe(req.SessionID)
defer stream.Close()
// 持续推送新日志
for entry := range ch {
// 应用级别过滤
if req.Level != "" && entry.Level != req.Level {
continue
}
@@ -698,7 +731,6 @@ func (c *Client) handleLogRequest(stream net.Conn, msg *protocol.Message) {
}
}
// handleLogStop 处理停止日志流请求
func (c *Client) handleLogStop(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
@@ -714,13 +746,20 @@ func (c *Client) handleLogStop(stream net.Conn, msg *protocol.Message) {
c.logger.Unsubscribe(req.SessionID)
}
// handleSystemStatsRequest 处理系统状态请求
func (c *Client) handleSystemStatsRequest(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
if !c.features.AllowSystemStats {
respMsg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsResponse, protocol.SystemStatsResponse{})
protocol.WriteMessage(stream, respMsg)
return
}
stats, err := utils.GetSystemStats()
if err != nil {
log.Printf("Failed to get system stats: %v", err)
respMsg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsResponse, protocol.SystemStatsResponse{})
protocol.WriteMessage(stream, respMsg)
return
}
@@ -738,14 +777,19 @@ func (c *Client) handleSystemStatsRequest(stream net.Conn, msg *protocol.Message
protocol.WriteMessage(stream, respMsg)
}
// handleScreenshotRequest 处理截图请求
func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ScreenshotRequest
msg.ParsePayload(&req)
// 捕获截图
if !c.features.AllowScreenshot {
resp := protocol.ScreenshotResponse{Error: "screenshot not supported on this platform"}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeScreenshotResponse, resp)
protocol.WriteMessage(stream, respMsg)
return
}
data, width, height, err := utils.CaptureScreenshot(req.Quality)
if err != nil {
c.logErrorf("Screenshot capture failed: %v", err)
@@ -755,9 +799,7 @@ func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message)
return
}
// 编码为 Base64
base64Data := base64.StdEncoding.EncodeToString(data)
resp := protocol.ScreenshotResponse{
Data: base64Data,
Width: width,
@@ -769,10 +811,16 @@ func (c *Client) handleScreenshotRequest(stream net.Conn, msg *protocol.Message)
protocol.WriteMessage(stream, respMsg)
}
// handleShellExecuteRequest 处理 Shell 执行请求
func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
if !c.features.AllowShellExecute {
resp := protocol.ShellExecuteResponse{ExitCode: -1, Error: "remote shell execution not supported on this platform"}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeShellExecuteResponse, resp)
protocol.WriteMessage(stream, respMsg)
return
}
var req protocol.ShellExecuteRequest
if err := msg.ParsePayload(&req); err != nil {
resp := protocol.ShellExecuteResponse{Error: err.Error(), ExitCode: -1}
@@ -781,7 +829,6 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
return
}
// 设置默认超时
timeout := req.Timeout
if timeout <= 0 {
timeout = 30
@@ -789,7 +836,6 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
c.logf("Executing shell command: %s", req.Command)
// 根据操作系统选择 shell
var cmd *exec.Cmd
if runtime.GOOS == "windows" {
cmd = exec.Command("cmd", "/C", req.Command)
@@ -797,12 +843,10 @@ func (c *Client) handleShellExecuteRequest(stream net.Conn, msg *protocol.Messag
cmd = exec.Command("sh", "-c", req.Command)
}
// 设置超时上下文
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
cmd = exec.CommandContext(ctx, cmd.Path, cmd.Args[1:]...)
// 执行命令并获取输出
output, err := cmd.CombinedOutput()
exitCode := 0

View File

@@ -0,0 +1,90 @@
package tunnel
import (
"os"
"path/filepath"
"runtime"
"strings"
"github.com/google/uuid"
)
const clientIDFileName = "client.id"
func resolveDataDir(explicit string) string {
if explicit != "" {
return explicit
}
if envDir := strings.TrimSpace(os.Getenv("GOTUNNEL_DATA_DIR")); envDir != "" {
return envDir
}
if configDir, err := os.UserConfigDir(); err == nil && configDir != "" {
return filepath.Join(configDir, "gotunnel")
}
if home, err := os.UserHomeDir(); err == nil && home != "" {
return filepath.Join(home, ".gotunnel")
}
if cwd, err := os.Getwd(); err == nil && cwd != "" {
return filepath.Join(cwd, ".gotunnel")
}
return ".gotunnel"
}
func resolveClientName(explicit string) string {
if explicit != "" {
return explicit
}
if hostname, err := os.Hostname(); err == nil && hostname != "" {
return hostname
}
if runtime.GOOS == "android" {
return "android-device"
}
return "gotunnel-client"
}
func resolveClientID(dataDir, explicit string) string {
if explicit != "" {
_ = persistClientID(dataDir, explicit)
return explicit
}
if id := loadClientID(dataDir); id != "" {
return id
}
if id := getMachineID(); id != "" {
_ = persistClientID(dataDir, id)
return id
}
id := strings.ReplaceAll(uuid.NewString(), "-", "")[:16]
_ = persistClientID(dataDir, id)
return id
}
func loadClientID(dataDir string) string {
data, err := os.ReadFile(filepath.Join(dataDir, clientIDFileName))
if err != nil {
return ""
}
return strings.TrimSpace(string(data))
}
func persistClientID(dataDir, id string) error {
if id == "" {
return nil
}
if err := os.MkdirAll(dataDir, 0755); err != nil {
return err
}
return os.WriteFile(filepath.Join(dataDir, clientIDFileName), []byte(id+"\n"), 0600)
}

View File

@@ -14,11 +14,15 @@ import (
// getMachineID builds a stable fingerprint from multiple host identifiers
// and hashes the combined result into the client ID we expose externally.
func getMachineID() string {
return hashID(strings.Join(collectMachineIDParts(), "|"))
parts := collectMachineIDParts()
if len(parts) == 0 {
return ""
}
return hashID(strings.Join(parts, "|"))
}
func collectMachineIDParts() []string {
parts := []string{"os=" + runtime.GOOS, "arch=" + runtime.GOARCH}
parts := make([]string, 0, 6)
if id := getSystemMachineID(); id != "" {
parts = append(parts, "system="+id)
@@ -36,6 +40,11 @@ func collectMachineIDParts() []string {
parts = append(parts, "ifaces="+strings.Join(names, ","))
}
if len(parts) == 0 {
return nil
}
parts = append(parts, "os="+runtime.GOOS, "arch="+runtime.GOARCH)
return parts
}
@@ -47,6 +56,8 @@ func getSystemMachineID() string {
return getDarwinMachineID()
case "windows":
return getWindowsMachineID()
case "android":
return ""
default:
return ""
}

View File

@@ -0,0 +1,41 @@
package tunnel
import "time"
// PlatformFeatures controls which platform-specific capabilities the client may use.
type PlatformFeatures struct {
AllowSelfUpdate bool
AllowScreenshot bool
AllowShellExecute bool
AllowSystemStats bool
}
// ClientOptions controls optional client runtime settings.
type ClientOptions struct {
DataDir string
ClientID string
ClientName string
Features *PlatformFeatures
ReconnectDelay time.Duration
ReconnectMaxDelay time.Duration
}
// DefaultPlatformFeatures enables the desktop feature set.
func DefaultPlatformFeatures() PlatformFeatures {
return PlatformFeatures{
AllowSelfUpdate: true,
AllowScreenshot: true,
AllowShellExecute: true,
AllowSystemStats: true,
}
}
// MobilePlatformFeatures disables capabilities that are unsuitable for a mobile sandbox.
func MobilePlatformFeatures() PlatformFeatures {
return PlatformFeatures{
AllowSelfUpdate: false,
AllowScreenshot: false,
AllowShellExecute: false,
AllowSystemStats: true,
}
}

View File

@@ -3,7 +3,10 @@ package handler
import (
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -14,6 +17,11 @@ type InstallHandler struct {
app AppInterface
}
const (
installTokenHeader = "X-GoTunnel-Install-Token"
installTokenTTL = 3600
)
func NewInstallHandler(app AppInterface) *InstallHandler {
return &InstallHandler{app: app}
}
@@ -61,7 +69,115 @@ func (h *InstallHandler) GenerateInstallCommand(c *gin.Context) {
c.JSON(http.StatusOK, InstallCommandResponse{
Token: token,
ExpiresAt: now + 3600,
ExpiresAt: now + installTokenTTL,
TunnelPort: h.app.GetServer().GetBindPort(),
})
}
func (h *InstallHandler) ServeShellScript(c *gin.Context) {
if !h.validateInstallToken(c) {
return
}
applyInstallSecurityHeaders(c)
c.Header("Content-Type", "text/x-shellscript; charset=utf-8")
c.String(http.StatusOK, shellInstallScript)
}
func (h *InstallHandler) ServePowerShellScript(c *gin.Context) {
if !h.validateInstallToken(c) {
return
}
applyInstallSecurityHeaders(c)
c.Header("Content-Type", "text/plain; charset=utf-8")
c.String(http.StatusOK, powerShellInstallScript)
}
func (h *InstallHandler) DownloadClient(c *gin.Context) {
if !h.validateInstallToken(c) {
return
}
osName := c.Query("os")
arch := c.Query("arch")
updateInfo, err := checkClientUpdateForPlatform(osName, arch)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to resolve client package"})
return
}
if updateInfo.DownloadURL == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "no client package found for this platform"})
return
}
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, updateInfo.DownloadURL, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to create download request"})
return
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to download client package"})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
c.JSON(http.StatusBadGateway, gin.H{"error": fmt.Sprintf("upstream returned %s", resp.Status)})
return
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/octet-stream"
}
applyInstallSecurityHeaders(c)
c.Header("Content-Type", contentType)
if contentLength := resp.Header.Get("Content-Length"); contentLength != "" {
c.Header("Content-Length", contentLength)
}
if updateInfo.AssetName != "" {
c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, updateInfo.AssetName))
}
c.Status(http.StatusOK)
_, _ = io.Copy(c.Writer, resp.Body)
}
func (h *InstallHandler) validateInstallToken(c *gin.Context) bool {
token := strings.TrimSpace(c.GetHeader(installTokenHeader))
if token == "" {
c.AbortWithStatus(http.StatusNotFound)
return false
}
store, ok := h.app.GetClientStore().(db.InstallTokenStore)
if !ok {
c.AbortWithStatus(http.StatusNotFound)
return false
}
installToken, err := store.GetInstallToken(token)
if err != nil {
c.AbortWithStatus(http.StatusNotFound)
return false
}
if installToken.Used || time.Now().Unix()-installToken.CreatedAt >= installTokenTTL {
c.AbortWithStatus(http.StatusNotFound)
return false
}
return true
}
func applyInstallSecurityHeaders(c *gin.Context) {
c.Header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
c.Header("Pragma", "no-cache")
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Robots-Tag", "noindex, nofollow, noarchive")
}

View File

@@ -0,0 +1,185 @@
package handler
const shellInstallScript = `#!/usr/bin/env bash
set -euo pipefail
usage() {
cat <<'EOF'
Usage: bash install.sh -s <server:port> -t <token> -b <web-base-url>
Options:
-s Tunnel server address, for example 10.0.0.2:7000
-t One-time install token generated by the server
-b Web console base URL, for example https://example.com:7500
EOF
}
require_cmd() {
if ! command -v "$1" >/dev/null 2>&1; then
echo "missing required command: $1" >&2
exit 1
fi
}
detect_os() {
case "$(uname -s)" in
Linux) echo "linux" ;;
Darwin) echo "darwin" ;;
*)
echo "unsupported operating system: $(uname -s)" >&2
exit 1
;;
esac
}
detect_arch() {
case "$(uname -m)" in
x86_64|amd64) echo "amd64" ;;
aarch64|arm64) echo "arm64" ;;
i386|i686) echo "386" ;;
armv7l|armv6l|arm) echo "arm" ;;
*)
echo "unsupported architecture: $(uname -m)" >&2
exit 1
;;
esac
}
SERVER_ADDR=""
INSTALL_TOKEN=""
BASE_URL=""
while getopts ":s:t:b:h" opt; do
case "$opt" in
s) SERVER_ADDR="$OPTARG" ;;
t) INSTALL_TOKEN="$OPTARG" ;;
b) BASE_URL="$OPTARG" ;;
h)
usage
exit 0
;;
:)
echo "option -$OPTARG requires a value" >&2
usage
exit 1
;;
\?)
echo "unknown option: -$OPTARG" >&2
usage
exit 1
;;
esac
done
if [[ -z "$SERVER_ADDR" || -z "$INSTALL_TOKEN" || -z "$BASE_URL" ]]; then
usage
exit 1
fi
require_cmd curl
require_cmd tar
require_cmd mktemp
OS_NAME="$(detect_os)"
ARCH_NAME="$(detect_arch)"
BASE_URL="${BASE_URL%/}"
INSTALL_ROOT="${HOME:-$(pwd)}/.gotunnel"
BIN_DIR="$INSTALL_ROOT/bin"
TARGET_BIN="$BIN_DIR/gotunnel-client"
LOG_FILE="$INSTALL_ROOT/client.log"
PID_FILE="$INSTALL_ROOT/client.pid"
TMP_DIR="$(mktemp -d)"
ARCHIVE_PATH="$TMP_DIR/gotunnel-client.tar.gz"
DOWNLOAD_URL="$BASE_URL/install/client?os=$OS_NAME&arch=$ARCH_NAME"
cleanup() {
rm -rf "$TMP_DIR"
}
trap cleanup EXIT
mkdir -p "$BIN_DIR"
echo "Downloading GoTunnel client from $DOWNLOAD_URL"
curl -fsSL -H "X-GoTunnel-Install-Token: $INSTALL_TOKEN" "$DOWNLOAD_URL" -o "$ARCHIVE_PATH"
tar -xzf "$ARCHIVE_PATH" -C "$TMP_DIR"
EXTRACTED_BIN="$(find "$TMP_DIR" -type f -name 'gotunnel-client*' ! -name '*.tar.gz' ! -name '*.zip' | head -n 1)"
if [[ -z "$EXTRACTED_BIN" ]]; then
echo "failed to find extracted client binary" >&2
exit 1
fi
cp "$EXTRACTED_BIN" "$TARGET_BIN"
chmod 0755 "$TARGET_BIN"
if [[ -f "$PID_FILE" ]]; then
OLD_PID="$(cat "$PID_FILE" 2>/dev/null || true)"
if [[ -n "$OLD_PID" ]]; then
kill "$OLD_PID" >/dev/null 2>&1 || true
fi
fi
nohup "$TARGET_BIN" -s "$SERVER_ADDR" -t "$INSTALL_TOKEN" >>"$LOG_FILE" 2>&1 &
NEW_PID=$!
echo "$NEW_PID" >"$PID_FILE"
echo "GoTunnel client installed to $TARGET_BIN"
echo "Client started in background with PID $NEW_PID"
echo "Logs: $LOG_FILE"
`
const powerShellInstallScript = `function Get-GoTunnelArch {
switch ([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture.ToString().ToLowerInvariant()) {
'x64' { return 'amd64' }
'arm64' { return 'arm64' }
'x86' { return '386' }
default { throw "Unsupported architecture: $([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture)" }
}
}
function Install-GoTunnel {
param(
[Parameter(Mandatory = $true)][string]$Server,
[Parameter(Mandatory = $true)][string]$Token,
[Parameter(Mandatory = $true)][string]$BaseUrl
)
$BaseUrl = $BaseUrl.TrimEnd('/')
$Arch = Get-GoTunnelArch
$InstallRoot = Join-Path $env:LOCALAPPDATA 'GoTunnel'
$ExtractDir = Join-Path $InstallRoot 'extract'
$ArchivePath = Join-Path $InstallRoot 'gotunnel-client.zip'
$TargetPath = Join-Path $InstallRoot 'gotunnel-client.exe'
$DownloadUrl = "$BaseUrl/install/client?os=windows&arch=$Arch"
New-Item -ItemType Directory -Force -Path $InstallRoot | Out-Null
Write-Host "Downloading GoTunnel client from $DownloadUrl"
$Headers = @{ 'X-GoTunnel-Install-Token' = $Token }
Invoke-WebRequest -Uri $DownloadUrl -Headers $Headers -OutFile $ArchivePath -MaximumRedirection 5
if (Test-Path $ExtractDir) {
Remove-Item -Path $ExtractDir -Recurse -Force
}
Expand-Archive -Path $ArchivePath -DestinationPath $ExtractDir -Force
$Binary = Get-ChildItem -Path $ExtractDir -Recurse -File |
Where-Object { $_.Name -eq 'gotunnel-client.exe' } |
Select-Object -First 1
if (-not $Binary) {
throw 'Failed to find extracted client binary.'
}
Copy-Item -Path $Binary.FullName -Destination $TargetPath -Force
Get-Process |
Where-Object { $_.Path -eq $TargetPath } |
Stop-Process -Force -ErrorAction SilentlyContinue
Start-Process -FilePath $TargetPath -ArgumentList @('-s', $Server, '-t', $Token) -WindowStyle Hidden
Write-Host "GoTunnel client installed to $TargetPath"
Write-Host 'Client started in background.'
}
`

View File

@@ -48,6 +48,11 @@ func (r *GinRouter) SetupRoutes(app handler.AppInterface, jwtAuth *auth.JWTAuth,
engine.POST("/api/auth/login", authHandler.Login)
engine.GET("/api/auth/check", authHandler.Check)
installHandler := handler.NewInstallHandler(app)
engine.GET("/install.sh", installHandler.ServeShellScript)
engine.GET("/install.ps1", installHandler.ServePowerShellScript)
engine.GET("/install/client", installHandler.DownloadClient)
// API 路由 (需要 JWT)
api := engine.Group("/api")
api.Use(middleware.JWTAuth(jwtAuth))
@@ -94,7 +99,6 @@ func (r *GinRouter) SetupRoutes(app handler.AppInterface, jwtAuth *auth.JWTAuth,
api.GET("/traffic/hourly", trafficHandler.GetHourly)
// 安装命令生成
installHandler := handler.NewInstallHandler(app)
api.POST("/install/generate", installHandler.GenerateInstallCommand)
}
}