- Deleted version comparison logic from `pkg/plugin/sign/version.go`. - Removed unused types and constants from `pkg/plugin/types.go`. - Updated `pkg/protocol/message.go` to remove plugin-related message types. - Enhanced `pkg/proxy/http.go` and `pkg/proxy/socks5.go` to include username/password authentication for HTTP and SOCKS5 proxies. - Modified `pkg/proxy/server.go` to pass authentication parameters to server constructors. - Added new API endpoint to generate installation commands with a token for clients. - Created database functions to manage installation tokens in `internal/server/db/install_token.go`. - Implemented the installation command generation logic in `internal/server/router/handler/install.go`. - Updated web frontend to support installation command generation and display in `web/src/views/ClientsView.vue`.
1293 lines
30 KiB
Go
1293 lines
30 KiB
Go
package tunnel
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"crypto/tls"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"log"
|
||
"net"
|
||
"regexp"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/gotunnel/internal/server/db"
|
||
"github.com/gotunnel/pkg/protocol"
|
||
"github.com/gotunnel/pkg/proxy"
|
||
"github.com/gotunnel/pkg/relay"
|
||
"github.com/gotunnel/pkg/security"
|
||
"github.com/gotunnel/pkg/utils"
|
||
"github.com/hashicorp/yamux"
|
||
)
|
||
|
||
// 服务端常量
|
||
const (
|
||
authTimeout = 10 * time.Second
|
||
heartbeatTimeout = 10 * time.Second
|
||
udpBufferSize = 65535
|
||
maxConnections = 10000 // 最大连接数
|
||
)
|
||
|
||
// 客户端 ID 验证正则
|
||
var clientIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,64}$`)
|
||
|
||
// isValidClientID 验证客户端 ID 格式
|
||
func isValidClientID(id string) bool {
|
||
return clientIDRegex.MatchString(id)
|
||
}
|
||
|
||
// generateClientID 生成随机客户端 ID
|
||
func generateClientID() string {
|
||
bytes := make([]byte, 8)
|
||
rand.Read(bytes)
|
||
return hex.EncodeToString(bytes)
|
||
}
|
||
|
||
// Server 隧道服务端
|
||
type Server struct {
|
||
clientStore db.ClientStore
|
||
trafficStore db.TrafficStore // 流量存储
|
||
bindAddr string
|
||
bindPort int
|
||
token string
|
||
heartbeat int
|
||
hbTimeout int
|
||
portManager *utils.PortManager
|
||
clients map[string]*ClientSession
|
||
mu sync.RWMutex
|
||
tlsConfig *tls.Config
|
||
connSem chan struct{} // 连接数信号量
|
||
activeConns int64 // 当前活跃连接数
|
||
listener net.Listener // 主监听器
|
||
shutdown chan struct{} // 关闭信号
|
||
wg sync.WaitGroup // 等待所有连接关闭
|
||
logSessions *LogSessionManager // 日志会话管理器
|
||
}
|
||
|
||
// ClientSession 客户端会话
|
||
type ClientSession struct {
|
||
ID string
|
||
Name string // 客户端名称(主机名)
|
||
RemoteAddr string // 客户端 IP 地址
|
||
OS string // 客户端操作系统
|
||
Arch string // 客户端架构
|
||
Version string // 客户端版本
|
||
Session *yamux.Session
|
||
Rules []protocol.ProxyRule
|
||
Listeners map[int]net.Listener
|
||
UDPConns map[int]*net.UDPConn // UDP 连接
|
||
LastPing time.Time
|
||
mu sync.Mutex
|
||
}
|
||
|
||
// NewServer 创建服务端
|
||
func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, heartbeat, hbTimeout int) *Server {
|
||
return &Server{
|
||
clientStore: cs,
|
||
bindAddr: bindAddr,
|
||
bindPort: bindPort,
|
||
token: token,
|
||
heartbeat: heartbeat,
|
||
hbTimeout: hbTimeout,
|
||
portManager: utils.NewPortManager(),
|
||
clients: make(map[string]*ClientSession),
|
||
connSem: make(chan struct{}, maxConnections),
|
||
shutdown: make(chan struct{}),
|
||
logSessions: NewLogSessionManager(),
|
||
}
|
||
}
|
||
|
||
// SetTLSConfig 设置 TLS 配置
|
||
func (s *Server) SetTLSConfig(config *tls.Config) {
|
||
s.tlsConfig = config
|
||
}
|
||
|
||
// Shutdown 优雅关闭服务端
|
||
func (s *Server) Shutdown(timeout time.Duration) error {
|
||
log.Printf("[Server] Initiating graceful shutdown...")
|
||
close(s.shutdown)
|
||
|
||
if s.listener != nil {
|
||
s.listener.Close()
|
||
}
|
||
|
||
// 关闭所有客户端会话
|
||
s.mu.Lock()
|
||
for _, cs := range s.clients {
|
||
cs.Session.Close()
|
||
}
|
||
s.mu.Unlock()
|
||
|
||
// 等待所有连接关闭,带超时
|
||
done := make(chan struct{})
|
||
go func() {
|
||
s.wg.Wait()
|
||
close(done)
|
||
}()
|
||
|
||
select {
|
||
case <-done:
|
||
log.Printf("[Server] All connections closed gracefully")
|
||
return nil
|
||
case <-time.After(timeout):
|
||
log.Printf("[Server] Shutdown timeout, forcing close")
|
||
return fmt.Errorf("shutdown timeout")
|
||
}
|
||
}
|
||
|
||
// SetTrafficStore 设置流量存储
|
||
func (s *Server) SetTrafficStore(store db.TrafficStore) {
|
||
s.trafficStore = store
|
||
}
|
||
|
||
// Run 启动服务端
|
||
func (s *Server) Run() error {
|
||
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
|
||
|
||
var ln net.Listener
|
||
var err error
|
||
|
||
if s.tlsConfig != nil {
|
||
ln, err = tls.Listen("tcp", addr, s.tlsConfig)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to listen TLS on %s: %v", addr, err)
|
||
}
|
||
log.Printf("[Server] TLS listening on %s", addr)
|
||
} else {
|
||
ln, err = net.Listen("tcp", addr)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to listen on %s: %v", addr, err)
|
||
}
|
||
log.Printf("[Server] Listening on %s (no TLS)", addr)
|
||
}
|
||
s.listener = ln
|
||
|
||
for {
|
||
select {
|
||
case <-s.shutdown:
|
||
log.Printf("[Server] Shutdown signal received, stopping accept loop")
|
||
ln.Close()
|
||
return nil
|
||
default:
|
||
}
|
||
|
||
conn, err := ln.Accept()
|
||
if err != nil {
|
||
select {
|
||
case <-s.shutdown:
|
||
return nil
|
||
default:
|
||
log.Printf("[Server] Accept error: %v", err)
|
||
continue
|
||
}
|
||
}
|
||
s.wg.Add(1)
|
||
go func() {
|
||
defer s.wg.Done()
|
||
s.handleConnection(conn)
|
||
}()
|
||
}
|
||
}
|
||
|
||
// handleConnection 处理客户端连接
|
||
func (s *Server) handleConnection(conn net.Conn) {
|
||
clientIP := conn.RemoteAddr().String()
|
||
|
||
// 连接数限制检查
|
||
select {
|
||
case s.connSem <- struct{}{}:
|
||
defer func() { <-s.connSem }()
|
||
default:
|
||
security.LogConnRejected(clientIP, "max connections reached")
|
||
conn.Close()
|
||
return
|
||
}
|
||
|
||
defer conn.Close()
|
||
|
||
conn.SetReadDeadline(time.Now().Add(authTimeout))
|
||
|
||
msg, err := protocol.ReadMessage(conn)
|
||
if err != nil {
|
||
log.Printf("[Server] Read auth error: %v", err)
|
||
return
|
||
}
|
||
|
||
if msg.Type != protocol.MsgTypeAuth {
|
||
log.Printf("[Server] Expected auth, got %d", msg.Type)
|
||
return
|
||
}
|
||
|
||
var authReq protocol.AuthRequest
|
||
if err := msg.ParsePayload(&authReq); err != nil {
|
||
log.Printf("[Server] Parse auth error: %v", err)
|
||
return
|
||
}
|
||
|
||
// 验证token:支持常规token和一次性安装token
|
||
validToken := authReq.Token == s.token
|
||
var isInstallToken bool
|
||
|
||
if !validToken {
|
||
// 尝试验证安装token
|
||
if tokenStore, ok := s.clientStore.(db.InstallTokenStore); ok {
|
||
if installToken, err := tokenStore.GetInstallToken(authReq.Token); err == nil {
|
||
if !installToken.Used && time.Now().Unix()-installToken.CreatedAt < 3600 {
|
||
// token有效且未过期
|
||
validToken = true
|
||
isInstallToken = true
|
||
// 验证客户端ID匹配
|
||
if authReq.ClientID != "" && authReq.ClientID != installToken.ClientID {
|
||
security.LogInvalidClientID(clientIP, authReq.ClientID)
|
||
s.sendAuthResponse(conn, false, "client id mismatch", "")
|
||
return
|
||
}
|
||
// 使用token中的客户端ID
|
||
authReq.ClientID = installToken.ClientID
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if !validToken {
|
||
security.LogInvalidToken(clientIP)
|
||
s.sendAuthResponse(conn, false, "invalid token", "")
|
||
return
|
||
}
|
||
|
||
// 处理客户端 ID
|
||
clientID := authReq.ClientID
|
||
if clientID == "" {
|
||
clientID = generateClientID()
|
||
} else if !isValidClientID(clientID) {
|
||
security.LogInvalidClientID(clientIP, clientID)
|
||
s.sendAuthResponse(conn, false, "invalid client id format", "")
|
||
return
|
||
}
|
||
|
||
// 检查客户端是否存在,不存在则自动创建
|
||
exists, err := s.clientStore.ClientExists(clientID)
|
||
if err != nil || !exists {
|
||
newClient := &db.Client{ID: clientID, Nickname: authReq.Name, Rules: []protocol.ProxyRule{}}
|
||
if err := s.clientStore.CreateClient(newClient); err != nil {
|
||
log.Printf("[Server] Create client error: %v", err)
|
||
s.sendAuthResponse(conn, false, "failed to create client", "")
|
||
return
|
||
}
|
||
log.Printf("[Server] New client registered: %s (%s)", clientID, authReq.Name)
|
||
} else if authReq.Name != "" {
|
||
// 客户端已存在,仅当 Nickname 为空时才用客户端名称更新
|
||
// 这样服务端手动设置的名称不会被客户端覆盖
|
||
if client, err := s.clientStore.GetClient(clientID); err == nil {
|
||
if client.Nickname == "" {
|
||
client.Nickname = authReq.Name
|
||
s.clientStore.UpdateClient(client)
|
||
}
|
||
}
|
||
}
|
||
|
||
rules, _ := s.clientStore.GetClientRules(clientID)
|
||
if rules == nil {
|
||
rules = []protocol.ProxyRule{}
|
||
}
|
||
|
||
// 如果使用安装token,标记为已使用
|
||
if isInstallToken {
|
||
if tokenStore, ok := s.clientStore.(db.InstallTokenStore); ok {
|
||
tokenStore.MarkTokenUsed(authReq.Token)
|
||
}
|
||
}
|
||
|
||
conn.SetReadDeadline(time.Time{})
|
||
|
||
if err := s.sendAuthResponse(conn, true, "ok", clientID); err != nil {
|
||
return
|
||
}
|
||
|
||
security.LogAuthSuccess(clientIP, clientID)
|
||
s.setupClientSession(conn, clientID, authReq.Name, authReq.OS, authReq.Arch, authReq.Version, rules)
|
||
}
|
||
|
||
// setupClientSession 建立客户端会话
|
||
func (s *Server) setupClientSession(conn net.Conn, clientID, clientName, clientOS, clientArch, clientVersion string, rules []protocol.ProxyRule) {
|
||
session, err := yamux.Server(conn, nil)
|
||
if err != nil {
|
||
log.Printf("[Server] Yamux error: %v", err)
|
||
return
|
||
}
|
||
|
||
// 提取客户端 IP(去掉端口)
|
||
remoteAddr := conn.RemoteAddr().String()
|
||
if host, _, err := net.SplitHostPort(remoteAddr); err == nil {
|
||
remoteAddr = host
|
||
}
|
||
|
||
cs := &ClientSession{
|
||
ID: clientID,
|
||
Name: clientName,
|
||
RemoteAddr: remoteAddr,
|
||
OS: clientOS,
|
||
Arch: clientArch,
|
||
Version: clientVersion,
|
||
Session: session,
|
||
Rules: rules,
|
||
Listeners: make(map[int]net.Listener),
|
||
UDPConns: make(map[int]*net.UDPConn),
|
||
LastPing: time.Now(),
|
||
}
|
||
|
||
s.registerClient(cs)
|
||
defer s.unregisterClient(cs)
|
||
|
||
// 启动代理监听器(会更新 rules 的 PortStatus)
|
||
s.startProxyListeners(cs)
|
||
|
||
// 发送配置到客户端(包含端口状态)
|
||
if err := s.sendProxyConfig(session, cs.Rules); err != nil {
|
||
log.Printf("[Server] Send config error: %v", err)
|
||
return
|
||
}
|
||
|
||
go s.heartbeatLoop(cs)
|
||
|
||
<-session.CloseChan()
|
||
log.Printf("[Server] Client %s disconnected", clientID)
|
||
}
|
||
|
||
// sendAuthResponse 发送认证响应
|
||
func (s *Server) sendAuthResponse(conn net.Conn, success bool, message, clientID string) error {
|
||
resp := protocol.AuthResponse{Success: success, Message: message, ClientID: clientID}
|
||
msg, err := protocol.NewMessage(protocol.MsgTypeAuthResp, resp)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return protocol.WriteMessage(conn, msg)
|
||
}
|
||
|
||
// sendProxyConfig 发送代理配置并等待客户端确认
|
||
func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyRule) error {
|
||
stream, err := session.Open()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer stream.Close()
|
||
|
||
cfg := protocol.ProxyConfig{Rules: rules}
|
||
msg, err := protocol.NewMessage(protocol.MsgTypeProxyConfig, cfg)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 等待客户端确认
|
||
ack, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return fmt.Errorf("wait config ack: %w", err)
|
||
}
|
||
if ack.Type != protocol.MsgTypeProxyReady {
|
||
return fmt.Errorf("unexpected ack type: %d", ack.Type)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// registerClient 注册客户端
|
||
func (s *Server) registerClient(cs *ClientSession) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
s.clients[cs.ID] = cs
|
||
}
|
||
|
||
// unregisterClient 注销客户端
|
||
func (s *Server) unregisterClient(cs *ClientSession) {
|
||
s.mu.Lock()
|
||
defer s.mu.Unlock()
|
||
|
||
cs.mu.Lock()
|
||
for port, ln := range cs.Listeners {
|
||
ln.Close()
|
||
s.portManager.Release(port)
|
||
}
|
||
for port, conn := range cs.UDPConns {
|
||
conn.Close()
|
||
s.portManager.Release(port)
|
||
}
|
||
cs.mu.Unlock()
|
||
|
||
delete(s.clients, cs.ID)
|
||
}
|
||
|
||
// stopProxyListeners 停止代理监听
|
||
func (s *Server) stopProxyListeners(cs *ClientSession) {
|
||
cs.mu.Lock()
|
||
defer cs.mu.Unlock()
|
||
|
||
// 关闭 TCP 监听器
|
||
for port, ln := range cs.Listeners {
|
||
ln.Close()
|
||
s.portManager.Release(port)
|
||
}
|
||
cs.Listeners = make(map[int]net.Listener)
|
||
|
||
// 关闭 UDP 连接
|
||
for port, conn := range cs.UDPConns {
|
||
conn.Close()
|
||
s.portManager.Release(port)
|
||
}
|
||
cs.UDPConns = make(map[int]*net.UDPConn)
|
||
}
|
||
|
||
// startProxyListeners 启动代理监听
|
||
func (s *Server) startProxyListeners(cs *ClientSession) {
|
||
for i := range cs.Rules {
|
||
rule := &cs.Rules[i]
|
||
if !rule.IsEnabled() {
|
||
continue
|
||
}
|
||
|
||
ruleType := rule.Type
|
||
if ruleType == "" {
|
||
ruleType = "tcp"
|
||
}
|
||
|
||
// UDP 单独处理
|
||
if ruleType == "udp" {
|
||
s.startUDPListener(cs, rule)
|
||
continue
|
||
}
|
||
|
||
// TCP 类型
|
||
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
|
||
log.Printf("[Server] Port %d error: %v", rule.RemotePort, err)
|
||
rule.PortStatus = "failed: " + err.Error()
|
||
continue
|
||
}
|
||
|
||
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.RemotePort))
|
||
if err != nil {
|
||
log.Printf("[Server] Listen %d error: %v", rule.RemotePort, err)
|
||
s.portManager.Release(rule.RemotePort)
|
||
rule.PortStatus = "failed: " + err.Error()
|
||
continue
|
||
}
|
||
|
||
rule.PortStatus = "listening"
|
||
|
||
cs.mu.Lock()
|
||
cs.Listeners[rule.RemotePort] = ln
|
||
cs.mu.Unlock()
|
||
|
||
switch ruleType {
|
||
case "socks5":
|
||
log.Printf("[Server] SOCKS5 proxy %s on :%d", rule.Name, rule.RemotePort)
|
||
go s.acceptProxyServerConns(cs, ln, *rule)
|
||
case "http", "https":
|
||
log.Printf("[Server] HTTP proxy %s on :%d", rule.Name, rule.RemotePort)
|
||
go s.acceptProxyServerConns(cs, ln, *rule)
|
||
case "websocket":
|
||
log.Printf("[Server] Websocket proxy %s on :%d", rule.Name, rule.RemotePort)
|
||
go s.acceptWebsocketConns(cs, ln, *rule)
|
||
default:
|
||
log.Printf("[Server] TCP proxy %s: :%d -> %s:%d",
|
||
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
|
||
go s.acceptProxyConns(cs, ln, *rule)
|
||
}
|
||
}
|
||
}
|
||
|
||
// acceptProxyConns 接受代理连接
|
||
func (s *Server) acceptProxyConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
|
||
for {
|
||
conn, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
go s.handleProxyConn(cs, conn, rule)
|
||
}
|
||
}
|
||
|
||
// acceptProxyServerConns 接受 SOCKS5/HTTP 代理连接
|
||
func (s *Server) acceptProxyServerConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
|
||
dialer := proxy.NewTunnelDialer(cs.Session)
|
||
|
||
// 使用内置 proxy 实现 (带流量统计和认证)
|
||
username := ""
|
||
password := ""
|
||
if rule.AuthEnabled {
|
||
username = rule.AuthUsername
|
||
password = rule.AuthPassword
|
||
}
|
||
proxyServer := proxy.NewServer(rule.Type, dialer, s.recordTraffic, username, password)
|
||
for {
|
||
conn, err := ln.Accept()
|
||
if err != nil {
|
||
return
|
||
}
|
||
go proxyServer.HandleConn(conn)
|
||
}
|
||
}
|
||
|
||
// handleProxyConn 处理代理连接
|
||
func (s *Server) handleProxyConn(cs *ClientSession, conn net.Conn, rule protocol.ProxyRule) {
|
||
defer conn.Close()
|
||
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
log.Printf("[Server] Open stream error: %v", err)
|
||
return
|
||
}
|
||
defer stream.Close()
|
||
|
||
req := protocol.NewProxyRequest{RemotePort: rule.RemotePort}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeNewProxy, req)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return
|
||
}
|
||
|
||
relay.RelayWithStats(conn, stream, s.recordTraffic)
|
||
}
|
||
|
||
// heartbeatLoop 心跳检测循环
|
||
func (s *Server) heartbeatLoop(cs *ClientSession) {
|
||
ticker := time.NewTicker(time.Duration(s.heartbeat) * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
timeout := time.Duration(s.hbTimeout) * time.Second
|
||
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
cs.mu.Lock()
|
||
if time.Since(cs.LastPing) > timeout {
|
||
cs.mu.Unlock()
|
||
log.Printf("[Server] Client %s heartbeat timeout", cs.ID)
|
||
cs.Session.Close()
|
||
return
|
||
}
|
||
cs.mu.Unlock()
|
||
|
||
// 发送心跳并等待响应
|
||
if s.sendHeartbeat(cs) {
|
||
cs.mu.Lock()
|
||
cs.LastPing = time.Now()
|
||
cs.mu.Unlock()
|
||
}
|
||
|
||
case <-cs.Session.CloseChan():
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// sendHeartbeat 发送心跳并等待响应
|
||
func (s *Server) sendHeartbeat(cs *ClientSession) bool {
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return false
|
||
}
|
||
defer stream.Close()
|
||
|
||
// 设置读写超时
|
||
stream.SetDeadline(time.Now().Add(heartbeatTimeout))
|
||
|
||
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeat}
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return false
|
||
}
|
||
|
||
// 等待心跳响应
|
||
resp, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
return resp.Type == protocol.MsgTypeHeartbeatAck
|
||
}
|
||
|
||
// GetClientStatus 获取客户端状态
|
||
func (s *Server) GetClientStatus(clientID string) (online bool, lastPing, remoteAddr, clientName, clientOS, clientArch, clientVersion string) {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
|
||
if cs, ok := s.clients[clientID]; ok {
|
||
cs.mu.Lock()
|
||
defer cs.mu.Unlock()
|
||
return true, cs.LastPing.Format(time.RFC3339), cs.RemoteAddr, cs.Name, cs.OS, cs.Arch, cs.Version
|
||
}
|
||
return false, "", "", "", "", "", ""
|
||
}
|
||
|
||
// IsClientOnline 检查客户端是否在线
|
||
func (s *Server) IsClientOnline(clientID string) bool {
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
_, ok := s.clients[clientID]
|
||
return ok
|
||
}
|
||
|
||
// GetAllClientStatus 获取所有客户端状态
|
||
func (s *Server) GetAllClientStatus() map[string]struct {
|
||
Online bool
|
||
LastPing string
|
||
RemoteAddr string
|
||
Name string
|
||
OS string
|
||
Arch string
|
||
Version string
|
||
} {
|
||
// 先复制客户端引用,避免嵌套锁
|
||
s.mu.RLock()
|
||
clients := make([]*ClientSession, 0, len(s.clients))
|
||
for _, cs := range s.clients {
|
||
clients = append(clients, cs)
|
||
}
|
||
s.mu.RUnlock()
|
||
|
||
result := make(map[string]struct {
|
||
Online bool
|
||
LastPing string
|
||
RemoteAddr string
|
||
Name string
|
||
OS string
|
||
Arch string
|
||
Version string
|
||
})
|
||
|
||
for _, cs := range clients {
|
||
cs.mu.Lock()
|
||
result[cs.ID] = struct {
|
||
Online bool
|
||
LastPing string
|
||
RemoteAddr string
|
||
Name string
|
||
OS string
|
||
Arch string
|
||
Version string
|
||
}{
|
||
Online: true,
|
||
LastPing: cs.LastPing.Format(time.RFC3339),
|
||
RemoteAddr: cs.RemoteAddr,
|
||
Name: cs.Name,
|
||
OS: cs.OS,
|
||
Arch: cs.Arch,
|
||
Version: cs.Version,
|
||
}
|
||
cs.mu.Unlock()
|
||
}
|
||
return result
|
||
}
|
||
|
||
// ReloadConfig 重新加载配置
|
||
// 注意: 当前版本不支持热重载,需要重启服务
|
||
func (s *Server) ReloadConfig() error {
|
||
return fmt.Errorf("hot reload not supported, please restart the server")
|
||
}
|
||
|
||
// GetBindAddr 获取绑定地址
|
||
func (s *Server) GetBindAddr() string {
|
||
return s.bindAddr
|
||
}
|
||
|
||
// GetBindPort 获取绑定端口
|
||
func (s *Server) GetBindPort() int {
|
||
return s.bindPort
|
||
}
|
||
|
||
// PushConfigToClient 推送配置到客户端
|
||
func (s *Server) PushConfigToClient(clientID string) error {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return fmt.Errorf("client %s not found", clientID)
|
||
}
|
||
|
||
rules, err := s.clientStore.GetClientRules(clientID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 停止旧的监听器
|
||
s.stopProxyListeners(cs)
|
||
|
||
// 更新规则
|
||
cs.mu.Lock()
|
||
cs.Rules = rules
|
||
cs.mu.Unlock()
|
||
|
||
// 启动新的监听器
|
||
s.startProxyListeners(cs)
|
||
|
||
// 检查是否有端口启动失败
|
||
var failedPorts []string
|
||
for _, rule := range cs.Rules {
|
||
if rule.IsEnabled() && strings.HasPrefix(rule.PortStatus, "failed:") {
|
||
failedPorts = append(failedPorts, fmt.Sprintf("port %d: %s", rule.RemotePort, strings.TrimPrefix(rule.PortStatus, "failed: ")))
|
||
}
|
||
}
|
||
|
||
// 发送配置到客户端
|
||
if err := s.sendProxyConfig(cs.Session, cs.Rules); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 如果有端口失败,返回错误
|
||
if len(failedPorts) > 0 {
|
||
return fmt.Errorf("some ports failed to start: %s", strings.Join(failedPorts, "; "))
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// DisconnectClient 断开客户端连接
|
||
func (s *Server) DisconnectClient(clientID string) error {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return fmt.Errorf("client %s not found", clientID)
|
||
}
|
||
|
||
return cs.Session.Close()
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
// startUDPListener 启动 UDP 监听
|
||
func (s *Server) startUDPListener(cs *ClientSession, rule *protocol.ProxyRule) {
|
||
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
|
||
log.Printf("[Server] UDP port %d error: %v", rule.RemotePort, err)
|
||
rule.PortStatus = "failed: " + err.Error()
|
||
return
|
||
}
|
||
|
||
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.RemotePort))
|
||
if err != nil {
|
||
log.Printf("[Server] UDP resolve error: %v", err)
|
||
s.portManager.Release(rule.RemotePort)
|
||
rule.PortStatus = "failed: " + err.Error()
|
||
return
|
||
}
|
||
|
||
conn, err := net.ListenUDP("udp", addr)
|
||
if err != nil {
|
||
log.Printf("[Server] UDP listen %d error: %v", rule.RemotePort, err)
|
||
s.portManager.Release(rule.RemotePort)
|
||
rule.PortStatus = "failed: " + err.Error()
|
||
return
|
||
}
|
||
|
||
rule.PortStatus = "listening"
|
||
|
||
cs.mu.Lock()
|
||
cs.UDPConns[rule.RemotePort] = conn
|
||
cs.mu.Unlock()
|
||
|
||
log.Printf("[Server] UDP proxy %s: :%d -> %s:%d",
|
||
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
|
||
|
||
go s.handleUDPConn(cs, conn, *rule)
|
||
}
|
||
|
||
// handleUDPConn 处理 UDP 连接
|
||
func (s *Server) handleUDPConn(cs *ClientSession, conn *net.UDPConn, rule protocol.ProxyRule) {
|
||
buf := make([]byte, udpBufferSize)
|
||
|
||
for {
|
||
n, clientAddr, err := conn.ReadFromUDP(buf)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 封装 UDP 数据包发送到客户端
|
||
packet := protocol.UDPPacket{
|
||
RemotePort: rule.RemotePort,
|
||
ClientAddr: clientAddr.String(),
|
||
Data: buf[:n],
|
||
}
|
||
|
||
go s.sendUDPPacket(cs, conn, clientAddr, packet)
|
||
}
|
||
}
|
||
|
||
// sendUDPPacket 发送 UDP 数据包到客户端
|
||
func (s *Server) sendUDPPacket(cs *ClientSession, conn *net.UDPConn, clientAddr *net.UDPAddr, packet protocol.UDPPacket) {
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return
|
||
}
|
||
defer stream.Close()
|
||
|
||
msg, err := protocol.NewMessage(protocol.MsgTypeUDPData, packet)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return
|
||
}
|
||
|
||
// 记录入站流量 (从外部接收的数据)
|
||
s.recordTraffic(int64(len(packet.Data)), 0)
|
||
|
||
// 等待客户端响应
|
||
respMsg, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if respMsg.Type == protocol.MsgTypeUDPData {
|
||
var respPacket protocol.UDPPacket
|
||
if err := respMsg.ParsePayload(&respPacket); err != nil {
|
||
return
|
||
}
|
||
conn.WriteToUDP(respPacket.Data, clientAddr)
|
||
// 记录出站流量 (发送回外部的数据)
|
||
s.recordTraffic(0, int64(len(respPacket.Data)))
|
||
}
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
|
||
// checkHTTPBasicAuth 检查 HTTP Basic Auth
|
||
// 返回 (认证成功, 已读取的数据)
|
||
func (s *Server) checkHTTPBasicAuth(conn net.Conn, username, password string) (bool, []byte) {
|
||
// 设置读取超时
|
||
conn.SetReadDeadline(time.Now().Add(10 * time.Second))
|
||
defer conn.SetReadDeadline(time.Time{}) // 重置超时
|
||
|
||
// 读取 HTTP 请求头
|
||
buf := make([]byte, 8192) // 增大缓冲区以处理更大的请求头
|
||
n, err := conn.Read(buf)
|
||
if err != nil {
|
||
return false, nil
|
||
}
|
||
|
||
data := buf[:n]
|
||
request := string(data)
|
||
|
||
// 解析 Authorization 头
|
||
authHeader := ""
|
||
lines := strings.Split(request, "\r\n")
|
||
for _, line := range lines {
|
||
if strings.HasPrefix(strings.ToLower(line), "authorization:") {
|
||
authHeader = strings.TrimSpace(line[14:])
|
||
break
|
||
}
|
||
}
|
||
|
||
// 检查 Basic Auth
|
||
if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") {
|
||
s.sendHTTPUnauthorized(conn)
|
||
return false, nil
|
||
}
|
||
|
||
// 解码 Base64
|
||
encoded := authHeader[6:]
|
||
decoded, err := base64.StdEncoding.DecodeString(encoded)
|
||
if err != nil {
|
||
s.sendHTTPUnauthorized(conn)
|
||
return false, nil
|
||
}
|
||
|
||
// 解析 username:password
|
||
credentials := string(decoded)
|
||
parts := strings.SplitN(credentials, ":", 2)
|
||
if len(parts) != 2 {
|
||
s.sendHTTPUnauthorized(conn)
|
||
return false, nil
|
||
}
|
||
|
||
if parts[0] != username || parts[1] != password {
|
||
s.sendHTTPUnauthorized(conn)
|
||
return false, nil
|
||
}
|
||
|
||
return true, data
|
||
}
|
||
|
||
// sendHTTPUnauthorized 发送 401 未授权响应
|
||
func (s *Server) sendHTTPUnauthorized(conn net.Conn) {
|
||
response := "HTTP/1.1 401 Unauthorized\r\n" +
|
||
"WWW-Authenticate: Basic realm=\"GoTunnel Plugin\"\r\n" +
|
||
"Content-Type: text/plain\r\n" +
|
||
"Content-Length: 12\r\n" +
|
||
"\r\n" +
|
||
"Unauthorized"
|
||
conn.Write([]byte(response))
|
||
}
|
||
|
||
|
||
|
||
// shouldPushToClient 检查是否应推送到指定客户端
|
||
func (s *Server) shouldPushToClient(autoPush []string, clientID string) bool {
|
||
if len(autoPush) == 0 {
|
||
return true
|
||
}
|
||
for _, id := range autoPush {
|
||
if id == clientID || id == "*" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// RestartClient 重启客户端(通过断开连接,让客户端自动重连)
|
||
func (s *Server) RestartClient(clientID string) error {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return fmt.Errorf("client %s not found or not online", clientID)
|
||
}
|
||
|
||
// 发送重启消息
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
req := protocol.ClientRestartRequest{
|
||
Reason: "server requested restart",
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeClientRestart, req)
|
||
protocol.WriteMessage(stream, msg)
|
||
stream.Close()
|
||
|
||
// 等待一小段时间后断开连接
|
||
time.AfterFunc(100*time.Millisecond, func() {
|
||
cs.Session.Close()
|
||
})
|
||
|
||
log.Printf("[Server] Restart initiated for client %s", clientID)
|
||
return nil
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
// IsPortAvailable 检查端口是否可用
|
||
func (s *Server) IsPortAvailable(port int, excludeClientID string) bool {
|
||
// 检查系统端口
|
||
if !utils.IsPortAvailable(port) {
|
||
return false
|
||
}
|
||
// 检查是否被其他客户端占用
|
||
s.mu.RLock()
|
||
defer s.mu.RUnlock()
|
||
for clientID, cs := range s.clients {
|
||
if clientID == excludeClientID {
|
||
continue
|
||
}
|
||
cs.mu.Lock()
|
||
_, occupied := cs.Listeners[port]
|
||
cs.mu.Unlock()
|
||
if occupied {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
|
||
|
||
|
||
|
||
|
||
// SendUpdateToClient 发送更新命令到客户端
|
||
func (s *Server) SendUpdateToClient(clientID, downloadURL string) error {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return fmt.Errorf("client %s not found or not online", clientID)
|
||
}
|
||
|
||
// 发送更新消息
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return err
|
||
}
|
||
defer stream.Close()
|
||
|
||
req := protocol.UpdateDownloadRequest{
|
||
DownloadURL: downloadURL,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeUpdateDownload, req)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return err
|
||
}
|
||
|
||
log.Printf("[Server] Update command sent to client %s: %s", clientID, downloadURL)
|
||
return nil
|
||
}
|
||
|
||
// StartClientLogStream 启动客户端日志流
|
||
func (s *Server) StartClientLogStream(clientID, sessionID string, lines int, follow bool, level string) (<-chan protocol.LogEntry, error) {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, fmt.Errorf("client %s not found or not online", clientID)
|
||
}
|
||
|
||
// 打开到客户端的流
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 发送日志请求
|
||
req := protocol.LogRequest{
|
||
SessionID: sessionID,
|
||
Lines: lines,
|
||
Follow: follow,
|
||
Level: level,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeLogRequest, req)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
stream.Close()
|
||
return nil, err
|
||
}
|
||
|
||
// 创建会话
|
||
session := s.logSessions.CreateSession(clientID, sessionID, stream)
|
||
listener := session.AddListener()
|
||
|
||
// 启动 goroutine 读取客户端日志
|
||
go s.readClientLogs(session, stream)
|
||
|
||
return listener, nil
|
||
}
|
||
|
||
// readClientLogs 读取客户端日志并广播到监听器
|
||
func (s *Server) readClientLogs(session *LogSession, stream net.Conn) {
|
||
defer s.logSessions.RemoveSession(session.ID)
|
||
|
||
for {
|
||
msg, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
if msg.Type != protocol.MsgTypeLogData {
|
||
continue
|
||
}
|
||
|
||
var data protocol.LogData
|
||
if err := msg.ParsePayload(&data); err != nil {
|
||
continue
|
||
}
|
||
|
||
for _, entry := range data.Entries {
|
||
session.Broadcast(entry)
|
||
}
|
||
|
||
if data.EOF {
|
||
return
|
||
}
|
||
}
|
||
}
|
||
|
||
// StopClientLogStream 停止客户端日志流
|
||
func (s *Server) StopClientLogStream(sessionID string) {
|
||
session := s.logSessions.GetSession(sessionID)
|
||
if session == nil {
|
||
return
|
||
}
|
||
|
||
// 发送停止请求到客户端
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[session.ClientID]
|
||
s.mu.RUnlock()
|
||
|
||
if ok {
|
||
stream, err := cs.Session.Open()
|
||
if err == nil {
|
||
req := protocol.LogStopRequest{SessionID: sessionID}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeLogStop, req)
|
||
protocol.WriteMessage(stream, msg)
|
||
stream.Close()
|
||
}
|
||
}
|
||
|
||
s.logSessions.RemoveSession(sessionID)
|
||
}
|
||
|
||
// recordTraffic 记录流量统计
|
||
func (s *Server) recordTraffic(inbound, outbound int64) {
|
||
if s.trafficStore == nil {
|
||
return
|
||
}
|
||
if err := s.trafficStore.AddTraffic(inbound, outbound); err != nil {
|
||
log.Printf("[Server] Record traffic error: %v", err)
|
||
}
|
||
}
|
||
|
||
// boolPtr 返回 bool 值的指针
|
||
func boolPtr(b bool) *bool {
|
||
return &b
|
||
}
|
||
|
||
// GetClientSystemStats 获取客户端系统状态
|
||
func (s *Server) GetClientSystemStats(clientID string) (*protocol.SystemStatsResponse, error) {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, fmt.Errorf("client %s not online", clientID)
|
||
}
|
||
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer stream.Close()
|
||
|
||
// 设置超时
|
||
stream.SetDeadline(time.Now().Add(10 * time.Second))
|
||
|
||
// 发送请求
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeSystemStatsRequest, nil)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 读取响应
|
||
resp, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp.Type != protocol.MsgTypeSystemStatsResponse {
|
||
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
|
||
}
|
||
|
||
var stats protocol.SystemStatsResponse
|
||
if err := resp.ParsePayload(&stats); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &stats, nil
|
||
}
|
||
|
||
// GetClientScreenshot 获取客户端截图
|
||
func (s *Server) GetClientScreenshot(clientID string, quality int) (*protocol.ScreenshotResponse, error) {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, fmt.Errorf("client %s not online", clientID)
|
||
}
|
||
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer stream.Close()
|
||
|
||
// 设置超时
|
||
stream.SetDeadline(time.Now().Add(15 * time.Second))
|
||
|
||
// 发送请求
|
||
req := protocol.ScreenshotRequest{Quality: quality}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeScreenshotRequest, req)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 读取响应
|
||
resp, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp.Type != protocol.MsgTypeScreenshotResponse {
|
||
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
|
||
}
|
||
|
||
var screenshot protocol.ScreenshotResponse
|
||
if err := resp.ParsePayload(&screenshot); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if screenshot.Error != "" {
|
||
return nil, fmt.Errorf("screenshot failed: %s", screenshot.Error)
|
||
}
|
||
|
||
return &screenshot, nil
|
||
}
|
||
|
||
// ExecuteClientShell 执行客户端 Shell 命令
|
||
func (s *Server) ExecuteClientShell(clientID, command string, timeout int) (*protocol.ShellExecuteResponse, error) {
|
||
s.mu.RLock()
|
||
cs, ok := s.clients[clientID]
|
||
s.mu.RUnlock()
|
||
|
||
if !ok {
|
||
return nil, fmt.Errorf("client %s not online", clientID)
|
||
}
|
||
|
||
stream, err := cs.Session.Open()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
defer stream.Close()
|
||
|
||
// 设置超时 (比命令超时长一点)
|
||
if timeout <= 0 {
|
||
timeout = 30
|
||
}
|
||
stream.SetDeadline(time.Now().Add(time.Duration(timeout+5) * time.Second))
|
||
|
||
// 发送请求
|
||
req := protocol.ShellExecuteRequest{
|
||
Command: command,
|
||
Timeout: timeout,
|
||
}
|
||
msg, _ := protocol.NewMessage(protocol.MsgTypeShellExecuteRequest, req)
|
||
if err := protocol.WriteMessage(stream, msg); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
// 读取响应
|
||
resp, err := protocol.ReadMessage(stream)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp.Type != protocol.MsgTypeShellExecuteResponse {
|
||
return nil, fmt.Errorf("unexpected response type: %d", resp.Type)
|
||
}
|
||
|
||
var result protocol.ShellExecuteResponse
|
||
if err := resp.ParsePayload(&result); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &result, nil
|
||
}
|