Open shenweijiekdel opened 1 year ago
@shenweijiekdel Too late but better then never.
import (
"context"
"database/sql"
"errors"
"github.com/google/uuid"
)
type contextKey string
const txKey contextKey = "sql_tx"
const txIDKey contextKey = "tx_id"
type Transactor struct {
conn *sql.DB
wraps map[context.Context][]func(ctx context.Context) error
}
func NewTransactor(c *sql.DB) Transactor {
return Transactor{
conn: c,
wraps: make(map[context.Context][]func(ctx context.Context) error),
}
}
func (t *Transactor) NewTxContext(ctx context.Context) context.Context {
return context.WithValue(ctx, txIDKey, uuid.NewString())
}
func (t *Transactor) hasTxID(ctx context.Context) bool {
txID := ctx.Value(txIDKey)
return txID != nil && txID != ""
}
func (t *Transactor) InTransaction(ctx context.Context, txFunc func(ctx context.Context) error) error {
if !t.hasTxID(ctx) {
return errors.New("not transaction context. Please create it with NewTxContext")
}
if _, ok := t.wraps[ctx]; !ok {
t.wraps[ctx] = make([]func(ctx context.Context) error, 0, 0)
}
t.wraps[ctx] = append(t.wraps[ctx], txFunc)
return nil
}
func (t *Transactor) GetConn(ctx context.Context) *sql.DB {
conn, ok := ctx.Value(txKey).(*sql.DB)
if !ok {
return t.conn
}
return conn
}
func (t *Transactor) RunTransaction(ctx context.Context) error {
defer t.reset(ctx)
tx, err := t.conn.BeginTx(ctx, nil)
if err != nil {
return err
}
txCtx := context.WithValue(ctx, txKey, tx)
for _, wrap := range t.wraps[ctx] {
err = wrap(txCtx)
if err != nil {
err = tx.Rollback()
if err != nil {
return err
}
}
}
err = tx.Commit()
if err != nil {
return err
}
return nil
}
func (t *Transactor) reset(ctx context.Context) {
delete(t.wraps, ctx)
}
Inject it into any repository as a transactor Example authorRepo
package mysql
import (
"context"
"database/sql"
"fmt"
"github.com/bxcodec/go-clean-arch/domain"
)
type AuthorRepository struct {
DB *sql.DB
transactor transactor.Transactor
}
// NewMysqlAuthorRepository will create an implementation of author.Repository
func NewAuthorRepository(db *sql.DB) *AuthorRepository {
return &AuthorRepository{
DB: db,
transactor:transactor.New(db)
}
}
func (m *AuthorRepository) ExampleOfUsage() error {
// Create a new transaction context
ctx := context.Background()
txCtx := m.transactor.NewTxContext(ctx)
// Add functions to the transaction
err = m.transactor.InTransaction(txCtx, func(ctx context.Context) error {
// Example database operation 1
_, err := m.transactor.GetConn(ctx).ExecContext(ctx, "query1")
if err != nil {
return err
}
return nil
})
if err != nil {
fmt.Errorf("Failed to add transaction function: %v", err)
}
err = transactor.InTransaction(txCtx, func(ctx context.Context) error {
// Example database operation 2
_, err := m.transactor.GetConn(ctx).ExecContext(ctx, "query2", )
if err != nil {
return err
}
return nil
})
if err != nil {
fmt.Errorf("Failed to add transaction function: %v", err)
}
// Run the transaction
err = m.transactor.RunTransaction(txCtx)
if err != nil {
return fmt.Errorf("Transaction failed: %v", err)
}
}
if the transaction only handles one repository, you can use begin & comit, but if you need to perform transactions for more than one repository function, you can handle it in the usecase by defining a new repository transaction.
`
I don't know if this method is clean or not, but at least it solves the problem.