Files
GoTunnel/internal/server/tunnel/server.go
Flik e4999abf47
Some checks failed
Build Multi-Platform Binaries / build-frontend (push) Successful in 34s
Build Multi-Platform Binaries / build-binaries (amd64, linux, client, true) (push) Successful in 1m20s
Build Multi-Platform Binaries / build-binaries (amd64, darwin, server, false) (push) Successful in 1m33s
Build Multi-Platform Binaries / build-binaries (amd64, windows, client, true) (push) Successful in 1m16s
Build Multi-Platform Binaries / build-binaries (amd64, linux, server, true) (push) Successful in 1m48s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, client, true) (push) Successful in 1m7s
Build Multi-Platform Binaries / build-binaries (amd64, windows, server, true) (push) Successful in 1m46s
Build Multi-Platform Binaries / build-binaries (arm64, darwin, server, false) (push) Successful in 1m31s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, server, true) (push) Successful in 1m58s
Build Multi-Platform Binaries / build-binaries (arm64, linux, client, true) (push) Successful in 1m35s
Build Multi-Platform Binaries / build-binaries (arm64, linux, server, true) (push) Has been cancelled
Build Multi-Platform Binaries / build-binaries (arm64, windows, server, false) (push) Has been cancelled
Remove manual client ID and TLS CLI options
2026-03-19 19:32:57 +08:00

1250 lines
30 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tunnel
import (
"crypto/tls"
"encoding/base64"
"fmt"
"log"
"net"
"regexp"
"strings"
"sync"
"time"
"github.com/gotunnel/internal/server/db"
"github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/proxy"
"github.com/gotunnel/pkg/relay"
"github.com/gotunnel/pkg/security"
"github.com/gotunnel/pkg/utils"
"github.com/hashicorp/yamux"
)
// 服务端常量
const (
authTimeout = 10 * time.Second
heartbeatTimeout = 10 * time.Second
udpBufferSize = 65535
maxConnections = 10000 // 最大连接数
)
// 客户端 ID 验证正则
var clientIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,64}$`)
// isValidClientID 验证客户端 ID 格式
func isValidClientID(id string) bool {
return clientIDRegex.MatchString(id)
}
// Server 隧道服务端
type Server struct {
clientStore db.ClientStore
trafficStore db.TrafficStore // 流量存储
bindAddr string
bindPort int
token string
heartbeat int
hbTimeout int
portManager *utils.PortManager
clients map[string]*ClientSession
mu sync.RWMutex
tlsConfig *tls.Config
connSem chan struct{} // 连接数信号量
activeConns int64 // 当前活跃连接数
listener net.Listener // 主监听器
shutdown chan struct{} // 关闭信号
wg sync.WaitGroup // 等待所有连接关闭
logSessions *LogSessionManager // 日志会话管理器
}
// ClientSession 客户端会话
type ClientSession struct {
ID string
Name string // 客户端名称(主机名)
RemoteAddr string // 客户端 IP 地址
OS string // 客户端操作系统
Arch string // 客户端架构
Version string // 客户端版本
Session *yamux.Session
Rules []protocol.ProxyRule
Listeners map[int]net.Listener
UDPConns map[int]*net.UDPConn // UDP 连接
LastPing time.Time
mu sync.Mutex
}
// NewServer 创建服务端
func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, heartbeat, hbTimeout int) *Server {
return &Server{
clientStore: cs,
bindAddr: bindAddr,
bindPort: bindPort,
token: token,
heartbeat: heartbeat,
hbTimeout: hbTimeout,
portManager: utils.NewPortManager(),
clients: make(map[string]*ClientSession),
connSem: make(chan struct{}, maxConnections),
shutdown: make(chan struct{}),
logSessions: NewLogSessionManager(),
}
}
// SetTLSConfig 设置 TLS 配置
func (s *Server) SetTLSConfig(config *tls.Config) {
s.tlsConfig = config
}
// Shutdown 优雅关闭服务端
func (s *Server) Shutdown(timeout time.Duration) error {
log.Printf("[Server] Initiating graceful shutdown...")
close(s.shutdown)
if s.listener != nil {
s.listener.Close()
}
// 关闭所有客户端会话
s.mu.Lock()
for _, cs := range s.clients {
cs.Session.Close()
}
s.mu.Unlock()
// 等待所有连接关闭,带超时
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Printf("[Server] All connections closed gracefully")
return nil
case <-time.After(timeout):
log.Printf("[Server] Shutdown timeout, forcing close")
return fmt.Errorf("shutdown timeout")
}
}
// SetTrafficStore 设置流量存储
func (s *Server) SetTrafficStore(store db.TrafficStore) {
s.trafficStore = store
}
// Run 启动服务端
func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
var ln net.Listener
var err error
if s.tlsConfig != nil {
ln, err = tls.Listen("tcp", addr, s.tlsConfig)
if err != nil {
return fmt.Errorf("failed to listen TLS on %s: %v", addr, err)
}
log.Printf("[Server] TLS listening on %s", addr)
} else {
ln, err = net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %v", addr, err)
}
log.Printf("[Server] Listening on %s (no TLS)", addr)
}
s.listener = ln
for {
select {
case <-s.shutdown:
log.Printf("[Server] Shutdown signal received, stopping accept loop")
ln.Close()
return nil
default:
}
conn, err := ln.Accept()
if err != nil {
select {
case <-s.shutdown:
return nil
default:
log.Printf("[Server] Accept error: %v", err)
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
// handleConnection 处理客户端连接
func (s *Server) handleConnection(conn net.Conn) {
clientIP := conn.RemoteAddr().String()
// 连接数限制检查
select {
case s.connSem <- struct{}{}:
defer func() { <-s.connSem }()
default:
security.LogConnRejected(clientIP, "max connections reached")
conn.Close()
return
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(authTimeout))
msg, err := protocol.ReadMessage(conn)
if err != nil {
log.Printf("[Server] Read auth error: %v", err)
return
}
if msg.Type != protocol.MsgTypeAuth {
log.Printf("[Server] Expected auth, got %d", msg.Type)
return
}
var authReq protocol.AuthRequest
if err := msg.ParsePayload(&authReq); err != nil {
log.Printf("[Server] Parse auth error: %v", err)
return
}
// 验证token支持常规token和一次性安装token
validToken := authReq.Token == s.token
var isInstallToken bool
if !validToken {
// 尝试验证安装token
if tokenStore, ok := s.clientStore.(db.InstallTokenStore); ok {
if installToken, err := tokenStore.GetInstallToken(authReq.Token); err == nil {
if !installToken.Used && time.Now().Unix()-installToken.CreatedAt < 3600 {
// token有效且未过期
validToken = true
isInstallToken = true
// 验证客户端ID匹配
// 使用token中的客户端ID
}
}
}
}
if !validToken {
security.LogInvalidToken(clientIP)
s.sendAuthResponse(conn, false, "invalid token", "")
return
}
// 处理客户端 ID
clientID := authReq.ClientID
if clientID == "" || !isValidClientID(clientID) {
security.LogInvalidClientID(clientIP, clientID)
s.sendAuthResponse(conn, false, "invalid client id format", "")
return
}
// 检查客户端是否存在,不存在则自动创建
exists, err := s.clientStore.ClientExists(clientID)
if err != nil || !exists {
newClient := &db.Client{ID: clientID, Nickname: authReq.Name, Rules: []protocol.ProxyRule{}}
if err := s.clientStore.CreateClient(newClient); err != nil {
log.Printf("[Server] Create client error: %v", err)
s.sendAuthResponse(conn, false, "failed to create client", "")
return
}
log.Printf("[Server] New client registered: %s (%s)", clientID, authReq.Name)
} else if authReq.Name != "" {
// 客户端已存在,仅当 Nickname 为空时才用客户端名称更新
// 这样服务端手动设置的名称不会被客户端覆盖
if client, err := s.clientStore.GetClient(clientID); err == nil {
if client.Nickname == "" {
client.Nickname = authReq.Name
s.clientStore.UpdateClient(client)
}
}
}
rules, _ := s.clientStore.GetClientRules(clientID)
if rules == nil {
rules = []protocol.ProxyRule{}
}
// 如果使用安装token标记为已使用
if isInstallToken {
if tokenStore, ok := s.clientStore.(db.InstallTokenStore); ok {
tokenStore.MarkTokenUsed(authReq.Token)
}
}
conn.SetReadDeadline(time.Time{})
if err := s.sendAuthResponse(conn, true, "ok", clientID); err != nil {
return
}
security.LogAuthSuccess(clientIP, clientID)
s.setupClientSession(conn, clientID, authReq.Name, authReq.OS, authReq.Arch, authReq.Version, rules)
}
// setupClientSession 建立客户端会话
func (s *Server) setupClientSession(conn net.Conn, clientID, clientName, clientOS, clientArch, clientVersion string, rules []protocol.ProxyRule) {
session, err := yamux.Server(conn, nil)
if err != nil {
log.Printf("[Server] Yamux error: %v", err)
return
}
// 提取客户端 IP去掉端口
remoteAddr := conn.RemoteAddr().String()
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
remoteAddr = host
}
cs := &ClientSession{
ID: clientID,
Name: clientName,
RemoteAddr: remoteAddr,
OS: clientOS,
Arch: clientArch,
Version: clientVersion,
Session: session,
Rules: rules,
Listeners: make(map[int]net.Listener),
UDPConns: make(map[int]*net.UDPConn),
LastPing: time.Now(),
}
s.registerClient(cs)
defer s.unregisterClient(cs)
// 启动代理监听器(会更新 rules 的 PortStatus
s.startProxyListeners(cs)
// 发送配置到客户端(包含端口状态)
if err := s.sendProxyConfig(session, cs.Rules); err != nil {
log.Printf("[Server] Send config error: %v", err)
return
}
go s.heartbeatLoop(cs)
<-session.CloseChan()
log.Printf("[Server] Client %s disconnected", clientID)
}
// sendAuthResponse 发送认证响应
func (s *Server) sendAuthResponse(conn net.Conn, success bool, message, clientID string) error {
resp := protocol.AuthResponse{Success: success, Message: message, ClientID: clientID}
msg, err := protocol.NewMessage(protocol.MsgTypeAuthResp, resp)
if err != nil {
return err
}
return protocol.WriteMessage(conn, msg)
}
// sendProxyConfig 发送代理配置并等待客户端确认
func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyRule) error {
stream, err := session.Open()
if err != nil {
return err
}
defer stream.Close()
cfg := protocol.ProxyConfig{Rules: rules}
msg, err := protocol.NewMessage(protocol.MsgTypeProxyConfig, cfg)
if err != nil {
return err
}
if err := protocol.WriteMessage(stream, msg); err != nil {
return err
}
// 等待客户端确认
ack, err := protocol.ReadMessage(stream)
if err != nil {
return fmt.Errorf("wait config ack: %w", err)
}
if ack.Type != protocol.MsgTypeProxyReady {
return fmt.Errorf("unexpected ack type: %d", ack.Type)
}
return nil
}
// registerClient 注册客户端
func (s *Server) registerClient(cs *ClientSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[cs.ID] = cs
}
// unregisterClient 注销客户端
func (s *Server) unregisterClient(cs *ClientSession) {
s.mu.Lock()
defer s.mu.Unlock()
cs.mu.Lock()
for port, ln := range cs.Listeners {
ln.Close()
s.portManager.Release(port)
}
for port, conn := range cs.UDPConns {
conn.Close()
s.portManager.Release(port)
}
cs.mu.Unlock()
delete(s.clients, cs.ID)
}
// stopProxyListeners 停止代理监听
func (s *Server) stopProxyListeners(cs *ClientSession) {
cs.mu.Lock()
defer cs.mu.Unlock()
// 关闭 TCP 监听器
for port, ln := range cs.Listeners {
ln.Close()
s.portManager.Release(port)
}
cs.Listeners = make(map[int]net.Listener)
// 关闭 UDP 连接
for port, conn := range cs.UDPConns {
conn.Close()
s.portManager.Release(port)
}
cs.UDPConns = make(map[int]*net.UDPConn)
}
// startProxyListeners 启动代理监听
func (s *Server) startProxyListeners(cs *ClientSession) {
for i := range cs.Rules {
rule := &cs.Rules[i]
if !rule.IsEnabled() {
continue
}
ruleType := rule.Type
if ruleType == "" {
ruleType = "tcp"
}
// UDP 单独处理
if ruleType == "udp" {
s.startUDPListener(cs, rule)
continue
}
// TCP 类型
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
log.Printf("[Server] Port %d error: %v", rule.RemotePort, err)
rule.PortStatus = "failed: " + err.Error()
continue
}
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.RemotePort))
if err != nil {
log.Printf("[Server] Listen %d error: %v", rule.RemotePort, err)
s.portManager.Release(rule.RemotePort)
rule.PortStatus = "failed: " + err.Error()
continue
}
rule.PortStatus = "listening"
cs.mu.Lock()
cs.Listeners[rule.RemotePort] = ln
cs.mu.Unlock()
switch ruleType {
case "socks5":
log.Printf("[Server] SOCKS5 proxy %s on :%d", rule.Name, rule.RemotePort)
go s.acceptProxyServerConns(cs, ln, *rule)
case "http", "https":
log.Printf("[Server] HTTP proxy %s on :%d", rule.Name, rule.RemotePort)
go s.acceptProxyServerConns(cs, ln, *rule)
case "websocket":
log.Printf("[Server] Websocket proxy %s on :%d", rule.Name, rule.RemotePort)
go s.acceptWebsocketConns(cs, ln, *rule)
default:
log.Printf("[Server] TCP proxy %s: :%d -> %s:%d",
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
go s.acceptProxyConns(cs, ln, *rule)
}
}
}
// acceptProxyConns 接受代理连接
func (s *Server) acceptProxyConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go s.handleProxyConn(cs, conn, rule)
}
}
// acceptProxyServerConns 接受 SOCKS5/HTTP 代理连接
func (s *Server) acceptProxyServerConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
dialer := proxy.NewTunnelDialer(cs.Session)
// 使用内置 proxy 实现 (带流量统计和认证)
username := ""
password := ""
if rule.AuthEnabled {
username = rule.AuthUsername
password = rule.AuthPassword
}
proxyServer := proxy.NewServer(rule.Type, dialer, s.recordTraffic, username, password)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go proxyServer.HandleConn(conn)
}
}
// handleProxyConn 处理代理连接
func (s *Server) handleProxyConn(cs *ClientSession, conn net.Conn, rule protocol.ProxyRule) {
defer conn.Close()
stream, err := cs.Session.Open()
if err != nil {
log.Printf("[Server] Open stream error: %v", err)
return
}
defer stream.Close()
req := protocol.NewProxyRequest{RemotePort: rule.RemotePort}
msg, _ := protocol.NewMessage(protocol.MsgTypeNewProxy, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
return
}
relay.RelayWithStats(conn, stream, s.recordTraffic)
}
// heartbeatLoop 心跳检测循环
func (s *Server) heartbeatLoop(cs *ClientSession) {
ticker := time.NewTicker(time.Duration(s.heartbeat) * time.Second)
defer ticker.Stop()
timeout := time.Duration(s.hbTimeout) * time.Second
for {
select {
case <-ticker.C:
cs.mu.Lock()
if time.Since(cs.LastPing) > timeout {
cs.mu.Unlock()
log.Printf("[Server] Client %s heartbeat timeout", cs.ID)
cs.Session.Close()
return
}
cs.mu.Unlock()
// 发送心跳并等待响应
if s.sendHeartbeat(cs) {
cs.mu.Lock()
cs.LastPing = time.Now()
cs.mu.Unlock()
}
case <-cs.Session.CloseChan():
return
}
}
}
// sendHeartbeat 发送心跳并等待响应
func (s *Server) sendHeartbeat(cs *ClientSession) bool {
stream, err := cs.Session.Open()
if err != nil {
return false
}
defer stream.Close()
// 设置读写超时
stream.SetDeadline(time.Now().Add(heartbeatTimeout))
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeat}
if err := protocol.WriteMessage(stream, msg); err != nil {
return false
}
// 等待心跳响应
resp, err := protocol.ReadMessage(stream)
if err != nil {
return false
}
return resp.Type == protocol.MsgTypeHeartbeatAck
}
// GetClientStatus 获取客户端状态
func (s *Server) GetClientStatus(clientID string) (online bool, lastPing, remoteAddr, clientName, clientOS, clientArch, clientVersion string) {
s.mu.RLock()
defer s.mu.RUnlock()
if cs, ok := s.clients[clientID]; ok {
cs.mu.Lock()
defer cs.mu.Unlock()
return true, cs.LastPing.Format(time.RFC3339), cs.RemoteAddr, cs.Name, cs.OS, cs.Arch, cs.Version
}
return false, "", "", "", "", "", ""
}
// IsClientOnline 检查客户端是否在线
func (s *Server) IsClientOnline(clientID string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.clients[clientID]
return ok
}
// GetAllClientStatus 获取所有客户端状态
func (s *Server) GetAllClientStatus() map[string]struct {
Online bool
LastPing string
RemoteAddr string
Name string
OS string
Arch string
Version string
} {
// 先复制客户端引用,避免嵌套锁
s.mu.RLock()
clients := make([]*ClientSession, 0, len(s.clients))
for _, cs := range s.clients {
clients = append(clients, cs)
}
s.mu.RUnlock()
result := make(map[string]struct {
Online bool
LastPing string
RemoteAddr string
Name string
OS string
Arch string
Version string
})
for _, cs := range clients {
cs.mu.Lock()
result[cs.ID] = struct {
Online bool
LastPing string
RemoteAddr string
Name string
OS string
Arch string
Version string
}{
Online: true,
LastPing: cs.LastPing.Format(time.RFC3339),
RemoteAddr: cs.RemoteAddr,
Name: cs.Name,
OS: cs.OS,
Arch: cs.Arch,
Version: cs.Version,
}
cs.mu.Unlock()
}
return result
}
// ReloadConfig 重新加载配置
// 注意: 当前版本不支持热重载,需要重启服务
func (s *Server) ReloadConfig() error {
return fmt.Errorf("hot reload not supported, please restart the server")
}
// GetBindAddr 获取绑定地址
func (s *Server) GetBindAddr() string {
return s.bindAddr
}
// GetBindPort 获取绑定端口
func (s *Server) GetBindPort() int {
return s.bindPort
}
// PushConfigToClient 推送配置到客户端
func (s *Server) PushConfigToClient(clientID string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found", clientID)
}
rules, err := s.clientStore.GetClientRules(clientID)
if err != nil {
return err
}
// 停止旧的监听器
s.stopProxyListeners(cs)
// 更新规则
cs.mu.Lock()
cs.Rules = rules
cs.mu.Unlock()
// 启动新的监听器
s.startProxyListeners(cs)
// 检查是否有端口启动失败
var failedPorts []string
for _, rule := range cs.Rules {
if rule.IsEnabled() && strings.HasPrefix(rule.PortStatus, "failed:") {
failedPorts = append(failedPorts, fmt.Sprintf("port %d: %s", rule.RemotePort, strings.TrimPrefix(rule.PortStatus, "failed: ")))
}
}
// 发送配置到客户端
if err := s.sendProxyConfig(cs.Session, cs.Rules); err != nil {
return err
}
// 如果有端口失败,返回错误
if len(failedPorts) > 0 {
return fmt.Errorf("some ports failed to start: %s", strings.Join(failedPorts, "; "))
}
return nil
}
// DisconnectClient 断开客户端连接
func (s *Server) DisconnectClient(clientID string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found", clientID)
}
return cs.Session.Close()
}
// startUDPListener 启动 UDP 监听
func (s *Server) startUDPListener(cs *ClientSession, rule *protocol.ProxyRule) {
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
log.Printf("[Server] UDP port %d error: %v", rule.RemotePort, err)
rule.PortStatus = "failed: " + err.Error()
return
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.RemotePort))
if err != nil {
log.Printf("[Server] UDP resolve error: %v", err)
s.portManager.Release(rule.RemotePort)
rule.PortStatus = "failed: " + err.Error()
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Printf("[Server] UDP listen %d error: %v", rule.RemotePort, err)
s.portManager.Release(rule.RemotePort)
rule.PortStatus = "failed: " + err.Error()
return
}
rule.PortStatus = "listening"
cs.mu.Lock()
cs.UDPConns[rule.RemotePort] = conn
cs.mu.Unlock()
log.Printf("[Server] UDP proxy %s: :%d -> %s:%d",
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
go s.handleUDPConn(cs, conn, *rule)
}
// handleUDPConn 处理 UDP 连接
func (s *Server) handleUDPConn(cs *ClientSession, conn *net.UDPConn, rule protocol.ProxyRule) {
buf := make([]byte, udpBufferSize)
for {
n, clientAddr, err := conn.ReadFromUDP(buf)
if err != nil {
return
}
// 封装 UDP 数据包发送到客户端
packet := protocol.UDPPacket{
RemotePort: rule.RemotePort,
ClientAddr: clientAddr.String(),
Data: buf[:n],
}
go s.sendUDPPacket(cs, conn, clientAddr, packet)
}
}
// sendUDPPacket 发送 UDP 数据包到客户端
func (s *Server) sendUDPPacket(cs *ClientSession, conn *net.UDPConn, clientAddr *net.UDPAddr, packet protocol.UDPPacket) {
stream, err := cs.Session.Open()
if err != nil {
return
}
defer stream.Close()
msg, err := protocol.NewMessage(protocol.MsgTypeUDPData, packet)
if err != nil {
return
}
if err := protocol.WriteMessage(stream, msg); err != nil {
return
}
// 记录入站流量 (从外部接收的数据)
s.recordTraffic(int64(len(packet.Data)), 0)
// 等待客户端响应
respMsg, err := protocol.ReadMessage(stream)
if err != nil {
return
}
if respMsg.Type == protocol.MsgTypeUDPData {
var respPacket protocol.UDPPacket
if err := respMsg.ParsePayload(&respPacket); err != nil {
return
}
conn.WriteToUDP(respPacket.Data, clientAddr)
// 记录出站流量 (发送回外部的数据)
s.recordTraffic(0, int64(len(respPacket.Data)))
}
}
// checkHTTPBasicAuth 检查 HTTP Basic Auth
// 返回 (认证成功, 已读取的数据)
func (s *Server) checkHTTPBasicAuth(conn net.Conn, username, password string) (bool, []byte) {
// 设置读取超时
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
defer conn.SetReadDeadline(time.Time{}) // 重置超时
// 读取 HTTP 请求头
buf := make([]byte, 8192) // 增大缓冲区以处理更大的请求头
n, err := conn.Read(buf)
if err != nil {
return false, nil
}
data := buf[:n]
request := string(data)
// 解析 Authorization 头
authHeader := ""
lines := strings.Split(request, "\r\n")
for _, line := range lines {
if strings.HasPrefix(strings.ToLower(line), "authorization:") {
authHeader = strings.TrimSpace(line[14:])
break
}
}
// 检查 Basic Auth
if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") {
s.sendHTTPUnauthorized(conn)
return false, nil
}
// 解码 Base64
encoded := authHeader[6:]
decoded, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
s.sendHTTPUnauthorized(conn)
return false, nil
}
// 解析 username:password
credentials := string(decoded)
parts := strings.SplitN(credentials, ":", 2)
if len(parts) != 2 {
s.sendHTTPUnauthorized(conn)
return false, nil
}
if parts[0] != username || parts[1] != password {
s.sendHTTPUnauthorized(conn)
return false, nil
}
return true, data
}
// sendHTTPUnauthorized 发送 401 未授权响应
func (s *Server) sendHTTPUnauthorized(conn net.Conn) {
response := "HTTP/1.1 401 Unauthorized\r\n" +
"WWW-Authenticate: Basic realm=\"GoTunnel Plugin\"\r\n" +
"Content-Type: text/plain\r\n" +
"Content-Length: 12\r\n" +
"\r\n" +
"Unauthorized"
conn.Write([]byte(response))
}
// shouldPushToClient 检查是否应推送到指定客户端
func (s *Server) shouldPushToClient(autoPush []string, clientID string) bool {
if len(autoPush) == 0 {
return true
}
for _, id := range autoPush {
if id == clientID || id == "*" {
return true
}
}
return false
}
// RestartClient 重启客户端(通过断开连接,让客户端自动重连)
func (s *Server) RestartClient(clientID string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found or not online", clientID)
}
// 发送重启消息
stream, err := cs.Session.Open()
if err != nil {
return err
}
req := protocol.ClientRestartRequest{
Reason: "server requested restart",
}
msg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, req)
protocol.WriteMessage(stream, msg)
stream.Close()
// 等待一小段时间后断开连接
time.AfterFunc(100*time.Millisecond, func() {
cs.Session.Close()
})
log.Printf("[Server] Restart initiated for client %s", clientID)
return nil
}
// IsPortAvailable 检查端口是否可用
func (s *Server) IsPortAvailable(port int, excludeClientID string) bool {
// 检查系统端口
if !utils.IsPortAvailable(port) {
return false
}
// 检查是否被其他客户端占用
s.mu.RLock()
defer s.mu.RUnlock()
for clientID, cs := range s.clients {
if clientID == excludeClientID {
continue
}
cs.mu.Lock()
_, occupied := cs.Listeners[port]
cs.mu.Unlock()
if occupied {
return false
}
}
return true
}
// SendUpdateToClient 发送更新命令到客户端
func (s *Server) SendUpdateToClient(clientID, downloadURL string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found or not online", clientID)
}
// 发送更新消息
stream, err := cs.Session.Open()
if err != nil {
return err
}
defer stream.Close()
req := protocol.UpdateDownloadRequest{
DownloadURL: downloadURL,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeUpdateDownload, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
return err
}
log.Printf("[Server] Update command sent to client %s: %s", clientID, downloadURL)
return nil
}
// StartClientLogStream 启动客户端日志流
func (s *Server) StartClientLogStream(clientID, sessionID string, lines int, follow bool, level string) (<-chan protocol.LogEntry, error) {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("client %s not found or not online", clientID)
}
// 打开到客户端的流
stream, err := cs.Session.Open()
if err != nil {
return nil, err
}
// 发送日志请求
req := protocol.LogRequest{
SessionID: sessionID,
Lines: lines,
Follow: follow,
Level: level,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeLogRequest, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
stream.Close()
return nil, err
}
// 创建会话
session := s.logSessions.CreateSession(clientID, sessionID, stream)
listener := session.AddListener()
// 启动 goroutine 读取客户端日志
go s.readClientLogs(session, stream)
return listener, nil
}
// readClientLogs 读取客户端日志并广播到监听器
func (s *Server) readClientLogs(session *LogSession, stream net.Conn) {
defer s.logSessions.RemoveSession(session.ID)
for {
msg, err := protocol.ReadMessage(stream)
if err != nil {
return
}
if msg.Type != protocol.MsgTypeLogData {
continue
}
var data protocol.LogData
if err := msg.ParsePayload(&data); err != nil {
continue
}
for _, entry := range data.Entries {
session.Broadcast(entry)
}
if data.EOF {
return
}
}
}
// StopClientLogStream 停止客户端日志流
func (s *Server) StopClientLogStream(sessionID string) {
session := s.logSessions.GetSession(sessionID)
if session == nil {
return
}
// 发送停止请求到客户端
s.mu.RLock()
cs, ok := s.clients[session.ClientID]
s.mu.RUnlock()
if ok {
stream, err := cs.Session.Open()
if err == nil {
req := protocol.LogStopRequest{SessionID: sessionID}
msg, _ := protocol.NewMessage(protocol.MsgTypeLogStop, req)
protocol.WriteMessage(stream, msg)
stream.Close()
}
}
s.logSessions.RemoveSession(sessionID)
}
// recordTraffic 记录流量统计
func (s *Server) recordTraffic(inbound, outbound int64) {
if s.trafficStore == nil {
return
}
if err := s.trafficStore.AddTraffic(inbound, outbound); err != nil {
log.Printf("[Server] Record traffic error: %v", err)
}
}
// boolPtr 返回 bool 值的指针
func boolPtr(b bool) *bool {
return &b
}
// GetClientSystemStats 获取客户端系统状态
func (s *Server) GetClientSystemStats(clientID string) (*protocol.SystemStatsResponse, error) {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("client %s not online", clientID)
}
stream, err := cs.Session.Open()
if err != nil {
return nil, err
}
defer stream.Close()
// 设置超时
stream.SetDeadline(time.Now().Add(10 * time.Second))
// 发送请求
msg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsRequest, nil)
if err := protocol.WriteMessage(stream, msg); err != nil {
return nil, err
}
// 读取响应
resp, err := protocol.ReadMessage(stream)
if err != nil {
return nil, err
}
if resp.Type != protocol.MsgTypeSystemStatsResponse {
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
}
var stats protocol.SystemStatsResponse
if err := resp.ParsePayload(&stats); err != nil {
return nil, err
}
return &stats, nil
}
// GetClientScreenshot 获取客户端截图
func (s *Server) GetClientScreenshot(clientID string, quality int) (*protocol.ScreenshotResponse, error) {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("client %s not online", clientID)
}
stream, err := cs.Session.Open()
if err != nil {
return nil, err
}
defer stream.Close()
// 设置超时
stream.SetDeadline(time.Now().Add(15 * time.Second))
// 发送请求
req := protocol.ScreenshotRequest{Quality: quality}
msg, _ := protocol.NewMessage(protocol.MsgTypeScreenshotRequest, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
return nil, err
}
// 读取响应
resp, err := protocol.ReadMessage(stream)
if err != nil {
return nil, err
}
if resp.Type != protocol.MsgTypeScreenshotResponse {
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
}
var screenshot protocol.ScreenshotResponse
if err := resp.ParsePayload(&screenshot); err != nil {
return nil, err
}
if screenshot.Error != "" {
return nil, fmt.Errorf("screenshot failed: %s", screenshot.Error)
}
return &screenshot, nil
}
// ExecuteClientShell 执行客户端 Shell 命令
func (s *Server) ExecuteClientShell(clientID, command string, timeout int) (*protocol.ShellExecuteResponse, error) {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return nil, fmt.Errorf("client %s not online", clientID)
}
stream, err := cs.Session.Open()
if err != nil {
return nil, err
}
defer stream.Close()
// 设置超时 (比命令超时长一点)
if timeout <= 0 {
timeout = 30
}
stream.SetDeadline(time.Now().Add(time.Duration(timeout+5) * time.Second))
// 发送请求
req := protocol.ShellExecuteRequest{
Command: command,
Timeout: timeout,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeShellExecuteRequest, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
return nil, err
}
// 读取响应
resp, err := protocol.ReadMessage(stream)
if err != nil {
return nil, err
}
if resp.Type != protocol.MsgTypeShellExecuteResponse {
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
}
var result protocol.ShellExecuteResponse
if err := resp.ParsePayload(&result); err != nil {
return nil, err
}
return &result, nil
}