update
All checks were successful
Build Multi-Platform Binaries / build-frontend (push) Successful in 30s
Build Multi-Platform Binaries / build-binaries (amd64, darwin, server, false) (push) Successful in 58s
Build Multi-Platform Binaries / build-binaries (amd64, linux, client, true) (push) Successful in 48s
Build Multi-Platform Binaries / build-binaries (amd64, linux, server, true) (push) Successful in 1m23s
Build Multi-Platform Binaries / build-binaries (amd64, windows, client, true) (push) Successful in 56s
Build Multi-Platform Binaries / build-binaries (amd64, windows, server, true) (push) Successful in 58s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, client, true) (push) Successful in 52s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, server, true) (push) Successful in 1m42s
Build Multi-Platform Binaries / build-binaries (arm64, darwin, server, false) (push) Successful in 1m19s
Build Multi-Platform Binaries / build-binaries (arm64, linux, client, true) (push) Successful in 54s
Build Multi-Platform Binaries / build-binaries (arm64, linux, server, true) (push) Successful in 2m3s
Build Multi-Platform Binaries / build-binaries (arm64, windows, server, false) (push) Successful in 1m1s

This commit is contained in:
Flik
2025-12-29 23:08:15 +08:00
parent d4984c8d78
commit 4d2a2a7117
10 changed files with 1429 additions and 35 deletions

View File

@@ -108,10 +108,26 @@ func setDefaults(cfg *ServerConfig) {
// generateToken 生成随机 token
func generateToken(length int) string {
bytes := make([]byte, length/2)
rand.Read(bytes)
n, err := rand.Read(bytes)
if err != nil || n != len(bytes) {
// 安全关键:随机数生成失败时 panic
panic("crypto/rand failed: unable to generate secure token")
}
return hex.EncodeToString(bytes)
}
// GenerateWebCredentials 生成 Web 控制台凭据
func GenerateWebCredentials(cfg *ServerConfig) bool {
if cfg.Web.Username == "" {
cfg.Web.Username = "admin"
}
if cfg.Web.Password == "" {
cfg.Web.Password = generateToken(16)
return true // 表示生成了新密码
}
return false
}
// SaveServerConfig 保存服务端配置
func SaveServerConfig(path string, cfg *ServerConfig) error {
data, err := yaml.Marshal(cfg)

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"github.com/gotunnel/pkg/auth"
"github.com/gotunnel/pkg/security"
)
// AuthHandler 认证处理器
@@ -51,6 +52,7 @@ func (h *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
passMatch := subtle.ConstantTimeCompare([]byte(req.Password), []byte(h.password)) == 1
if !userMatch || !passMatch {
security.LogWebLogin(r.RemoteAddr, req.Username, false)
http.Error(w, `{"error":"invalid credentials"}`, http.StatusUnauthorized)
return
}
@@ -62,6 +64,7 @@ func (h *AuthHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
return
}
security.LogWebLogin(r.RemoteAddr, req.Username, true)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]string{
"token": token,
@@ -75,6 +78,31 @@ func (h *AuthHandler) handleCheck(w http.ResponseWriter, r *http.Request) {
return
}
// 从 Authorization header 获取 token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, `{"error":"missing authorization header"}`, http.StatusUnauthorized)
return
}
// 解析 Bearer token
const prefix = "Bearer "
if len(authHeader) < len(prefix) || authHeader[:len(prefix)] != prefix {
http.Error(w, `{"error":"invalid authorization format"}`, http.StatusUnauthorized)
return
}
tokenStr := authHeader[len(prefix):]
// 验证 token
claims, err := h.jwtAuth.ValidateToken(tokenStr)
if err != nil {
http.Error(w, `{"error":"invalid token"}`, http.StatusUnauthorized)
return
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(map[string]bool{"valid": true})
json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": claims.Username,
})
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"log"
"net"
"regexp"
"sync"
"time"
@@ -16,6 +17,7 @@ import (
"github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/proxy"
"github.com/gotunnel/pkg/relay"
"github.com/gotunnel/pkg/security"
"github.com/gotunnel/pkg/utils"
"github.com/hashicorp/yamux"
)
@@ -25,8 +27,17 @@ const (
authTimeout = 10 * time.Second
heartbeatTimeout = 10 * time.Second
udpBufferSize = 65535
maxConnections = 10000 // 最大连接数
)
// 客户端 ID 验证正则
var clientIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,64}$`)
// isValidClientID 验证客户端 ID 格式
func isValidClientID(id string) bool {
return clientIDRegex.MatchString(id)
}
// generateClientID 生成随机客户端 ID
func generateClientID() string {
bytes := make([]byte, 8)
@@ -48,6 +59,11 @@ type Server struct {
tlsConfig *tls.Config
pluginRegistry *plugin.Registry
jsPlugins []JSPluginEntry // 配置的 JS 插件
connSem chan struct{} // 连接数信号量
activeConns int64 // 当前活跃连接数
listener net.Listener // 主监听器
shutdown chan struct{} // 关闭信号
wg sync.WaitGroup // 等待所有连接关闭
}
// JSPluginEntry JS 插件条目
@@ -83,6 +99,8 @@ func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, h
hbTimeout: hbTimeout,
portManager: utils.NewPortManager(),
clients: make(map[string]*ClientSession),
connSem: make(chan struct{}, maxConnections),
shutdown: make(chan struct{}),
}
}
@@ -91,6 +109,39 @@ func (s *Server) SetTLSConfig(config *tls.Config) {
s.tlsConfig = config
}
// Shutdown 优雅关闭服务端
func (s *Server) Shutdown(timeout time.Duration) error {
log.Printf("[Server] Initiating graceful shutdown...")
close(s.shutdown)
if s.listener != nil {
s.listener.Close()
}
// 关闭所有客户端会话
s.mu.Lock()
for _, cs := range s.clients {
cs.Session.Close()
}
s.mu.Unlock()
// 等待所有连接关闭,带超时
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
log.Printf("[Server] All connections closed gracefully")
return nil
case <-time.After(timeout):
log.Printf("[Server] Shutdown timeout, forcing close")
return fmt.Errorf("shutdown timeout")
}
}
// SetPluginRegistry 设置插件注册表
func (s *Server) SetPluginRegistry(registry *plugin.Registry) {
s.pluginRegistry = registry
@@ -122,20 +173,49 @@ func (s *Server) Run() error {
}
log.Printf("[Server] Listening on %s (no TLS)", addr)
}
defer ln.Close()
s.listener = ln
for {
select {
case <-s.shutdown:
log.Printf("[Server] Shutdown signal received, stopping accept loop")
ln.Close()
return nil
default:
}
conn, err := ln.Accept()
if err != nil {
log.Printf("[Server] Accept error: %v", err)
continue
select {
case <-s.shutdown:
return nil
default:
log.Printf("[Server] Accept error: %v", err)
continue
}
}
go s.handleConnection(conn)
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(conn)
}()
}
}
// handleConnection 处理客户端连接
func (s *Server) handleConnection(conn net.Conn) {
clientIP := conn.RemoteAddr().String()
// 连接数限制检查
select {
case s.connSem <- struct{}{}:
defer func() { <-s.connSem }()
default:
security.LogConnRejected(clientIP, "max connections reached")
conn.Close()
return
}
defer conn.Close()
conn.SetReadDeadline(time.Now().Add(authTimeout))
@@ -158,6 +238,7 @@ func (s *Server) handleConnection(conn net.Conn) {
}
if authReq.Token != s.token {
security.LogInvalidToken(clientIP)
s.sendAuthResponse(conn, false, "invalid token", "")
return
}
@@ -166,6 +247,10 @@ func (s *Server) handleConnection(conn net.Conn) {
clientID := authReq.ClientID
if clientID == "" {
clientID = generateClientID()
} else if !isValidClientID(clientID) {
security.LogInvalidClientID(clientIP, clientID)
s.sendAuthResponse(conn, false, "invalid client id format", "")
return
}
// 检查客户端是否存在,不存在则自动创建
@@ -191,7 +276,7 @@ func (s *Server) handleConnection(conn net.Conn) {
return
}
log.Printf("[Server] Client %s authenticated", clientID)
security.LogAuthSuccess(clientIP, clientID)
s.setupClientSession(conn, clientID, rules)
}