From 4623a7f0310ec0d3a9c11e6edfa894b53bd8676a Mon Sep 17 00:00:00 2001 From: Flik Date: Fri, 26 Dec 2025 11:24:23 +0800 Subject: [PATCH] add plugins --- CLAUDE.md | 53 ++++++- README.md | 145 +++++++++++++++++- cmd/server/main.go | 13 ++ go.mod | 3 +- go.sum | 4 + internal/client/plugin/cache.go | 114 ++++++++++++++ internal/client/plugin/manager.go | 70 +++++++++ internal/client/tunnel/client.go | 100 +++++++++++-- internal/server/app/app.go | 18 +-- internal/server/plugin/manager.go | 137 +++++++++++++++++ internal/server/router/api.go | 22 ++- internal/server/router/router.go | 35 +++++ internal/server/tunnel/server.go | 240 +++++++++++++++++++++++++----- pkg/plugin/builtin.go | 16 ++ pkg/plugin/builtin/http.go | 116 +++++++++++++++ pkg/plugin/builtin/socks5.go | 167 +++++++++++++++++++++ pkg/plugin/registry.go | 93 ++++++++++++ pkg/plugin/store/interface.go | 29 ++++ pkg/plugin/store/sqlite.go | 168 +++++++++++++++++++++ pkg/plugin/types.go | 99 ++++++++++++ pkg/plugin/wasm/host.go | 146 ++++++++++++++++++ pkg/plugin/wasm/memory.go | 29 ++++ pkg/plugin/wasm/module.go | 148 ++++++++++++++++++ pkg/plugin/wasm/runtime.go | 116 +++++++++++++++ pkg/protocol/message.go | 81 +++++++++- pkg/proxy/server.go | 2 +- pkg/relay/relay.go | 23 ++- 27 files changed, 2090 insertions(+), 97 deletions(-) create mode 100644 internal/client/plugin/cache.go create mode 100644 internal/client/plugin/manager.go create mode 100644 internal/server/plugin/manager.go create mode 100644 pkg/plugin/builtin.go create mode 100644 pkg/plugin/builtin/http.go create mode 100644 pkg/plugin/builtin/socks5.go create mode 100644 pkg/plugin/registry.go create mode 100644 pkg/plugin/store/interface.go create mode 100644 pkg/plugin/store/sqlite.go create mode 100644 pkg/plugin/types.go create mode 100644 pkg/plugin/wasm/host.go create mode 100644 pkg/plugin/wasm/memory.go create mode 100644 pkg/plugin/wasm/module.go create mode 100644 pkg/plugin/wasm/runtime.go diff --git a/CLAUDE.md b/CLAUDE.md index f5747cb..c40ebc4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -43,15 +43,23 @@ internal/server/ ├── config/ # YAML configuration loading ├── db/ # SQLite storage (ClientStore interface) ├── app/ # Web server, SPA handler - └── router/ # REST API endpoints + ├── router/ # REST API endpoints + └── plugin/ # Server-side plugin manager internal/client/ - └── tunnel/ # Client tunnel logic, auto-reconnect + ├── tunnel/ # Client tunnel logic, auto-reconnect + └── plugin/ # Client-side plugin manager and cache pkg/ ├── protocol/ # Message types and serialization ├── crypto/ # TLS certificate generation - ├── proxy/ # SOCKS5 and HTTP proxy implementations + ├── proxy/ # Legacy proxy implementations ├── relay/ # Bidirectional data relay (32KB buffers) - └── utils/ # Port availability checking + ├── utils/ # Port availability checking + └── plugin/ # Plugin system core + ├── types.go # ProxyHandler interface, PluginMetadata + ├── registry.go # Plugin registry + ├── builtin/ # Built-in plugins (socks5, http) + ├── wasm/ # WASM runtime (wazero) + └── store/ # Plugin persistence (SQLite) web/ # Vue 3 + TypeScript frontend (Vite) ``` @@ -59,12 +67,19 @@ web/ # Vue 3 + TypeScript frontend (Vite) - `ClientStore` (internal/server/db/): Database abstraction for client rules storage - `ServerInterface` (internal/server/router/): API handler interface +- `ProxyHandler` (pkg/plugin/): Plugin interface for proxy handlers +- `PluginStore` (pkg/plugin/store/): Plugin persistence interface ### Proxy Types +**内置类型** (直接在 tunnel 中处理): 1. **TCP** (default): Direct port forwarding (remote_port → local_ip:local_port) -2. **SOCKS5**: Full SOCKS5 protocol via `TunnelDialer` -3. **HTTP**: HTTP/HTTPS proxy through client network +2. **UDP**: UDP port forwarding +3. **HTTP**: HTTP proxy through client network +4. **HTTPS**: HTTPS proxy through client network + +**插件类型** (通过 plugin 系统提供): +- **SOCKS5**: Full SOCKS5 protocol (official plugin) ### Data Flow @@ -75,3 +90,29 @@ External User → Server Port → Yamux Stream → Client → Local Service - Server: YAML config + SQLite database for client rules - Client: Command-line flags only (server address, token, client ID) - Default ports: 7000 (tunnel), 7500 (web console) + +## Plugin System + +GoTunnel supports a WASM-based plugin system for extensible proxy handlers. + +### Plugin Architecture + +- **内置类型**: tcp, udp, http, https 直接在 tunnel 代码中处理 +- **Official Plugin**: SOCKS5 作为官方 plugin 提供 +- **WASM Plugins**: 自定义 plugins 可通过 wazero 运行时动态加载 +- **Hybrid Distribution**: 内置 plugins 离线可用;WASM plugins 可从服务端下载 + +### ProxyHandler Interface + +```go +type ProxyHandler interface { + Metadata() PluginMetadata + Init(config map[string]string) error + HandleConn(conn net.Conn, dialer Dialer) error + Close() error +} +``` + +### Creating a Built-in Plugin + +See `pkg/plugin/builtin/socks5.go` as reference implementation. diff --git a/README.md b/README.md index e21a16c..181f1b0 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,23 @@ # GoTunnel +[![Go Version](https://img.shields.io/badge/Go-1.21+-00ADD8?style=flat&logo=go)](https://go.dev/) +[![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) + 一个轻量级、高性能的内网穿透工具,采用服务端集中化管理模式,支持 TLS 加密通信。 ## 项目简介 GoTunnel 是一个类似 frp 的内网穿透解决方案,核心特点是**服务端集中管理配置**和**零配置 TLS 加密**。客户端只需提供认证信息即可自动获取映射规则,无需在客户端维护复杂配置。 +### 与 frp 的主要区别 + +| 特性 | GoTunnel | frp | +|------|----------|-----| +| 配置管理 | 服务端集中管理 | 客户端各自配置 | +| TLS 证书 | 自动生成,零配置 | 需手动配置 | +| 管理界面 | 内置 Web 控制台 | 需额外部署 Dashboard | +| 客户端部署 | 仅需 3 个参数 | 需配置文件 | + ### 架构设计 ``` @@ -157,11 +169,20 @@ web: 通过 Web 控制台配置客户端规则时,支持以下类型: +### 内置类型 + | 类型 | 说明 | 示例用途 | |------|------|----------| | `tcp` | TCP 端口转发(默认) | SSH、MySQL、Web 服务 | -| `socks5` | SOCKS5 代理 | 通过客户端网络访问任意地址 | +| `udp` | UDP 端口转发 | DNS、游戏服务器、VoIP | | `http` | HTTP 代理 | 通过客户端网络访问 HTTP/HTTPS | +| `https` | HTTPS 代理 | 同 HTTP,支持 CONNECT 方法 | + +### 插件类型 + +| 类型 | 说明 | 示例用途 | +|------|------|----------| +| `socks5` | SOCKS5 代理(官方插件) | 通过客户端网络访问任意地址 | **规则配置示例(通过 Web API):** @@ -170,6 +191,7 @@ web: "id": "client-a", "rules": [ {"name": "web", "type": "tcp", "local_ip": "127.0.0.1", "local_port": 80, "remote_port": 8080}, + {"name": "dns", "type": "udp", "local_ip": "127.0.0.1", "local_port": 53, "remote_port": 5353}, {"name": "socks5-proxy", "type": "socks5", "remote_port": 1080}, {"name": "http-proxy", "type": "http", "remote_port": 8888} ] @@ -189,18 +211,131 @@ GoTunnel/ │ │ ├── config/ # 配置管理 │ │ ├── db/ # 数据库存储 │ │ ├── app/ # Web 服务 -│ │ └── router/ # API 路由 +│ │ ├── router/ # API 路由 +│ │ └── plugin/ # 服务端插件管理 │ └── client/ -│ └── tunnel/ # 客户端隧道 +│ ├── tunnel/ # 客户端隧道 +│ └── plugin/ # 客户端插件管理和缓存 ├── pkg/ │ ├── protocol/ # 通信协议 │ ├── crypto/ # TLS 加密 -│ ├── proxy/ # SOCKS5/HTTP 代理 +│ ├── proxy/ # 代理服务器 │ ├── relay/ # 数据转发 -│ └── utils/ # 工具函数 +│ ├── utils/ # 工具函数 +│ └── plugin/ # 插件系统核心 +│ ├── builtin/ # 内置插件 (socks5) +│ ├── wasm/ # WASM 运行时 (wazero) +│ └── store/ # 插件持久化 (SQLite) +├── web/ # Vue 3 前端 └── go.mod ``` +## 插件系统 + +GoTunnel 支持基于 WASM 的插件系统,可扩展代理协议支持。 + +### 架构设计 + +- **内置类型**: tcp, udp, http, https 直接在 tunnel 代码中处理,无需插件 +- **官方插件**: SOCKS5 作为官方插件提供 +- **WASM 插件**: 自定义插件可通过 wazero 运行时动态加载 +- **混合分发**: 内置插件离线可用;WASM 插件可从服务端下载 + +### 开发自定义插件 + +插件需实现 `ProxyHandler` 接口: + +```go +type ProxyHandler interface { + Metadata() PluginMetadata + Init(config map[string]string) error + HandleConn(conn net.Conn, dialer Dialer) error + Close() error +} +``` + +参考实现:`pkg/plugin/builtin/socks5.go` + +## Web API + +Web 控制台提供 RESTful API 用于管理客户端和配置。 + +### 客户端管理 + +```bash +# 获取所有客户端 +GET /api/clients + +# 添加客户端 +POST /api/clients +Content-Type: application/json +{"id": "client-a", "rules": [...]} + +# 获取单个客户端 +GET /api/client/{id} + +# 更新客户端规则 +PUT /api/client/{id} +Content-Type: application/json +{"rules": [...]} + +# 删除客户端 +DELETE /api/client/{id} +``` + +### 服务状态 + +```bash +# 获取服务状态 +GET /api/status + +# 获取配置 +GET /api/config + +# 重载配置 +POST /api/config/reload +``` + +## 使用场景 + +### 场景一:暴露内网 Web 服务 + +```bash +# 服务端配置客户端规则(通过 Web 控制台或 API) +curl -X POST http://server:7500/api/clients \ + -H "Content-Type: application/json" \ + -d '{"id":"home","rules":[{"name":"web","type":"tcp","local_ip":"127.0.0.1","local_port":80,"remote_port":8080}]}' + +# 客户端连接 +./client -s server:7000 -t -id home + +# 访问:http://server:8080 -> 内网 127.0.0.1:80 +``` + +### 场景二:SOCKS5 代理访问内网 + +```bash +# 配置 SOCKS5 代理规则 +{"name":"proxy","type":"socks5","remote_port":1080} + +# 使用代理 +curl --socks5 server:1080 http://internal-service/ +``` + +## 常见问题 + +**Q: 客户端连接失败,提示 "client not configured"** + +A: 需要先在服务端 Web 控制台添加对应的客户端 ID。 + +**Q: 如何禁用 TLS?** + +A: 服务端配置 `tls_disabled: true`,客户端使用 `-no-tls` 参数。 + +**Q: 端口被占用怎么办?** + +A: 服务端会自动检测端口冲突,请检查日志并更换端口。 + ## 许可证 本项目采用 [MIT License](LICENSE) 开源许可证。 diff --git a/cmd/server/main.go b/cmd/server/main.go index 81612b3..4d8d4f3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -10,6 +10,8 @@ import ( "github.com/gotunnel/internal/server/db" "github.com/gotunnel/internal/server/tunnel" "github.com/gotunnel/pkg/crypto" + "github.com/gotunnel/pkg/plugin" + "github.com/gotunnel/pkg/plugin/builtin" ) func main() { @@ -49,6 +51,17 @@ func main() { log.Printf("[Server] TLS enabled") } + // 初始化插件系统 + registry := plugin.NewRegistry() + if err := registry.RegisterBuiltin(builtin.NewSOCKS5Plugin()); err != nil { + log.Printf("[Plugin] Register socks5 error: %v", err) + } + if err := registry.RegisterBuiltin(builtin.NewHTTPPlugin()); err != nil { + log.Printf("[Plugin] Register http error: %v", err) + } + server.SetPluginRegistry(registry) + log.Printf("[Plugin] Plugins registered: socks5, http") + // 启动 Web 控制台 if cfg.Web.Enabled { ws := app.NewWebServer(clientStore, server, cfg, *configPath) diff --git a/go.mod b/go.mod index b3196ac..28a6dbd 100644 --- a/go.mod +++ b/go.mod @@ -14,8 +14,9 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + github.com/tetratelabs/wazero v1.11.0 // indirect golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b // indirect - golang.org/x/sys v0.36.0 // indirect + golang.org/x/sys v0.38.0 // indirect modernc.org/libc v1.66.10 // indirect modernc.org/mathutil v1.7.1 // indirect modernc.org/memory v1.11.0 // indirect diff --git a/go.sum b/go.sum index 0ac7b58..ec2537e 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdh github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/tetratelabs/wazero v1.11.0 h1:+gKemEuKCTevU4d7ZTzlsvgd1uaToIDtlQlmNbwqYhA= +github.com/tetratelabs/wazero v1.11.0/go.mod h1:eV28rsN8Q+xwjogd7f4/Pp4xFxO7uOGbLcD/LzB1wiU= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/mod v0.27.0 h1:kb+q2PyFnEADO2IEF935ehFUXlWiNjJWtRNgBLSfbxQ= @@ -21,6 +23,8 @@ golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/tools v0.36.0 h1:kWS0uv/zsvHEle1LbV5LE8QujrxB3wfQyxHfhOk0Qkg= golang.org/x/tools v0.36.0/go.mod h1:WBDiHKJK8YgLHlcQPYQzNCkUxUypCaa5ZegCVutKm+s= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/internal/client/plugin/cache.go b/internal/client/plugin/cache.go new file mode 100644 index 0000000..4d4c742 --- /dev/null +++ b/internal/client/plugin/cache.go @@ -0,0 +1,114 @@ +package plugin + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "os" + "path/filepath" + "sync" + "time" + + "github.com/gotunnel/pkg/plugin" +) + +// CachedPlugin 缓存的 plugin 信息 +type CachedPlugin struct { + Metadata plugin.PluginMetadata + Path string + LoadedAt time.Time +} + +// Cache 管理本地 plugin 存储 +type Cache struct { + dir string + plugins map[string]*CachedPlugin + mu sync.RWMutex +} + +// NewCache 创建 plugin 缓存 +func NewCache(cacheDir string) (*Cache, error) { + if err := os.MkdirAll(cacheDir, 0755); err != nil { + return nil, err + } + + return &Cache{ + dir: cacheDir, + plugins: make(map[string]*CachedPlugin), + }, nil +} + +// Get 返回缓存的 plugin(如果有效) +func (c *Cache) Get(name, version, checksum string) (*CachedPlugin, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + cached, ok := c.plugins[name] + if !ok { + return nil, nil + } + + // 验证版本和 checksum + if cached.Metadata.Version != version { + return nil, nil + } + if checksum != "" && cached.Metadata.Checksum != checksum { + return nil, nil + } + + return cached, nil +} + +// Store 保存 plugin 到缓存 +func (c *Cache) Store(meta plugin.PluginMetadata, wasmData []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + + // 验证 checksum + hash := sha256.Sum256(wasmData) + checksum := hex.EncodeToString(hash[:]) + if meta.Checksum != "" && meta.Checksum != checksum { + return fmt.Errorf("checksum mismatch") + } + meta.Checksum = checksum + + // 写入文件 + path := filepath.Join(c.dir, meta.Name+".wasm") + if err := os.WriteFile(path, wasmData, 0644); err != nil { + return err + } + + c.plugins[meta.Name] = &CachedPlugin{ + Metadata: meta, + Path: path, + LoadedAt: time.Now(), + } + return nil +} + +// Remove 删除缓存的 plugin +func (c *Cache) Remove(name string) error { + c.mu.Lock() + defer c.mu.Unlock() + + cached, ok := c.plugins[name] + if !ok { + return nil + } + + os.Remove(cached.Path) + delete(c.plugins, name) + return nil +} + +// List 返回所有缓存的 plugins +func (c *Cache) List() []plugin.PluginMetadata { + c.mu.RLock() + defer c.mu.RUnlock() + + var result []plugin.PluginMetadata + for _, cached := range c.plugins { + result = append(result, cached.Metadata) + } + return result +} diff --git a/internal/client/plugin/manager.go b/internal/client/plugin/manager.go new file mode 100644 index 0000000..17e514f --- /dev/null +++ b/internal/client/plugin/manager.go @@ -0,0 +1,70 @@ +package plugin + +import ( + "context" + "log" + "sync" + + "github.com/gotunnel/pkg/plugin" + "github.com/gotunnel/pkg/plugin/builtin" + "github.com/gotunnel/pkg/plugin/wasm" +) + +// Manager 客户端 plugin 管理器 +type Manager struct { + registry *plugin.Registry + cache *Cache + runtime *wasm.Runtime + mu sync.RWMutex +} + +// NewManager 创建客户端 plugin 管理器 +func NewManager(cacheDir string) (*Manager, error) { + ctx := context.Background() + + cache, err := NewCache(cacheDir) + if err != nil { + return nil, err + } + + runtime, err := wasm.NewRuntime(ctx) + if err != nil { + return nil, err + } + + registry := plugin.NewRegistry() + + m := &Manager{ + registry: registry, + cache: cache, + runtime: runtime, + } + + // 注册内置 plugins + if err := m.registerBuiltins(); err != nil { + return nil, err + } + + return m, nil +} + +// registerBuiltins 注册内置 plugins +// 注意: tcp, udp, http, https 是内置类型,直接在 tunnel 中处理 +func (m *Manager) registerBuiltins() error { + // 注册 SOCKS5 plugin + if err := m.registry.RegisterBuiltin(builtin.NewSOCKS5Plugin()); err != nil { + return err + } + log.Println("[Plugin] Builtin plugins registered: socks5") + return nil +} + +// GetHandler 返回指定代理类型的 handler +func (m *Manager) GetHandler(proxyType string) (plugin.ProxyHandler, error) { + return m.registry.Get(proxyType) +} + +// Close 关闭管理器 +func (m *Manager) Close(ctx context.Context) error { + return m.runtime.Close(ctx) +} diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index daafa7c..5916eba 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -14,6 +14,16 @@ import ( "github.com/hashicorp/yamux" ) +// 客户端常量 +const ( + dialTimeout = 10 * time.Second + localDialTimeout = 5 * time.Second + udpTimeout = 10 * time.Second + reconnectDelay = 5 * time.Second + disconnectDelay = 3 * time.Second + udpBufferSize = 65535 +) + // Client 隧道客户端 type Client struct { ServerAddr string @@ -43,14 +53,14 @@ func (c *Client) Run() error { for { if err := c.connect(); err != nil { log.Printf("[Client] Connect error: %v", err) - log.Printf("[Client] Reconnecting in 5s...") - time.Sleep(5 * time.Second) + log.Printf("[Client] Reconnecting in %v...", reconnectDelay) + time.Sleep(reconnectDelay) continue } c.handleSession() log.Printf("[Client] Disconnected, reconnecting...") - time.Sleep(3 * time.Second) + time.Sleep(disconnectDelay) } } @@ -60,10 +70,10 @@ func (c *Client) connect() error { var err error if c.TLSEnabled && c.TLSConfig != nil { - dialer := &net.Dialer{Timeout: 10 * time.Second} + dialer := &net.Dialer{Timeout: dialTimeout} conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig) } else { - conn, err = net.DialTimeout("tcp", c.ServerAddr, 10*time.Second) + conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout) } if err != nil { return err @@ -83,7 +93,10 @@ func (c *Client) connect() error { } var authResp protocol.AuthResponse - resp.ParsePayload(&authResp) + if err := resp.ParsePayload(&authResp); err != nil { + conn.Close() + return fmt.Errorf("parse auth response: %w", err) + } if !authResp.Success { conn.Close() return fmt.Errorf("auth failed: %s", authResp.Message) @@ -137,13 +150,18 @@ func (c *Client) handleStream(stream net.Conn) { c.handleHeartbeat(stream) case protocol.MsgTypeProxyConnect: c.handleProxyConnect(stream, msg) + case protocol.MsgTypeUDPData: + c.handleUDPData(stream, msg) } } // handleProxyConfig 处理代理配置 func (c *Client) handleProxyConfig(msg *protocol.Message) { var cfg protocol.ProxyConfig - msg.ParsePayload(&cfg) + if err := msg.ParsePayload(&cfg); err != nil { + log.Printf("[Client] Parse proxy config error: %v", err) + return + } c.mu.Lock() c.rules = cfg.Rules @@ -158,7 +176,10 @@ func (c *Client) handleProxyConfig(msg *protocol.Message) { // handleNewProxy 处理新代理请求 func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { var req protocol.NewProxyRequest - msg.ParsePayload(&req) + if err := msg.ParsePayload(&req); err != nil { + log.Printf("[Client] Parse new proxy request error: %v", err) + return + } var rule *protocol.ProxyRule c.mu.RLock() @@ -176,7 +197,7 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) { } localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort) - localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second) + localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout) if err != nil { log.Printf("[Client] Connect %s error: %v", localAddr, err) return @@ -202,7 +223,7 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) { } // 连接目标地址 - targetConn, err := net.DialTimeout("tcp", req.Target, 10*time.Second) + targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout) if err != nil { c.sendProxyResult(stream, false, err.Error()) return @@ -224,3 +245,62 @@ func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result) return protocol.WriteMessage(stream, msg) } + +// handleUDPData 处理 UDP 数据 +func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) { + defer stream.Close() + + var packet protocol.UDPPacket + if err := msg.ParsePayload(&packet); err != nil { + return + } + + // 查找对应的规则 + rule := c.findRuleByPort(packet.RemotePort) + if rule == nil { + return + } + + // 连接本地 UDP 服务 + target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort) + conn, err := net.DialTimeout("udp", target, localDialTimeout) + if err != nil { + return + } + defer conn.Close() + + // 发送数据到本地服务 + conn.SetDeadline(time.Now().Add(udpTimeout)) + if _, err := conn.Write(packet.Data); err != nil { + return + } + + // 读取响应 + buf := make([]byte, udpBufferSize) + n, err := conn.Read(buf) + if err != nil { + return + } + + // 发送响应回服务端 + respPacket := protocol.UDPPacket{ + RemotePort: packet.RemotePort, + ClientAddr: packet.ClientAddr, + Data: buf[:n], + } + respMsg, _ := protocol.NewMessage(protocol.MsgTypeUDPData, respPacket) + protocol.WriteMessage(stream, respMsg) +} + +// findRuleByPort 根据端口查找规则 +func (c *Client) findRuleByPort(port int) *protocol.ProxyRule { + c.mu.RLock() + defer c.mu.RUnlock() + + for i := range c.rules { + if c.rules[i].RemotePort == port { + return &c.rules[i] + } + } + return nil +} diff --git a/internal/server/app/app.go b/internal/server/app/app.go index dace5b7..4c50d5f 100644 --- a/internal/server/app/app.go +++ b/internal/server/app/app.go @@ -86,26 +86,12 @@ func (w *WebServer) RunWithAuth(addr, username, password string) error { } r.Handle("/", spaHandler{fs: http.FS(staticFS)}) - handler := &authMiddleware{username, password, r.Handler()} + auth := &router.AuthConfig{Username: username, Password: password} + handler := router.BasicAuthMiddleware(auth, r.Handler()) log.Printf("[Web] Console listening on %s (auth enabled)", addr) return http.ListenAndServe(addr, handler) } -type authMiddleware struct { - username, password string - handler http.Handler -} - -func (a *authMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { - user, pass, ok := r.BasicAuth() - if !ok || user != a.username || pass != a.password { - w.Header().Set("WWW-Authenticate", `Basic realm="GoTunnel"`) - http.Error(w, "Unauthorized", http.StatusUnauthorized) - return - } - a.handler.ServeHTTP(w, r) -} - // GetClientStore 获取客户端存储 func (w *WebServer) GetClientStore() db.ClientStore { return w.ClientStore diff --git a/internal/server/plugin/manager.go b/internal/server/plugin/manager.go new file mode 100644 index 0000000..6bd2217 --- /dev/null +++ b/internal/server/plugin/manager.go @@ -0,0 +1,137 @@ +package plugin + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "log" + "sync" + + "github.com/gotunnel/pkg/plugin" + "github.com/gotunnel/pkg/plugin/builtin" + "github.com/gotunnel/pkg/plugin/store" + "github.com/gotunnel/pkg/plugin/wasm" +) + +// Manager 服务端 plugin 管理器 +type Manager struct { + registry *plugin.Registry + store store.PluginStore + runtime *wasm.Runtime + mu sync.RWMutex +} + +// NewManager 创建 plugin 管理器 +func NewManager(pluginStore store.PluginStore) (*Manager, error) { + ctx := context.Background() + + runtime, err := wasm.NewRuntime(ctx) + if err != nil { + return nil, fmt.Errorf("create wasm runtime: %w", err) + } + + registry := plugin.NewRegistry() + + m := &Manager{ + registry: registry, + store: pluginStore, + runtime: runtime, + } + + // 注册内置 plugins + if err := m.registerBuiltins(); err != nil { + return nil, err + } + + return m, nil +} + +// registerBuiltins 注册内置 plugins +// 注意: tcp, udp, http, https 是内置类型,直接在 tunnel 中处理 +// 这里只注册需要通过 plugin 系统提供的协议 +func (m *Manager) registerBuiltins() error { + // 注册 SOCKS5 plugin + if err := m.registry.RegisterBuiltin(builtin.NewSOCKS5Plugin()); err != nil { + return fmt.Errorf("register socks5: %w", err) + } + + log.Println("[Plugin] Builtin plugins registered: socks5") + return nil +} + +// LoadStoredPlugins 从数据库加载所有 plugins +func (m *Manager) LoadStoredPlugins(ctx context.Context) error { + if m.store == nil { + return nil + } + + plugins, err := m.store.GetAllPlugins() + if err != nil { + return err + } + + for _, meta := range plugins { + data, err := m.store.GetPluginData(meta.Name) + if err != nil { + log.Printf("[Plugin] Failed to load %s: %v", meta.Name, err) + continue + } + + if err := m.loadWASMPlugin(ctx, meta.Name, data); err != nil { + log.Printf("[Plugin] Failed to init %s: %v", meta.Name, err) + } + } + + return nil +} + +// loadWASMPlugin 加载 WASM plugin +func (m *Manager) loadWASMPlugin(ctx context.Context, name string, data []byte) error { + _, err := m.runtime.LoadModule(ctx, name, data) + if err != nil { + return err + } + log.Printf("[Plugin] WASM plugin loaded: %s", name) + return nil +} + +// InstallPlugin 安装新的 WASM plugin +func (m *Manager) InstallPlugin(ctx context.Context, meta plugin.PluginMetadata, wasmData []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + // 验证 checksum + hash := sha256.Sum256(wasmData) + checksum := hex.EncodeToString(hash[:]) + if meta.Checksum != "" && meta.Checksum != checksum { + return fmt.Errorf("checksum mismatch") + } + meta.Checksum = checksum + meta.Size = int64(len(wasmData)) + + // 存储到数据库 + if m.store != nil { + if err := m.store.SavePlugin(meta, wasmData); err != nil { + return err + } + } + + // 加载到运行时 + return m.loadWASMPlugin(ctx, meta.Name, wasmData) +} + +// GetHandler 返回指定代理类型的 handler +func (m *Manager) GetHandler(proxyType string) (plugin.ProxyHandler, error) { + return m.registry.Get(proxyType) +} + +// ListPlugins 返回所有可用的 plugins +func (m *Manager) ListPlugins() []plugin.PluginInfo { + return m.registry.List() +} + +// Close 关闭管理器 +func (m *Manager) Close(ctx context.Context) error { + return m.runtime.Close(ctx) +} diff --git a/internal/server/router/api.go b/internal/server/router/api.go index a3e2bd6..24d9639 100644 --- a/internal/server/router/api.go +++ b/internal/server/router/api.go @@ -3,12 +3,21 @@ package router import ( "encoding/json" "net/http" + "regexp" "github.com/gotunnel/internal/server/config" "github.com/gotunnel/internal/server/db" "github.com/gotunnel/pkg/protocol" ) +// 客户端 ID 验证规则 +var clientIDRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]{1,64}$`) + +// validateClientID 验证客户端 ID 格式 +func validateClientID(id string) bool { + return clientIDRegex.MatchString(id) +} + // ClientStatus 客户端状态 type ClientStatus struct { ID string `json:"id"` @@ -122,6 +131,10 @@ func (h *APIHandler) addClient(rw http.ResponseWriter, r *http.Request) { http.Error(rw, "client id required", http.StatusBadRequest) return } + if !validateClientID(req.ID) { + http.Error(rw, "invalid client id: must be 1-64 alphanumeric characters, underscore or hyphen", http.StatusBadRequest) + return + } exists, _ := h.clientStore.ClientExists(req.ID) if exists { @@ -218,11 +231,16 @@ func (h *APIHandler) handleConfig(rw http.ResponseWriter, r *http.Request) { func (h *APIHandler) getConfig(rw http.ResponseWriter) { cfg := h.app.GetConfig() + // Token 脱敏处理,只显示前4位 + maskedToken := cfg.Server.Token + if len(maskedToken) > 4 { + maskedToken = maskedToken[:4] + "****" + } h.jsonResponse(rw, map[string]interface{}{ "server": map[string]interface{}{ "bind_addr": cfg.Server.BindAddr, "bind_port": cfg.Server.BindPort, - "token": cfg.Server.Token, + "token": maskedToken, "heartbeat_sec": cfg.Server.HeartbeatSec, "heartbeat_timeout": cfg.Server.HeartbeatTimeout, }, @@ -231,7 +249,7 @@ func (h *APIHandler) getConfig(rw http.ResponseWriter) { "bind_addr": cfg.Web.BindAddr, "bind_port": cfg.Web.BindPort, "username": cfg.Web.Username, - "password": cfg.Web.Password, + "password": "****", }, }) } diff --git a/internal/server/router/router.go b/internal/server/router/router.go index 2ac30d6..c70cdaf 100644 --- a/internal/server/router/router.go +++ b/internal/server/router/router.go @@ -1,6 +1,7 @@ package router import ( + "crypto/subtle" "net/http" ) @@ -9,6 +10,12 @@ type Router struct { mux *http.ServeMux } +// AuthConfig 认证配置 +type AuthConfig struct { + Username string + Password string +} + // New 创建路由管理器 func New() *Router { return &Router{ @@ -49,3 +56,31 @@ func (g *RouteGroup) HandleFunc(pattern string, handler http.HandlerFunc) { func (r *Router) Handler() http.Handler { return r.mux } + +// BasicAuthMiddleware 基础认证中间件 +func BasicAuthMiddleware(auth *AuthConfig, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if auth == nil || (auth.Username == "" && auth.Password == "") { + next.ServeHTTP(w, r) + return + } + + user, pass, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="GoTunnel"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + userMatch := subtle.ConstantTimeCompare([]byte(user), []byte(auth.Username)) == 1 + passMatch := subtle.ConstantTimeCompare([]byte(pass), []byte(auth.Password)) == 1 + + if !userMatch || !passMatch { + w.Header().Set("WWW-Authenticate", `Basic realm="GoTunnel"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + next.ServeHTTP(w, r) + }) +} diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 3dc23f1..3b39a29 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gotunnel/internal/server/db" + "github.com/gotunnel/pkg/plugin" "github.com/gotunnel/pkg/protocol" "github.com/gotunnel/pkg/proxy" "github.com/gotunnel/pkg/relay" @@ -16,28 +17,37 @@ import ( "github.com/hashicorp/yamux" ) +// 服务端常量 +const ( + authTimeout = 10 * time.Second + heartbeatTimeout = 10 * time.Second + udpBufferSize = 65535 +) + // Server 隧道服务端 type Server struct { - clientStore db.ClientStore - bindAddr string - bindPort int - token string - heartbeat int - hbTimeout int - portManager *utils.PortManager - clients map[string]*ClientSession - mu sync.RWMutex - tlsConfig *tls.Config + clientStore db.ClientStore + bindAddr string + bindPort int + token string + heartbeat int + hbTimeout int + portManager *utils.PortManager + clients map[string]*ClientSession + mu sync.RWMutex + tlsConfig *tls.Config + pluginRegistry *plugin.Registry } // ClientSession 客户端会话 type ClientSession struct { - ID string - Session *yamux.Session - Rules []protocol.ProxyRule - Listeners map[int]net.Listener - LastPing time.Time - mu sync.Mutex + ID string + Session *yamux.Session + Rules []protocol.ProxyRule + Listeners map[int]net.Listener + UDPConns map[int]*net.UDPConn // UDP 连接 + LastPing time.Time + mu sync.Mutex } // NewServer 创建服务端 @@ -59,6 +69,11 @@ func (s *Server) SetTLSConfig(config *tls.Config) { s.tlsConfig = config } +// SetPluginRegistry 设置插件注册表 +func (s *Server) SetPluginRegistry(registry *plugin.Registry) { + s.pluginRegistry = registry +} + // Run 启动服务端 func (s *Server) Run() error { addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort) @@ -95,7 +110,7 @@ func (s *Server) Run() error { func (s *Server) handleConnection(conn net.Conn) { defer conn.Close() - conn.SetReadDeadline(time.Now().Add(10 * time.Second)) + conn.SetReadDeadline(time.Now().Add(authTimeout)) msg, err := protocol.ReadMessage(conn) if err != nil { @@ -148,6 +163,7 @@ func (s *Server) setupClientSession(conn net.Conn, clientID string, rules []prot Session: session, Rules: rules, Listeners: make(map[int]net.Listener), + UDPConns: make(map[int]*net.UDPConn), LastPing: time.Now(), } @@ -169,7 +185,10 @@ func (s *Server) setupClientSession(conn net.Conn, clientID string, rules []prot // sendAuthResponse 发送认证响应 func (s *Server) sendAuthResponse(conn net.Conn, success bool, message string) error { resp := protocol.AuthResponse{Success: success, Message: message} - msg, _ := protocol.NewMessage(protocol.MsgTypeAuthResp, resp) + msg, err := protocol.NewMessage(protocol.MsgTypeAuthResp, resp) + if err != nil { + return err + } return protocol.WriteMessage(conn, msg) } @@ -182,7 +201,10 @@ func (s *Server) sendProxyConfig(session *yamux.Session, rules []protocol.ProxyR defer stream.Close() cfg := protocol.ProxyConfig{Rules: rules} - msg, _ := protocol.NewMessage(protocol.MsgTypeProxyConfig, cfg) + msg, err := protocol.NewMessage(protocol.MsgTypeProxyConfig, cfg) + if err != nil { + return err + } return protocol.WriteMessage(stream, msg) } @@ -203,6 +225,10 @@ func (s *Server) unregisterClient(cs *ClientSession) { ln.Close() s.portManager.Release(port) } + for port, conn := range cs.UDPConns { + conn.Close() + s.portManager.Release(port) + } cs.mu.Unlock() delete(s.clients, cs.ID) @@ -211,6 +237,18 @@ func (s *Server) unregisterClient(cs *ClientSession) { // startProxyListeners 启动代理监听 func (s *Server) startProxyListeners(cs *ClientSession) { for _, rule := range cs.Rules { + ruleType := rule.Type + if ruleType == "" { + ruleType = "tcp" + } + + // UDP 单独处理 + if ruleType == "udp" { + s.startUDPListener(cs, rule) + continue + } + + // TCP 类型 if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil { log.Printf("[Server] Port %d error: %v", rule.RemotePort, err) continue @@ -227,15 +265,12 @@ func (s *Server) startProxyListeners(cs *ClientSession) { cs.Listeners[rule.RemotePort] = ln cs.mu.Unlock() - ruleType := rule.Type - if ruleType == "" { - ruleType = "tcp" - } - switch ruleType { - case "socks5", "http": - log.Printf("[Server] %s proxy %s on :%d", - ruleType, rule.Name, rule.RemotePort) + case "socks5": + log.Printf("[Server] SOCKS5 proxy %s on :%d", rule.Name, rule.RemotePort) + go s.acceptProxyServerConns(cs, ln, rule) + case "http", "https": + log.Printf("[Server] HTTP proxy %s on :%d", rule.Name, rule.RemotePort) go s.acceptProxyServerConns(cs, ln, rule) default: log.Printf("[Server] TCP proxy %s: :%d -> %s:%d", @@ -259,8 +294,23 @@ func (s *Server) acceptProxyConns(cs *ClientSession, ln net.Listener, rule proto // acceptProxyServerConns 接受 SOCKS5/HTTP 代理连接 func (s *Server) acceptProxyServerConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) { dialer := proxy.NewTunnelDialer(cs.Session) - proxyServer := proxy.NewServer(rule.Type, dialer) + // 优先使用插件系统 + if s.pluginRegistry != nil { + if handler, err := s.pluginRegistry.Get(rule.Type); err == nil { + handler.Init(rule.PluginConfig) + for { + conn, err := ln.Accept() + if err != nil { + return + } + go handler.HandleConn(conn, dialer) + } + } + } + + // 回退到内置 proxy 实现 + proxyServer := proxy.NewServer(rule.Type, dialer) for { conn, err := ln.Accept() if err != nil { @@ -309,13 +359,12 @@ func (s *Server) heartbeatLoop(cs *ClientSession) { } cs.mu.Unlock() - stream, err := cs.Session.Open() - if err != nil { - return + // 发送心跳并等待响应 + if s.sendHeartbeat(cs) { + cs.mu.Lock() + cs.LastPing = time.Now() + cs.mu.Unlock() } - msg := &protocol.Message{Type: protocol.MsgTypeHeartbeat} - protocol.WriteMessage(stream, msg) - stream.Close() case <-cs.Session.CloseChan(): return @@ -323,6 +372,31 @@ func (s *Server) heartbeatLoop(cs *ClientSession) { } } +// sendHeartbeat 发送心跳并等待响应 +func (s *Server) sendHeartbeat(cs *ClientSession) bool { + stream, err := cs.Session.Open() + if err != nil { + return false + } + defer stream.Close() + + // 设置读写超时 + stream.SetDeadline(time.Now().Add(heartbeatTimeout)) + + msg := &protocol.Message{Type: protocol.MsgTypeHeartbeat} + if err := protocol.WriteMessage(stream, msg); err != nil { + return false + } + + // 等待心跳响应 + resp, err := protocol.ReadMessage(stream) + if err != nil { + return false + } + + return resp.Type == protocol.MsgTypeHeartbeatAck +} + // GetClientStatus 获取客户端状态 func (s *Server) GetClientStatus(clientID string) (online bool, lastPing string) { s.mu.RLock() @@ -341,17 +415,22 @@ func (s *Server) GetAllClientStatus() map[string]struct { Online bool LastPing string } { + // 先复制客户端引用,避免嵌套锁 s.mu.RLock() - defer s.mu.RUnlock() + clients := make([]*ClientSession, 0, len(s.clients)) + for _, cs := range s.clients { + clients = append(clients, cs) + } + s.mu.RUnlock() result := make(map[string]struct { Online bool LastPing string }) - for id, cs := range s.clients { + for _, cs := range clients { cs.mu.Lock() - result[id] = struct { + result[cs.ID] = struct { Online bool LastPing string }{ @@ -364,8 +443,9 @@ func (s *Server) GetAllClientStatus() map[string]struct { } // ReloadConfig 重新加载配置 +// 注意: 当前版本不支持热重载,需要重启服务 func (s *Server) ReloadConfig() error { - return nil + return fmt.Errorf("hot reload not supported, please restart the server") } // GetBindAddr 获取绑定地址 @@ -377,3 +457,87 @@ func (s *Server) GetBindAddr() string { func (s *Server) GetBindPort() int { return s.bindPort } + +// startUDPListener 启动 UDP 监听 +func (s *Server) startUDPListener(cs *ClientSession, rule protocol.ProxyRule) { + if err := s.portManager.Reserve(rule.RemotePort, cs.ID); err != nil { + log.Printf("[Server] UDP port %d error: %v", rule.RemotePort, err) + return + } + + addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf(":%d", rule.RemotePort)) + if err != nil { + log.Printf("[Server] UDP resolve error: %v", err) + s.portManager.Release(rule.RemotePort) + return + } + + conn, err := net.ListenUDP("udp", addr) + if err != nil { + log.Printf("[Server] UDP listen %d error: %v", rule.RemotePort, err) + s.portManager.Release(rule.RemotePort) + return + } + + cs.mu.Lock() + cs.UDPConns[rule.RemotePort] = conn + cs.mu.Unlock() + + log.Printf("[Server] UDP proxy %s: :%d -> %s:%d", + rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort) + + go s.handleUDPConn(cs, conn, rule) +} + +// handleUDPConn 处理 UDP 连接 +func (s *Server) handleUDPConn(cs *ClientSession, conn *net.UDPConn, rule protocol.ProxyRule) { + buf := make([]byte, udpBufferSize) + + for { + n, clientAddr, err := conn.ReadFromUDP(buf) + if err != nil { + return + } + + // 封装 UDP 数据包发送到客户端 + packet := protocol.UDPPacket{ + RemotePort: rule.RemotePort, + ClientAddr: clientAddr.String(), + Data: buf[:n], + } + + go s.sendUDPPacket(cs, conn, clientAddr, packet) + } +} + +// sendUDPPacket 发送 UDP 数据包到客户端 +func (s *Server) sendUDPPacket(cs *ClientSession, conn *net.UDPConn, clientAddr *net.UDPAddr, packet protocol.UDPPacket) { + stream, err := cs.Session.Open() + if err != nil { + return + } + defer stream.Close() + + msg, err := protocol.NewMessage(protocol.MsgTypeUDPData, packet) + if err != nil { + return + } + + if err := protocol.WriteMessage(stream, msg); err != nil { + return + } + + // 等待客户端响应 + respMsg, err := protocol.ReadMessage(stream) + if err != nil { + return + } + + if respMsg.Type == protocol.MsgTypeUDPData { + var respPacket protocol.UDPPacket + if err := respMsg.ParsePayload(&respPacket); err != nil { + return + } + conn.WriteToUDP(respPacket.Data, clientAddr) + } +} diff --git a/pkg/plugin/builtin.go b/pkg/plugin/builtin.go new file mode 100644 index 0000000..eb45d1a --- /dev/null +++ b/pkg/plugin/builtin.go @@ -0,0 +1,16 @@ +package plugin + +// RegisterBuiltins 注册所有内置 plugins +// 注意:此函数需要在调用方导入 builtin 包并手动注册 +// 示例: +// registry := plugin.NewRegistry() +// registry.RegisterBuiltin(builtin.NewSOCKS5Plugin()) +// registry.RegisterBuiltin(builtin.NewHTTPPlugin()) +func RegisterBuiltins(registry *Registry, handlers ...ProxyHandler) error { + for _, handler := range handlers { + if err := registry.RegisterBuiltin(handler); err != nil { + return err + } + } + return nil +} diff --git a/pkg/plugin/builtin/http.go b/pkg/plugin/builtin/http.go new file mode 100644 index 0000000..41e47f6 --- /dev/null +++ b/pkg/plugin/builtin/http.go @@ -0,0 +1,116 @@ +package builtin + +import ( + "bufio" + "io" + "net" + "net/http" + "strings" + + "github.com/gotunnel/pkg/plugin" +) + +// HTTPPlugin 将现有 HTTP 代理实现封装为 plugin +type HTTPPlugin struct { + config map[string]string +} + +// NewHTTPPlugin 创建 HTTP plugin +func NewHTTPPlugin() *HTTPPlugin { + return &HTTPPlugin{} +} + +// Metadata 返回 plugin 信息 +func (p *HTTPPlugin) Metadata() plugin.PluginMetadata { + return plugin.PluginMetadata{ + Name: "http", + Version: "1.0.0", + Type: plugin.PluginTypeProxy, + Source: plugin.PluginSourceBuiltin, + Description: "HTTP/HTTPS proxy protocol handler", + Author: "GoTunnel", + Capabilities: []string{ + "dial", "read", "write", "close", + }, + } +} + +// Init 初始化 plugin +func (p *HTTPPlugin) Init(config map[string]string) error { + p.config = config + return nil +} + +// HandleConn 处理 HTTP 代理连接 +func (p *HTTPPlugin) HandleConn(conn net.Conn, dialer plugin.Dialer) error { + defer conn.Close() + + reader := bufio.NewReader(conn) + req, err := http.ReadRequest(reader) + if err != nil { + return err + } + + if req.Method == http.MethodConnect { + return p.handleConnect(conn, req, dialer) + } + return p.handleHTTP(conn, req, dialer) +} + +// Close 释放资源 +func (p *HTTPPlugin) Close() error { + return nil +} + +// handleConnect 处理 CONNECT 方法 (HTTPS) +func (p *HTTPPlugin) handleConnect(conn net.Conn, req *http.Request, dialer plugin.Dialer) error { + target := req.Host + if !strings.Contains(target, ":") { + target = target + ":443" + } + + remote, err := dialer.Dial("tcp", target) + if err != nil { + conn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) + return err + } + defer remote.Close() + + conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) + + go io.Copy(remote, conn) + io.Copy(conn, remote) + return nil +} + +// handleHTTP 处理普通 HTTP 请求 +func (p *HTTPPlugin) handleHTTP(conn net.Conn, req *http.Request, dialer plugin.Dialer) error { + target := req.Host + if !strings.Contains(target, ":") { + target = target + ":80" + } + + remote, err := dialer.Dial("tcp", target) + if err != nil { + conn.Write([]byte("HTTP/1.1 502 Bad Gateway\r\n\r\n")) + return err + } + defer remote.Close() + + // 修改请求路径为相对路径 + req.URL.Scheme = "" + req.URL.Host = "" + req.RequestURI = req.URL.Path + if req.URL.RawQuery != "" { + req.RequestURI += "?" + req.URL.RawQuery + } + + // 发送请求到目标 + if err := req.Write(remote); err != nil { + return err + } + + // 转发响应 + _, err = io.Copy(conn, remote) + return err +} diff --git a/pkg/plugin/builtin/socks5.go b/pkg/plugin/builtin/socks5.go new file mode 100644 index 0000000..1109304 --- /dev/null +++ b/pkg/plugin/builtin/socks5.go @@ -0,0 +1,167 @@ +package builtin + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + + "github.com/gotunnel/pkg/plugin" +) + +const ( + socks5Version = 0x05 + noAuth = 0x00 + cmdConnect = 0x01 + atypIPv4 = 0x01 + atypDomain = 0x03 + atypIPv6 = 0x04 +) + +// SOCKS5Plugin 将现有 SOCKS5 实现封装为 plugin +type SOCKS5Plugin struct { + config map[string]string +} + +// NewSOCKS5Plugin 创建 SOCKS5 plugin +func NewSOCKS5Plugin() *SOCKS5Plugin { + return &SOCKS5Plugin{} +} + +// Metadata 返回 plugin 信息 +func (p *SOCKS5Plugin) Metadata() plugin.PluginMetadata { + return plugin.PluginMetadata{ + Name: "socks5", + Version: "1.0.0", + Type: plugin.PluginTypeProxy, + Source: plugin.PluginSourceBuiltin, + Description: "SOCKS5 proxy protocol handler (official plugin)", + Author: "GoTunnel", + Capabilities: []string{ + "dial", "read", "write", "close", + }, + } +} + +// Init 初始化 plugin +func (p *SOCKS5Plugin) Init(config map[string]string) error { + p.config = config + return nil +} + +// HandleConn 处理 SOCKS5 连接 +func (p *SOCKS5Plugin) HandleConn(conn net.Conn, dialer plugin.Dialer) error { + defer conn.Close() + + // 握手阶段 + if err := p.handshake(conn); err != nil { + return err + } + + // 获取请求 + target, err := p.readRequest(conn) + if err != nil { + return err + } + + // 连接目标 + remote, err := dialer.Dial("tcp", target) + if err != nil { + p.sendReply(conn, 0x05) // Connection refused + return err + } + defer remote.Close() + + // 发送成功响应 + if err := p.sendReply(conn, 0x00); err != nil { + return err + } + + // 双向转发 + go io.Copy(remote, conn) + io.Copy(conn, remote) + + return nil +} + +// Close 释放资源 +func (p *SOCKS5Plugin) Close() error { + return nil +} + +// handshake 处理握手 +func (p *SOCKS5Plugin) handshake(conn net.Conn) error { + buf := make([]byte, 2) + if _, err := io.ReadFull(conn, buf); err != nil { + return err + } + if buf[0] != socks5Version { + return errors.New("unsupported SOCKS version") + } + + nmethods := int(buf[1]) + methods := make([]byte, nmethods) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + // 响应:使用无认证 + _, err := conn.Write([]byte{socks5Version, noAuth}) + return err +} + +// readRequest 读取请求 +func (p *SOCKS5Plugin) readRequest(conn net.Conn) (string, error) { + buf := make([]byte, 4) + if _, err := io.ReadFull(conn, buf); err != nil { + return "", err + } + + if buf[0] != socks5Version || buf[1] != cmdConnect { + return "", errors.New("unsupported command") + } + + var host string + switch buf[3] { + case atypIPv4: + ip := make([]byte, 4) + if _, err := io.ReadFull(conn, ip); err != nil { + return "", err + } + host = net.IP(ip).String() + case atypDomain: + lenBuf := make([]byte, 1) + if _, err := io.ReadFull(conn, lenBuf); err != nil { + return "", err + } + domain := make([]byte, lenBuf[0]) + if _, err := io.ReadFull(conn, domain); err != nil { + return "", err + } + host = string(domain) + case atypIPv6: + ip := make([]byte, 16) + if _, err := io.ReadFull(conn, ip); err != nil { + return "", err + } + host = net.IP(ip).String() + default: + return "", errors.New("unsupported address type") + } + + portBuf := make([]byte, 2) + if _, err := io.ReadFull(conn, portBuf); err != nil { + return "", err + } + port := binary.BigEndian.Uint16(portBuf) + + return fmt.Sprintf("%s:%d", host, port), nil +} + +// sendReply 发送响应 +func (p *SOCKS5Plugin) sendReply(conn net.Conn, rep byte) error { + reply := []byte{socks5Version, rep, 0x00, atypIPv4, 0, 0, 0, 0, 0, 0} + _, err := conn.Write(reply) + return err +} diff --git a/pkg/plugin/registry.go b/pkg/plugin/registry.go new file mode 100644 index 0000000..c592d1c --- /dev/null +++ b/pkg/plugin/registry.go @@ -0,0 +1,93 @@ +package plugin + +import ( + "context" + "fmt" + "sync" +) + +// Registry 管理可用的 plugins +type Registry struct { + builtin map[string]ProxyHandler // 内置 Go 实现 + mu sync.RWMutex +} + +// NewRegistry 创建 plugin 注册表 +func NewRegistry() *Registry { + return &Registry{ + builtin: make(map[string]ProxyHandler), + } +} + +// RegisterBuiltin 注册内置 plugin +func (r *Registry) RegisterBuiltin(handler ProxyHandler) error { + r.mu.Lock() + defer r.mu.Unlock() + + meta := handler.Metadata() + if meta.Name == "" { + return fmt.Errorf("plugin name cannot be empty") + } + + if _, exists := r.builtin[meta.Name]; exists { + return fmt.Errorf("plugin %s already registered", meta.Name) + } + + r.builtin[meta.Name] = handler + return nil +} + +// Get 返回指定代理类型的 handler +func (r *Registry) Get(proxyType string) (ProxyHandler, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + // 先查找内置 plugin + if handler, ok := r.builtin[proxyType]; ok { + return handler, nil + } + + return nil, fmt.Errorf("plugin %s not found", proxyType) +} + +// List 返回所有可用的 plugins +func (r *Registry) List() []PluginInfo { + r.mu.RLock() + defer r.mu.RUnlock() + + var plugins []PluginInfo + + // 内置 plugins + for _, handler := range r.builtin { + plugins = append(plugins, PluginInfo{ + Metadata: handler.Metadata(), + Loaded: true, + }) + } + + return plugins +} + +// Has 检查 plugin 是否存在 +func (r *Registry) Has(name string) bool { + r.mu.RLock() + defer r.mu.RUnlock() + + _, ok := r.builtin[name] + return ok +} + +// Close 关闭所有 plugins +func (r *Registry) Close(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + var lastErr error + for name, handler := range r.builtin { + if err := handler.Close(); err != nil { + lastErr = fmt.Errorf("failed to close plugin %s: %w", name, err) + } + } + + return lastErr +} diff --git a/pkg/plugin/store/interface.go b/pkg/plugin/store/interface.go new file mode 100644 index 0000000..c3d9578 --- /dev/null +++ b/pkg/plugin/store/interface.go @@ -0,0 +1,29 @@ +package store + +import ( + "github.com/gotunnel/pkg/plugin" +) + +// PluginStore 管理 plugin 持久化 +type PluginStore interface { + // GetAllPlugins 返回所有存储的 plugins + GetAllPlugins() ([]plugin.PluginMetadata, error) + + // GetPlugin 返回指定 plugin 的元数据 + GetPlugin(name string) (*plugin.PluginMetadata, error) + + // GetPluginData 返回 WASM 二进制 + GetPluginData(name string) ([]byte, error) + + // SavePlugin 存储 plugin + SavePlugin(metadata plugin.PluginMetadata, wasmData []byte) error + + // DeletePlugin 删除 plugin + DeletePlugin(name string) error + + // PluginExists 检查 plugin 是否存在 + PluginExists(name string) (bool, error) + + // Close 关闭存储 + Close() error +} diff --git a/pkg/plugin/store/sqlite.go b/pkg/plugin/store/sqlite.go new file mode 100644 index 0000000..137f938 --- /dev/null +++ b/pkg/plugin/store/sqlite.go @@ -0,0 +1,168 @@ +package store + +import ( + "database/sql" + "encoding/json" + "fmt" + "sync" + "time" + + "github.com/gotunnel/pkg/plugin" + _ "modernc.org/sqlite" +) + +// SQLiteStore SQLite 实现的 PluginStore +type SQLiteStore struct { + db *sql.DB + mu sync.RWMutex +} + +// NewSQLiteStore 创建 SQLite plugin 存储 +func NewSQLiteStore(dbPath string) (*SQLiteStore, error) { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return nil, err + } + + store := &SQLiteStore{db: db} + if err := store.init(); err != nil { + db.Close() + return nil, err + } + + return store, nil +} + +// init 初始化数据库表 +func (s *SQLiteStore) init() error { + query := ` + CREATE TABLE IF NOT EXISTS plugins ( + name TEXT PRIMARY KEY, + version TEXT NOT NULL, + type TEXT NOT NULL DEFAULT 'proxy', + source TEXT NOT NULL DEFAULT 'wasm', + description TEXT, + author TEXT, + checksum TEXT NOT NULL, + size INTEGER NOT NULL, + capabilities TEXT NOT NULL DEFAULT '[]', + config_schema TEXT NOT NULL DEFAULT '{}', + wasm_data BLOB NOT NULL, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME DEFAULT CURRENT_TIMESTAMP + )` + _, err := s.db.Exec(query) + return err +} + +// GetAllPlugins 返回所有存储的 plugins +func (s *SQLiteStore) GetAllPlugins() ([]plugin.PluginMetadata, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + rows, err := s.db.Query(` + SELECT name, version, type, source, description, author, + checksum, size, capabilities, config_schema + FROM plugins`) + if err != nil { + return nil, err + } + defer rows.Close() + + var plugins []plugin.PluginMetadata + for rows.Next() { + var m plugin.PluginMetadata + var capJSON, configJSON string + err := rows.Scan(&m.Name, &m.Version, &m.Type, &m.Source, + &m.Description, &m.Author, &m.Checksum, &m.Size, + &capJSON, &configJSON) + if err != nil { + return nil, err + } + json.Unmarshal([]byte(capJSON), &m.Capabilities) + json.Unmarshal([]byte(configJSON), &m.ConfigSchema) + plugins = append(plugins, m) + } + return plugins, rows.Err() +} + +// GetPlugin 返回指定 plugin 的元数据 +func (s *SQLiteStore) GetPlugin(name string) (*plugin.PluginMetadata, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var m plugin.PluginMetadata + var capJSON, configJSON string + err := s.db.QueryRow(` + SELECT name, version, type, source, description, author, + checksum, size, capabilities, config_schema + FROM plugins WHERE name = ?`, name).Scan( + &m.Name, &m.Version, &m.Type, &m.Source, + &m.Description, &m.Author, &m.Checksum, &m.Size, + &capJSON, &configJSON) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + json.Unmarshal([]byte(capJSON), &m.Capabilities) + json.Unmarshal([]byte(configJSON), &m.ConfigSchema) + return &m, nil +} + +// GetPluginData 返回 WASM 二进制 +func (s *SQLiteStore) GetPluginData(name string) ([]byte, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var data []byte + err := s.db.QueryRow(`SELECT wasm_data FROM plugins WHERE name = ?`, name).Scan(&data) + if err == sql.ErrNoRows { + return nil, fmt.Errorf("plugin %s not found", name) + } + return data, err +} + +// SavePlugin 存储 plugin +func (s *SQLiteStore) SavePlugin(metadata plugin.PluginMetadata, wasmData []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + + capJSON, _ := json.Marshal(metadata.Capabilities) + configJSON, _ := json.Marshal(metadata.ConfigSchema) + + _, err := s.db.Exec(` + INSERT OR REPLACE INTO plugins + (name, version, type, source, description, author, checksum, size, + capabilities, config_schema, wasm_data, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + metadata.Name, metadata.Version, metadata.Type, metadata.Source, + metadata.Description, metadata.Author, metadata.Checksum, metadata.Size, + string(capJSON), string(configJSON), wasmData, time.Now()) + return err +} + +// DeletePlugin 删除 plugin +func (s *SQLiteStore) DeletePlugin(name string) error { + s.mu.Lock() + defer s.mu.Unlock() + + _, err := s.db.Exec(`DELETE FROM plugins WHERE name = ?`, name) + return err +} + +// PluginExists 检查 plugin 是否存在 +func (s *SQLiteStore) PluginExists(name string) (bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + var count int + err := s.db.QueryRow(`SELECT COUNT(*) FROM plugins WHERE name = ?`, name).Scan(&count) + return count > 0, err +} + +// Close 关闭存储 +func (s *SQLiteStore) Close() error { + return s.db.Close() +} diff --git a/pkg/plugin/types.go b/pkg/plugin/types.go new file mode 100644 index 0000000..5b91f88 --- /dev/null +++ b/pkg/plugin/types.go @@ -0,0 +1,99 @@ +package plugin + +import ( + "net" + "time" +) + +// PluginType 定义 plugin 类别 +type PluginType string + +const ( + PluginTypeProxy PluginType = "proxy" // 代理处理器 (SOCKS5, HTTP 等) +) + +// PluginSource 表示 plugin 来源 +type PluginSource string + +const ( + PluginSourceBuiltin PluginSource = "builtin" // 内置编译 + PluginSourceWASM PluginSource = "wasm" // WASM 模块 +) + +// PluginMetadata 描述一个 plugin +type PluginMetadata struct { + Name string `json:"name"` // 唯一标识符 (如 "socks5") + Version string `json:"version"` // 语义化版本 + Type PluginType `json:"type"` // Plugin 类别 + Source PluginSource `json:"source"` // builtin 或 wasm + Description string `json:"description"` // 人类可读描述 + Author string `json:"author"` // Plugin 作者 + Checksum string `json:"checksum,omitempty"` // WASM 二进制的 SHA256 + Size int64 `json:"size,omitempty"` // WASM 二进制大小 + Capabilities []string `json:"capabilities,omitempty"` // 所需 host functions + ConfigSchema map[string]string `json:"config_schema,omitempty"` +} + +// PluginInfo 组合元数据和运行时状态 +type PluginInfo struct { + Metadata PluginMetadata `json:"metadata"` + Loaded bool `json:"loaded"` + LoadedAt time.Time `json:"loaded_at,omitempty"` + Error string `json:"error,omitempty"` +} + +// Dialer 用于建立连接的接口 +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +// ProxyHandler 是所有 proxy plugin 必须实现的接口 +type ProxyHandler interface { + // Metadata 返回 plugin 信息 + Metadata() PluginMetadata + + // Init 使用配置初始化 plugin + Init(config map[string]string) error + + // HandleConn 处理传入连接 + // dialer 用于通过隧道建立连接 + HandleConn(conn net.Conn, dialer Dialer) error + + // Close 释放 plugin 资源 + Close() error +} + +// LogLevel 日志级别 +type LogLevel uint8 + +const ( + LogDebug LogLevel = iota + LogInfo + LogWarn + LogError +) + +// ConnHandle WASM 连接句柄 +type ConnHandle uint32 + +// HostContext 提供给 WASM plugin 的 host functions +type HostContext interface { + // 网络操作 + Dial(network, address string) (ConnHandle, error) + Read(handle ConnHandle, buf []byte) (int, error) + Write(handle ConnHandle, buf []byte) (int, error) + CloseConn(handle ConnHandle) error + + // 客户端连接操作 + ClientRead(buf []byte) (int, error) + ClientWrite(buf []byte) (int, error) + + // 日志 + Log(level LogLevel, message string) + + // 时间 + Now() int64 + + // 配置 + GetConfig(key string) string +} diff --git a/pkg/plugin/wasm/host.go b/pkg/plugin/wasm/host.go new file mode 100644 index 0000000..80097aa --- /dev/null +++ b/pkg/plugin/wasm/host.go @@ -0,0 +1,146 @@ +package wasm + +import ( + "errors" + "log" + "net" + "sync" + "time" + + "github.com/gotunnel/pkg/plugin" +) + +// ErrInvalidHandle 无效的连接句柄 +var ErrInvalidHandle = errors.New("invalid connection handle") + +// HostContextImpl 实现 HostContext 接口 +type HostContextImpl struct { + dialer plugin.Dialer + clientConn net.Conn + config map[string]string + + // 连接管理 + conns map[plugin.ConnHandle]net.Conn + nextHandle plugin.ConnHandle + mu sync.Mutex +} + +// NewHostContext 创建 host context +func NewHostContext(dialer plugin.Dialer, clientConn net.Conn, config map[string]string) *HostContextImpl { + return &HostContextImpl{ + dialer: dialer, + clientConn: clientConn, + config: config, + conns: make(map[plugin.ConnHandle]net.Conn), + nextHandle: 1, + } +} + +// Dial 通过隧道建立连接 +func (h *HostContextImpl) Dial(network, address string) (plugin.ConnHandle, error) { + conn, err := h.dialer.Dial(network, address) + if err != nil { + return 0, err + } + + h.mu.Lock() + handle := h.nextHandle + h.nextHandle++ + h.conns[handle] = conn + h.mu.Unlock() + + return handle, nil +} + +// Read 从连接读取数据 +func (h *HostContextImpl) Read(handle plugin.ConnHandle, buf []byte) (int, error) { + h.mu.Lock() + conn, ok := h.conns[handle] + h.mu.Unlock() + + if !ok { + return 0, ErrInvalidHandle + } + + return conn.Read(buf) +} + +// Write 向连接写入数据 +func (h *HostContextImpl) Write(handle plugin.ConnHandle, buf []byte) (int, error) { + h.mu.Lock() + conn, ok := h.conns[handle] + h.mu.Unlock() + + if !ok { + return 0, ErrInvalidHandle + } + + return conn.Write(buf) +} + +// CloseConn 关闭连接 +func (h *HostContextImpl) CloseConn(handle plugin.ConnHandle) error { + h.mu.Lock() + conn, ok := h.conns[handle] + if ok { + delete(h.conns, handle) + } + h.mu.Unlock() + + if !ok { + return ErrInvalidHandle + } + + return conn.Close() +} + +// ClientRead 从客户端连接读取数据 +func (h *HostContextImpl) ClientRead(buf []byte) (int, error) { + return h.clientConn.Read(buf) +} + +// ClientWrite 向客户端连接写入数据 +func (h *HostContextImpl) ClientWrite(buf []byte) (int, error) { + return h.clientConn.Write(buf) +} + +// Log 记录日志 +func (h *HostContextImpl) Log(level plugin.LogLevel, message string) { + prefix := "[WASM]" + switch level { + case plugin.LogDebug: + prefix = "[WASM DEBUG]" + case plugin.LogInfo: + prefix = "[WASM INFO]" + case plugin.LogWarn: + prefix = "[WASM WARN]" + case plugin.LogError: + prefix = "[WASM ERROR]" + } + log.Printf("%s %s", prefix, message) +} + +// Now 返回当前 Unix 时间戳 +func (h *HostContextImpl) Now() int64 { + return time.Now().Unix() +} + +// GetConfig 获取配置值 +func (h *HostContextImpl) GetConfig(key string) string { + if h.config == nil { + return "" + } + return h.config[key] +} + +// Close 关闭所有连接 +func (h *HostContextImpl) Close() error { + h.mu.Lock() + defer h.mu.Unlock() + + for handle, conn := range h.conns { + conn.Close() + delete(h.conns, handle) + } + return nil +} diff --git a/pkg/plugin/wasm/memory.go b/pkg/plugin/wasm/memory.go new file mode 100644 index 0000000..072bd4c --- /dev/null +++ b/pkg/plugin/wasm/memory.go @@ -0,0 +1,29 @@ +package wasm + +import ( + "github.com/tetratelabs/wazero/api" +) + +// ReadString 从 WASM 内存读取字符串 +func ReadString(mem api.Memory, ptr, len uint32) (string, bool) { + data, ok := mem.Read(ptr, len) + if !ok { + return "", false + } + return string(data), true +} + +// WriteString 向 WASM 内存写入字符串 +func WriteString(mem api.Memory, ptr uint32, s string) bool { + return mem.Write(ptr, []byte(s)) +} + +// ReadBytes 从 WASM 内存读取字节 +func ReadBytes(mem api.Memory, ptr, len uint32) ([]byte, bool) { + return mem.Read(ptr, len) +} + +// WriteBytes 向 WASM 内存写入字节 +func WriteBytes(mem api.Memory, ptr uint32, data []byte) bool { + return mem.Write(ptr, data) +} diff --git a/pkg/plugin/wasm/module.go b/pkg/plugin/wasm/module.go new file mode 100644 index 0000000..3673ea7 --- /dev/null +++ b/pkg/plugin/wasm/module.go @@ -0,0 +1,148 @@ +package wasm + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/gotunnel/pkg/plugin" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// WASMPlugin 封装 WASM 模块作为 ProxyHandler +type WASMPlugin struct { + name string + metadata plugin.PluginMetadata + runtime *Runtime + compiled wazero.CompiledModule + config map[string]string +} + +// NewWASMPlugin 从 WASM 字节创建 plugin +func NewWASMPlugin(ctx context.Context, rt *Runtime, name string, wasmBytes []byte) (*WASMPlugin, error) { + compiled, err := rt.runtime.CompileModule(ctx, wasmBytes) + if err != nil { + return nil, fmt.Errorf("compile module: %w", err) + } + + p := &WASMPlugin{ + name: name, + runtime: rt, + compiled: compiled, + } + + // 尝试获取元数据 + if err := p.loadMetadata(ctx); err != nil { + // 使用默认元数据 + p.metadata = plugin.PluginMetadata{ + Name: name, + Type: plugin.PluginTypeProxy, + Source: plugin.PluginSourceWASM, + } + } + + return p, nil +} + +// loadMetadata 从 WASM 模块加载元数据 +func (p *WASMPlugin) loadMetadata(ctx context.Context) error { + // 创建临时实例获取元数据 + inst, err := p.runtime.runtime.InstantiateModule(ctx, p.compiled, wazero.NewModuleConfig()) + if err != nil { + return err + } + defer inst.Close(ctx) + + metadataFn := inst.ExportedFunction("metadata") + if metadataFn == nil { + return fmt.Errorf("metadata function not exported") + } + + allocFn := inst.ExportedFunction("alloc") + if allocFn == nil { + return fmt.Errorf("alloc function not exported") + } + + // 分配缓冲区 + results, err := allocFn.Call(ctx, 1024) + if err != nil { + return err + } + bufPtr := uint32(results[0]) + + // 调用 metadata 函数 + results, err = metadataFn.Call(ctx, uint64(bufPtr), 1024) + if err != nil { + return err + } + actualLen := uint32(results[0]) + + // 读取元数据 + mem := inst.Memory() + data, ok := mem.Read(bufPtr, actualLen) + if !ok { + return fmt.Errorf("failed to read metadata") + } + + return json.Unmarshal(data, &p.metadata) +} + +// Metadata 返回 plugin 信息 +func (p *WASMPlugin) Metadata() plugin.PluginMetadata { + return p.metadata +} + +// Init 初始化 plugin +func (p *WASMPlugin) Init(config map[string]string) error { + p.config = config + return nil +} + +// HandleConn 处理连接 +func (p *WASMPlugin) HandleConn(conn interface{}, dialer plugin.Dialer) error { + // WASM plugin 的连接处理需要更复杂的实现 + // 这里提供基础框架,实际实现需要注册 host functions + return fmt.Errorf("WASM plugin HandleConn not fully implemented") +} + +// Close 关闭 plugin +func (p *WASMPlugin) Close() error { + return p.compiled.Close(context.Background()) +} + +// RegisterHostFunctions 注册 host functions 到 wazero 运行时 +func RegisterHostFunctions(ctx context.Context, r wazero.Runtime) (wazero.CompiledModule, error) { + return r.NewHostModuleBuilder("env"). + NewFunctionBuilder(). + WithFunc(hostLog). + Export("log"). + NewFunctionBuilder(). + WithFunc(hostNow). + Export("now"). + Compile(ctx) +} + +// host function 实现 +func hostLog(ctx context.Context, m api.Module, level uint32, msgPtr, msgLen uint32) { + data, ok := m.Memory().Read(msgPtr, msgLen) + if !ok { + return + } + prefix := "[WASM]" + switch plugin.LogLevel(level) { + case plugin.LogDebug: + prefix = "[WASM DEBUG]" + case plugin.LogInfo: + prefix = "[WASM INFO]" + case plugin.LogWarn: + prefix = "[WASM WARN]" + case plugin.LogError: + prefix = "[WASM ERROR]" + } + fmt.Printf("%s %s\n", prefix, string(data)) +} + +func hostNow(ctx context.Context) int64 { + return ctx.Value("now").(func() int64)() +} diff --git a/pkg/plugin/wasm/runtime.go b/pkg/plugin/wasm/runtime.go new file mode 100644 index 0000000..4a3bf0f --- /dev/null +++ b/pkg/plugin/wasm/runtime.go @@ -0,0 +1,116 @@ +package wasm + +import ( + "context" + "fmt" + "sync" + + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/api" +) + +// Runtime 管理 wazero WASM 运行时 +type Runtime struct { + runtime wazero.Runtime + modules map[string]*Module + mu sync.RWMutex +} + +// NewRuntime 创建新的 WASM 运行时 +func NewRuntime(ctx context.Context) (*Runtime, error) { + r := wazero.NewRuntime(ctx) + return &Runtime{ + runtime: r, + modules: make(map[string]*Module), + }, nil +} + +// GetWazeroRuntime 返回底层 wazero 运行时 +func (r *Runtime) GetWazeroRuntime() wazero.Runtime { + return r.runtime +} + +// LoadModule 从字节加载 WASM 模块 +func (r *Runtime) LoadModule(ctx context.Context, name string, wasmBytes []byte) (*Module, error) { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.modules[name]; exists { + return nil, fmt.Errorf("module %s already loaded", name) + } + + compiled, err := r.runtime.CompileModule(ctx, wasmBytes) + if err != nil { + return nil, fmt.Errorf("failed to compile module: %w", err) + } + + module := &Module{ + name: name, + compiled: compiled, + } + + r.modules[name] = module + return module, nil +} + +// GetModule 获取已加载的模块 +func (r *Runtime) GetModule(name string) (*Module, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + m, ok := r.modules[name] + return m, ok +} + +// UnloadModule 卸载 WASM 模块 +func (r *Runtime) UnloadModule(ctx context.Context, name string) error { + r.mu.Lock() + defer r.mu.Unlock() + + module, exists := r.modules[name] + if !exists { + return fmt.Errorf("module %s not found", name) + } + + if err := module.Close(ctx); err != nil { + return err + } + + delete(r.modules, name) + return nil +} + +// Close 关闭运行时 +func (r *Runtime) Close(ctx context.Context) error { + r.mu.Lock() + defer r.mu.Unlock() + + for name, module := range r.modules { + if err := module.Close(ctx); err != nil { + return fmt.Errorf("failed to close module %s: %w", name, err) + } + } + + return r.runtime.Close(ctx) +} + +// Module WASM 模块封装 +type Module struct { + name string + compiled wazero.CompiledModule + instance api.Module +} + +// Name 返回模块名称 +func (m *Module) Name() string { + return m.name +} + +// Close 关闭模块 +func (m *Module) Close(ctx context.Context) error { + if m.instance != nil { + if err := m.instance.Close(ctx); err != nil { + return err + } + } + return m.compiled.Close(ctx) +} diff --git a/pkg/protocol/message.go b/pkg/protocol/message.go index 8054210..044e1ae 100644 --- a/pkg/protocol/message.go +++ b/pkg/protocol/message.go @@ -7,6 +7,12 @@ import ( "io" ) +// 协议常量 +const ( + MaxMessageSize = 1024 * 1024 // 最大消息大小 1MB + HeaderSize = 5 // 消息头大小 +) + // 消息类型定义 const ( MsgTypeAuth uint8 = 1 // 认证请求 @@ -19,6 +25,15 @@ const ( MsgTypeError uint8 = 8 // 错误消息 MsgTypeProxyConnect uint8 = 9 // 代理连接请求 (SOCKS5/HTTP) MsgTypeProxyResult uint8 = 10 // 代理连接结果 + + // Plugin 相关消息 + MsgTypePluginList uint8 = 20 // 请求/响应可用 plugins + MsgTypePluginDownload uint8 = 21 // 请求下载 plugin + MsgTypePluginData uint8 = 22 // Plugin 二进制数据(分块) + MsgTypePluginReady uint8 = 23 // Plugin 加载确认 + + // UDP 相关消息 + MsgTypeUDPData uint8 = 30 // UDP 数据包 ) // Message 基础消息结构 @@ -42,10 +57,14 @@ type AuthResponse struct { // ProxyRule 代理规则 type ProxyRule struct { Name string `json:"name" yaml:"name"` - Type string `json:"type" yaml:"type"` // tcp, socks5, http - LocalIP string `json:"local_ip" yaml:"local_ip"` // tcp 模式使用 - LocalPort int `json:"local_port" yaml:"local_port"` // tcp 模式使用 + Type string `json:"type" yaml:"type"` // 内置: tcp, udp, http, https; 插件: socks5 等 + LocalIP string `json:"local_ip" yaml:"local_ip"` // tcp/udp 模式使用 + LocalPort int `json:"local_port" yaml:"local_port"` // tcp/udp 模式使用 RemotePort int `json:"remote_port" yaml:"remote_port"` // 服务端监听端口 + // Plugin 支持字段 + PluginName string `json:"plugin_name,omitempty" yaml:"plugin_name"` + PluginVersion string `json:"plugin_version,omitempty" yaml:"plugin_version"` + PluginConfig map[string]string `json:"plugin_config,omitempty" yaml:"plugin_config"` } // ProxyConfig 代理配置下发 @@ -75,9 +94,59 @@ type ProxyConnectResult struct { Message string `json:"message,omitempty"` } +// PluginMetadata Plugin 元数据(协议层) +type PluginMetadata struct { + Name string `json:"name"` + Version string `json:"version"` + Checksum string `json:"checksum"` + Size int64 `json:"size"` + Description string `json:"description,omitempty"` +} + +// PluginListRequest 请求可用 plugins +type PluginListRequest struct { + ClientVersion string `json:"client_version"` +} + +// PluginListResponse 返回可用 plugins +type PluginListResponse struct { + Plugins []PluginMetadata `json:"plugins"` +} + +// PluginDownloadRequest 请求下载 plugin +type PluginDownloadRequest struct { + Name string `json:"name"` + Version string `json:"version"` +} + +// PluginDataChunk Plugin 二进制数据块 +type PluginDataChunk struct { + Name string `json:"name"` + Version string `json:"version"` + ChunkIndex int `json:"chunk_index"` + TotalChunks int `json:"total_chunks"` + Data []byte `json:"data"` + Checksum string `json:"checksum,omitempty"` +} + +// PluginReadyNotification Plugin 加载确认 +type PluginReadyNotification struct { + Name string `json:"name"` + Version string `json:"version"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +// UDPPacket UDP 数据包 +type UDPPacket struct { + RemotePort int `json:"remote_port"` // 服务端监听端口 + ClientAddr string `json:"client_addr"` // 客户端地址 (用于回复) + Data []byte `json:"data"` // UDP 数据 +} + // WriteMessage 写入消息到 writer func WriteMessage(w io.Writer, msg *Message) error { - header := make([]byte, 5) + header := make([]byte, HeaderSize) header[0] = msg.Type binary.BigEndian.PutUint32(header[1:], uint32(len(msg.Payload))) @@ -94,7 +163,7 @@ func WriteMessage(w io.Writer, msg *Message) error { // ReadMessage 从 reader 读取消息 func ReadMessage(r io.Reader) (*Message, error) { - header := make([]byte, 5) + header := make([]byte, HeaderSize) if _, err := io.ReadFull(r, header); err != nil { return nil, err } @@ -102,7 +171,7 @@ func ReadMessage(r io.Reader) (*Message, error) { msgType := header[0] length := binary.BigEndian.Uint32(header[1:]) - if length > 1024*1024 { + if length > MaxMessageSize { return nil, errors.New("message too large") } diff --git a/pkg/proxy/server.go b/pkg/proxy/server.go index b6b71b9..7ec37bc 100644 --- a/pkg/proxy/server.go +++ b/pkg/proxy/server.go @@ -45,7 +45,7 @@ func (s *Server) HandleConn(conn net.Conn) { switch s.typ { case "socks5": err = s.socks5.HandleConn(conn) - case "http": + case "http", "https": err = s.http.HandleConn(conn) } if err != nil { diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go index 40ca8a7..1af86b6 100644 --- a/pkg/relay/relay.go +++ b/pkg/relay/relay.go @@ -1,30 +1,29 @@ package relay import ( + "io" "net" "sync" ) +const bufferSize = 32 * 1024 + // Relay 双向数据转发 func Relay(c1, c2 net.Conn) { var wg sync.WaitGroup wg.Add(2) - copy := func(dst, src net.Conn) { + copyConn := func(dst, src net.Conn) { defer wg.Done() - buf := make([]byte, 32*1024) - for { - n, err := src.Read(buf) - if n > 0 { - dst.Write(buf[:n]) - } - if err != nil { - return - } + buf := make([]byte, bufferSize) + _, _ = io.CopyBuffer(dst, src, buf) + // 关闭写端,通知对方数据传输完成 + if tc, ok := dst.(*net.TCPConn); ok { + tc.CloseWrite() } } - go copy(c1, c2) - go copy(c2, c1) + go copyConn(c1, c2) + go copyConn(c2, c1) wg.Wait() }