diff --git a/internal/client/tunnel/client.go b/internal/client/tunnel/client.go index 8d9d82d..23cd804 100644 --- a/internal/client/tunnel/client.go +++ b/internal/client/tunnel/client.go @@ -3,10 +3,8 @@ package tunnel import ( "crypto/tls" "fmt" - "io" "log" "net" - "net/http" "os" "os/exec" "path/filepath" @@ -19,6 +17,7 @@ import ( "github.com/gotunnel/pkg/plugin/sign" "github.com/gotunnel/pkg/protocol" "github.com/gotunnel/pkg/relay" + "github.com/gotunnel/pkg/update" "github.com/hashicorp/yamux" ) @@ -794,38 +793,23 @@ func (c *Client) sendUpdateResult(stream net.Conn, success bool, message string) func (c *Client) performSelfUpdate(downloadURL string) error { log.Printf("[Client] Starting self-update from: %s", downloadURL) - // 创建临时文件 - tempDir := os.TempDir() - tempFile := filepath.Join(tempDir, "gotunnel_client_update") - - if runtime.GOOS == "windows" { - tempFile += ".exe" - } - - // 下载新版本 - if err := downloadUpdateFile(downloadURL, tempFile); err != nil { - return fmt.Errorf("download update: %w", err) - } - - // 设置执行权限 - if runtime.GOOS != "windows" { - if err := os.Chmod(tempFile, 0755); err != nil { - os.Remove(tempFile) - return fmt.Errorf("chmod: %w", err) - } + // 使用共享的下载和解压逻辑 + binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "client") + if err != nil { + return err } + defer cleanup() // 获取当前可执行文件路径 currentPath, err := os.Executable() if err != nil { - os.Remove(tempFile) return fmt.Errorf("get executable: %w", err) } currentPath, _ = filepath.EvalSymlinks(currentPath) // Windows 需要特殊处理 if runtime.GOOS == "windows" { - return performWindowsClientUpdate(tempFile, currentPath, c.ServerAddr, c.Token, c.ID) + return performWindowsClientUpdate(binaryPath, currentPath, c.ServerAddr, c.Token, c.ID) } // Linux/Mac: 直接替换 @@ -836,16 +820,21 @@ func (c *Client) performSelfUpdate(downloadURL string) error { // 备份当前文件 if err := os.Rename(currentPath, backupPath); err != nil { - os.Remove(tempFile) return fmt.Errorf("backup current: %w", err) } - // 移动新文件 - if err := os.Rename(tempFile, currentPath); err != nil { + // 复制新文件(不能用 rename,可能跨文件系统) + if err := update.CopyFile(binaryPath, currentPath); err != nil { os.Rename(backupPath, currentPath) return fmt.Errorf("replace binary: %w", err) } + // 设置执行权限 + if err := os.Chmod(currentPath, 0755); err != nil { + os.Rename(backupPath, currentPath) + return fmt.Errorf("chmod: %w", err) + } + // 删除备份 os.Remove(backupPath) @@ -867,29 +856,6 @@ func (c *Client) stopAllPlugins() { c.pluginMu.Unlock() } -// downloadUpdateFile 下载更新文件 -func downloadUpdateFile(url, dest string) error { - client := &http.Client{Timeout: 10 * time.Minute} - resp, err := client.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed: %s", resp.Status) - } - - out, err := os.Create(dest) - if err != nil { - return err - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - return err -} - // performWindowsClientUpdate Windows 平台更新 func performWindowsClientUpdate(newFile, currentPath, serverAddr, token, id string) error { // 创建批处理脚本 diff --git a/internal/server/router/handler/helpers.go b/internal/server/router/handler/helpers.go index e960f0f..5ed0bc1 100644 --- a/internal/server/router/handler/helpers.go +++ b/internal/server/router/handler/helpers.go @@ -2,15 +2,13 @@ package handler import ( "fmt" - "io" - "net/http" "os" "os/exec" "path/filepath" "runtime" "strings" - "time" + "github.com/gotunnel/pkg/update" "github.com/gotunnel/pkg/version" ) @@ -115,37 +113,23 @@ func findAssetForPlatform(assets []version.ReleaseAsset, component, osName, arch // performSelfUpdate 执行自更新 func performSelfUpdate(downloadURL string, restart bool) error { - // 下载新版本 - tempDir := os.TempDir() - tempFile := filepath.Join(tempDir, "gotunnel_update_"+time.Now().Format("20060102150405")) - - if runtime.GOOS == "windows" { - tempFile += ".exe" - } - - if err := downloadFile(downloadURL, tempFile); err != nil { - return fmt.Errorf("download update: %w", err) - } - - // 设置执行权限 - if runtime.GOOS != "windows" { - if err := os.Chmod(tempFile, 0755); err != nil { - os.Remove(tempFile) - return fmt.Errorf("chmod: %w", err) - } + // 使用共享的下载和解压逻辑 + binaryPath, cleanup, err := update.DownloadAndExtract(downloadURL, "server") + if err != nil { + return err } + defer cleanup() // 获取当前可执行文件路径 currentPath, err := os.Executable() if err != nil { - os.Remove(tempFile) return fmt.Errorf("get executable: %w", err) } currentPath, _ = filepath.EvalSymlinks(currentPath) // Windows 需要特殊处理(运行中的文件无法直接替换) if runtime.GOOS == "windows" { - return performWindowsUpdate(tempFile, currentPath, restart) + return performWindowsUpdate(binaryPath, currentPath, restart) } // Linux/Mac: 直接替换 @@ -153,16 +137,21 @@ func performSelfUpdate(downloadURL string, restart bool) error { // 备份当前文件 if err := os.Rename(currentPath, backupPath); err != nil { - os.Remove(tempFile) return fmt.Errorf("backup current: %w", err) } - // 移动新文件 - if err := os.Rename(tempFile, currentPath); err != nil { + // 复制新文件(不能用 rename,可能跨文件系统) + if err := update.CopyFile(binaryPath, currentPath); err != nil { os.Rename(backupPath, currentPath) return fmt.Errorf("replace binary: %w", err) } + // 设置执行权限 + if err := os.Chmod(currentPath, 0755); err != nil { + os.Rename(backupPath, currentPath) + return fmt.Errorf("chmod new binary: %w", err) + } + // 删除备份 os.Remove(backupPath) @@ -210,26 +199,3 @@ func restartProcess(path string) { cmd.Start() os.Exit(0) } - -// downloadFile 下载文件 -func downloadFile(url, dest string) error { - client := &http.Client{Timeout: 10 * time.Minute} - resp, err := client.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("download failed: %s", resp.Status) - } - - out, err := os.Create(dest) - if err != nil { - return err - } - defer out.Close() - - _, err = io.Copy(out, resp.Body) - return err -} diff --git a/pkg/update/update.go b/pkg/update/update.go new file mode 100644 index 0000000..c7145b6 --- /dev/null +++ b/pkg/update/update.go @@ -0,0 +1,253 @@ +package update + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +// GetArchiveExt 根据 URL 获取压缩包扩展名 +func GetArchiveExt(url string) string { + if strings.HasSuffix(url, ".tar.gz") { + return ".tar.gz" + } + if strings.HasSuffix(url, ".zip") { + return ".zip" + } + // 默认根据平台 + if runtime.GOOS == "windows" { + return ".zip" + } + return ".tar.gz" +} + +// ExtractArchive 解压压缩包 +func ExtractArchive(archivePath, destDir string) error { + if strings.HasSuffix(archivePath, ".tar.gz") { + return ExtractTarGz(archivePath, destDir) + } + if strings.HasSuffix(archivePath, ".zip") { + return ExtractZip(archivePath, destDir) + } + return fmt.Errorf("unsupported archive format") +} + +// ExtractTarGz 解压 tar.gz 文件 +func ExtractTarGz(archivePath, destDir string) error { + file, err := os.Open(archivePath) + if err != nil { + return err + } + defer file.Close() + + gzReader, err := gzip.NewReader(file) + if err != nil { + return err + } + defer gzReader.Close() + + tarReader := tar.NewReader(gzReader) + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + targetPath := filepath.Join(destDir, header.Name) + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(targetPath, 0755); err != nil { + return err + } + case tar.TypeReg: + outFile, err := os.Create(targetPath) + if err != nil { + return err + } + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return err + } + outFile.Close() + } + } + + return nil +} + +// ExtractZip 解压 zip 文件 +func ExtractZip(archivePath, destDir string) error { + reader, err := zip.OpenReader(archivePath) + if err != nil { + return err + } + defer reader.Close() + + for _, file := range reader.File { + targetPath := filepath.Join(destDir, file.Name) + + if file.FileInfo().IsDir() { + if err := os.MkdirAll(targetPath, 0755); err != nil { + return err + } + continue + } + + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err + } + + srcFile, err := file.Open() + if err != nil { + return err + } + + dstFile, err := os.Create(targetPath) + if err != nil { + srcFile.Close() + return err + } + + _, err = io.Copy(dstFile, srcFile) + srcFile.Close() + dstFile.Close() + if err != nil { + return err + } + } + + return nil +} + +// FindExtractedBinary 在解压目录中查找可执行文件 +func FindExtractedBinary(extractDir, component string) (string, error) { + var binaryPath string + prefix := "gotunnel-" + component + + err := filepath.Walk(extractDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + return nil + } + + name := info.Name() + // 匹配 gotunnel-server-* 或 gotunnel-client-* + if strings.HasPrefix(name, prefix) { + // 排除压缩包本身 + if !strings.HasSuffix(name, ".tar.gz") && !strings.HasSuffix(name, ".zip") { + binaryPath = path + return filepath.SkipAll + } + } + return nil + }) + + if err != nil && err != filepath.SkipAll { + return "", err + } + + if binaryPath == "" { + return "", fmt.Errorf("binary not found in archive") + } + + return binaryPath, nil +} + +// CopyFile 复制文件 +func CopyFile(src, dst string) error { + srcFile, err := os.Open(src) + if err != nil { + return err + } + defer srcFile.Close() + + dstFile, err := os.Create(dst) + if err != nil { + return err + } + defer dstFile.Close() + + _, err = io.Copy(dstFile, srcFile) + return err +} + +// DownloadFile 下载文件 +func DownloadFile(url, dest string) error { + client := &http.Client{Timeout: 10 * time.Minute} + resp, err := client.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("download failed: %s", resp.Status) + } + + out, err := os.Create(dest) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + return err +} + +// DownloadAndExtract 下载并解压压缩包,返回解压后的可执行文件路径 +func DownloadAndExtract(downloadURL, component string) (binaryPath string, cleanup func(), err error) { + tempDir := os.TempDir() + timestamp := time.Now().Format("20060102150405") + archivePath := filepath.Join(tempDir, "gotunnel_update_"+timestamp+GetArchiveExt(downloadURL)) + + if err := DownloadFile(downloadURL, archivePath); err != nil { + return "", nil, fmt.Errorf("download update: %w", err) + } + + extractDir := filepath.Join(tempDir, "gotunnel_extract_"+timestamp) + if err := os.MkdirAll(extractDir, 0755); err != nil { + os.Remove(archivePath) + return "", nil, fmt.Errorf("create extract dir: %w", err) + } + + cleanup = func() { + os.Remove(archivePath) + os.RemoveAll(extractDir) + } + + if err := ExtractArchive(archivePath, extractDir); err != nil { + cleanup() + return "", nil, fmt.Errorf("extract archive: %w", err) + } + + binaryPath, err = FindExtractedBinary(extractDir, component) + if err != nil { + cleanup() + return "", nil, fmt.Errorf("find binary: %w", err) + } + + // 设置执行权限 + if runtime.GOOS != "windows" { + if err := os.Chmod(binaryPath, 0755); err != nil { + cleanup() + return "", nil, fmt.Errorf("chmod: %w", err) + } + } + + return binaryPath, cleanup, nil +}