golang之sftp小工具

golang之sftp小工具


前言

使用xshell在几个服务器之间传输数据的时候,如果需要传输很多带目录的多个文件,xshell不方便,所以使用golang写了一个小工具,可以直接在两个服务器之间传输文件(保留原有的目录的格式)


一、代码

package main

import (
	"github.com/pkg/sftp"
	"golang.org/x/crypto/ssh"
	"time"
	"fmt"
	"io/ioutil"
	"path"
	"os"
	"flag"
	"strings"
	"strconv"
	"sync"
)

// -L 本地目录 如果有-s则便是从远程下载到本地  如果有-d表示从本地下载到远程
var source = flag.String("s", "","")
var dest = flag.String("d","","")
var help = flag.Bool("h",false,"")
var local = flag.String("L","","")
var contiuneIfError = flag.Bool("c",true,"")
var exclude = flag.String("e","","")
var include = flag.String("i","","")

var successCount int = 0
var errorCount int = 0
var errorPaths = make([]string,0)

func main(){
	flag.Parse()
	if *help {
		fmt.Println("示例: ./sftp.exe -s 用户名:密码:10.X.X.X:22:/dev/test -d 用户名:密码:10.X.X.X:22:/opt")
		fmt.Println("-s user:passwd:ip:port:dirPath", "源目标服务器信息")
		fmt.Println("-d user:passwd:ip:port:dirPath", "目的地服务器信息")
		fmt.Println("-L 路径 注:-L不为空时,当有-s时表示从远程下载到本地 当有-d时表示上传")
		fmt.Println("优先级:-L>-s>-d")
		os.Exit(0)
	}
	if *local == ""{
		if *source == "" || *dest == "" {
			fmt.Println("输入格式不规范,使用-h查看使用说明")
			os.Exit(0)
		}
	}

	if *local == ""{
		s := strings.Split(*source, ":")
		e := strings.Split(*dest, ":")

		sUser := s[0]
		sPwd := s[1]
		sHost := s[2]
		sPort := s[3]
		sPath := s[4]

		dUser := e[0]
		dPwd := e[1]
		dHost := e[2]
		dPort := e[3]
		dPath := e[4]

		if sUser == "" || sPwd == "" || sHost == "" || sPort == "" || sPath == "" ||
			dUser == "" || dPwd == "" || dHost == "" || dPort == "" || dPath == ""{
			fmt.Println("输入格式不规范,使用-h查看使用说明")
			os.Exit(0)
		}
		sP,_ := strconv.Atoi(sPort)
		sourceClient,_ := connect(sUser,sPwd,sHost,sP)
		dP,_ := strconv.Atoi(dPort)
		destClient,_ := connect(dUser,dPwd,dHost,dP)
		fmt.Println("开始传输")
		transfer(sourceClient, destClient, sPath,dPath)
	}else{
		if *source != ""{
			s := strings.Split(*source, ":")
			sUser := s[0]
			sPwd := s[1]
			sHost := s[2]
			sPort := s[3]
			sPath := s[4]

			if sUser == "" || sPwd == "" || sHost == "" || sPort == "" || sPath == "" {
				fmt.Println("输入格式不规范,使用-h查看使用说明")
				os.Exit(0)
			}
			sP,_ := strconv.Atoi(sPort)
			sourceClient,_ := connect(sUser,sPwd,sHost,sP)
			fmt.Println("开始下载")
			download(sourceClient, sPath,*local)
		}else{
			d := strings.Split(*dest, ":")
			dUser := d[0]
			dPwd := d[1]
			dHost := d[2]
			dPort := d[3]
			dPath := d[4]

			if  dUser == "" || dPwd == "" || dHost == "" || dPort == "" || dPath == ""{
				fmt.Println("输入格式不规范,使用-h查看使用说明")
				os.Exit(0)
			}
			dP,_ := strconv.Atoi(dPort)
			destClient,_ := connect(dUser,dPwd,dHost,dP)
			fmt.Println("开始上传")
			upload(*local, dPath, destClient)
		}
	}
	fmt.Println("本次传输成功文件数量:",successCount," 失败文件数量:",errorCount)
	if errorCount != 0{
		fmt.Println("错误文件路径:")
		for _,path := range errorPaths {
			fmt.Println(path)
		}
	}
}

func connect(user, password, host string, port int) (*sftp.Client, error) {
	var (
		auth         []ssh.AuthMethod
		addr         string
		clientConfig *ssh.ClientConfig
		sshClient    *ssh.Client
		sftpClient   *sftp.Client
		err          error
	)
	// get auth method
	auth = make([]ssh.AuthMethod, 0)
	auth = append(auth, ssh.Password(password))
	clientConfig = &ssh.ClientConfig{
		User:            user,
		Auth:            auth,
		Timeout:         30 * time.Second,
		HostKeyCallback: ssh.InsecureIgnoreHostKey(), //ssh.FixedHostKey(hostKey),
	}

	// connet to ssh
	addr = fmt.Sprintf("%s:%d", host, port)
	if sshClient, err = ssh.Dial("tcp", addr, clientConfig); err != nil {
		return nil, err
	}

	// create sftp client
	if sftpClient, err = sftp.NewClient(sshClient); err != nil {
		return nil, err
	}
	return sftpClient, nil
}

// 远程->远程
func transfer(sourceClient,destClient *sftp.Client,sourcePath,destPath string){
	defer func() {
		if err := recover(); err != nil {
			fmt.Println("程序异常退出")
		}
	}()
	sourceFile, err := sourceClient.Open(sourcePath)
	defer func() { if sourceFile != nil {sourceFile.Close()}}()
	if err != nil {
		fmt.Println("源文件|目录打开失败",err)
		os.Exit(0)
	}

	info,_ := sourceFile.Stat()

	if info.IsDir(){
		// 如果是目录就继续往下遍历
		nextDestPath := path.Join(destPath, info.Name())
		destClient.Mkdir(nextDestPath) // 创建对应目录
		childInfos,_ := sourceClient.ReadDir(sourcePath)
		for _,child := range childInfos {
			nextSourcePath := path.Join(sourcePath, child.Name())
			transfer(sourceClient,destClient,nextSourcePath,nextDestPath)
		}
	}else{
		// 是文件就移走
		sourceFile, err := sourceClient.Open(sourcePath)
		defer func() { if sourceFile != nil {sourceFile.Close()}}()
		if err != nil {
			fmt.Println("源文件|目录打开失败",err)
			os.Exit(0)
		}
		destFileName := path.Join(destPath, info.Name())
		destFile,_ := destClient.Create(destFileName)
		defer destFile.Close()

		if _,err := sourceFile.WriteTo(destFile); err != nil {
			errorCount++
			errorPaths = append(errorPaths, sourcePath)
			fmt.Println(sourcePath," -> ",destFileName,"失败")
		}else{
			successCount++
			fmt.Println(sourcePath," -> ",destFileName, "成功")
		}
	}
}

// 远程->本地
func download(sourceClient *sftp.Client, sourcePath string, destPath string){
	defer func() {
		if err := recover(); err != nil {
			fmt.Println("程序异常", err)
		}
	}()

	sourceFile, err := sourceClient.Open(sourcePath)
	//defer func() { if sourceFile != nil {sourceFile.Close()}}()
	if err != nil {
		fmt.Println("源文件|目录打开失败",err)
		os.Exit(0)
	}
	defer sourceFile.Close()

	info,_ := sourceFile.Stat()
	if info.IsDir(){
		nextDestPath := path.Join(destPath, info.Name())
		os.Mkdir(nextDestPath,755)
		childInfos,_ := sourceClient.ReadDir(sourcePath)
		for _,child := range childInfos {
			nextSourcePath := path.Join(sourcePath, child.Name())
			download(sourceClient, nextSourcePath, nextDestPath)
		}
	}else{
		sourceFile,_ := sourceClient.Open(sourcePath)
		defer func() { if sourceFile != nil {sourceFile.Close()}}()

		destFileName := path.Join(destPath, info.Name())
		destFile,_ := os.Create(destFileName)
		defer destFile.Close()

		if _, err = sourceFile.WriteTo(destFile); err != nil {
			errorCount++
			errorPaths = append(errorPaths, sourcePath)
			fmt.Println(sourcePath," -> ",destFileName,"失败")
		}else{
			successCount++
			fmt.Println(sourcePath," -> ",destFileName, "成功")
		}
	}
}

// 本地->远程
func upload(soucePath,destPath string, destSftpClient *sftp.Client){
	defer func() {
		if err := recover(); err != nil {
			fmt.Println("程序异常", err)
		}
	}()

	f,_ := os.Open(soucePath)
	defer f.Close()
	info,_ := f.Stat()
	if info.IsDir(){
		nextDestPath := path.Join(destPath, info.Name())
		destSftpClient.Mkdir(nextDestPath) // 创建对应目录
		localFiles,_ := ioutil.ReadDir(soucePath)
		for _,localFile := range localFiles {
			nextSourcePath := path.Join(soucePath, localFile.Name())
			upload(nextSourcePath, nextDestPath,destSftpClient)
		}
	}else{
		info,_ := f.Stat()
		destFileName := path.Join(destPath, info.Name())
		//fmt.Println("dest:" + destFileName)
		dstFile, _ := destSftpClient.Create(destFileName)
		defer dstFile.Close()
		ff, _ := ioutil.ReadAll(f)
		_,err := dstFile.Write(ff)
		if err != nil {
			errorCount++
			errorPaths = append(errorPaths, f.Name())
			fmt.Println(f.Name()," -> ",destFileName,"失败")
		}else{
			successCount++
			fmt.Println(f.Name()," -> ",destFileName, "成功")
		}
	}
}

可以在windows使用也可以在linux上使用,./sftp.ext -h 可以查看使用帮助,golang的自己打包吧

总结

能够优化的地方:开启多协程传输更快,有些详细错误没有打印,可以加上过滤器

上一篇:异步复制文件


下一篇:批量替换MarkDown文档中指定的字符串