From 1870b59214896276d93411e96c10a33201dcc3d3 Mon Sep 17 00:00:00 2001 From: Flik Date: Fri, 26 Dec 2025 00:00:29 +0800 Subject: [PATCH] add tls --- cmd/client/main.go | 12 +++++++- cmd/server/main.go | 11 +++++++ internal/client/tunnel/client.go | 15 ++++++++-- internal/server/config/config.go | 1 + internal/server/tunnel/server.go | 28 ++++++++++++++---- pkg/crypto/tls.go | 49 ++++++++++++++------------------ 6 files changed, 80 insertions(+), 36 deletions(-) diff --git a/cmd/client/main.go b/cmd/client/main.go index f698dfa..6332ba5 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -5,18 +5,28 @@ import ( "log" "github.com/gotunnel/internal/client/tunnel" + "github.com/gotunnel/pkg/crypto" ) func main() { server := flag.String("s", "", "server address (ip:port)") token := flag.String("t", "", "auth token") id := flag.String("id", "", "client id (optional)") + noTLS := flag.Bool("no-tls", false, "disable TLS") flag.Parse() if *server == "" || *token == "" { - log.Fatal("Usage: client -s -t [-id ]") + log.Fatal("Usage: client -s -t [-id ] [-no-tls]") } client := tunnel.NewClient(*server, *token, *id) + + // TLS 默认启用 + if !*noTLS { + client.TLSEnabled = true + client.TLSConfig = crypto.ClientTLSConfig() + log.Printf("[Client] TLS enabled") + } + client.Run() } diff --git a/cmd/server/main.go b/cmd/server/main.go index 8495b90..81612b3 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -9,6 +9,7 @@ import ( "github.com/gotunnel/internal/server/config" "github.com/gotunnel/internal/server/db" "github.com/gotunnel/internal/server/tunnel" + "github.com/gotunnel/pkg/crypto" ) func main() { @@ -38,6 +39,16 @@ func main() { cfg.Server.HeartbeatTimeout, ) + // 配置 TLS(默认启用) + if !cfg.Server.TLSDisabled { + tlsConfig, err := crypto.GenerateTLSConfig() + if err != nil { + log.Fatalf("Generate TLS config error: %v", err) + } + server.SetTLSConfig(tlsConfig) + log.Printf("[Server] TLS enabled") + } + // 启动 Web 控制台 if cfg.Web.Enabled { ws := app.NewWebServer(clientStore, server, cfg, *configPath) diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index 7bb5698..daafa7c 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -1,15 +1,16 @@ package tunnel import ( + "crypto/tls" "fmt" "log" "net" "sync" "time" + "github.com/google/uuid" "github.com/gotunnel/pkg/protocol" "github.com/gotunnel/pkg/relay" - "github.com/google/uuid" "github.com/hashicorp/yamux" ) @@ -18,6 +19,8 @@ type Client struct { ServerAddr string Token string ID string + TLSEnabled bool + TLSConfig *tls.Config session *yamux.Session rules []protocol.ProxyRule mu sync.RWMutex @@ -53,7 +56,15 @@ func (c *Client) Run() error { // connect 连接到服务端并认证 func (c *Client) connect() error { - conn, err := net.DialTimeout("tcp", c.ServerAddr, 10*time.Second) + var conn net.Conn + var err error + + if c.TLSEnabled && c.TLSConfig != nil { + dialer := &net.Dialer{Timeout: 10 * time.Second} + conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig) + } else { + conn, err = net.DialTimeout("tcp", c.ServerAddr, 10*time.Second) + } if err != nil { return err } diff --git a/internal/server/config/config.go b/internal/server/config/config.go index c13c07c..9fbf860 100644 --- a/internal/server/config/config.go +++ b/internal/server/config/config.go @@ -22,6 +22,7 @@ type ServerSettings struct { HeartbeatSec int `yaml:"heartbeat_sec"` HeartbeatTimeout int `yaml:"heartbeat_timeout"` DBPath string `yaml:"db_path"` + TLSDisabled bool `yaml:"tls_disabled"` // 默认启用 TLS,设置为 true 禁用 } // WebSettings Web控制台设置 diff --git a/internal/server/tunnel/server.go b/internal/server/tunnel/server.go index 3173558..3dc23f1 100644 --- a/internal/server/tunnel/server.go +++ b/internal/server/tunnel/server.go @@ -1,6 +1,7 @@ package tunnel import ( + "crypto/tls" "fmt" "log" "net" @@ -26,6 +27,7 @@ type Server struct { portManager *utils.PortManager clients map[string]*ClientSession mu sync.RWMutex + tlsConfig *tls.Config } // ClientSession 客户端会话 @@ -52,17 +54,33 @@ func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, h } } +// SetTLSConfig 设置 TLS 配置 +func (s *Server) SetTLSConfig(config *tls.Config) { + s.tlsConfig = config +} + // Run 启动服务端 func (s *Server) Run() error { addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort) - ln, err := net.Listen("tcp", addr) - if err != nil { - return fmt.Errorf("failed to listen on %s: %v", addr, err) + + var ln net.Listener + var err error + + if s.tlsConfig != nil { + ln, err = tls.Listen("tcp", addr, s.tlsConfig) + if err != nil { + return fmt.Errorf("failed to listen TLS on %s: %v", addr, err) + } + log.Printf("[Server] TLS listening on %s", addr) + } else { + ln, err = net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %v", addr, err) + } + log.Printf("[Server] Listening on %s (no TLS)", addr) } defer ln.Close() - log.Printf("[Server] Listening on %s", addr) - for { conn, err := ln.Accept() if err != nil { diff --git a/pkg/crypto/tls.go b/pkg/crypto/tls.go index e5dcf76..cd8c2a1 100644 --- a/pkg/crypto/tls.go +++ b/pkg/crypto/tls.go @@ -7,23 +7,21 @@ import ( "crypto/tls" "crypto/x509" "crypto/x509/pkix" - "encoding/pem" "math/big" "net" - "os" "time" ) -// GenerateSelfSignedCert 生成自签名证书 -func GenerateSelfSignedCert(certFile, keyFile string) error { +// GenerateTLSConfig 生成内存中的自签名证书并返回 TLS 配置 +func GenerateTLSConfig() (*tls.Config, error) { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - return err + return nil, err } serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) if err != nil { - return err + return nil, err } template := x509.Certificate{ @@ -33,7 +31,7 @@ func GenerateSelfSignedCert(certFile, keyFile string) error { CommonName: "GoTunnel Server", }, NotBefore: time.Now(), - NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期 + NotAfter: time.Now().AddDate(10, 0, 0), KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, @@ -43,29 +41,24 @@ func GenerateSelfSignedCert(certFile, keyFile string) error { certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) if err != nil { - return err + return nil, err } - // 写入证书文件 - certOut, err := os.Create(certFile) - if err != nil { - return err + cert := tls.Certificate{ + Certificate: [][]byte{certDER}, + PrivateKey: priv, } - defer certOut.Close() - pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - // 写入私钥文件 - keyOut, err := os.Create(keyFile) - if err != nil { - return err - } - defer keyOut.Close() - - privBytes, err := x509.MarshalECPrivateKey(priv) - if err != nil { - return err - } - pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) - - return nil + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }, nil +} + +// ClientTLSConfig 创建客户端 TLS 配置 +func ClientTLSConfig() *tls.Config { + return &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + } }