Files
GoTunnel/internal/server/tunnel/server.go
Flik 549f9aaf26
All checks were successful
Build Multi-Platform Binaries / build (push) Successful in 11m54s
update
2025-12-26 17:14:54 +08:00

671 lines
16 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tunnel
import (
"crypto/rand"
"crypto/tls"
"encoding/hex"
"fmt"
"log"
"net"
"sync"
"time"
"github.com/gotunnel/internal/server/db"
"github.com/gotunnel/internal/server/router"
"github.com/gotunnel/pkg/plugin"
"github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/proxy"
"github.com/gotunnel/pkg/relay"
"github.com/gotunnel/pkg/utils"
"github.com/hashicorp/yamux"
)
// 服务端常量
const (
authTimeout = 10 * time.Second
heartbeatTimeout = 10 * time.Second
udpBufferSize = 65535
)
// generateClientID 生成随机客户端 ID
func generateClientID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// Server 隧道服务端
type Server struct {
clientStore db.ClientStore
bindAddr string
bindPort int
token string
heartbeat int
hbTimeout int
portManager *utils.PortManager
clients map[string]*ClientSession
mu sync.RWMutex
tlsConfig *tls.Config
pluginRegistry *plugin.Registry
}
// ClientSession 客户端会话
type ClientSession struct {
ID string
Session *yamux.Session
Rules []protocol.ProxyRule
Listeners map[int]net.Listener
UDPConns map[int]*net.UDPConn // UDP 连接
LastPing time.Time
mu sync.Mutex
}
// NewServer 创建服务端
func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, heartbeat, hbTimeout int) *Server {
return &Server{
clientStore: cs,
bindAddr: bindAddr,
bindPort: bindPort,
token: token,
heartbeat: heartbeat,
hbTimeout: hbTimeout,
portManager: utils.NewPortManager(),
clients: make(map[string]*ClientSession),
}
}
// SetTLSConfig 设置 TLS 配置
func (s *Server) SetTLSConfig(config *tls.Config) {
s.tlsConfig = config
}
// SetPluginRegistry 设置插件注册表
func (s *Server) SetPluginRegistry(registry *plugin.Registry) {
s.pluginRegistry = registry
}
// Run 启动服务端
func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
var ln net.Listener
var err error
if s.tlsConfig != nil {
ln, err = tls.Listen("tcp", addr, s.tlsConfig)
if err != nil {
return fmt.Errorf("failed to listen TLS on %s: %v", addr, err)
}
log.Printf("[Server] TLS listening on %s", addr)
} else {
ln, err = net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %v", addr, err)
}
log.Printf("[Server] Listening on %s (no TLS)", addr)
}
defer ln.Close()
for {
conn, err := ln.Accept()
if err != nil {
log.Printf("[Server] Accept error: %v", err)
continue
}
go s.handleConnection(conn)
}
}
// handleConnection 处理客户端连接
func (s *Server) handleConnection(conn net.Conn) {
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(authTimeout))
msg, err := protocol.ReadMessage(conn)
if err != nil {
log.Printf("[Server] Read auth error: %v", err)
return
}
if msg.Type != protocol.MsgTypeAuth {
log.Printf("[Server] Expected auth, got %d", msg.Type)
return
}
var authReq protocol.AuthRequest
if err := msg.ParsePayload(&authReq); err != nil {
log.Printf("[Server] Parse auth error: %v", err)
return
}
if authReq.Token != s.token {
s.sendAuthResponse(conn, false, "invalid token", "")
return
}
// 如果客户端没有提供 ID则生成一个新的
clientID := authReq.ClientID
if clientID == "" {
clientID = generateClientID()
// 创建新客户端记录
newClient := &db.Client{ID: clientID, Rules: []protocol.ProxyRule{}}
if err := s.clientStore.CreateClient(newClient); err != nil {
log.Printf("[Server] Create client error: %v", err)
s.sendAuthResponse(conn, false, "failed to create client", "")
return
}
log.Printf("[Server] New client registered: %s", clientID)
}
// 检查客户端是否存在
exists, err := s.clientStore.ClientExists(clientID)
if err != nil || !exists {
s.sendAuthResponse(conn, false, "client not found", "")
return
}
rules, _ := s.clientStore.GetClientRules(clientID)
if rules == nil {
rules = []protocol.ProxyRule{}
}
conn.SetReadDeadline(time.Time{})
if err := s.sendAuthResponse(conn, true, "ok", clientID); err != nil {
return
}
log.Printf("[Server] Client %s authenticated", clientID)
s.setupClientSession(conn, clientID, rules)
}
// setupClientSession 建立客户端会话
func (s *Server) setupClientSession(conn net.Conn, clientID string, rules []protocol.ProxyRule) {
session, err := yamux.Server(conn, nil)
if err != nil {
log.Printf("[Server] Yamux error: %v", err)
return
}
cs := &ClientSession{
ID: clientID,
Session: session,
Rules: rules,
Listeners: make(map[int]net.Listener),
UDPConns: make(map[int]*net.UDPConn),
LastPing: time.Now(),
}
s.registerClient(cs)
defer s.unregisterClient(cs)
if err := s.sendProxyConfig(session, rules); err != nil {
log.Printf("[Server] Send config error: %v", err)
return
}
s.startProxyListeners(cs)
go s.heartbeatLoop(cs)
<-session.CloseChan()
log.Printf("[Server] Client %s disconnected", clientID)
}
// sendAuthResponse 发送认证响应
func (s *Server) sendAuthResponse(conn net.Conn, success bool, message, clientID string) error {
resp := protocol.AuthResponse{Success: success, Message: message, ClientID: clientID}
msg, err := protocol.NewMessage(protocol.MsgTypeAuthResp, resp)
if err != nil {
return err
}
return protocol.WriteMessage(conn, msg)
}
// sendProxyConfig 发送代理配置
func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyRule) error {
stream, err := session.Open()
if err != nil {
return err
}
defer stream.Close()
cfg := protocol.ProxyConfig{Rules: rules}
msg, err := protocol.NewMessage(protocol.MsgTypeProxyConfig, cfg)
if err != nil {
return err
}
return protocol.WriteMessage(stream, msg)
}
// registerClient 注册客户端
func (s *Server) registerClient(cs *ClientSession) {
s.mu.Lock()
defer s.mu.Unlock()
s.clients[cs.ID] = cs
}
// unregisterClient 注销客户端
func (s *Server) unregisterClient(cs *ClientSession) {
s.mu.Lock()
defer s.mu.Unlock()
cs.mu.Lock()
for port, ln := range cs.Listeners {
ln.Close()
s.portManager.Release(port)
}
for port, conn := range cs.UDPConns {
conn.Close()
s.portManager.Release(port)
}
cs.mu.Unlock()
delete(s.clients, cs.ID)
}
// startProxyListeners 启动代理监听
func (s *Server) startProxyListeners(cs *ClientSession) {
for _, rule := range cs.Rules {
ruleType := rule.Type
if ruleType == "" {
ruleType = "tcp"
}
// UDP 单独处理
if ruleType == "udp" {
s.startUDPListener(cs, rule)
continue
}
// TCP 类型
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
log.Printf("[Server] Port %d error: %v", rule.RemotePort, err)
continue
}
ln, err := net.Listen("tcp", fmt.Sprintf(":%d", rule.RemotePort))
if err != nil {
log.Printf("[Server] Listen %d error: %v", rule.RemotePort, err)
s.portManager.Release(rule.RemotePort)
continue
}
cs.mu.Lock()
cs.Listeners[rule.RemotePort] = ln
cs.mu.Unlock()
switch ruleType {
case "socks5":
log.Printf("[Server] SOCKS5 proxy %s on :%d", rule.Name, rule.RemotePort)
go s.acceptProxyServerConns(cs, ln, rule)
case "http", "https":
log.Printf("[Server] HTTP proxy %s on :%d", rule.Name, rule.RemotePort)
go s.acceptProxyServerConns(cs, ln, rule)
default:
log.Printf("[Server] TCP proxy %s: :%d -> %s:%d",
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
go s.acceptProxyConns(cs, ln, rule)
}
}
}
// acceptProxyConns 接受代理连接
func (s *Server) acceptProxyConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
for {
conn, err := ln.Accept()
if err != nil {
return
}
go s.handleProxyConn(cs, conn, rule)
}
}
// acceptProxyServerConns 接受 SOCKS5/HTTP 代理连接
func (s *Server) acceptProxyServerConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) {
dialer := proxy.NewTunnelDialer(cs.Session)
// 优先使用插件系统
if s.pluginRegistry != nil {
if handler, err := s.pluginRegistry.Get(rule.Type); err == nil {
handler.Init(rule.PluginConfig)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go handler.HandleConn(conn, dialer)
}
}
}
// 回退到内置 proxy 实现
proxyServer := proxy.NewServer(rule.Type, dialer)
for {
conn, err := ln.Accept()
if err != nil {
return
}
go proxyServer.HandleConn(conn)
}
}
// handleProxyConn 处理代理连接
func (s *Server) handleProxyConn(cs *ClientSession, conn net.Conn, rule protocol.ProxyRule) {
defer conn.Close()
stream, err := cs.Session.Open()
if err != nil {
log.Printf("[Server] Open stream error: %v", err)
return
}
defer stream.Close()
req := protocol.NewProxyRequest{RemotePort: rule.RemotePort}
msg, _ := protocol.NewMessage(protocol.MsgTypeNewProxy, req)
if err := protocol.WriteMessage(stream, msg); err != nil {
return
}
relay.Relay(conn, stream)
}
// heartbeatLoop 心跳检测循环
func (s *Server) heartbeatLoop(cs *ClientSession) {
ticker := time.NewTicker(time.Duration(s.heartbeat) * time.Second)
defer ticker.Stop()
timeout := time.Duration(s.hbTimeout) * time.Second
for {
select {
case <-ticker.C:
cs.mu.Lock()
if time.Since(cs.LastPing) > timeout {
cs.mu.Unlock()
log.Printf("[Server] Client %s heartbeat timeout", cs.ID)
cs.Session.Close()
return
}
cs.mu.Unlock()
// 发送心跳并等待响应
if s.sendHeartbeat(cs) {
cs.mu.Lock()
cs.LastPing = time.Now()
cs.mu.Unlock()
}
case <-cs.Session.CloseChan():
return
}
}
}
// sendHeartbeat 发送心跳并等待响应
func (s *Server) sendHeartbeat(cs *ClientSession) bool {
stream, err := cs.Session.Open()
if err != nil {
return false
}
defer stream.Close()
// 设置读写超时
stream.SetDeadline(time.Now().Add(heartbeatTimeout))
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeat}
if err := protocol.WriteMessage(stream, msg); err != nil {
return false
}
// 等待心跳响应
resp, err := protocol.ReadMessage(stream)
if err != nil {
return false
}
return resp.Type == protocol.MsgTypeHeartbeatAck
}
// GetClientStatus 获取客户端状态
func (s *Server) GetClientStatus(clientID string) (online bool, lastPing string) {
s.mu.RLock()
defer s.mu.RUnlock()
if cs, ok := s.clients[clientID]; ok {
cs.mu.Lock()
defer cs.mu.Unlock()
return true, cs.LastPing.Format(time.RFC3339)
}
return false, ""
}
// GetAllClientStatus 获取所有客户端状态
func (s *Server) GetAllClientStatus() map[string]struct {
Online bool
LastPing string
} {
// 先复制客户端引用,避免嵌套锁
s.mu.RLock()
clients := make([]*ClientSession, 0, len(s.clients))
for _, cs := range s.clients {
clients = append(clients, cs)
}
s.mu.RUnlock()
result := make(map[string]struct {
Online bool
LastPing string
})
for _, cs := range clients {
cs.mu.Lock()
result[cs.ID] = struct {
Online bool
LastPing string
}{
Online: true,
LastPing: cs.LastPing.Format(time.RFC3339),
}
cs.mu.Unlock()
}
return result
}
// ReloadConfig 重新加载配置
// 注意: 当前版本不支持热重载,需要重启服务
func (s *Server) ReloadConfig() error {
return fmt.Errorf("hot reload not supported, please restart the server")
}
// GetBindAddr 获取绑定地址
func (s *Server) GetBindAddr() string {
return s.bindAddr
}
// GetBindPort 获取绑定端口
func (s *Server) GetBindPort() int {
return s.bindPort
}
// PushConfigToClient 推送配置到客户端
func (s *Server) PushConfigToClient(clientID string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found", clientID)
}
rules, err := s.clientStore.GetClientRules(clientID)
if err != nil {
return err
}
return s.sendProxyConfig(cs.Session, rules)
}
// DisconnectClient 断开客户端连接
func (s *Server) DisconnectClient(clientID string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found", clientID)
}
return cs.Session.Close()
}
// GetPluginList 获取插件列表
func (s *Server) GetPluginList() []router.PluginInfo {
var result []router.PluginInfo
if s.pluginRegistry == nil {
return result
}
for _, info := range s.pluginRegistry.List() {
result = append(result, router.PluginInfo{
Name: info.Metadata.Name,
Version: info.Metadata.Version,
Type: string(info.Metadata.Type),
Description: info.Metadata.Description,
Source: string(info.Metadata.Source),
Enabled: info.Enabled,
})
}
return result
}
// EnablePlugin 启用插件
func (s *Server) EnablePlugin(name string) error {
if s.pluginRegistry == nil {
return fmt.Errorf("plugin registry not initialized")
}
return s.pluginRegistry.Enable(name)
}
// DisablePlugin 禁用插件
func (s *Server) DisablePlugin(name string) error {
if s.pluginRegistry == nil {
return fmt.Errorf("plugin registry not initialized")
}
return s.pluginRegistry.Disable(name)
}
// InstallPluginsToClient 安装插件到客户端
func (s *Server) InstallPluginsToClient(clientID string, plugins []string) error {
s.mu.RLock()
cs, ok := s.clients[clientID]
s.mu.RUnlock()
if !ok {
return fmt.Errorf("client %s not found", clientID)
}
return s.sendInstallPlugins(cs.Session, plugins)
}
// sendInstallPlugins 发送安装插件请求
func (s *Server) sendInstallPlugins(session *yamux.Session, plugins []string) error {
stream, err := session.Open()
if err != nil {
return err
}
defer stream.Close()
req := protocol.InstallPluginsRequest{Plugins: plugins}
msg, err := protocol.NewMessage(protocol.MsgTypeInstallPlugins, req)
if err != nil {
return err
}
return protocol.WriteMessage(stream, msg)
}
// startUDPListener 启动 UDP 监听
func (s *Server) startUDPListener(cs *ClientSession, rule protocol.ProxyRule) {
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
log.Printf("[Server] UDP port %d error: %v", rule.RemotePort, err)
return
}
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.RemotePort))
if err != nil {
log.Printf("[Server] UDP resolve error: %v", err)
s.portManager.Release(rule.RemotePort)
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
log.Printf("[Server] UDP listen %d error: %v", rule.RemotePort, err)
s.portManager.Release(rule.RemotePort)
return
}
cs.mu.Lock()
cs.UDPConns[rule.RemotePort] = conn
cs.mu.Unlock()
log.Printf("[Server] UDP proxy %s: :%d -> %s:%d",
rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort)
go s.handleUDPConn(cs, conn, rule)
}
// handleUDPConn 处理 UDP 连接
func (s *Server) handleUDPConn(cs *ClientSession, conn *net.UDPConn, rule protocol.ProxyRule) {
buf := make([]byte, udpBufferSize)
for {
n, clientAddr, err := conn.ReadFromUDP(buf)
if err != nil {
return
}
// 封装 UDP 数据包发送到客户端
packet := protocol.UDPPacket{
RemotePort: rule.RemotePort,
ClientAddr: clientAddr.String(),
Data: buf[:n],
}
go s.sendUDPPacket(cs, conn, clientAddr, packet)
}
}
// sendUDPPacket 发送 UDP 数据包到客户端
func (s *Server) sendUDPPacket(cs *ClientSession, conn *net.UDPConn, clientAddr *net.UDPAddr, packet protocol.UDPPacket) {
stream, err := cs.Session.Open()
if err != nil {
return
}
defer stream.Close()
msg, err := protocol.NewMessage(protocol.MsgTypeUDPData, packet)
if err != nil {
return
}
if err := protocol.WriteMessage(stream, msg); err != nil {
return
}
// 等待客户端响应
respMsg, err := protocol.ReadMessage(stream)
if err != nil {
return
}
if respMsg.Type == protocol.MsgTypeUDPData {
var respPacket protocol.UDPPacket
if err := respMsg.ParsePayload(&respPacket); err != nil {
return
}
conn.WriteToUDP(respPacket.Data, clientAddr)
}
}