ecodeclub / eorm

简单 ORM 框架
Apache License 2.0
191 stars 64 forks source link

分库分表:ShardingUpdater 实现 #201

Closed flycash closed 1 year ago

flycash commented 1 year ago

仅限中文

使用场景

现在可以考虑支持 Updater了。差不多可以直接参考 Select 语句和 Inserter 的实现。

在 UPDATE 语句里面,要考虑这个情况:

在集成测试里面要考虑到:

你设置的的 Go 环境?

上传 go env 的结果

Stone-afk commented 1 year ago

+1

Stone-afk commented 1 year ago

ShardingKeys() []string 是否改为 ShardingKey() string

Stone-afk commented 1 year ago

难道出现多 ShardingKeys 的情况?

Stone-afk commented 1 year ago
// Build returns UPDATE []sharding.Query
func (s *ShardingUpdater[T]) Build(ctx context.Context) ([]sharding.Query, error) {
    //t := new(T)
    if s.table == nil {
        s.table = new(T)
    }
    var err error
    if s.meta == nil {
        s.meta, err = s.metaRegistry.Get(s.table)
        if err != nil {
            return nil, err
        }
    }

    shardingRes, err := s.findDst(ctx)
    if err != nil {
        return nil, err
    }

    dsDBMap, err := mapx.NewTreeMap[sharding.Dst, *mapx.TreeMap[sharding.Dst, []*T]](sharding.CompareDSDB)
    if err != nil {
        return nil, err
    }

    for _, dst := range shardingRes.Dsts {
        dsDBVal, ok := dsDBMap.Get(dst)
        if !ok {
            dsDBVal, err = mapx.NewTreeMap[sharding.Dst, []*T](sharding.CompareDSDBTab)
            if err != nil {
                return nil, err
            }
            err = dsDBVal.Put(dst, []*T{s.table})
            if err != nil {
                return nil, err
            }
        } else {
            valList, _ := dsDBVal.Get(dst)
            valList = append(valList, s.table)
            err = dsDBVal.Put(dst, valList)
            if err != nil {
                return nil, err
            }
        }
        err = dsDBMap.Put(dst, dsDBVal)
        if err != nil {
            return nil, err
        }
    }

    // 针对每一个目标库,生成一个 update 语句
    dsDBKeys := dsDBMap.Keys()
    res := make([]sharding.Query, 0, len(dsDBKeys))
    defer bytebufferpool.Put(s.buffer)
    for _, dsDBKey := range dsDBKeys {
        ds := dsDBKey.Name
        db := dsDBKey.DB
        dsDBVal, _ := dsDBMap.Get(dsDBKey)
        for _, dsDBTabKey := range dsDBVal.Keys() {
            //dsDBTabVals, _ := dsDBVal.Get(dsDBTabKey)
            err = s.buildQuery(db, dsDBTabKey.Table)
            if err != nil {
                return nil, err
            }
        }
        res = append(res, sharding.Query{
            SQL:        s.buffer.String(),
            Args:       s.args,
            DB:         db,
            Datasource: ds,
        })
        s.args = nil
        s.buffer.Reset()
    }

    return res, nil

}

func (s *ShardingUpdater[T]) buildQuery(db, tbl string) error {
    var err error

    s.val = s.valCreator.NewPrimitiveValue(s.table, s.meta)
    s.args = make([]interface{}, 0, len(s.meta.Columns))

    s.writeString("UPDATE ")
    s.quote(db)
    s.writeByte('.')
    s.quote(tbl)
    s.writeString(" SET ")
    if len(s.assigns) == 0 {
        err = s.buildDefaultColumns()
    } else {
        err = s.buildAssigns()
    }
    if err != nil {
        return err
    }

    if len(s.where) > 0 {
        s.writeString(" WHERE ")
        err = s.buildPredicates(s.where)
        if err != nil {
            return err
        }
    }
    s.end()

    return nil
}

func (s *ShardingUpdater[T]) findDst(ctx context.Context) (sharding.Result, error) {
    if len(s.where) > 0 {
        pre := s.where[0]
        for i := 1; i < len(s.where)-1; i++ {
            pre = pre.And(s.where[i])
        }
        return s.findDstByPredicate(ctx, pre)
    }
    res := sharding.Result{
        Dsts: s.meta.ShardingAlgorithm.Broadcast(ctx),
    }
    return res, nil
}

Build 参考 insert, findDst 参考 select