add tls
All checks were successful
Build Multi-Platform Binaries / build (push) Successful in 11m8s

This commit is contained in:
Flik
2025-12-26 00:00:29 +08:00
parent f1038a132b
commit 1870b59214
6 changed files with 80 additions and 36 deletions

View File

@@ -5,18 +5,28 @@ import (
"log" "log"
"github.com/gotunnel/internal/client/tunnel" "github.com/gotunnel/internal/client/tunnel"
"github.com/gotunnel/pkg/crypto"
) )
func main() { func main() {
server := flag.String("s", "", "server address (ip:port)") server := flag.String("s", "", "server address (ip:port)")
token := flag.String("t", "", "auth token") token := flag.String("t", "", "auth token")
id := flag.String("id", "", "client id (optional)") id := flag.String("id", "", "client id (optional)")
noTLS := flag.Bool("no-tls", false, "disable TLS")
flag.Parse() flag.Parse()
if *server == "" || *token == "" { if *server == "" || *token == "" {
log.Fatal("Usage: client -s <server:port> -t <token> [-id <client_id>]") log.Fatal("Usage: client -s <server:port> -t <token> [-id <client_id>] [-no-tls]")
} }
client := tunnel.NewClient(*server, *token, *id) client := tunnel.NewClient(*server, *token, *id)
// TLS 默认启用
if !*noTLS {
client.TLSEnabled = true
client.TLSConfig = crypto.ClientTLSConfig()
log.Printf("[Client] TLS enabled")
}
client.Run() client.Run()
} }

View File

@@ -9,6 +9,7 @@ import (
"github.com/gotunnel/internal/server/config" "github.com/gotunnel/internal/server/config"
"github.com/gotunnel/internal/server/db" "github.com/gotunnel/internal/server/db"
"github.com/gotunnel/internal/server/tunnel" "github.com/gotunnel/internal/server/tunnel"
"github.com/gotunnel/pkg/crypto"
) )
func main() { func main() {
@@ -38,6 +39,16 @@ func main() {
cfg.Server.HeartbeatTimeout, cfg.Server.HeartbeatTimeout,
) )
// 配置 TLS默认启用
if !cfg.Server.TLSDisabled {
tlsConfig, err := crypto.GenerateTLSConfig()
if err != nil {
log.Fatalf("Generate TLS config error: %v", err)
}
server.SetTLSConfig(tlsConfig)
log.Printf("[Server] TLS enabled")
}
// 启动 Web 控制台 // 启动 Web 控制台
if cfg.Web.Enabled { if cfg.Web.Enabled {
ws := app.NewWebServer(clientStore, server, cfg, *configPath) ws := app.NewWebServer(clientStore, server, cfg, *configPath)

View File

@@ -1,15 +1,16 @@
package tunnel package tunnel
import ( import (
"crypto/tls"
"fmt" "fmt"
"log" "log"
"net" "net"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
"github.com/gotunnel/pkg/protocol" "github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/relay" "github.com/gotunnel/pkg/relay"
"github.com/google/uuid"
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
) )
@@ -18,6 +19,8 @@ type Client struct {
ServerAddr string ServerAddr string
Token string Token string
ID string ID string
TLSEnabled bool
TLSConfig *tls.Config
session *yamux.Session session *yamux.Session
rules []protocol.ProxyRule rules []protocol.ProxyRule
mu sync.RWMutex mu sync.RWMutex
@@ -53,7 +56,15 @@ func (c *Client) Run() error {
// connect 连接到服务端并认证 // connect 连接到服务端并认证
func (c *Client) connect() error { func (c *Client) connect() error {
conn, err := net.DialTimeout("tcp", c.ServerAddr, 10*time.Second) var conn net.Conn
var err error
if c.TLSEnabled && c.TLSConfig != nil {
dialer := &net.Dialer{Timeout: 10 * time.Second}
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
} else {
conn, err = net.DialTimeout("tcp", c.ServerAddr, 10*time.Second)
}
if err != nil { if err != nil {
return err return err
} }

View File

@@ -22,6 +22,7 @@ type ServerSettings struct {
HeartbeatSec int `yaml:"heartbeat_sec"` HeartbeatSec int `yaml:"heartbeat_sec"`
HeartbeatTimeout int `yaml:"heartbeat_timeout"` HeartbeatTimeout int `yaml:"heartbeat_timeout"`
DBPath string `yaml:"db_path"` DBPath string `yaml:"db_path"`
TLSDisabled bool `yaml:"tls_disabled"` // 默认启用 TLS设置为 true 禁用
} }
// WebSettings Web控制台设置 // WebSettings Web控制台设置

View File

@@ -1,6 +1,7 @@
package tunnel package tunnel
import ( import (
"crypto/tls"
"fmt" "fmt"
"log" "log"
"net" "net"
@@ -26,6 +27,7 @@ type Server struct {
portManager *utils.PortManager portManager *utils.PortManager
clients map[string]*ClientSession clients map[string]*ClientSession
mu sync.RWMutex mu sync.RWMutex
tlsConfig *tls.Config
} }
// ClientSession 客户端会话 // ClientSession 客户端会话
@@ -52,17 +54,33 @@ func NewServer(cs db.ClientStore, bindAddr string, bindPort int, token string, h
} }
} }
// SetTLSConfig 设置 TLS 配置
func (s *Server) SetTLSConfig(config *tls.Config) {
s.tlsConfig = config
}
// Run 启动服务端 // Run 启动服务端
func (s *Server) Run() error { func (s *Server) Run() error {
addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort) addr := fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)
ln, err := net.Listen("tcp", addr)
if err != nil { var ln net.Listener
return fmt.Errorf("failed to listen on %s: %v", addr, err) var err error
if s.tlsConfig != nil {
ln, err = tls.Listen("tcp", addr, s.tlsConfig)
if err != nil {
return fmt.Errorf("failed to listen TLS on %s: %v", addr, err)
}
log.Printf("[Server] TLS listening on %s", addr)
} else {
ln, err = net.Listen("tcp", addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %v", addr, err)
}
log.Printf("[Server] Listening on %s (no TLS)", addr)
} }
defer ln.Close() defer ln.Close()
log.Printf("[Server] Listening on %s", addr)
for { for {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {

View File

@@ -7,23 +7,21 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem"
"math/big" "math/big"
"net" "net"
"os"
"time" "time"
) )
// GenerateSelfSignedCert 生成自签名证书 // GenerateTLSConfig 生成内存中的自签名证书并返回 TLS 配置
func GenerateSelfSignedCert(certFile, keyFile string) error { func GenerateTLSConfig() (*tls.Config, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return err return nil, err
} }
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil { if err != nil {
return err return nil, err
} }
template := x509.Certificate{ template := x509.Certificate{
@@ -33,7 +31,7 @@ func GenerateSelfSignedCert(certFile, keyFile string) error {
CommonName: "GoTunnel Server", CommonName: "GoTunnel Server",
}, },
NotBefore: time.Now(), NotBefore: time.Now(),
NotAfter: time.Now().AddDate(10, 0, 0), // 10年有效期 NotAfter: time.Now().AddDate(10, 0, 0),
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
@@ -43,29 +41,24 @@ func GenerateSelfSignedCert(certFile, keyFile string) error {
certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil { if err != nil {
return err return nil, err
} }
// 写入证书文件 cert := tls.Certificate{
certOut, err := os.Create(certFile) Certificate: [][]byte{certDER},
if err != nil { PrivateKey: priv,
return err
} }
defer certOut.Close()
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
// 写入私钥文件 return &tls.Config{
keyOut, err := os.Create(keyFile) Certificates: []tls.Certificate{cert},
if err != nil { MinVersion: tls.VersionTLS12,
return err }, nil
} }
defer keyOut.Close()
// ClientTLSConfig 创建客户端 TLS 配置
privBytes, err := x509.MarshalECPrivateKey(priv) func ClientTLSConfig() *tls.Config {
if err != nil { return &tls.Config{
return err InsecureSkipVerify: true,
} MinVersion: tls.VersionTLS12,
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privBytes}) }
return nil
} }