40分钟学 Go 语言高并发:Pipeline模式(一)

Pipeline模式

一、课程概述

学习要点 重要程度 掌握目标
流水线设计 ★★★★★ 掌握Pipeline基本结构和设计原则
扇入扇出 ★★★★☆ 理解并实现多输入多输出的Pipeline
错误传播 ★★★★★ 掌握Pipeline中的错误处理机制
吞吐量优化 ★★★★☆ 学会优化Pipeline的性能和吞吐量

二、Pipeline模式基础

让我们首先实现一个基础的Pipeline框架:

package pipeline

import (
    "context"
    "fmt"
    "sync"
)

// Stage 代表Pipeline中的一个阶段
type Stage func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error)

// Pipeline 代表一个完整的处理流水线
type Pipeline struct {
    stages []Stage
    errCh  chan error
}

// New 创建新的Pipeline
func New(stages ...Stage) *Pipeline {
    return &Pipeline{
        stages: stages,
        errCh:  make(chan error, len(stages)),
    }
}

// Run 运行Pipeline
func (p *Pipeline) Run(ctx context.Context, in <-chan interface{}) (<-chan interface{}, <-chan error) {
    out := in
    var err error

    // 按顺序执行每个Stage
    for i, stage := range p.stages {
        out, err = stage(ctx, out)
        if err != nil {
            p.errCh <- fmt.Errorf("stage %d failed: %v", i, err)
            close(p.errCh)
            return nil, p.errCh
        }
    }

    return out, p.errCh
}

// Merge 合并多个channel的数据(扇入)
func Merge(ctx context.Context, channels ...<-chan interface{}) <-chan interface{} {
    var wg sync.WaitGroup
    out := make(chan interface{})

    // 为每个输入channel启动一个goroutine
    output := func(c <-chan interface{}) {
        defer wg.Done()
        for n := range c {
            select {
            case out <- n:
            case <-ctx.Done():
                return
            }
        }
    }

    wg.Add(len(channels))
    for _, c := range channels {
        go output(c)
    }

    // 当所有输入channel都关闭后,关闭输出channel
    go func() {
        wg.Wait()
        close(out)
    }()

    return out
}

// Split 将一个channel的数据分配给多个处理goroutine(扇出)
func Split(ctx context.Context, in <-chan interface{}, n int) []<-chan interface{} {
    outs := make([]<-chan interface{}, n)
    for i := 0; i < n; i++ {
        outs[i] = make(chan interface{})
    }

    distribute := func(ch chan<- interface{}) {
        defer close(ch)
        for n := range in {
            select {
            case ch <- n:
            case <-ctx.Done():
                return
            }
        }
    }

    for i := 0; i < n; i++ {
        go distribute(outs[i].(chan interface{}))
    }

    return outs
}

让我们实现一个具体的示例 - 数字处理Pipeline:

package main

import (
    "context"
    "fmt"
    "log"
    "time"
)

// 生成器,生成1到n的数字
func generator(ctx context.Context, n int) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for i := 1; i <= n; i++ {
            select {
            case out <- i:
            case <-ctx.Done():
                return
            }
        }
    }()
    return out, nil
}

// 平方计算Stage
func square(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for n := range in {
            num, ok := n.(int)
            if !ok {
                continue
            }
            select {
            case out <- num * num:
            case <-ctx.Done():
                return
            }
        }
    }()
    return out, nil
}

// 过滤Stage:只保留能被3整除的数
func filter(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for n := range in {
            num, ok := n.(int)
            if !ok {
                continue
            }
            if num%3 == 0 {
                select {
                case out <- num:
                case <-ctx.Done():
                    return
                }
            }
        }
    }()
    return out, nil
}

func main() {
    // 创建Context
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    // 创建Pipeline
    p := New(
        func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
            return square(ctx, in)
        },
        func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
            return filter(ctx, in)
        },
    )

    // 生成输入数据
    input, err := generator(ctx, 10)
    if err != nil {
        log.Fatalf("Generator failed: %v", err)
    }

    // 运行Pipeline
    output, errCh := p.Run(ctx, input)

    // 处理输出和错误
    for {
        select {
        case n, ok := <-output:
            if !ok {
                return
            }
            fmt.Printf("Output: %v\n", n)
        case err := <-errCh:
            if err != nil {
                log.Printf("Pipeline error: %v", err)
                return
            }
        case <-ctx.Done():
            fmt.Println("Pipeline cancelled")
            return
        }
    }
}

三、Pipeline流程图

在这里插入图片描述

四、高级Pipeline实现

让我们实现一个更复杂的Pipeline,包含错误处理和性能优化:

package pipeline

import (
    "context"
    "fmt"
    "runtime"
    "sync"
    "time"
)

// Result 包含处理结果和错误信息
type Result struct {
    Value interface{}
    Err   error
}

// StageFunc 定义处理函数类型
type StageFunc func(interface{}) (interface{}, error)

// Options Pipeline配置选项
type Options struct {
    BufferSize  int           // channel缓冲区大小
    NumWorkers  int           // 工作goroutine数量
    Timeout     time.Duration // 处理超时时间
    RetryCount  int           // 重试次数
    RetryDelay  time.Duration // 重试延迟
}

// AdvancedPipeline 高级Pipeline实现
type AdvancedPipeline struct {
    stages   []StageFunc
    options  Options
    metrics  *Metrics
    errHandler func(error) error
}

// Metrics 性能指标
type Metrics struct {
    mu            sync.RWMutex
    processedItems int64
    errorCount     int64
    avgProcessTime time.Duration
}

// NewAdvanced 创建高级Pipeline
func NewAdvanced(opts Options, stages ...StageFunc) *AdvancedPipeline {
    if opts.NumWorkers <= 0 {
        opts.NumWorkers = runtime.NumCPU()
    }

    return &AdvancedPipeline{
        stages:  stages,
        options: opts,
        metrics: &Metrics{},
    }
}

// SetErrorHandler 设置错误处理函数
func (p *AdvancedPipeline) SetErrorHandler(handler func(error) error) {
    p.errHandler = handler
}

// Process 处理数据
func (p *AdvancedPipeline) Process(ctx context.Context, input <-chan interface{}) (<-chan Result, error) {
    if len(p.stages) == 0 {
        return nil, fmt.Errorf("no stages defined")
    }

    output := make(chan Result, p.options.BufferSize)
    var wg sync.WaitGroup

    // 创建工作池
    for i := 0; i < p.options.NumWorkers; i++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()
            p.worker(ctx, workerID, input, output)
        }(i)
    }

    // 等待所有工作完成后关闭输出channel
    go func() {
        wg.Wait()
        close(output)
    }()

    return output, nil
}

// worker 工作goroutine
func (p *AdvancedPipeline) worker(ctx context.Context, id int, input <-chan interface{}, output chan<- Result) {
    for data := range input {
        // 处理每个输入项
        startTime := time.Now()
        result := p.processItem(ctx, data)

        // 更新指标
        p.updateMetrics(startTime, result.Err != nil)

        // 发送结果
        select {
        case output <- result:
        case <-ctx.Done():
            return
        }
    }
}

// processItem 处理单个数据项
func (p *AdvancedPipeline) processItem(ctx context.Context, data interface{}) Result {
    var value interface{} = data
    var err error

    // 执行每个阶段
    for i, stage := range p.stages {
        value, err = p.executeStageWithRetry(ctx, stage, value)
        if err != nil {
            if p.errHandler != nil {
                if handlerErr := p.errHandler(err); handlerErr != nil {
                    err = fmt.Errorf("stage %d failed: %v (handler error: %v)", i, err, handlerErr)
                }
            }
            return Result{Err: err}
        }
    }

    return Result{Value: value}
}

// executeStageWithRetry 带重试的阶段执行
func (p *AdvancedPipeline) executeStageWithRetry(ctx context.Context, stage StageFunc, data interface{}) (interface{}, error) {
    var lastErr error

    for attempt := 0; attempt <= p.options.RetryCount; attempt++ {
        // 创建带超时的Context
        timeoutCtx, cancel := context.WithTimeout(ctx, p.options.Timeout)
        
        // 执行阶段处理
        done := make(chan struct{})
        var result interface{}
        var err error

        go func() {
            result, err = stage(data)
            close(done)
        }()

        // 等待处理完成或超时
        select {
        case <-done:
            cancel()
            if err == nil {
                return result, nil
            }
            lastErr = err
        case <-timeoutCtx.Done():
            cancel()
            lastErr = fmt.Errorf("stage timeout")
        }

        // 如果不是最后一次重试,则等待后继续
        if attempt < p.options.RetryCount {
            select {
            case <-time.After(p.options.RetryDelay):
            case <-ctx.Done():
                return nil, ctx.Err()
            }
        }
    }

    return nil, fmt.Errorf("all retry attempts failed: %v", lastErr)
}

// updateMetrics 更新性能指标
func (p *AdvancedPipeline) updateMetrics(startTime time.Time, hasError bool) {
    p.metrics.mu.Lock()
    defer p.metrics.mu.Unlock()

    p.metrics.processedItems++
    if hasError {
        p.metrics.errorCount++
    }

    // 更新平均处理时间
    processingTime := time.Since(startTime)
    if p.metrics.avgProcessTime == 0 {
        p.metrics.avgProcessTime = processingTime
    } else {
        p.metrics.avgProcessTime = (p.metrics.avgProcessTime + processingTime) / 2
    }
}

// GetMetrics 获取性能指标
func (p *AdvancedPipeline) GetMetrics() (int64, int64, time.Duration) {
    p.metrics.mu.RLock()
    defer p.metrics.mu.RUnlock()
    return p.metrics.processedItems, p.metrics.errorCount, p.metrics.avgProcessTime
}

// Reset 重置性能指标
func (p *AdvancedPipeline) Reset() {
    p.metrics.mu.Lock()
    defer p.metrics.mu.Unlock()
    p.metrics.processedItems = 0
    p.metrics.errorCount = 0
    p.metrics.avgProcessTime = 0
}

// WithBufferSize 设置缓冲区大小
func (p *AdvancedPipeline) WithBufferSize(size int
上一篇:Ubuntu20.04下配置Cuda12.1+Cudnn


下一篇:音视频基础扫盲之认识PCM(Pulse Code Modulation,脉冲编码调制)