From e40d079f7aa5ecd9aebd1cb141e7528d2ad0c93c Mon Sep 17 00:00:00 2001 From: Flik Date: Thu, 29 Jan 2026 15:38:27 +0800 Subject: [PATCH] feat(server): add traffic storage and statistics tracking for improved traffic management --- cmd/server/main.go | 1 + internal/server/db/sqlite.go | 28 +++++++++++++++++++++++----- internal/server/tunnel/server.go | 22 +++++++++++++++++++--- internal/server/tunnel/websocket.go | 2 +- 4 files changed, 44 insertions(+), 9 deletions(-) diff --git a/cmd/server/main.go b/cmd/server/main.go index 4abc304..46b3243 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -84,6 +84,7 @@ func main() { registry := plugin.NewRegistry() server.SetPluginRegistry(registry) server.SetJSPluginStore(clientStore) // 设置 JS 插件存储,用于客户端重连时恢复插件 + server.SetTrafficStore(clientStore) // 设置流量存储,用于记录流量统计 // 启动 Web 控制台 if cfg.Server.Web.Enabled { diff --git a/internal/server/db/sqlite.go b/internal/server/db/sqlite.go index eabe82c..b5d3e86 100644 --- a/internal/server/db/sqlite.go +++ b/internal/server/db/sqlite.go @@ -397,14 +397,19 @@ func (s *SQLiteStore) Get24HourTraffic() (inbound, outbound int64, err error) { return } -// GetHourlyTraffic 获取每小时流量记录 +// GetHourlyTraffic 获取每小时流量记录(始终返回完整的 hours 小时数据) func (s *SQLiteStore) GetHourlyTraffic(hours int) ([]TrafficRecord, error) { s.mu.RLock() defer s.mu.RUnlock() - cutoff := time.Now().Add(-time.Duration(hours) * time.Hour).Unix() + // 计算当前小时的起始时间戳 + now := time.Now() + currentHour := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) + + // 查询数据库中的记录 + cutoff := currentHour.Add(-time.Duration(hours-1) * time.Hour).Unix() rows, err := s.db.Query(` - SELECT hour_ts, inbound, outbound FROM traffic_stats + SELECT hour_ts, inbound, outbound FROM traffic_stats WHERE hour_ts >= ? ORDER BY hour_ts ASC `, cutoff) if err != nil { @@ -412,13 +417,26 @@ func (s *SQLiteStore) GetHourlyTraffic(hours int) ([]TrafficRecord, error) { } defer rows.Close() - var records []TrafficRecord + // 将数据库记录放入 map 以便快速查找 + dbRecords := make(map[int64]TrafficRecord) for rows.Next() { var r TrafficRecord if err := rows.Scan(&r.Timestamp, &r.Inbound, &r.Outbound); err != nil { return nil, err } - records = append(records, r) + dbRecords[r.Timestamp] = r } + + // 生成完整的 hours 小时数据 + records := make([]TrafficRecord, hours) + for i := 0; i < hours; i++ { + ts := currentHour.Add(-time.Duration(hours-1-i) * time.Hour).Unix() + if r, ok := dbRecords[ts]; ok { + records[i] = r + } else { + records[i] = TrafficRecord{Timestamp: ts, Inbound: 0, Outbound: 0} + } + } + return records, nil } diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 5d1692c..3e58130 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -50,7 +50,8 @@ func generateClientID() string { // Server 隧道服务端 type Server struct { clientStore db.ClientStore - jsPluginStore db.JSPluginStore // JS 插件存储 + jsPluginStore db.JSPluginStore // JS 插件存储 + trafficStore db.TrafficStore // 流量存储 bindAddr string bindPort int token string @@ -161,6 +162,11 @@ func (s *Server) SetJSPluginStore(store db.JSPluginStore) { s.jsPluginStore = store } +// SetTrafficStore 设置流量存储 +func (s *Server) SetTrafficStore(store db.TrafficStore) { + s.trafficStore = store +} + // LoadJSPlugins 加载 JS 插件配置 func (s *Server) LoadJSPlugins(plugins []JSPluginEntry) { s.jsPlugins = plugins @@ -536,7 +542,7 @@ func (s *Server) handleProxyConn(cs *ClientSession, conn net.Conn, rule protocol return } - relay.Relay(conn, stream) + relay.RelayWithStats(conn, stream, s.recordTraffic) } // heartbeatLoop 心跳检测循环 @@ -1224,7 +1230,7 @@ func (s *Server) handleClientPluginConn(cs *ClientSession, conn net.Conn, rule p } } - relay.Relay(conn, stream) + relay.RelayWithStats(conn, stream, s.recordTraffic) } // checkHTTPBasicAuth 检查 HTTP Basic Auth @@ -1907,6 +1913,16 @@ func (s *Server) StopClientLogStream(sessionID string) { s.logSessions.RemoveSession(sessionID) } +// recordTraffic 记录流量统计 +func (s *Server) recordTraffic(inbound, outbound int64) { + if s.trafficStore == nil { + return + } + if err := s.trafficStore.AddTraffic(inbound, outbound); err != nil { + log.Printf("[Server] Record traffic error: %v", err) + } +} + // boolPtr 返回 bool 值的指针 func boolPtr(b bool) *bool { return &b diff --git a/internal/server/tunnel/websocket.go b/internal/server/tunnel/websocket.go index 77ac48c..2d834b0 100644 --- a/internal/server/tunnel/websocket.go +++ b/internal/server/tunnel/websocket.go @@ -142,5 +142,5 @@ func (s *Server) handleWebsocketProxyConn(cs *ClientSession, conn net.Conn, rule return } - relay.Relay(conn, stream) + relay.RelayWithStats(conn, stream, s.recordTraffic) }