Files
GoTunnel/internal/client/tunnel/client.go
Flik e10736e05e
All checks were successful
Build Multi-Platform Binaries / build-frontend (push) Successful in 29s
Build Multi-Platform Binaries / build-binaries (amd64, darwin, server, false) (push) Successful in 49s
Build Multi-Platform Binaries / build-binaries (amd64, linux, client, true) (push) Successful in 34s
Build Multi-Platform Binaries / build-binaries (amd64, linux, server, true) (push) Successful in 59s
Build Multi-Platform Binaries / build-binaries (amd64, windows, client, true) (push) Successful in 34s
Build Multi-Platform Binaries / build-binaries (amd64, windows, server, true) (push) Successful in 55s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, client, true) (push) Successful in 37s
Build Multi-Platform Binaries / build-binaries (arm, 7, linux, server, true) (push) Successful in 1m7s
Build Multi-Platform Binaries / build-binaries (arm64, darwin, server, false) (push) Successful in 50s
Build Multi-Platform Binaries / build-binaries (arm64, linux, client, true) (push) Successful in 33s
Build Multi-Platform Binaries / build-binaries (arm64, linux, server, true) (push) Successful in 59s
Build Multi-Platform Binaries / build-binaries (arm64, windows, server, false) (push) Successful in 52s
update
2025-12-29 14:24:46 +08:00

465 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package tunnel
import (
"crypto/tls"
"fmt"
"log"
"net"
"os"
"path/filepath"
"sync"
"time"
"github.com/gotunnel/pkg/plugin"
"github.com/gotunnel/pkg/protocol"
"github.com/gotunnel/pkg/relay"
"github.com/hashicorp/yamux"
)
// 客户端常量
const (
dialTimeout = 10 * time.Second
localDialTimeout = 5 * time.Second
udpTimeout = 10 * time.Second
reconnectDelay = 5 * time.Second
disconnectDelay = 3 * time.Second
udpBufferSize = 65535
idFileName = ".gotunnel_id"
)
// Client 隧道客户端
type Client struct {
ServerAddr string
Token string
ID string
TLSEnabled bool
TLSConfig *tls.Config
session *yamux.Session
rules []protocol.ProxyRule
mu sync.RWMutex
pluginRegistry *plugin.Registry
runningPlugins map[string]plugin.ClientHandler // 运行中的客户端插件
pluginMu sync.RWMutex
}
// NewClient 创建客户端
func NewClient(serverAddr, token, id string) *Client {
if id == "" {
id = loadClientID()
}
return &Client{
ServerAddr: serverAddr,
Token: token,
ID: id,
runningPlugins: make(map[string]plugin.ClientHandler),
}
}
// getIDFilePath 获取 ID 文件路径
func getIDFilePath() string {
home, err := os.UserHomeDir()
if err != nil {
return idFileName
}
return filepath.Join(home, idFileName)
}
// loadClientID 从本地文件加载客户端 ID
func loadClientID() string {
data, err := os.ReadFile(getIDFilePath())
if err != nil {
return ""
}
return string(data)
}
// saveClientID 保存客户端 ID 到本地文件
func saveClientID(id string) {
if err := os.WriteFile(getIDFilePath(), []byte(id), 0600); err != nil {
log.Printf("[Client] Failed to save client ID: %v", err)
}
}
// SetPluginRegistry 设置插件注册表
func (c *Client) SetPluginRegistry(registry *plugin.Registry) {
c.pluginRegistry = registry
}
// Run 启动客户端(带断线重连)
func (c *Client) Run() error {
for {
if err := c.connect(); err != nil {
log.Printf("[Client] Connect error: %v", err)
log.Printf("[Client] Reconnecting in %v...", reconnectDelay)
time.Sleep(reconnectDelay)
continue
}
c.handleSession()
log.Printf("[Client] Disconnected, reconnecting...")
time.Sleep(disconnectDelay)
}
}
// connect 连接到服务端并认证
func (c *Client) connect() error {
var conn net.Conn
var err error
if c.TLSEnabled && c.TLSConfig != nil {
dialer := &net.Dialer{Timeout: dialTimeout}
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
} else {
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
}
if err != nil {
return err
}
authReq := protocol.AuthRequest{ClientID: c.ID, Token: c.Token}
msg, _ := protocol.NewMessage(protocol.MsgTypeAuth, authReq)
if err := protocol.WriteMessage(conn, msg); err != nil {
conn.Close()
return err
}
resp, err := protocol.ReadMessage(conn)
if err != nil {
conn.Close()
return err
}
var authResp protocol.AuthResponse
if err := resp.ParsePayload(&authResp); err != nil {
conn.Close()
return fmt.Errorf("parse auth response: %w", err)
}
if !authResp.Success {
conn.Close()
return fmt.Errorf("auth failed: %s", authResp.Message)
}
// 如果服务端分配了新 ID则更新并保存
if authResp.ClientID != "" && authResp.ClientID != c.ID {
c.ID = authResp.ClientID
saveClientID(c.ID)
log.Printf("[Client] New ID assigned and saved: %s", c.ID)
}
log.Printf("[Client] Authenticated as %s", c.ID)
session, err := yamux.Client(conn, nil)
if err != nil {
conn.Close()
return err
}
c.mu.Lock()
c.session = session
c.mu.Unlock()
return nil
}
// handleSession 处理会话
func (c *Client) handleSession() {
defer c.session.Close()
for {
stream, err := c.session.Accept()
if err != nil {
return
}
go c.handleStream(stream)
}
}
// handleStream 处理流
func (c *Client) handleStream(stream net.Conn) {
msg, err := protocol.ReadMessage(stream)
if err != nil {
stream.Close()
return
}
switch msg.Type {
case protocol.MsgTypeProxyConfig:
defer stream.Close()
c.handleProxyConfig(msg)
case protocol.MsgTypeNewProxy:
defer stream.Close()
c.handleNewProxy(stream, msg)
case protocol.MsgTypeHeartbeat:
defer stream.Close()
c.handleHeartbeat(stream)
case protocol.MsgTypeProxyConnect:
c.handleProxyConnect(stream, msg)
case protocol.MsgTypeUDPData:
c.handleUDPData(stream, msg)
case protocol.MsgTypePluginConfig:
defer stream.Close()
c.handlePluginConfig(msg)
case protocol.MsgTypeClientPluginStart:
c.handleClientPluginStart(stream, msg)
case protocol.MsgTypeClientPluginConn:
c.handleClientPluginConn(stream, msg)
}
}
// handleProxyConfig 处理代理配置
func (c *Client) handleProxyConfig(msg *protocol.Message) {
var cfg protocol.ProxyConfig
if err := msg.ParsePayload(&cfg); err != nil {
log.Printf("[Client] Parse proxy config error: %v", err)
return
}
c.mu.Lock()
c.rules = cfg.Rules
c.mu.Unlock()
log.Printf("[Client] Received %d proxy rules", len(cfg.Rules))
for _, r := range cfg.Rules {
log.Printf("[Client] %s: %s:%d", r.Name, r.LocalIP, r.LocalPort)
}
}
// handleNewProxy 处理新代理请求
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
var req protocol.NewProxyRequest
if err := msg.ParsePayload(&req); err != nil {
log.Printf("[Client] Parse new proxy request error: %v", err)
return
}
var rule *protocol.ProxyRule
c.mu.RLock()
for _, r := range c.rules {
if r.RemotePort == req.RemotePort {
rule = &r
break
}
}
c.mu.RUnlock()
if rule == nil {
log.Printf("[Client] Unknown port %d", req.RemotePort)
return
}
localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout)
if err != nil {
log.Printf("[Client] Connect %s error: %v", localAddr, err)
return
}
relay.Relay(stream, localConn)
}
// handleHeartbeat 处理心跳
func (c *Client) handleHeartbeat(stream net.Conn) {
msg := &protocol.Message{Type: protocol.MsgTypeHeartbeatAck}
protocol.WriteMessage(stream, msg)
}
// handleProxyConnect 处理代理连接请求 (SOCKS5/HTTP)
func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ProxyConnectRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendProxyResult(stream, false, "invalid request")
return
}
// 连接目标地址
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
if err != nil {
c.sendProxyResult(stream, false, err.Error())
return
}
defer targetConn.Close()
// 发送成功响应
if err := c.sendProxyResult(stream, true, ""); err != nil {
return
}
// 双向转发数据
relay.Relay(stream, targetConn)
}
// sendProxyResult 发送代理连接结果
func (c *Client) sendProxyResult(stream net.Conn, success bool, message string) error {
result := protocol.ProxyConnectResult{Success: success, Message: message}
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
return protocol.WriteMessage(stream, msg)
}
// handleUDPData 处理 UDP 数据
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var packet protocol.UDPPacket
if err := msg.ParsePayload(&packet); err != nil {
return
}
// 查找对应的规则
rule := c.findRuleByPort(packet.RemotePort)
if rule == nil {
return
}
// 连接本地 UDP 服务
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
conn, err := net.DialTimeout("udp", target, localDialTimeout)
if err != nil {
return
}
defer conn.Close()
// 发送数据到本地服务
conn.SetDeadline(time.Now().Add(udpTimeout))
if _, err := conn.Write(packet.Data); err != nil {
return
}
// 读取响应
buf := make([]byte, udpBufferSize)
n, err := conn.Read(buf)
if err != nil {
return
}
// 发送响应回服务端
respPacket := protocol.UDPPacket{
RemotePort: packet.RemotePort,
ClientAddr: packet.ClientAddr,
Data: buf[:n],
}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeUDPData, respPacket)
protocol.WriteMessage(stream, respMsg)
}
// findRuleByPort 根据端口查找规则
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
c.mu.RLock()
defer c.mu.RUnlock()
for i := range c.rules {
if c.rules[i].RemotePort == port {
return &c.rules[i]
}
}
return nil
}
// handlePluginConfig 处理插件配置同步
func (c *Client) handlePluginConfig(msg *protocol.Message) {
var cfg protocol.PluginConfigSync
if err := msg.ParsePayload(&cfg); err != nil {
log.Printf("[Client] Parse plugin config error: %v", err)
return
}
log.Printf("[Client] Received config for plugin: %s", cfg.PluginName)
// 应用配置到插件
if c.pluginRegistry != nil {
handler, err := c.pluginRegistry.Get(cfg.PluginName)
if err != nil {
log.Printf("[Client] Plugin %s not found: %v", cfg.PluginName, err)
return
}
if err := handler.Init(cfg.Config); err != nil {
log.Printf("[Client] Plugin %s init error: %v", cfg.PluginName, err)
return
}
log.Printf("[Client] Plugin %s config applied", cfg.PluginName)
}
}
// handleClientPluginStart 处理客户端插件启动请求
func (c *Client) handleClientPluginStart(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var req protocol.ClientPluginStartRequest
if err := msg.ParsePayload(&req); err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
log.Printf("[Client] Starting plugin %s for rule %s", req.PluginName, req.RuleName)
// 获取插件
if c.pluginRegistry == nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", "plugin registry not set")
return
}
handler, err := c.pluginRegistry.GetClientPlugin(req.PluginName)
if err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
// 初始化并启动
if err := handler.Init(req.Config); err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
localAddr, err := handler.Start()
if err != nil {
c.sendPluginStatus(stream, req.PluginName, req.RuleName, false, "", err.Error())
return
}
// 保存运行中的插件
key := req.PluginName + ":" + req.RuleName
c.pluginMu.Lock()
c.runningPlugins[key] = handler
c.pluginMu.Unlock()
log.Printf("[Client] Plugin %s started at %s", req.PluginName, localAddr)
c.sendPluginStatus(stream, req.PluginName, req.RuleName, true, localAddr, "")
}
// sendPluginStatus 发送插件状态响应
func (c *Client) sendPluginStatus(stream net.Conn, pluginName, ruleName string, running bool, localAddr, errMsg string) {
resp := protocol.ClientPluginStatusResponse{
PluginName: pluginName,
RuleName: ruleName,
Running: running,
LocalAddr: localAddr,
Error: errMsg,
}
msg, _ := protocol.NewMessage(protocol.MsgTypeClientPluginStatus, resp)
protocol.WriteMessage(stream, msg)
}
// handleClientPluginConn 处理客户端插件连接
func (c *Client) handleClientPluginConn(stream net.Conn, msg *protocol.Message) {
var req protocol.ClientPluginConnRequest
if err := msg.ParsePayload(&req); err != nil {
stream.Close()
return
}
key := req.PluginName + ":" + req.RuleName
c.pluginMu.RLock()
handler, ok := c.runningPlugins[key]
c.pluginMu.RUnlock()
if !ok {
log.Printf("[Client] Plugin %s not running", key)
stream.Close()
return
}
// 让插件处理连接
handler.HandleConn(stream)
}