This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
@@ -9,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gotunnel/internal/server/db"
|
||||
"github.com/gotunnel/internal/server/router"
|
||||
"github.com/gotunnel/pkg/plugin"
|
||||
"github.com/gotunnel/pkg/protocol"
|
||||
"github.com/gotunnel/pkg/proxy"
|
||||
@@ -19,11 +22,18 @@ import (
|
||||
|
||||
// 服务端常量
|
||||
const (
|
||||
authTimeout = 10 * time.Second
|
||||
heartbeatTimeout = 10 * time.Second
|
||||
udpBufferSize = 65535
|
||||
authTimeout = 10 * time.Second
|
||||
heartbeatTimeout = 10 * time.Second
|
||||
udpBufferSize = 65535
|
||||
)
|
||||
|
||||
// generateClientID 生成随机客户端 ID
|
||||
func generateClientID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// Server 隧道服务端
|
||||
type Server struct {
|
||||
clientStore db.ClientStore
|
||||
@@ -130,24 +140,44 @@ func (s *Server) handleConnection(conn net.Conn) {
|
||||
}
|
||||
|
||||
if authReq.Token != s.token {
|
||||
s.sendAuthResponse(conn, false, "invalid token")
|
||||
s.sendAuthResponse(conn, false, "invalid token", "")
|
||||
return
|
||||
}
|
||||
|
||||
rules, err := s.clientStore.GetClientRules(authReq.ClientID)
|
||||
if err != nil || rules == nil {
|
||||
s.sendAuthResponse(conn, false, "client not configured")
|
||||
// 如果客户端没有提供 ID,则生成一个新的
|
||||
clientID := authReq.ClientID
|
||||
if clientID == "" {
|
||||
clientID = generateClientID()
|
||||
// 创建新客户端记录
|
||||
newClient := &db.Client{ID: clientID, Rules: []protocol.ProxyRule{}}
|
||||
if err := s.clientStore.CreateClient(newClient); err != nil {
|
||||
log.Printf("[Server] Create client error: %v", err)
|
||||
s.sendAuthResponse(conn, false, "failed to create client", "")
|
||||
return
|
||||
}
|
||||
log.Printf("[Server] New client registered: %s", clientID)
|
||||
}
|
||||
|
||||
// 检查客户端是否存在
|
||||
exists, err := s.clientStore.ClientExists(clientID)
|
||||
if err != nil || !exists {
|
||||
s.sendAuthResponse(conn, false, "client not found", "")
|
||||
return
|
||||
}
|
||||
|
||||
rules, _ := s.clientStore.GetClientRules(clientID)
|
||||
if rules == nil {
|
||||
rules = []protocol.ProxyRule{}
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if err := s.sendAuthResponse(conn, true, "ok"); err != nil {
|
||||
if err := s.sendAuthResponse(conn, true, "ok", clientID); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[Server] Client %s authenticated", authReq.ClientID)
|
||||
s.setupClientSession(conn, authReq.ClientID, rules)
|
||||
log.Printf("[Server] Client %s authenticated", clientID)
|
||||
s.setupClientSession(conn, clientID, rules)
|
||||
}
|
||||
|
||||
// setupClientSession 建立客户端会话
|
||||
@@ -183,8 +213,8 @@ func (s *Server) setupClientSession(conn net.Conn, clientID string, rules []prot
|
||||
}
|
||||
|
||||
// sendAuthResponse 发送认证响应
|
||||
func (s *Server) sendAuthResponse(conn net.Conn, success bool, message string) error {
|
||||
resp := protocol.AuthResponse{Success: success, Message: message}
|
||||
func (s *Server) sendAuthResponse(conn net.Conn, success bool, message, clientID string) error {
|
||||
resp := protocol.AuthResponse{Success: success, Message: message, ClientID: clientID}
|
||||
msg, err := protocol.NewMessage(protocol.MsgTypeAuthResp, resp)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -458,6 +488,103 @@ func (s *Server) GetBindPort() int {
|
||||
return s.bindPort
|
||||
}
|
||||
|
||||
// PushConfigToClient 推送配置到客户端
|
||||
func (s *Server) PushConfigToClient(clientID string) error {
|
||||
s.mu.RLock()
|
||||
cs, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("client %s not found", clientID)
|
||||
}
|
||||
|
||||
rules, err := s.clientStore.GetClientRules(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.sendProxyConfig(cs.Session, rules)
|
||||
}
|
||||
|
||||
// DisconnectClient 断开客户端连接
|
||||
func (s *Server) DisconnectClient(clientID string) error {
|
||||
s.mu.RLock()
|
||||
cs, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("client %s not found", clientID)
|
||||
}
|
||||
|
||||
return cs.Session.Close()
|
||||
}
|
||||
|
||||
// GetPluginList 获取插件列表
|
||||
func (s *Server) GetPluginList() []router.PluginInfo {
|
||||
var result []router.PluginInfo
|
||||
|
||||
if s.pluginRegistry == nil {
|
||||
return result
|
||||
}
|
||||
|
||||
for _, info := range s.pluginRegistry.List() {
|
||||
result = append(result, router.PluginInfo{
|
||||
Name: info.Metadata.Name,
|
||||
Version: info.Metadata.Version,
|
||||
Type: string(info.Metadata.Type),
|
||||
Description: info.Metadata.Description,
|
||||
Source: string(info.Metadata.Source),
|
||||
Enabled: info.Enabled,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// EnablePlugin 启用插件
|
||||
func (s *Server) EnablePlugin(name string) error {
|
||||
if s.pluginRegistry == nil {
|
||||
return fmt.Errorf("plugin registry not initialized")
|
||||
}
|
||||
return s.pluginRegistry.Enable(name)
|
||||
}
|
||||
|
||||
// DisablePlugin 禁用插件
|
||||
func (s *Server) DisablePlugin(name string) error {
|
||||
if s.pluginRegistry == nil {
|
||||
return fmt.Errorf("plugin registry not initialized")
|
||||
}
|
||||
return s.pluginRegistry.Disable(name)
|
||||
}
|
||||
|
||||
// InstallPluginsToClient 安装插件到客户端
|
||||
func (s *Server) InstallPluginsToClient(clientID string, plugins []string) error {
|
||||
s.mu.RLock()
|
||||
cs, ok := s.clients[clientID]
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !ok {
|
||||
return fmt.Errorf("client %s not found", clientID)
|
||||
}
|
||||
|
||||
return s.sendInstallPlugins(cs.Session, plugins)
|
||||
}
|
||||
|
||||
// sendInstallPlugins 发送安装插件请求
|
||||
func (s *Server) sendInstallPlugins(session *yamux.Session, plugins []string) error {
|
||||
stream, err := session.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
req := protocol.InstallPluginsRequest{Plugins: plugins}
|
||||
msg, err := protocol.NewMessage(protocol.MsgTypeInstallPlugins, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return protocol.WriteMessage(stream, msg)
|
||||
}
|
||||
|
||||
// startUDPListener 启动 UDP 监听
|
||||
func (s *Server) startUDPListener(cs *ClientSession, rule protocol.ProxyRule) {
|
||||
if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user