From 78982a26b04621cd4fca2f3e778a87f173cff4bb Mon Sep 17 00:00:00 2001 From: Flik Date: Sun, 4 Jan 2026 20:32:21 +0800 Subject: [PATCH] feat: implement plugin API request handling with HTTP Basic Auth support --- internal/client/tunnel/client.go | 169 ++++++++++++++----- internal/server/router/dto/plugin.go | 4 + internal/server/router/handler/interfaces.go | 2 + internal/server/router/handler/plugin_api.go | 140 +++++++++++++++ internal/server/router/handler/store.go | 49 +++++- internal/server/router/router.go | 4 + internal/server/tunnel/server.go | 133 +++++++++++++++ pkg/plugin/script/js.go | 141 ++++++++++++++++ pkg/protocol/message.go | 26 +++ 9 files changed, 620 insertions(+), 48 deletions(-) create mode 100644 internal/server/router/handler/plugin_api.go diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index 0bb1cba..2550c14 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -117,18 +117,45 @@ func (c *Client) SetPluginRegistry(registry *plugin.Registry) { c.pluginRegistry = registry } +// logf 安全地记录日志(同时输出到标准日志和日志收集器) +func (c *Client) logf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.Print(msg) + if c.logger != nil { + c.logger.Printf(msg) + } +} + +// logErrorf 安全地记录错误日志 +func (c *Client) logErrorf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.Print(msg) + if c.logger != nil { + c.logger.Errorf(msg) + } +} + +// logWarnf 安全地记录警告日志 +func (c *Client) logWarnf(format string, args ...interface{}) { + msg := fmt.Sprintf(format, args...) + log.Print(msg) + if c.logger != nil { + c.logger.Warnf(msg) + } +} + // Run 启动客户端(带断线重连) func (c *Client) Run() error { for { if err := c.connect(); err != nil { - log.Printf("[Client] Connect error: %v", err) - log.Printf("[Client] Reconnecting in %v...", reconnectDelay) + c.logErrorf("[Client] Connect error: %v", err) + c.logf("[Client] Reconnecting in %v...", reconnectDelay) time.Sleep(reconnectDelay) continue } c.handleSession() - log.Printf("[Client] Disconnected, reconnecting...") + c.logWarnf("[Client] Disconnected, reconnecting...") time.Sleep(disconnectDelay) } } @@ -175,10 +202,10 @@ func (c *Client) connect() error { if authResp.ClientID != "" && authResp.ClientID != c.ID { c.ID = authResp.ClientID saveClientID(c.ID) - log.Printf("[Client] New ID assigned and saved: %s", c.ID) + c.logf("[Client] New ID assigned and saved: %s", c.ID) } - log.Printf("[Client] Authenticated as %s", c.ID) + c.logf("[Client] Authenticated as %s", c.ID) session, err := yamux.Client(conn, nil) if err != nil { @@ -251,6 +278,8 @@ func (c *Client) handleStream(stream net.Conn) { c.handleLogStop(stream, msg) case protocol.MsgTypePluginStatusQuery: c.handlePluginStatusQuery(stream, msg) + case protocol.MsgTypePluginAPIRequest: + c.handlePluginAPIRequest(stream, msg) } } @@ -258,7 +287,7 @@ func (c *Client) handleStream(stream net.Conn) { func (c *Client) handleProxyConfig(msg *protocol.Message) { var cfg protocol.ProxyConfig if err := msg.ParsePayload(&cfg); err != nil { - log.Printf("[Client] Parse proxy config error: %v", err) + c.logErrorf("[Client] Parse proxy config error: %v", err) return } @@ -266,9 +295,9 @@ func (c *Client) handleProxyConfig(msg *protocol.Message) { c.rules = cfg.Rules c.mu.Unlock() - log.Printf("[Client] Received %d proxy rules", len(cfg.Rules)) + c.logf("[Client] Received %d proxy rules", len(cfg.Rules)) for _, r := range cfg.Rules { - log.Printf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort) + c.logf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort) } } @@ -276,7 +305,7 @@ func (c *Client) handleProxyConfig(msg *protocol.Message) { func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { var req protocol.NewProxyRequest if err := msg.ParsePayload(&req); err != nil { - log.Printf("[Client] Parse new proxy request error: %v", err) + c.logErrorf("[Client] Parse new proxy request error: %v", err) return } @@ -291,14 +320,14 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { c.mu.RUnlock() if rule == nil { - log.Printf("[Client] Unknown port %d", req.RemotePort) + c.logWarnf("[Client] Unknown port %d", req.RemotePort) return } localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort) localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout) if err != nil { - log.Printf("[Client] Connect %s error: %v", localAddr, err) + c.logErrorf("[Client] Connect %s error: %v", localAddr, err) return } @@ -408,24 +437,24 @@ func (c *Client) findRuleByPort(port int) *protocol.ProxyRule { func (c *Client) handlePluginConfig(msg *protocol.Message) { var cfg protocol.PluginConfigSync if err := msg.ParsePayload(&cfg); err != nil { - log.Printf("[Client] Parse plugin config error: %v", err) + c.logErrorf("[Client] Parse plugin config error: %v", err) return } - log.Printf("[Client] Received config for plugin: %s", cfg.PluginName) + c.logf("[Client] Received config for plugin: %s", cfg.PluginName) // 应用配置到插件 if c.pluginRegistry != nil { handler, err := c.pluginRegistry.GetClient(cfg.PluginName) if err != nil { - log.Printf("[Client] Plugin %s not found: %v", cfg.PluginName, err) + c.logWarnf("[Client] Plugin %s not found: %v", cfg.PluginName, err) return } if err := handler.Init(cfg.Config); err != nil { - log.Printf("[Client] Plugin %s init error: %v", cfg.PluginName, err) + c.logErrorf("[Client] Plugin %s init error: %v", cfg.PluginName, err) return } - log.Printf("[Client] Plugin %s config applied", cfg.PluginName) + c.logf("[Client] Plugin %s config applied", cfg.PluginName) } } @@ -439,7 +468,7 @@ func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) return } - log.Printf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName) + c.logf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName) // 获取插件 if c.pluginRegistry == nil { @@ -471,7 +500,7 @@ func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) c.runningPlugins[key] = handler c.pluginMu.Unlock() - log.Printf("[Client] Plugin %s started at %s", req.PluginName, localAddr) + c.logf("[Client] Plugin %s started at %s", req.PluginName, localAddr) c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, localAddr, "") } @@ -502,7 +531,7 @@ func (c *Client) handleClientPluginConn(stream net.Conn, msg *protocol.Message) c.pluginMu.RUnlock() if !ok { - log.Printf("[Client] Plugin %s not running", key) + c.logWarnf("[Client] Plugin %s not running", key) stream.Close() return } @@ -521,15 +550,15 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { return } - log.Printf("[Client] Installing JS plugin: %s", req.PluginName) + c.logf("[Client] Installing JS plugin: %s", req.PluginName) // 如果插件已经在运行,先停止它 key := req.PluginName + ":" + req.RuleName c.pluginMu.Lock() if existingHandler, ok := c.runningPlugins[key]; ok { - log.Printf("[Client] Stopping existing plugin %s before reinstall", key) + c.logf("[Client] Stopping existing plugin %s before reinstall", key) if err := existingHandler.Stop(); err != nil { - log.Printf("[Client] Stop existing plugin error: %v", err) + c.logErrorf("[Client] Stop existing plugin error: %v", err) } delete(c.runningPlugins, key) } @@ -537,11 +566,11 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { // 验证官方签名 if err := c.verifyJSPluginSignature(req.PluginName, req.Source, req.Signature); err != nil { - log.Printf("[Client] JS plugin %s signature verification failed: %v", req.PluginName, err) + c.logErrorf("[Client] JS plugin %s signature verification failed: %v", req.PluginName, err) c.sendJSPluginResult(stream, req.PluginName, false, "signature verification failed: "+err.Error()) return } - log.Printf("[Client] JS plugin %s signature verified", req.PluginName) + c.logf("[Client] JS plugin %s signature verified", req.PluginName) // 创建 JS 插件 jsPlugin, err := script.NewJSPlugin(req.PluginName, req.Source) @@ -555,7 +584,7 @@ func (c *Client) handleJSPluginInstall(stream net.Conn, msg *protocol.Message) { c.pluginRegistry.RegisterClient(jsPlugin) } - log.Printf("[Client] JS plugin %s installed", req.PluginName) + c.logf("[Client] JS plugin %s installed", req.PluginName) c.sendJSPluginResult(stream, req.PluginName, true, "") // 保存版本信息(防止降级攻击) @@ -586,13 +615,13 @@ func (c *Client) sendJSPluginResult(stream net.Conn, name string, success bool, // startJSPlugin 启动 JS 插件 func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPluginInstallRequest) { if err := handler.Init(req.Config); err != nil { - log.Printf("[Client] JS plugin %s init error: %v", req.PluginName, err) + c.logErrorf("[Client] JS plugin %s init error: %v", req.PluginName, err) return } localAddr, err := handler.Start() if err != nil { - log.Printf("[Client] JS plugin %s start error: %v", req.PluginName, err) + c.logErrorf("[Client] JS plugin %s start error: %v", req.PluginName, err) return } @@ -601,7 +630,7 @@ func (c *Client) startJSPlugin(handler plugin.ClientPlugin, req protocol.JSPlugi c.runningPlugins[key] = handler c.pluginMu.Unlock() - log.Printf("[Client] JS plugin %s started at %s", req.PluginName, localAddr) + c.logf("[Client] JS plugin %s started at %s", req.PluginName, localAddr) } // verifyJSPluginSignature 验证 JS 插件签名 @@ -664,13 +693,13 @@ func (c *Client) handleClientPluginStop(stream net.Conn, msg *protocol.Message) handler, ok := c.runningPlugins[key] if ok { if err := handler.Stop(); err != nil { - log.Printf("[Client] Plugin %s stop error: %v", key, err) + c.logErrorf("[Client] Plugin %s stop error: %v", key, err) } delete(c.runningPlugins, key) } c.pluginMu.Unlock() - log.Printf("[Client] Plugin %s stopped", key) + c.logf("[Client] Plugin %s stopped", key) c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "") } @@ -681,7 +710,7 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) { var req protocol.ClientRestartRequest msg.ParsePayload(&req) - log.Printf("[Client] Restart requested: %s", req.Reason) + c.logf("[Client] Restart requested: %s", req.Reason) // 发送响应 resp := protocol.ClientRestartResponse{ @@ -694,7 +723,7 @@ func (c *Client) handleClientRestart(stream net.Conn, msg *protocol.Message) { // 停止所有运行中的插件 c.pluginMu.Lock() for key, handler := range c.runningPlugins { - log.Printf("[Client] Stopping plugin %s for restart", key) + c.logf("[Client] Stopping plugin %s for restart", key) handler.Stop() } c.runningPlugins = make(map[string]plugin.ClientPlugin) @@ -717,7 +746,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message } key := req.PluginName + ":" + req.RuleName - log.Printf("[Client] Config update for plugin %s", key) + c.logf("[Client] Config update for plugin %s", key) c.pluginMu.RLock() handler, ok := c.runningPlugins[key] @@ -732,7 +761,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message // 停止并重启插件 c.pluginMu.Lock() if err := handler.Stop(); err != nil { - log.Printf("[Client] Plugin %s stop error: %v", key, err) + c.logErrorf("[Client] Plugin %s stop error: %v", key, err) } delete(c.runningPlugins, key) c.pluginMu.Unlock() @@ -753,7 +782,7 @@ func (c *Client) handlePluginConfigUpdate(stream net.Conn, msg *protocol.Message c.runningPlugins[key] = handler c.pluginMu.Unlock() - log.Printf("[Client] Plugin %s restarted at %s with new config", key, localAddr) + c.logf("[Client] Plugin %s restarted at %s with new config", key, localAddr) } c.sendPluginConfigUpdateResult(stream, req.PluginName, req.RuleName, true, "") @@ -777,17 +806,17 @@ func (c *Client) handleUpdateDownload(stream net.Conn, msg *protocol.Message) { var req protocol.UpdateDownloadRequest if err := msg.ParsePayload(&req); err != nil { - log.Printf("[Client] Parse update request error: %v", err) + c.logErrorf("[Client] Parse update request error: %v", err) c.sendUpdateResult(stream, false, "invalid request") return } - log.Printf("[Client] Update download requested: %s", req.DownloadURL) + c.logf("[Client] Update download requested: %s", req.DownloadURL) // 异步执行更新 go func() { if err := c.performSelfUpdate(req.DownloadURL); err != nil { - log.Printf("[Client] Update failed: %v", err) + c.logErrorf("[Client] Update failed: %v", err) } }() @@ -806,7 +835,7 @@ func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) // performSelfUpdate 执行自更新 func (c *Client) performSelfUpdate(downloadURL string) error { - log.Printf("[Client] Starting self-update from: %s", downloadURL) + c.logf("[Client] Starting self-update from: %s", downloadURL) // 使用共享的下载和解压逻辑 binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client") @@ -853,7 +882,7 @@ func (c *Client) performSelfUpdate(downloadURL string) error { // 删除备份 os.Remove(backupPath) - log.Printf("[Client] Update completed, restarting...") + c.logf("[Client] Update completed, restarting...") // 重启进程 restartClientProcess(currentPath, c.ServerAddr, c.Token, c.ID) @@ -864,7 +893,7 @@ func (c *Client) performSelfUpdate(downloadURL string) error { func (c *Client) stopAllPlugins() { c.pluginMu.Lock() for key, handler := range c.runningPlugins { - log.Printf("[Client] Stopping plugin %s for update", key) + c.logf("[Client] Stopping plugin %s for update", key) handler.Stop() } c.runningPlugins = make(map[string]plugin.ClientPlugin) @@ -1015,3 +1044,61 @@ func (c *Client) handleLogStop(stream net.Conn, msg *protocol.Message) { c.logger.Unsubscribe(req.SessionID) } + +// handlePluginAPIRequest 处理插件 API 请求 +func (c *Client) handlePluginAPIRequest(stream net.Conn, msg *protocol.Message) { + defer stream.Close() + + var req protocol.PluginAPIRequest + if err := msg.ParsePayload(&req); err != nil { + c.sendPluginAPIResponse(stream, 400, nil, "", "invalid request: "+err.Error()) + return + } + + c.logf("[Client] Plugin API request: %s %s for plugin %s", req.Method, req.Path, req.PluginName) + + // 查找运行中的插件 + c.pluginMu.RLock() + var handler plugin.ClientPlugin + for key, p := range c.runningPlugins { + // key 格式为 "pluginName:ruleName" + if strings.HasPrefix(key, req.PluginName+":") { + handler = p + break + } + } + c.pluginMu.RUnlock() + + if handler == nil { + c.sendPluginAPIResponse(stream, 404, nil, "", "plugin not running: "+req.PluginName) + return + } + + // 类型断言为 JSPlugin + jsPlugin, ok := handler.(*script.JSPlugin) + if !ok { + c.sendPluginAPIResponse(stream, 500, nil, "", "plugin does not support API routing") + return + } + + // 调用插件的 API 处理函数 + status, headers, body, err := jsPlugin.HandleAPIRequest(req.Method, req.Path, req.Query, req.Headers, req.Body) + if err != nil { + c.sendPluginAPIResponse(stream, 500, nil, "", err.Error()) + return + } + + c.sendPluginAPIResponse(stream, status, headers, body, "") +} + +// sendPluginAPIResponse 发送插件 API 响应 +func (c *Client) sendPluginAPIResponse(stream net.Conn, status int, headers map[string]string, body, errMsg string) { + resp := protocol.PluginAPIResponse{ + Status: status, + Headers: headers, + Body: body, + Error: errMsg, + } + msg, _ := protocol.NewMessage(protocol.MsgTypePluginAPIResponse, resp) + protocol.WriteMessage(stream, msg) +} diff --git a/internal/server/router/dto/plugin.go b/internal/server/router/dto/plugin.go index a974b0e..d1b1abe 100644 --- a/internal/server/router/dto/plugin.go +++ b/internal/server/router/dto/plugin.go @@ -106,6 +106,10 @@ type StoreInstallRequest struct { ClientID string `json:"client_id" binding:"required"` RemotePort int `json:"remote_port"` ConfigSchema []ConfigField `json:"config_schema,omitempty"` + // HTTP Basic Auth 配置 + AuthEnabled bool `json:"auth_enabled,omitempty"` + AuthUsername string `json:"auth_username,omitempty"` + AuthPassword string `json:"auth_password,omitempty"` } // JSPluginPushRequest 推送 JS 插件到客户端请求 diff --git a/internal/server/router/handler/interfaces.go b/internal/server/router/handler/interfaces.go index c26bfed..ec280fe 100644 --- a/internal/server/router/handler/interfaces.go +++ b/internal/server/router/handler/interfaces.go @@ -49,6 +49,8 @@ type ServerInterface interface { GetClientPluginStatus(clientID string) ([]protocol.PluginStatusEntry, error) // 插件规则管理 StartPluginRule(clientID string, rule protocol.ProxyRule) error + // 插件 API 代理 + ProxyPluginAPIRequest(clientID string, req protocol.PluginAPIRequest) (*protocol.PluginAPIResponse, error) } // ConfigField 配置字段 diff --git a/internal/server/router/handler/plugin_api.go b/internal/server/router/handler/plugin_api.go new file mode 100644 index 0000000..0f88186 --- /dev/null +++ b/internal/server/router/handler/plugin_api.go @@ -0,0 +1,140 @@ +package handler + +import ( + "io" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/gotunnel/pkg/protocol" +) + +// PluginAPIHandler 插件 API 代理处理器 +type PluginAPIHandler struct { + app AppInterface +} + +// NewPluginAPIHandler 创建插件 API 代理处理器 +func NewPluginAPIHandler(app AppInterface) *PluginAPIHandler { + return &PluginAPIHandler{app: app} +} + +// ProxyRequest 代理请求到客户端插件 +// @Summary 代理插件 API 请求 +// @Description 将请求代理到客户端的 JS 插件处理 +// @Tags 插件 API +// @Accept json +// @Produce json +// @Security Bearer +// @Param clientID path string true "客户端 ID" +// @Param pluginName path string true "插件名称" +// @Param route path string true "插件路由" +// @Success 200 {object} object +// @Failure 404 {object} Response +// @Failure 502 {object} Response +// @Router /api/client/{clientID}/plugin/{pluginName}/{route} [get] +func (h *PluginAPIHandler) ProxyRequest(c *gin.Context) { + clientID := c.Param("clientID") + pluginName := c.Param("pluginName") + route := c.Param("route") + + // 确保路由以 / 开头 + if !strings.HasPrefix(route, "/") { + route = "/" + route + } + + // 检查客户端是否在线 + online, _, _ := h.app.GetServer().GetClientStatus(clientID) + if !online { + ClientNotOnline(c) + return + } + + // 读取请求体 + var body string + if c.Request.Body != nil { + bodyBytes, _ := io.ReadAll(c.Request.Body) + body = string(bodyBytes) + } + + // 构建请求头 + headers := make(map[string]string) + for key, values := range c.Request.Header { + if len(values) > 0 { + headers[key] = values[0] + } + } + + // 构建 API 请求 + apiReq := protocol.PluginAPIRequest{ + PluginName: pluginName, + Method: c.Request.Method, + Path: route, + Query: c.Request.URL.RawQuery, + Headers: headers, + Body: body, + } + + // 发送请求到客户端 + resp, err := h.app.GetServer().ProxyPluginAPIRequest(clientID, apiReq) + if err != nil { + BadGateway(c, "Plugin request failed: "+err.Error()) + return + } + + // 检查错误 + if resp.Error != "" { + c.JSON(http.StatusBadGateway, gin.H{ + "code": 502, + "message": resp.Error, + }) + return + } + + // 设置响应头 + for key, value := range resp.Headers { + c.Header(key, value) + } + + // 返回响应 + c.String(resp.Status, resp.Body) +} + +// ProxyPluginAPIRequest 接口方法声明 - 添加到 ServerInterface +type PluginAPIProxyInterface interface { + ProxyPluginAPIRequest(clientID string, req protocol.PluginAPIRequest) (*protocol.PluginAPIResponse, error) +} + +// AuthConfig 认证配置 +type AuthConfig struct { + Type string `json:"type"` // none, basic, token + Username string `json:"username"` // Basic Auth 用户名 + Password string `json:"password"` // Basic Auth 密码 + Token string `json:"token"` // Token 认证 +} + +// BasicAuthMiddleware 创建 Basic Auth 中间件 +func BasicAuthMiddleware(username, password string) gin.HandlerFunc { + return func(c *gin.Context) { + user, pass, ok := c.Request.BasicAuth() + if !ok || user != username || pass != password { + c.Header("WWW-Authenticate", `Basic realm="Plugin"`) + c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{ + "code": 401, + "message": "Unauthorized", + }) + return + } + c.Next() + } +} + +// WithTimeout 带超时的请求处理 +func WithTimeout(timeout time.Duration, handler gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + // 设置请求超时 + c.Request = c.Request.WithContext(c.Request.Context()) + handler(c) + } +} diff --git a/internal/server/router/handler/store.go b/internal/server/router/handler/store.go index 10a7e9b..dc0de50 100644 --- a/internal/server/router/handler/store.go +++ b/internal/server/router/handler/store.go @@ -155,15 +155,16 @@ func (h *StoreHandler) Install(c *gin.Context) { dbClient, err := h.app.GetClientStore().GetClient(req.ClientID) if err == nil { // 检查插件是否已存在 - exists := false + pluginExists := false for i, p := range dbClient.Plugins { if p.Name == req.PluginName { dbClient.Plugins[i].Enabled = true - exists = true + dbClient.Plugins[i].RemotePort = req.RemotePort + pluginExists = true break } } - if !exists { + if !pluginExists { version := req.Version if version == "" { version = "1.0.0" @@ -189,16 +190,50 @@ func (h *StoreHandler) Install(c *gin.Context) { ConfigSchema: configSchema, }) } + + // 自动创建代理规则(如果指定了端口) + if req.RemotePort > 0 { + ruleExists := false + for i, r := range dbClient.Rules { + if r.Name == req.PluginName { + // 更新现有规则 + dbClient.Rules[i].Type = req.PluginName + dbClient.Rules[i].RemotePort = req.RemotePort + dbClient.Rules[i].Enabled = boolPtr(true) + dbClient.Rules[i].AuthEnabled = req.AuthEnabled + dbClient.Rules[i].AuthUsername = req.AuthUsername + dbClient.Rules[i].AuthPassword = req.AuthPassword + ruleExists = true + break + } + } + if !ruleExists { + // 创建新规则 + dbClient.Rules = append(dbClient.Rules, protocol.ProxyRule{ + Name: req.PluginName, + Type: req.PluginName, + RemotePort: req.RemotePort, + Enabled: boolPtr(true), + AuthEnabled: req.AuthEnabled, + AuthUsername: req.AuthUsername, + AuthPassword: req.AuthPassword, + }) + } + } + h.app.GetClientStore().UpdateClient(dbClient) } // 启动服务端监听器(让外部用户可以通过 RemotePort 访问插件) if req.RemotePort > 0 { pluginRule := protocol.ProxyRule{ - Name: req.PluginName, - Type: req.PluginName, // 使用插件名作为类型,让 isClientPlugin 识别 - RemotePort: req.RemotePort, - Enabled: boolPtr(true), + Name: req.PluginName, + Type: req.PluginName, // 使用插件名作为类型,让 isClientPlugin 识别 + RemotePort: req.RemotePort, + Enabled: boolPtr(true), + AuthEnabled: req.AuthEnabled, + AuthUsername: req.AuthUsername, + AuthPassword: req.AuthPassword, } // 启动监听器(忽略错误,可能端口已被占用) h.app.GetServer().StartPluginRule(req.ClientID, pluginRule) diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 82ade77..2d0d67d 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -109,6 +109,10 @@ func (r *GinRouter) SetupRoutes(app handler.AppInterface, jwtAuth *auth.JWTAuth, // 日志管理 logHandler := handler.NewLogHandler(app) api.GET("/client/:id/logs", logHandler.StreamLogs) + + // 插件 API 代理 (通过 Web API 访问客户端插件) + pluginAPIHandler := handler.NewPluginAPIHandler(app) + api.Any("/client/:clientID/plugin/:pluginName/*route", pluginAPIHandler.ProxyRequest) } } diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 781ae91..d5c0839 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -3,11 +3,13 @@ package tunnel import ( "crypto/rand" "crypto/tls" + "encoding/base64" "encoding/hex" "fmt" "log" "net" "regexp" + "strings" "sync" "time" @@ -1123,6 +1125,16 @@ func (s *Server) acceptClientPluginConns(cs *ClientSession, ln net.Listener, rul func (s *Server) handleClientPluginConn(cs *ClientSession, conn net.Conn, rule protocol.ProxyRule) { defer conn.Close() + // 如果启用了 HTTP Basic Auth,先进行认证 + var bufferedData []byte + if rule.AuthEnabled { + authenticated, data := s.checkHTTPBasicAuth(conn, rule.AuthUsername, rule.AuthPassword) + if !authenticated { + return + } + bufferedData = data + } + stream, err := cs.Session.Open() if err != nil { log.Printf("[Server] Open stream error: %v", err) @@ -1139,9 +1151,84 @@ func (s *Server) handleClientPluginConn(cs *ClientSession, conn net.Conn, rule p return } + // 如果有缓冲的数据(已读取的 HTTP 请求头),先发送给客户端 + if len(bufferedData) > 0 { + if _, err := stream.Write(bufferedData); err != nil { + return + } + } + relay.Relay(conn, stream) } +// checkHTTPBasicAuth 检查 HTTP Basic Auth +// 返回 (认证成功, 已读取的数据) +func (s *Server) checkHTTPBasicAuth(conn net.Conn, username, password string) (bool, []byte) { + // 设置读取超时 + conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + defer conn.SetReadDeadline(time.Time{}) // 重置超时 + + // 读取 HTTP 请求头 + buf := make([]byte, 8192) // 增大缓冲区以处理更大的请求头 + n, err := conn.Read(buf) + if err != nil { + return false, nil + } + + data := buf[:n] + request := string(data) + + // 解析 Authorization 头 + authHeader := "" + lines := strings.Split(request, "\r\n") + for _, line := range lines { + if strings.HasPrefix(strings.ToLower(line), "authorization:") { + authHeader = strings.TrimSpace(line[14:]) + break + } + } + + // 检查 Basic Auth + if authHeader == "" || !strings.HasPrefix(authHeader, "Basic ") { + s.sendHTTPUnauthorized(conn) + return false, nil + } + + // 解码 Base64 + encoded := authHeader[6:] + decoded, err := base64.StdEncoding.DecodeString(encoded) + if err != nil { + s.sendHTTPUnauthorized(conn) + return false, nil + } + + // 解析 username:password + credentials := string(decoded) + parts := strings.SplitN(credentials, ":", 2) + if len(parts) != 2 { + s.sendHTTPUnauthorized(conn) + return false, nil + } + + if parts[0] != username || parts[1] != password { + s.sendHTTPUnauthorized(conn) + return false, nil + } + + return true, data +} + +// sendHTTPUnauthorized 发送 401 未授权响应 +func (s *Server) sendHTTPUnauthorized(conn net.Conn) { + response := "HTTP/1.1 401 Unauthorized\r\n" + + "WWW-Authenticate: Basic realm=\"GoTunnel Plugin\"\r\n" + + "Content-Type: text/plain\r\n" + + "Content-Length: 12\r\n" + + "\r\n" + + "Unauthorized" + conn.Write([]byte(response)) +} + // autoPushJSPlugins 自动推送 JS 插件到客户端 func (s *Server) autoPushJSPlugins(cs *ClientSession) { // 记录已推送的插件,避免重复推送 @@ -1375,6 +1462,52 @@ func (s *Server) StartPluginRule(clientID string, rule protocol.ProxyRule) error return nil } +// ProxyPluginAPIRequest 代理插件 API 请求到客户端 +func (s *Server) ProxyPluginAPIRequest(clientID string, req protocol.PluginAPIRequest) (*protocol.PluginAPIResponse, error) { + s.mu.RLock() + cs, ok := s.clients[clientID] + s.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("client %s not found or not online", clientID) + } + + stream, err := cs.Session.Open() + if err != nil { + return nil, fmt.Errorf("open stream: %w", err) + } + defer stream.Close() + + // 设置超时(30秒) + stream.SetDeadline(time.Now().Add(30 * time.Second)) + + // 发送 API 请求 + msg, err := protocol.NewMessage(protocol.MsgTypePluginAPIRequest, req) + if err != nil { + return nil, err + } + if err := protocol.WriteMessage(stream, msg); err != nil { + return nil, err + } + + // 读取响应 + resp, err := protocol.ReadMessage(stream) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.Type != protocol.MsgTypePluginAPIResponse { + return nil, fmt.Errorf("unexpected response type: %d", resp.Type) + } + + var apiResp protocol.PluginAPIResponse + if err := resp.ParsePayload(&apiResp); err != nil { + return nil, err + } + + return &apiResp, nil +} + // RestartClientPlugin 重启客户端 JS 插件 func (s *Server) RestartClientPlugin(clientID, pluginName, ruleName string) error { s.mu.RLock() diff --git a/pkg/plugin/script/js.go b/pkg/plugin/script/js.go index 7a2f20a..03cdbf4 100644 --- a/pkg/plugin/script/js.go +++ b/pkg/plugin/script/js.go @@ -27,6 +27,7 @@ type JSPlugin struct { mu sync.Mutex eventListeners map[string][]func(goja.Value) storagePath string + apiHandlers map[string]map[string]goja.Callable // method -> path -> handler } // NewJSPlugin 从 JS 源码创建插件 @@ -38,6 +39,7 @@ func NewJSPlugin(name, source string) (*JSPlugin, error) { sandbox: DefaultSandbox(), eventListeners: make(map[string][]func(goja.Value)), storagePath: filepath.Join("plugin_data", name+".json"), + apiHandlers: make(map[string]map[string]goja.Callable), } // 确保存储目录存在 @@ -86,6 +88,9 @@ func (p *JSPlugin) init() error { // 注入 HTTP API p.vm.Set("http", p.createHttpAPI()) + // 注入路由 API + p.vm.Set("api", p.createRouteAPI()) + // 执行脚本 _, err := p.vm.RunString(p.source) if err != nil { @@ -669,3 +674,139 @@ func (p *JSPlugin) createNotifyAPI() map[string]interface{} { }, } } + +// ============================================================================= +// Route API (用于 Web API 代理) +// ============================================================================= + +func (p *JSPlugin) createRouteAPI() map[string]interface{} { + return map[string]interface{}{ + "handle": p.apiHandle, + "get": func(path string, handler goja.Callable) { p.apiRegister("GET", path, handler) }, + "post": func(path string, handler goja.Callable) { p.apiRegister("POST", path, handler) }, + "put": func(path string, handler goja.Callable) { p.apiRegister("PUT", path, handler) }, + "delete": func(path string, handler goja.Callable) { p.apiRegister("DELETE", path, handler) }, + } +} + +// apiHandle 注册 API 路由处理函数 +func (p *JSPlugin) apiHandle(method, path string, handler goja.Callable) { + p.apiRegister(method, path, handler) +} + +// apiRegister 注册 API 路由 +func (p *JSPlugin) apiRegister(method, path string, handler goja.Callable) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.apiHandlers[method] == nil { + p.apiHandlers[method] = make(map[string]goja.Callable) + } + p.apiHandlers[method][path] = handler + fmt.Printf("[JS:%s] Registered API: %s %s\n", p.name, method, path) +} + +// HandleAPIRequest 处理 API 请求 +func (p *JSPlugin) HandleAPIRequest(method, path, query string, headers map[string]string, body string) (int, map[string]string, string, error) { + p.mu.Lock() + handlers := p.apiHandlers[method] + p.mu.Unlock() + + if handlers == nil { + return 404, nil, `{"error":"method not allowed"}`, nil + } + + // 查找匹配的路由 + var handler goja.Callable + var matchedPath string + + for registeredPath, h := range handlers { + if matchRoute(registeredPath, path) { + handler = h + matchedPath = registeredPath + break + } + } + + if handler == nil { + return 404, nil, `{"error":"route not found"}`, nil + } + + // 构建请求对象 + reqObj := map[string]interface{}{ + "method": method, + "path": path, + "pattern": matchedPath, + "query": query, + "headers": headers, + "body": body, + "params": extractParams(matchedPath, path), + } + + // 调用处理函数 + result, err := handler(goja.Undefined(), p.vm.ToValue(reqObj)) + if err != nil { + return 500, nil, fmt.Sprintf(`{"error":"%s"}`, err.Error()), nil + } + + // 解析响应 + if result == nil || goja.IsUndefined(result) || goja.IsNull(result) { + return 200, nil, "", nil + } + + respObj := result.ToObject(p.vm) + status := 200 + if s := respObj.Get("status"); s != nil && !goja.IsUndefined(s) { + status = int(s.ToInteger()) + } + + respHeaders := make(map[string]string) + if h := respObj.Get("headers"); h != nil && !goja.IsUndefined(h) { + hObj := h.ToObject(p.vm) + for _, key := range hObj.Keys() { + respHeaders[key] = hObj.Get(key).String() + } + } + + respBody := "" + if b := respObj.Get("body"); b != nil && !goja.IsUndefined(b) { + respBody = b.String() + } + + return status, respHeaders, respBody, nil +} + +// matchRoute 匹配路由 (支持简单的路径参数) +func matchRoute(pattern, path string) bool { + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + if len(patternParts) != len(pathParts) { + return false + } + + for i, part := range patternParts { + if strings.HasPrefix(part, ":") { + continue // 路径参数,匹配任意值 + } + if part != pathParts[i] { + return false + } + } + return true +} + +// extractParams 提取路径参数 +func extractParams(pattern, path string) map[string]string { + params := make(map[string]string) + patternParts := strings.Split(strings.Trim(pattern, "/"), "/") + pathParts := strings.Split(strings.Trim(path, "/"), "/") + + for i, part := range patternParts { + if strings.HasPrefix(part, ":") && i < len(pathParts) { + paramName := strings.TrimPrefix(part, ":") + params[paramName] = pathParts[i] + } + } + return params +} diff --git a/pkg/protocol/message.go b/pkg/protocol/message.go index 0378684..70bca54 100644 --- a/pkg/protocol/message.go +++ b/pkg/protocol/message.go @@ -67,6 +67,10 @@ const ( MsgTypeLogRequest uint8 = 80 // 请求客户端日志 MsgTypeLogData uint8 = 81 // 日志数据 MsgTypeLogStop uint8 = 82 // 停止日志流 + + // 插件 API 路由消息 + MsgTypePluginAPIRequest uint8 = 90 // 插件 API 请求 + MsgTypePluginAPIResponse uint8 = 91 // 插件 API 响应 ) // Message 基础消息结构 @@ -100,6 +104,10 @@ type ProxyRule struct { PluginName string `json:"plugin_name,omitempty" yaml:"plugin_name"` PluginVersion string `json:"plugin_version,omitempty" yaml:"plugin_version"` PluginConfig map[string]string `json:"plugin_config,omitempty" yaml:"plugin_config"` + // HTTP Basic Auth 字段 (用于独立端口模式) + AuthEnabled bool `json:"auth_enabled,omitempty" yaml:"auth_enabled"` + AuthUsername string `json:"auth_username,omitempty" yaml:"auth_username"` + AuthPassword string `json:"auth_password,omitempty" yaml:"auth_password"` } // IsEnabled 检查规则是否启用,默认为 true @@ -351,6 +359,24 @@ type LogStopRequest struct { SessionID string `json:"session_id"` // 会话 ID } +// PluginAPIRequest 插件 API 请求 +type PluginAPIRequest struct { + PluginName string `json:"plugin_name"` // 插件名称 + Method string `json:"method"` // HTTP 方法: GET, POST, PUT, DELETE + Path string `json:"path"` // 路由路径 + Query string `json:"query"` // 查询参数 + Headers map[string]string `json:"headers"` // 请求头 + Body string `json:"body"` // 请求体 +} + +// PluginAPIResponse 插件 API 响应 +type PluginAPIResponse struct { + Status int `json:"status"` // HTTP 状态码 + Headers map[string]string `json:"headers"` // 响应头 + Body string `json:"body"` // 响应体 + Error string `json:"error"` // 错误信息 +} + // WriteMessage 写入消息到 writer func WriteMessage(w io.Writer, msg *Message) error { header := make([]byte, HeaderSize)