diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index c63c070..f3437a8 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -61,7 +61,7 @@ func NewClient(serverAddr, token, id string) *Client { // 确保数据目录存在 if err := os.MkdirAll(dataDir, 0755); err != nil { - log.Printf("[Client] Failed to create data dir: %v", err) + log.Printf("Failed to create data dir: %v", err) } if id == "" { @@ -71,7 +71,7 @@ func NewClient(serverAddr, token, id string) *Client { // 初始化日志收集器 logger, err := NewLogger(dataDir) if err != nil { - log.Printf("[Client] Failed to initialize logger: %v", err) + log.Printf("Failed to initialize logger: %v", err) } return &Client{ @@ -111,7 +111,7 @@ func loadClientID(dataDir string) string { // saveClientID 保存客户端 ID 到本地文件 func saveClientID(dataDir, id string) { if err := os.WriteFile(getIDFilePath(dataDir), []byte(id), 0600); err != nil { - log.Printf("[Client] Failed to save client ID: %v", err) + log.Printf("Failed to save client ID: %v", err) } } @@ -151,14 +151,14 @@ func (c *Client) logWarnf(format string, args ...interface{}) { func (c *Client) Run() error { for { if err := c.connect(); err != nil { - c.logErrorf("[Client] Connect error: %v", err) - c.logf("[Client] Reconnecting in %v...", reconnectDelay) + c.logErrorf("Connect error: %v", err) + c.logf("Reconnecting in %v...", reconnectDelay) time.Sleep(reconnectDelay) continue } c.handleSession() - c.logWarnf("[Client] Disconnected, reconnecting...") + c.logWarnf("Disconnected, reconnecting...") time.Sleep(disconnectDelay) } } @@ -211,10 +211,10 @@ func (c *Client) connect() error { if authResp.ClientID != "" && authResp.ClientID != c.ID { c.ID = authResp.ClientID saveClientID(c.DataDir, c.ID) - c.logf("[Client] New ID assigned and saved: %s", c.ID) + c.logf("New ID assigned and saved: %s", c.ID) } - c.logf("[Client] Authenticated as %s", c.ID) + c.logf("Authenticated as %s", c.ID) session, err := yamux.Client(conn, nil) if err != nil { @@ -252,8 +252,7 @@ func (c *Client) handleStream(stream net.Conn) { switch msg.Type { case protocol.MsgTypeProxyConfig: - defer stream.Close() - c.handleProxyConfig(msg) + c.handleProxyConfig(stream, msg) case protocol.MsgTypeNewProxy: defer stream.Close() c.handleNewProxy(stream, msg) @@ -295,10 +294,12 @@ func (c *Client) handleStream(stream net.Conn) { } // handleProxyConfig 处理代理配置 -func (c *Client) handleProxyConfig(msg *protocol.Message) { +func (c *Client) handleProxyConfig(stream net.Conn, msg *protocol.Message) { + defer stream.Close() + var cfg protocol.ProxyConfig if err := msg.ParsePayload(&cfg); err != nil { - c.logErrorf("[Client] Parse proxy config error: %v", err) + c.logErrorf("Parse proxy config error: %v", err) return } @@ -306,17 +307,21 @@ func (c *Client) handleProxyConfig(msg *protocol.Message) { c.rules = cfg.Rules c.mu.Unlock() - c.logf("[Client] Received %d proxy rules", len(cfg.Rules)) + c.logf("Received %d proxy rules", len(cfg.Rules)) for _, r := range cfg.Rules { - c.logf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort) + c.logf(" %s: %s:%d", r.Name, r.LocalIP, r.LocalPort) } + + // 发送配置确认 + ack := &protocol.Message{Type: protocol.MsgTypeProxyReady} + protocol.WriteMessage(stream, ack) } // handleNewProxy 处理新代理请求 func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { var req protocol.NewProxyRequest if err := msg.ParsePayload(&req); err != nil { - c.logErrorf("[Client] Parse new proxy request error: %v", err) + c.logErrorf("Parse new proxy request error: %v", err) return } @@ -331,14 +336,14 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { c.mu.RUnlock() if rule == nil { - c.logWarnf("[Client] Unknown port %d", req.RemotePort) + c.logWarnf("Unknown port %d", req.RemotePort) return } localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort) localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout) if err != nil { - c.logErrorf("[Client] Connect %s error: %v", localAddr, err) + c.logErrorf("Connect %s error: %v", localAddr, err) return } @@ -448,24 +453,24 @@ func (c *Client) findRuleByPort(port int) *protocol.ProxyRule { func (c *Client) handlePluginConfig(msg *protocol.Message) { var cfg protocol.PluginConfigSync if err := msg.ParsePayload(&cfg); err != nil { - c.logErrorf("[Client] Parse plugin config error: %v", err) + c.logErrorf("Parse plugin config error: %v", err) return } - c.logf("[Client] Received config for plugin: %s", cfg.PluginName) + c.logf("Received config for plugin: %s", cfg.PluginName) // 应用配置到插件 if c.pluginRegistry != nil { handler, err := c.pluginRegistry.GetClient(cfg.PluginName) if err != nil { - c.logWarnf("[Client] Plugin %s not found: %v", cfg.PluginName, err) + c.logWarnf("Plugin %s not found: %v", cfg.PluginName, err) return } if err := handler.Init(cfg.Config); err != nil { - c.logErrorf("[Client] Plugin %s init error: %v", cfg.PluginName, err) + c.logErrorf("Plugin %s init error: %v", cfg.PluginName, err) return } - c.logf("[Client] Plugin %s config applied", cfg.PluginName) + c.logf("Plugin %s config applied", cfg.PluginName) } } @@ -479,7 +484,7 @@ func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) return } - c.logf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName) + c.logf("Starting plugin %s for rule %s", req.PluginName, req.RuleName) // 获取插件 if c.pluginRegistry == nil { @@ -511,7 +516,7 @@ func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) c.runningPlugins[key] = handler c.pluginMu.Unlock() - c.logf("[Client] Plugin %s started at %s", req.PluginName, localAddr) + c.logf("Plugin %s started at %s", req.PluginName, localAddr) c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, localAddr, "") } @@ -553,7 +558,7 @@ func (c *Client) handleClientPluginConn(stream net.Conn, msg *protocol.Message) c.pluginMu.RUnlock() if !ok { - c.logWarnf("[Client] Plugin %s (ID: %s) not running", req.PluginName, req.PluginID) + c.logWarnf("Plugin %s (ID: %s) not running", req.PluginName, req.PluginID) stream.Close() return } @@ -572,7 +577,7 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { return } - c.logf("[Client] Installing JS plugin: %s (ID: %s)", req.PluginName, req.PluginID) + c.logf("Installing JS plugin: %s (ID: %s)", req.PluginName, req.PluginID) // 使用 PluginID 作为 key(如果有),否则回退到 pluginName:ruleName key := req.PluginID @@ -583,9 +588,9 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { // 如果插件已经在运行,先停止它 c.pluginMu.Lock() if existingHandler, ok := c.runningPlugins[key]; ok { - c.logf("[Client] Stopping existing plugin %s before reinstall", key) + c.logf("Stopping existing plugin %s before reinstall", key) if err := existingHandler.Stop(); err != nil { - c.logErrorf("[Client] Stop existing plugin error: %v", err) + c.logErrorf("Stop existing plugin error: %v", err) } delete(c.runningPlugins, key) } @@ -593,11 +598,11 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { // 验证官方签名 if err := c.verifyJSPluginSignature(req.PluginName, req.Source, req.Signature); err != nil { - c.logErrorf("[Client] JS plugin %s signature verification failed: %v", req.PluginName, err) + c.logErrorf("JS plugin %s signature verification failed: %v", req.PluginName, err) c.sendJSPluginResult(stream, req.PluginName, false, "signature verification failed: "+err.Error()) return } - c.logf("[Client] JS plugin %s signature verified", req.PluginName) + c.logf("JS plugin %s signature verified", req.PluginName) // 创建 JS 插件 jsPlugin, err := script.NewJSPlugin(req.PluginName, req.Source) @@ -611,7 +616,7 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { c.pluginRegistry.RegisterClient(jsPlugin) } - c.logf("[Client] JS plugin %s installed", req.PluginName) + c.logf("JS plugin %s installed", req.PluginName) // 保存版本信息(防止降级攻击) if c.versionStore != nil { @@ -644,13 +649,13 @@ func (c *Client) sendJSPluginResult(stream net.Conn, name string, success bool, // startJSPlugin 启动 JS 插件 func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPluginInstallRequest) { if err := handler.Init(req.Config); err != nil { - c.logErrorf("[Client] JS plugin %s init error: %v", req.PluginName, err) + c.logErrorf("JS plugin %s init error: %v", req.PluginName, err) return } localAddr, err := handler.Start() if err != nil { - c.logErrorf("[Client] JS plugin %s start error: %v", req.PluginName, err) + c.logErrorf("JS plugin %s start error: %v", req.PluginName, err) return } @@ -663,7 +668,7 @@ func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPlugi c.runningPlugins[key] = handler c.pluginMu.Unlock() - c.logf("[Client] JS plugin %s (ID: %s) started at %s", req.PluginName, req.PluginID, localAddr) + c.logf("JS plugin %s (ID: %s) started at %s", req.PluginName, req.PluginID, localAddr) } // verifyJSPluginSignature 验证 JS 插件签名 @@ -741,13 +746,13 @@ func (c *Client) handleClientPluginStop(stream net.Conn, msg *protocol.Message) if ok { if err := handler.Stop(); err != nil { - c.logErrorf("[Client] Plugin %s stop error: %v", key, err) + c.logErrorf("Plugin %s stop error: %v", key, err) } delete(c.runningPlugins, key) } c.pluginMu.Unlock() - c.logf("[Client] Plugin %s stopped", key) + c.logf("Plugin %s stopped", key) c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "") } @@ -758,7 +763,7 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) { var req protocol.ClientRestartRequest msg.ParsePayload(&req) - c.logf("[Client] Restart requested: %s", req.Reason) + c.logf("Restart requested: %s", req.Reason) // 发送响应 resp := protocol.ClientRestartResponse{ @@ -771,7 +776,7 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) { // 停止所有运行中的插件 c.pluginMu.Lock() for key, handler := range c.runningPlugins { - c.logf("[Client] Stopping plugin %s for restart", key) + c.logf("Stopping plugin %s for restart", key) handler.Stop() } c.runningPlugins = make(map[string]plugin.ClientPlugin) @@ -813,7 +818,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message } c.pluginMu.RUnlock() - c.logf("[Client] Config update for plugin %s", key) + c.logf("Config update for plugin %s", key) if !ok { c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, false, "plugin not running") @@ -824,7 +829,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message // 停止并重启插件 c.pluginMu.Lock() if err := handler.Stop(); err != nil { - c.logErrorf("[Client] Plugin %s stop error: %v", key, err) + c.logErrorf("Plugin %s stop error: %v", key, err) } delete(c.runningPlugins, key) c.pluginMu.Unlock() @@ -845,7 +850,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message c.runningPlugins[key] = handler c.pluginMu.Unlock() - c.logf("[Client] Plugin %s restarted at %s with new config", key, localAddr) + c.logf("Plugin %s restarted at %s with new config", key, localAddr) } c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, true, "") @@ -869,17 +874,17 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) { var req protocol.UpdateDownloadRequest if err := msg.ParsePayload(&req); err != nil { - c.logErrorf("[Client] Parse update request error: %v", err) + c.logErrorf("Parse update request error: %v", err) c.sendUpdateResult(stream, false, "invalid request") return } - c.logf("[Client] Update download requested: %s", req.DownloadURL) + c.logf("Update download requested: %s", req.DownloadURL) // 异步执行更新 go func() { if err := c.performSelfUpdate(req.DownloadURL); err != nil { - c.logErrorf("[Client] Update failed: %v", err) + c.logErrorf("Update failed: %v", err) } }() @@ -898,7 +903,7 @@ func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) // performSelfUpdate 执行自更新 func (c *Client) performSelfUpdate(downloadURL string) error { - c.logf("[Client] Starting self-update from: %s", downloadURL) + c.logf("Starting self-update from: %s", downloadURL) // 使用共享的下载和解压逻辑 binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client") @@ -945,7 +950,7 @@ func (c *Client) performSelfUpdate(downloadURL string) error { // 删除备份 os.Remove(backupPath) - c.logf("[Client] Update completed, restarting...") + c.logf("Update completed, restarting...") // 重启进程 restartClientProcess(currentPath, c.ServerAddr, c.Token, c.ID) @@ -956,7 +961,7 @@ func (c *Client) performSelfUpdate(downloadURL string) error { func (c *Client) stopAllPlugins() { c.pluginMu.Lock() for key, handler := range c.runningPlugins { - c.logf("[Client] Stopping plugin %s for update", key) + c.logf("Stopping plugin %s for update", key) handler.Stop() } c.runningPlugins = make(map[string]plugin.ClientPlugin) @@ -1122,7 +1127,7 @@ func (c *Client) handlePluginAPIRequest(stream net.Conn, msg *protocol.Message) return } - c.logf("[Client] Plugin API request: %s %s for plugin %s (ID: %s)", req.Method, req.Path, req.PluginName, req.PluginID) + c.logf("Plugin API request: %s %s for plugin %s (ID: %s)", req.Method, req.Path, req.PluginName, req.PluginID) // 查找运行中的插件 c.pluginMu.RLock() @@ -1185,7 +1190,7 @@ func (c *Client) handleSystemStatsRequest(stream net.Conn, msg *protocol.Message stats, err := utils.GetSystemStats() if err != nil { - log.Printf("[Client] Failed to get system stats: %v", err) + log.Printf("Failed to get system stats: %v", err) return } diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 52dbd3f..827c055 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -348,7 +348,7 @@ func (s *Server) sendAuthResponse(conn net.Conn, success bool, message, clientID return protocol.WriteMessage(conn, msg) } -// sendProxyConfig 发送代理配置 +// sendProxyConfig 发送代理配置并等待客户端确认 func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyRule) error { stream, err := session.Open() if err != nil { @@ -361,7 +361,20 @@ func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyR if err != nil { return err } - return protocol.WriteMessage(stream, msg) + if err := protocol.WriteMessage(stream, msg); err != nil { + return err + } + + // 等待客户端确认 + ack, err := protocol.ReadMessage(stream) + if err != nil { + return fmt.Errorf("wait config ack: %w", err) + } + if ack.Type != protocol.MsgTypeProxyReady { + return fmt.Errorf("unexpected ack type: %d", ack.Type) + } + + return nil } // registerClient 注册客户端