add plugins
All checks were successful
Build Multi-Platform Binaries / build (push) Successful in 11m9s

This commit is contained in:
Flik
2025-12-26 11:24:23 +08:00
parent d56fdafc1e
commit 4623a7f031
27 changed files with 2090 additions and 97 deletions

View File

@@ -0,0 +1,114 @@
package plugin
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"os"
"path/filepath"
"sync"
"time"
"github.com/gotunnel/pkg/plugin"
)
// CachedPlugin 缓存的 plugin 信息
type CachedPlugin struct {
Metadata plugin.PluginMetadata
Path string
LoadedAt time.Time
}
// Cache 管理本地 plugin 存储
type Cache struct {
dir string
plugins map[string]*CachedPlugin
mu sync.RWMutex
}
// NewCache 创建 plugin 缓存
func NewCache(cacheDir string) (*Cache, error) {
if err := os.MkdirAll(cacheDir, 0755); err != nil {
return nil, err
}
return &Cache{
dir: cacheDir,
plugins: make(map[string]*CachedPlugin),
}, nil
}
// Get 返回缓存的 plugin如果有效
func (c *Cache) Get(name, version, checksum string) (*CachedPlugin, error) {
c.mu.RLock()
defer c.mu.RUnlock()
cached, ok := c.plugins[name]
if !ok {
return nil, nil
}
// 验证版本和 checksum
if cached.Metadata.Version != version {
return nil, nil
}
if checksum != "" && cached.Metadata.Checksum != checksum {
return nil, nil
}
return cached, nil
}
// Store 保存 plugin 到缓存
func (c *Cache) Store(meta plugin.PluginMetadata, wasmData []byte) error {
c.mu.Lock()
defer c.mu.Unlock()
// 验证 checksum
hash := sha256.Sum256(wasmData)
checksum := hex.EncodeToString(hash[:])
if meta.Checksum != "" && meta.Checksum != checksum {
return fmt.Errorf("checksum mismatch")
}
meta.Checksum = checksum
// 写入文件
path := filepath.Join(c.dir, meta.Name+".wasm")
if err := os.WriteFile(path, wasmData, 0644); err != nil {
return err
}
c.plugins[meta.Name] = &CachedPlugin{
Metadata: meta,
Path: path,
LoadedAt: time.Now(),
}
return nil
}
// Remove 删除缓存的 plugin
func (c *Cache) Remove(name string) error {
c.mu.Lock()
defer c.mu.Unlock()
cached, ok := c.plugins[name]
if !ok {
return nil
}
os.Remove(cached.Path)
delete(c.plugins, name)
return nil
}
// List 返回所有缓存的 plugins
func (c *Cache) List() []plugin.PluginMetadata {
c.mu.RLock()
defer c.mu.RUnlock()
var result []plugin.PluginMetadata
for _, cached := range c.plugins {
result = append(result, cached.Metadata)
}
return result
}

View File

@@ -0,0 +1,70 @@
package plugin
import (
"context"
"log"
"sync"
"github.com/gotunnel/pkg/plugin"
"github.com/gotunnel/pkg/plugin/builtin"
"github.com/gotunnel/pkg/plugin/wasm"
)
// Manager 客户端 plugin 管理器
type Manager struct {
registry *plugin.Registry
cache *Cache
runtime *wasm.Runtime
mu sync.RWMutex
}
// NewManager 创建客户端 plugin 管理器
func NewManager(cacheDir string) (*Manager, error) {
ctx := context.Background()
cache, err := NewCache(cacheDir)
if err != nil {
return nil, err
}
runtime, err := wasm.NewRuntime(ctx)
if err != nil {
return nil, err
}
registry := plugin.NewRegistry()
m := &Manager{
registry: registry,
cache: cache,
runtime: runtime,
}
// 注册内置 plugins
if err := m.registerBuiltins(); err != nil {
return nil, err
}
return m, nil
}
// registerBuiltins 注册内置 plugins
// 注意: tcp, udp, http, https 是内置类型,直接在 tunnel 中处理
func (m *Manager) registerBuiltins() error {
// 注册 SOCKS5 plugin
if err := m.registry.RegisterBuiltin(builtin.NewSOCKS5Plugin()); err != nil {
return err
}
log.Println("[Plugin] Builtin plugins registered: socks5")
return nil
}
// GetHandler 返回指定代理类型的 handler
func (m *Manager) GetHandler(proxyType string) (plugin.ProxyHandler, error) {
return m.registry.Get(proxyType)
}
// Close 关闭管理器
func (m *Manager) Close(ctx context.Context) error {
return m.runtime.Close(ctx)
}

View File

@@ -14,6 +14,16 @@ import (
"github.com/hashicorp/yamux"
)
// 客户端常量
const (
dialTimeout = 10 * time.Second
localDialTimeout = 5 * time.Second
udpTimeout = 10 * time.Second
reconnectDelay = 5 * time.Second
disconnectDelay = 3 * time.Second
udpBufferSize = 65535
)
// Client 隧道客户端
type Client struct {
ServerAddr string
@@ -43,14 +53,14 @@ func (c *Client) Run() error {
for {
if err := c.connect(); err != nil {
log.Printf("[Client] Connect error: %v", err)
log.Printf("[Client] Reconnecting in 5s...")
time.Sleep(5 * time.Second)
log.Printf("[Client] Reconnecting in %v...", reconnectDelay)
time.Sleep(reconnectDelay)
continue
}
c.handleSession()
log.Printf("[Client] Disconnected, reconnecting...")
time.Sleep(3 * time.Second)
time.Sleep(disconnectDelay)
}
}
@@ -60,10 +70,10 @@ func (c *Client) connect() error {
var err error
if c.TLSEnabled && c.TLSConfig != nil {
dialer := &net.Dialer{Timeout: 10 * time.Second}
dialer := &net.Dialer{Timeout: dialTimeout}
conn, err = tls.DialWithDialer(dialer, "tcp", c.ServerAddr, c.TLSConfig)
} else {
conn, err = net.DialTimeout("tcp", c.ServerAddr, 10*time.Second)
conn, err = net.DialTimeout("tcp", c.ServerAddr, dialTimeout)
}
if err != nil {
return err
@@ -83,7 +93,10 @@ func (c *Client) connect() error {
}
var authResp protocol.AuthResponse
resp.ParsePayload(&authResp)
if err := resp.ParsePayload(&authResp); err != nil {
conn.Close()
return fmt.Errorf("parse auth response: %w", err)
}
if !authResp.Success {
conn.Close()
return fmt.Errorf("auth failed: %s", authResp.Message)
@@ -137,13 +150,18 @@ func (c *Client) handleStream(stream net.Conn) {
c.handleHeartbeat(stream)
case protocol.MsgTypeProxyConnect:
c.handleProxyConnect(stream, msg)
case protocol.MsgTypeUDPData:
c.handleUDPData(stream, msg)
}
}
// handleProxyConfig 处理代理配置
func (c *Client) handleProxyConfig(msg *protocol.Message) {
var cfg protocol.ProxyConfig
msg.ParsePayload(&cfg)
if err := msg.ParsePayload(&cfg); err != nil {
log.Printf("[Client] Parse proxy config error: %v", err)
return
}
c.mu.Lock()
c.rules = cfg.Rules
@@ -158,7 +176,10 @@ func (c *Client) handleProxyConfig(msg *protocol.Message) {
// handleNewProxy 处理新代理请求
func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
var req protocol.NewProxyRequest
msg.ParsePayload(&req)
if err := msg.ParsePayload(&req); err != nil {
log.Printf("[Client] Parse new proxy request error: %v", err)
return
}
var rule *protocol.ProxyRule
c.mu.RLock()
@@ -176,7 +197,7 @@ func (c *Client) handleNewProxy(stream net.Conn, msg *protocol.Message) {
}
localAddr := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
localConn, err := net.DialTimeout("tcp", localAddr, 5*time.Second)
localConn, err := net.DialTimeout("tcp", localAddr, localDialTimeout)
if err != nil {
log.Printf("[Client] Connect %s error: %v", localAddr, err)
return
@@ -202,7 +223,7 @@ func (c *Client) handleProxyConnect(stream net.Conn, msg *protocol.Message) {
}
// 连接目标地址
targetConn, err := net.DialTimeout("tcp", req.Target, 10*time.Second)
targetConn, err := net.DialTimeout("tcp", req.Target, dialTimeout)
if err != nil {
c.sendProxyResult(stream, false, err.Error())
return
@@ -224,3 +245,62 @@ func (c *Client) sendProxyResult(stream net.Conn, success bool, message string)
msg, _ := protocol.NewMessage(protocol.MsgTypeProxyResult, result)
return protocol.WriteMessage(stream, msg)
}
// handleUDPData 处理 UDP 数据
func (c *Client) handleUDPData(stream net.Conn, msg *protocol.Message) {
defer stream.Close()
var packet protocol.UDPPacket
if err := msg.ParsePayload(&packet); err != nil {
return
}
// 查找对应的规则
rule := c.findRuleByPort(packet.RemotePort)
if rule == nil {
return
}
// 连接本地 UDP 服务
target := fmt.Sprintf("%s:%d", rule.LocalIP, rule.LocalPort)
conn, err := net.DialTimeout("udp", target, localDialTimeout)
if err != nil {
return
}
defer conn.Close()
// 发送数据到本地服务
conn.SetDeadline(time.Now().Add(udpTimeout))
if _, err := conn.Write(packet.Data); err != nil {
return
}
// 读取响应
buf := make([]byte, udpBufferSize)
n, err := conn.Read(buf)
if err != nil {
return
}
// 发送响应回服务端
respPacket := protocol.UDPPacket{
RemotePort: packet.RemotePort,
ClientAddr: packet.ClientAddr,
Data: buf[:n],
}
respMsg, _ := protocol.NewMessage(protocol.MsgTypeUDPData, respPacket)
protocol.WriteMessage(stream, respMsg)
}
// findRuleByPort 根据端口查找规则
func (c *Client) findRuleByPort(port int) *protocol.ProxyRule {
c.mu.RLock()
defer c.mu.RUnlock()
for i := range c.rules {
if c.rules[i].RemotePort == port {
return &c.rules[i]
}
}
return nil
}