package db import ( "database/sql" "encoding/json" "sync" "time" _ "modernc.org/sqlite" "github.com/gotunnel/pkg/protocol" ) // SQLiteStore SQLite 存储实现 type SQLiteStore struct { db *sql.DB mu sync.RWMutex } // NewSQLiteStore 创建 SQLite 存储 func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { db, err := sql.Open("sqlite", dbPath) if err != nil { return nil, err } s := &SQLiteStore{db: db} if err := s.init(); err != nil { db.Close() return nil, err } return s, nil } // init 初始化数据库表 func (s *SQLiteStore) init() error { // 创建客户端表 _, err := s.db.Exec(` CREATE TABLE IF NOT EXISTS clients ( id TEXT PRIMARY KEY, nickname TEXT NOT NULL DEFAULT '', rules TEXT NOT NULL DEFAULT '[]' ) `) if err != nil { return err } // 迁移:添加 nickname 列 s.db.Exec(`ALTER TABLE clients ADD COLUMN nickname TEXT NOT NULL DEFAULT ''`) // 创建流量统计表 _, err = s.db.Exec(` CREATE TABLE IF NOT EXISTS traffic_stats ( hour_ts INTEGER PRIMARY KEY, inbound INTEGER NOT NULL DEFAULT 0, outbound INTEGER NOT NULL DEFAULT 0 ) `) if err != nil { return err } // 创建总流量表 _, err = s.db.Exec(` CREATE TABLE IF NOT EXISTS traffic_total ( id INTEGER PRIMARY KEY CHECK (id = 1), inbound INTEGER NOT NULL DEFAULT 0, outbound INTEGER NOT NULL DEFAULT 0 ) `) if err != nil { return err } // 初始化总流量记录 s.db.Exec(`INSERT OR IGNORE INTO traffic_total (id, inbound, outbound) VALUES (1, 0, 0)`) // 创建安装token表 _, err = s.db.Exec(` CREATE TABLE IF NOT EXISTS install_tokens ( token TEXT PRIMARY KEY, client_id TEXT NOT NULL, created_at INTEGER NOT NULL, used INTEGER NOT NULL DEFAULT 0 ) `) if err != nil { return err } return nil } // Close 关闭数据库连接 func (s *SQLiteStore) Close() error { return s.db.Close() } // GetAllClients 获取所有客户端 func (s *SQLiteStore) GetAllClients() ([]Client, error) { s.mu.RLock() defer s.mu.RUnlock() rows, err := s.db.Query(`SELECT id, nickname, rules FROM clients`) if err != nil { return nil, err } defer rows.Close() var clients []Client for rows.Next() { var c Client var rulesJSON string if err := rows.Scan(&c.ID, &c.Nickname, &rulesJSON); err != nil { return nil, err } if err := json.Unmarshal([]byte(rulesJSON), &c.Rules); err != nil { c.Rules = []protocol.ProxyRule{} } clients = append(clients, c) } return clients, nil } // GetClient 获取单个客户端 func (s *SQLiteStore) GetClient(id string) (*Client, error) { s.mu.RLock() defer s.mu.RUnlock() var c Client var rulesJSON string err := s.db.QueryRow(`SELECT id, nickname, rules FROM clients WHERE id = ?`, id).Scan(&c.ID, &c.Nickname, &rulesJSON) if err != nil { return nil, err } if err := json.Unmarshal([]byte(rulesJSON), &c.Rules); err != nil { c.Rules = []protocol.ProxyRule{} } return &c, nil } // CreateClient 创建客户端 func (s *SQLiteStore) CreateClient(c *Client) error { s.mu.Lock() defer s.mu.Unlock() rulesJSON, err := json.Marshal(c.Rules) if err != nil { return err } _, err = s.db.Exec(`INSERT INTO clients (id, nickname, rules) VALUES (?, ?, ?)`, c.ID, c.Nickname, string(rulesJSON)) return err } // UpdateClient 更新客户端 func (s *SQLiteStore) UpdateClient(c *Client) error { s.mu.Lock() defer s.mu.Unlock() rulesJSON, err := json.Marshal(c.Rules) if err != nil { return err } _, err = s.db.Exec(`UPDATE clients SET nickname = ?, rules = ? WHERE id = ?`, c.Nickname, string(rulesJSON), c.ID) return err } // DeleteClient 删除客户端 func (s *SQLiteStore) DeleteClient(id string) error { s.mu.Lock() defer s.mu.Unlock() _, err := s.db.Exec(`DELETE FROM clients WHERE id = ?`, id) return err } // ClientExists 检查客户端是否存在 func (s *SQLiteStore) ClientExists(id string) (bool, error) { s.mu.RLock() defer s.mu.RUnlock() var count int err := s.db.QueryRow(`SELECT COUNT(*) FROM clients WHERE id = ?`, id).Scan(&count) return count > 0, err } // GetClientRules 获取客户端规则 func (s *SQLiteStore) GetClientRules(id string) ([]protocol.ProxyRule, error) { c, err := s.GetClient(id) if err != nil { return nil, err } return c.Rules, nil } // ========== 流量统计方法 ========== // getHourTimestamp 获取当前小时的时间戳 func getHourTimestamp() int64 { now := time.Now() return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()).Unix() } // AddTraffic 添加流量记录 func (s *SQLiteStore) AddTraffic(inbound, outbound int64) error { s.mu.Lock() defer s.mu.Unlock() hourTs := getHourTimestamp() // 更新小时统计 _, err := s.db.Exec(` INSERT INTO traffic_stats (hour_ts, inbound, outbound) VALUES (?, ?, ?) ON CONFLICT(hour_ts) DO UPDATE SET inbound = inbound + ?, outbound = outbound + ? `, hourTs, inbound, outbound, inbound, outbound) if err != nil { return err } // 更新总流量 _, err = s.db.Exec(` UPDATE traffic_total SET inbound = inbound + ?, outbound = outbound + ? WHERE id = 1 `, inbound, outbound) return err } // GetTotalTraffic 获取总流量 func (s *SQLiteStore) GetTotalTraffic() (inbound, outbound int64, err error) { s.mu.RLock() defer s.mu.RUnlock() err = s.db.QueryRow(`SELECT inbound, outbound FROM traffic_total WHERE id = 1`).Scan(&inbound, &outbound) return } // Get24HourTraffic 获取24小时流量 func (s *SQLiteStore) Get24HourTraffic() (inbound, outbound int64, err error) { s.mu.RLock() defer s.mu.RUnlock() cutoff := time.Now().Add(-24 * time.Hour).Unix() err = s.db.QueryRow(` SELECT COALESCE(SUM(inbound), 0), COALESCE(SUM(outbound), 0) FROM traffic_stats WHERE hour_ts >= ? `, cutoff).Scan(&inbound, &outbound) return } // GetHourlyTraffic 获取每小时流量记录(始终返回完整的 hours 小时数据) func (s *SQLiteStore) GetHourlyTraffic(hours int) ([]TrafficRecord, error) { s.mu.RLock() defer s.mu.RUnlock() // 计算当前小时的起始时间戳 now := time.Now() currentHour := time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) // 查询数据库中的记录 cutoff := currentHour.Add(-time.Duration(hours-1) * time.Hour).Unix() rows, err := s.db.Query(` SELECT hour_ts, inbound, outbound FROM traffic_stats WHERE hour_ts >= ? ORDER BY hour_ts ASC `, cutoff) if err != nil { return nil, err } defer rows.Close() // 将数据库记录放入 map 以便快速查找 dbRecords := make(map[int64]TrafficRecord) for rows.Next() { var r TrafficRecord if err := rows.Scan(&r.Timestamp, &r.Inbound, &r.Outbound); err != nil { return nil, err } dbRecords[r.Timestamp] = r } // 生成完整的 hours 小时数据 records := make([]TrafficRecord, hours) for i := 0; i < hours; i++ { ts := currentHour.Add(-time.Duration(hours-1-i) * time.Hour).Unix() if r, ok := dbRecords[ts]; ok { records[i] = r } else { records[i] = TrafficRecord{Timestamp: ts, Inbound: 0, Outbound: 0} } } return records, nil }