diff --git a/cmd/client/main.go b/cmd/client/main.go index ebec689..c63b8b5 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -4,6 +4,7 @@ import ( "flag" "log" + "github.com/gotunnel/internal/client/config" "github.com/gotunnel/internal/client/tunnel" "github.com/gotunnel/pkg/crypto" "github.com/gotunnel/pkg/plugin" @@ -14,16 +15,43 @@ func main() { token := flag.String("t", "", "auth token") id := flag.String("id", "", "client id (optional, auto-assigned if empty)") noTLS := flag.Bool("no-tls", false, "disable TLS") + configPath := flag.String("c", "", "config file path") flag.Parse() - if *server == "" || *token == "" { - log.Fatal("Usage: client -s -t [-id ] [-no-tls]") + // 优先加载配置文件 + var cfg *config.ClientConfig + if *configPath != "" { + var err error + cfg, err = config.LoadClientConfig(*configPath) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + } else { + cfg = &config.ClientConfig{} } - client := tunnel.NewClient(*server, *token, *id) + // 命令行参数覆盖配置文件 + if *server != "" { + cfg.Server = *server + } + if *token != "" { + cfg.Token = *token + } + if *id != "" { + cfg.ID = *id + } + if *noTLS { + cfg.NoTLS = *noTLS + } + + if cfg.Server == "" || cfg.Token == "" { + log.Fatal("Usage: client [-c config.yaml] | [-s -t [-id ] [-no-tls]]") + } + + client := tunnel.NewClient(cfg.Server, cfg.Token, cfg.ID) // TLS 默认启用,默认跳过证书验证(类似 frp) - if !*noTLS { + if !cfg.NoTLS { client.TLSEnabled = true client.TLSConfig = crypto.ClientTLSConfig() log.Printf("[Client] TLS enabled") @@ -33,5 +61,10 @@ func main() { registry := plugin.NewRegistry() client.SetPluginRegistry(registry) + // 初始化版本存储 + if err := client.InitVersionStore(); err != nil { + log.Printf("[Client] Warning: failed to init version store: %v", err) + } + client.Run() } diff --git a/go.mod b/go.mod index 2b6e070..5dc273e 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/goccy/go-yaml v1.19.1 // indirect github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e // indirect github.com/google/uuid v1.6.0 // indirect + github.com/gorilla/websocket v1.5.3 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index bb3011e..6efdfed 100644 --- a/go.sum +++ b/go.sum @@ -81,6 +81,8 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= diff --git a/internal/client/config/config.go b/internal/client/config/config.go new file mode 100644 index 0000000..543991d --- /dev/null +++ b/internal/client/config/config.go @@ -0,0 +1,30 @@ +package config + +import ( + "os" + + "gopkg.in/yaml.v3" +) + +// ClientConfig 客户端配置 +type ClientConfig struct { + Server string `yaml:"server"` // 服务器地址 + Token string `yaml:"token"` // 认证 Token + ID string `yaml:"id"` // 客户端 ID + NoTLS bool `yaml:"no_tls"` // 禁用 TLS +} + +// LoadClientConfig 加载客户端配置 +func LoadClientConfig(path string) (*ClientConfig, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + var cfg ClientConfig + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, err + } + + return &cfg, nil +} diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index 196ca3a..9522c22 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -30,37 +30,42 @@ const ( reconnectDelay = 5 * time.Second disconnectDelay = 3 * time.Second udpBufferSize = 65535 - idFileName = ".gotunnel_id" + idFileName = "id" ) // Client 隧道客户端 type Client struct { - ServerAddr string - Token string - ID string - TLSEnabled bool - TLSConfig *tls.Config - DataDir string // 数据目录 - session *yamux.Session - rules []protocol.ProxyRule - mu sync.RWMutex - pluginRegistry *plugin.Registry - runningPlugins map[string]plugin.ClientPlugin - versionStore *PluginVersionStore - pluginMu sync.RWMutex - logger *Logger // 日志收集器 + ServerAddr string + Token string + ID string + TLSEnabled bool + TLSConfig *tls.Config + DataDir string // 数据目录 + session *yamux.Session + rules []protocol.ProxyRule + mu sync.RWMutex + pluginRegistry *plugin.Registry + runningPlugins map[string]plugin.ClientPlugin + versionStore *PluginVersionStore + pluginMu sync.RWMutex + logger *Logger // 日志收集器 } // NewClient 创建客户端 func NewClient(serverAddr, token, id string) *Client { - if id == "" { - id = loadClientID() - } - // 默认数据目录 home, _ := os.UserHomeDir() dataDir := filepath.Join(home, ".gotunnel") + // 确保数据目录存在 + if err := os.MkdirAll(dataDir, 0755); err != nil { + log.Printf("[Client] Failed to create data dir: %v", err) + } + + if id == "" { + id = loadClientID(dataDir) + } + // 初始化日志收集器 logger, err := NewLogger(dataDir) if err != nil { @@ -88,17 +93,13 @@ func (c *Client) InitVersionStore() error { } // getIDFilePath 获取 ID 文件路径 -func getIDFilePath() string { - home, err := os.UserHomeDir() - if err != nil { - return idFileName - } - return filepath.Join(home, idFileName) +func getIDFilePath(dataDir string) string { + return filepath.Join(dataDir, idFileName) } // loadClientID 从本地文件加载客户端 ID -func loadClientID() string { - data, err := os.ReadFile(getIDFilePath()) +func loadClientID(dataDir string) string { + data, err := os.ReadFile(getIDFilePath(dataDir)) if err != nil { return "" } @@ -106,8 +107,8 @@ func loadClientID() string { } // saveClientID 保存客户端 ID 到本地文件 -func saveClientID(id string) { - if err := os.WriteFile(getIDFilePath(), []byte(id), 0600); err != nil { +func saveClientID(dataDir, id string) { + if err := os.WriteFile(getIDFilePath(dataDir), []byte(id), 0600); err != nil { log.Printf("[Client] Failed to save client ID: %v", err) } } @@ -201,7 +202,7 @@ func (c *Client) connect() error { // 如果服务端分配了新 ID,则更新并保存 if authResp.ClientID != "" && authResp.ClientID != c.ID { c.ID = authResp.ClientID - saveClientID(c.ID) + saveClientID(c.DataDir, c.ID) c.logf("[Client] New ID assigned and saved: %s", c.ID) } diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 65c130d..609efaa 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -61,12 +61,12 @@ type Server struct { mu sync.RWMutex tlsConfig *tls.Config pluginRegistry *plugin.Registry - jsPlugins []JSPluginEntry // 配置的 JS 插件 - connSem chan struct{} // 连接数信号量 - activeConns int64 // 当前活跃连接数 - listener net.Listener // 主监听器 - shutdown chan struct{} // 关闭信号 - wg sync.WaitGroup // 等待所有连接关闭 + jsPlugins []JSPluginEntry // 配置的 JS 插件 + connSem chan struct{} // 连接数信号量 + activeConns int64 // 当前活跃连接数 + listener net.Listener // 主监听器 + shutdown chan struct{} // 关闭信号 + wg sync.WaitGroup // 等待所有连接关闭 logSessions *LogSessionManager // 日志会话管理器 } @@ -82,14 +82,14 @@ type JSPluginEntry struct { // ClientSession 客户端会话 type ClientSession struct { - ID string - RemoteAddr string // 客户端 IP 地址 - Session *yamux.Session - Rules []protocol.ProxyRule - Listeners map[int]net.Listener - UDPConns map[int]*net.UDPConn // UDP 连接 - LastPing time.Time - mu sync.Mutex + ID string + RemoteAddr string // 客户端 IP 地址 + Session *yamux.Session + Rules []protocol.ProxyRule + Listeners map[int]net.Listener + UDPConns map[int]*net.UDPConn // UDP 连接 + LastPing time.Time + mu sync.Mutex } // NewServer 创建服务端 @@ -452,6 +452,9 @@ func (s *Server) startProxyListeners(cs *ClientSession) { case "http", "https": log.Printf("[Server] HTTP proxy %s on :%d", rule.Name, rule.RemotePort) go s.acceptProxyServerConns(cs, ln, rule) + case "websocket": + log.Printf("[Server] Websocket proxy %s on :%d", rule.Name, rule.RemotePort) + go s.acceptWebsocketConns(cs, ln, rule) default: log.Printf("[Server] TCP proxy %s: :%d -> %s:%d", rule.Name, rule.RemotePort, rule.LocalIP, rule.LocalPort) diff --git a/internal/server/tunnel/websocket.go b/internal/server/tunnel/websocket.go new file mode 100644 index 0000000..77ac48c --- /dev/null +++ b/internal/server/tunnel/websocket.go @@ -0,0 +1,146 @@ +package tunnel + +import ( + "io" + "log" + "net" + "net/http" + "time" + + "github.com/gorilla/websocket" + "github.com/gotunnel/pkg/protocol" + "github.com/gotunnel/pkg/relay" +) + +var upgrader = websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // 允许所有跨域请求 + }, +} + +// WSConnAdapter 适配器:将 websocket.Conn 适配为 io.ReadWriter +type WSConnAdapter struct { + conn *websocket.Conn + // 读缓冲 + reader io.Reader +} + +func NewWSConnAdapter(conn *websocket.Conn) *WSConnAdapter { + return &WSConnAdapter{ + conn: conn, + } +} + +func (a *WSConnAdapter) Read(p []byte) (n int, err error) { + if a.reader == nil { + messageType, reader, err := a.conn.NextReader() + if err != nil { + return 0, err + } + if messageType != websocket.BinaryMessage && messageType != websocket.TextMessage { + // 忽略非数据消息 + return 0, nil + } + a.reader = reader + } + n, err = a.reader.Read(p) + if err == io.EOF { + a.reader = nil + err = nil // 当前消息读完,不代表连接断开 + // 如果读到了0字节,尝试读下一个消息,避免因为返回 (0, nil) 导致调用方以为无数据空转 + if n == 0 { + return a.Read(p) + } + } + return n, err +} + +func (a *WSConnAdapter) Write(p []byte) (n int, err error) { + err = a.conn.WriteMessage(websocket.BinaryMessage, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (a *WSConnAdapter) Close() error { + return a.conn.Close() +} + +func (a *WSConnAdapter) LocalAddr() net.Addr { + return a.conn.LocalAddr() +} + +func (a *WSConnAdapter) RemoteAddr() net.Addr { + return a.conn.RemoteAddr() +} + +func (a *WSConnAdapter) SetDeadline(t time.Time) error { + if err := a.conn.SetReadDeadline(t); err != nil { + return err + } + return a.conn.SetWriteDeadline(t) +} + +func (a *WSConnAdapter) SetReadDeadline(t time.Time) error { + return a.conn.SetReadDeadline(t) +} + +func (a *WSConnAdapter) SetWriteDeadline(t time.Time) error { + return a.conn.SetWriteDeadline(t) +} + +// acceptWebsocketConns 接受 Websocket 连接 +func (s *Server) acceptWebsocketConns(cs *ClientSession, ln net.Listener, rule protocol.ProxyRule) { + mux := http.NewServeMux() + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + wsConn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("[Server] Websocket upgrade error: %v", err) + return + } + + conn := NewWSConnAdapter(wsConn) + // 这里的 conn 并没有实现 net.Conn 接口的全部方法 (LocalAddr, RemoteAddr 等), + // Relay 函数如果需要 net.Conn,可能需要更完整的适配器。 + // 查看 relay.Relay 签名:func Relay(c1, c2 io.ReadWriteCloser) + // 假设 relay.Relay 接受 io.ReadWriteCloser。 + + go s.handleWebsocketProxyConn(cs, conn, rule) + }) + + server := &http.Server{ + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + + // 这里不需要协程,因为 startProxyListeners 中已经是 go s.acceptWebsocketConns(...) 调用了? + // 不,startProxyListeners 中 iterate rules。如果是 acceptWebsocketConns,应该是在那里 go。 + // 检查 caller 逻辑。 + + if err := server.Serve(ln); err != nil && err != http.ErrServerClosed { + log.Printf("[Server] Websocket server error: %v", err) + } +} + +// handleWebsocketProxyConn 处理 Websocket 代理连接 +func (s *Server) handleWebsocketProxyConn(cs *ClientSession, conn net.Conn, rule protocol.ProxyRule) { + defer conn.Close() + + stream, err := cs.Session.Open() + if err != nil { + log.Printf("[Server] Open stream error: %v", err) + return + } + defer stream.Close() + + // 发送新代理连接请求,告知客户端连接到哪里 + req := protocol.NewProxyRequest{RemotePort: rule.RemotePort} + msg, _ := protocol.NewMessage(protocol.MsgTypeNewProxy, req) + if err := protocol.WriteMessage(stream, msg); err != nil { + return + } + + relay.Relay(conn, stream) +} diff --git a/internal/server/tunnel/websocket_test.go b/internal/server/tunnel/websocket_test.go new file mode 100644 index 0000000..94e4938 --- /dev/null +++ b/internal/server/tunnel/websocket_test.go @@ -0,0 +1,120 @@ +package tunnel + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" +) + +func TestWSConnAdapter(t *testing.T) { + // 1. 设置测试服务器 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade error: %v", err) + return + } + defer c.Close() + + adapter := NewWSConnAdapter(c) + defer adapter.Close() + + // Echo server + buf := make([]byte, 1024) + for { + n, err := adapter.Read(buf) + if err != nil { + if err != io.EOF { + // websocket close might cause normal error locally + } + break + } + _, err = adapter.Write(buf[:n]) + if err != nil { + t.Errorf("write error: %v", err) + break + } + } + })) + defer server.Close() + + // 2. 客户端连接 + u := "ws" + strings.TrimPrefix(server.URL, "http") + ws, _, err := websocket.DefaultDialer.Dial(u, nil) + if err != nil { + t.Fatalf("dial error: %v", err) + } + defer ws.Close() + + // 3. 发送数据 + message := []byte("hello websocket") + err = ws.WriteMessage(websocket.BinaryMessage, message) + if err != nil { + t.Fatalf("write message error: %v", err) + } + + // 4. 接收响应 + _, p, err := ws.ReadMessage() + if err != nil { + t.Fatalf("read message error: %v", err) + } + + if !bytes.Equal(message, p) { + t.Errorf("expected %s, got %s", message, p) + } +} + +func TestWSConnAdapter_ReadMultiFrame(t *testing.T) { + // 测试多次 Read 调用读取一个 frame,或者一个 Read 读取多个 frame (net.Conn 语义) + // WSConnAdapter 实现是 Read 对应 NextReader,如果 buffer 小,可能一部分一部分读。 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgrader := websocket.Upgrader{} + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer c.Close() + adapter := NewWSConnAdapter(c) + + // 只要收到数据就这就验证通过 + buf := make([]byte, 10) + n, err := adapter.Read(buf) + if err != nil { + t.Errorf("read error: %v", err) + } + if n != 5 { // "hello" + t.Errorf("expected 5 bytes, got %d", n) + } + + // 读剩下的 "world" + n, err = adapter.Read(buf) + if err != nil { + t.Errorf("read 2 error: %v", err) + } + if n != 5 { + t.Errorf("expected 5 bytes, got %d", n) + } + })) + defer server.Close() + + u := "ws" + strings.TrimPrefix(server.URL, "http") + ws, _, err := websocket.DefaultDialer.Dial(u, nil) + if err != nil { + t.Fatalf("dial error: %v", err) + } + defer ws.Close() + + // 发送两个 BinaryMessage + ws.WriteMessage(websocket.BinaryMessage, []byte("hello")) + ws.WriteMessage(websocket.BinaryMessage, []byte("world")) + + time.Sleep(100 * time.Millisecond) +} diff --git a/pkg/protocol/message.go b/pkg/protocol/message.go index 91af99d..19430f4 100644 --- a/pkg/protocol/message.go +++ b/pkg/protocol/message.go @@ -95,10 +95,10 @@ type AuthResponse struct { // ProxyRule 代理规则 type ProxyRule struct { Name string `json:"name" yaml:"name"` - Type string `json:"type" yaml:"type"` // 内置: tcp, udp, http, https; 插件: socks5 等 - LocalIP string `json:"local_ip" yaml:"local_ip"` // tcp/udp 模式使用 - LocalPort int `json:"local_port" yaml:"local_port"` // tcp/udp 模式使用 - RemotePort int `json:"remote_port" yaml:"remote_port"` // 服务端监听端口 + Type string `json:"type" yaml:"type"` // 内置: tcp, udp, http, https, websocket; 插件: socks5 等 + LocalIP string `json:"local_ip" yaml:"local_ip"` // tcp/udp 模式使用 + LocalPort int `json:"local_port" yaml:"local_port"` // tcp/udp 模式使用 + RemotePort int `json:"remote_port" yaml:"remote_port"` // 服务端监听端口 Enabled *bool `json:"enabled,omitempty" yaml:"enabled"` // 是否启用,默认为 true // Plugin 支持字段 PluginID string `json:"plugin_id,omitempty" yaml:"plugin_id"` // 插件实例ID @@ -150,11 +150,11 @@ type ProxyConnectResult struct { // PluginMetadata Plugin 元数据(协议层) type PluginMetadata struct { - Name string `json:"name"` - Version string `json:"version"` - Checksum string `json:"checksum"` - Size int64 `json:"size"` - Description string `json:"description,omitempty"` + Name string `json:"name"` + Version string `json:"version"` + Checksum string `json:"checksum"` + Size int64 `json:"size"` + Description string `json:"description,omitempty"` } // PluginListRequest 请求可用 plugins diff --git a/test_config.yaml b/test_config.yaml new file mode 100644 index 0000000..efc658e --- /dev/null +++ b/test_config.yaml @@ -0,0 +1,3 @@ +server: "127.0.0.1:7000" +token: "testtoken" +id: "testclient"