diff --git a/cmd/doctor.go b/cmd/doctor.go index ba17e06de2..3807623ccd 100644 --- a/cmd/doctor.go +++ b/cmd/doctor.go @@ -20,7 +20,6 @@ import ( "forgejo.org/services/doctor" "github.com/urfave/cli/v2" - "xorm.io/xorm" ) // CmdDoctor represents the available doctor sub-command. @@ -120,7 +119,7 @@ func runRecreateTable(ctx *cli.Context) error { args := ctx.Args() names := make([]string, 0, ctx.NArg()) - for i := 0; i < ctx.NArg(); i++ { + for i := range ctx.NArg() { names = append(names, args.Get(i)) } @@ -130,11 +129,17 @@ func runRecreateTable(ctx *cli.Context) error { } recreateTables := migrate_base.RecreateTables(beans...) - return db.InitEngineWithMigration(stdCtx, func(x *xorm.Engine) error { - if err := migrations.EnsureUpToDate(x); err != nil { + return db.InitEngineWithMigration(stdCtx, func(x db.Engine) error { + engine, err := db.GetMasterEngine(x) + if err != nil { return err } - return recreateTables(x) + + if err := migrations.EnsureUpToDate(engine); err != nil { + return err + } + + return recreateTables(engine) }) } diff --git a/cmd/migrate.go b/cmd/migrate.go index ab291cfb66..c192ca1966 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -36,7 +36,13 @@ func runMigrate(ctx *cli.Context) error { log.Info("Log path: %s", setting.Log.RootPath) log.Info("Configuration file: %s", setting.CustomConf) - if err := db.InitEngineWithMigration(context.Background(), migrations.Migrate); err != nil { + if err := db.InitEngineWithMigration(context.Background(), func(dbEngine db.Engine) error { + masterEngine, err := db.GetMasterEngine(dbEngine) + if err != nil { + return err + } + return migrations.Migrate(masterEngine) + }); err != nil { log.Fatal("Failed to initialize ORM engine: %v", err) return err } diff --git a/cmd/migrate_storage.go b/cmd/migrate_storage.go index 1b839e7169..6a04dd48ae 100644 --- a/cmd/migrate_storage.go +++ b/cmd/migrate_storage.go @@ -23,6 +23,7 @@ import ( "forgejo.org/modules/storage" "github.com/urfave/cli/v2" + "xorm.io/xorm" ) // CmdMigrateStorage represents the available migrate storage sub-command. @@ -195,7 +196,9 @@ func runMigrateStorage(ctx *cli.Context) error { log.Info("Log path: %s", setting.Log.RootPath) log.Info("Configuration file: %s", setting.CustomConf) - if err := db.InitEngineWithMigration(context.Background(), migrations.Migrate); err != nil { + if err := db.InitEngineWithMigration(context.Background(), func(e db.Engine) error { + return migrations.Migrate(e.(*xorm.Engine)) + }); err != nil { log.Fatal("Failed to initialize ORM engine: %v", err) return err } diff --git a/models/db/engine.go b/models/db/engine.go index ca6576da8a..7283b1d516 100755 --- a/models/db/engine.go +++ b/models/db/engine.go @@ -95,34 +95,70 @@ func init() { } } -// newXORMEngine returns a new XORM engine from the configuration -func newXORMEngine() (*xorm.Engine, error) { - connStr, err := setting.DBConnStr() +// newXORMEngineGroup creates an xorm.EngineGroup (with one master and one or more slaves). +// It assumes you have separate master and slave DSNs defined via the settings package. +func newXORMEngineGroup() (Engine, error) { + // Retrieve master DSN from settings. + masterConnStr, err := setting.DBMasterConnStr() if err != nil { - return nil, err + return nil, fmt.Errorf("failed to determine master DSN: %w", err) } - var engine *xorm.Engine - + var masterEngine *xorm.Engine + // For PostgreSQL: if a schema is provided, we use the special "postgresschema" driver. if setting.Database.Type.IsPostgreSQL() && len(setting.Database.Schema) > 0 { - // OK whilst we sort out our schema issues - create a schema aware postgres registerPostgresSchemaDriver() - engine, err = xorm.NewEngine("postgresschema", connStr) + masterEngine, err = xorm.NewEngine("postgresschema", masterConnStr) } else { - engine, err = xorm.NewEngine(setting.Database.Type.String(), connStr) + masterEngine, err = xorm.NewEngine(setting.Database.Type.String(), masterConnStr) } - if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create master engine: %w", err) } if setting.Database.Type.IsMySQL() { - engine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + masterEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) } - engine.SetSchema(setting.Database.Schema) - return engine, nil + masterEngine.SetSchema(setting.Database.Schema) + + slaveConnStrs, err := setting.DBSlaveConnStrs() + if err != nil { + return nil, fmt.Errorf("failed to load slave DSNs: %w", err) + } + + var slaveEngines []*xorm.Engine + // Iterate over all slave DSNs and create engines + for _, dsn := range slaveConnStrs { + slaveEngine, err := xorm.NewEngine(setting.Database.Type.String(), dsn) + if err != nil { + return nil, fmt.Errorf("failed to create slave engine for dsn %q: %w", dsn, err) + } + if setting.Database.Type.IsMySQL() { + slaveEngine.Dialect().SetParams(map[string]string{"rowFormat": "DYNAMIC"}) + } + slaveEngine.SetSchema(setting.Database.Schema) + slaveEngines = append(slaveEngines, slaveEngine) + } + + policy := setting.BuildLoadBalancePolicy(&setting.Database, slaveEngines) + + // Create the EngineGroup using the selected policy + group, err := xorm.NewEngineGroup(masterEngine, slaveEngines, policy) + if err != nil { + return nil, fmt.Errorf("failed to create engine group: %w", err) + } + return engineGroupWrapper{group}, nil } -// SyncAllTables sync the schemas of all tables, is required by unit test code +type engineGroupWrapper struct { + *xorm.EngineGroup +} + +func (w engineGroupWrapper) AddHook(hook contexts.Hook) bool { + w.EngineGroup.AddHook(hook) + return true +} + +// SyncAllTables sync the schemas of all tables func SyncAllTables() error { _, err := x.StoreEngine("InnoDB").SyncWithOptions(xorm.SyncOptions{ WarnIfDatabaseColumnMissed: true, @@ -130,52 +166,61 @@ func SyncAllTables() error { return err } -// InitEngine initializes the xorm.Engine and sets it as db.DefaultContext +// InitEngine initializes the xorm EngineGroup and sets it as db.DefaultContext func InitEngine(ctx context.Context) error { - xormEngine, err := newXORMEngine() + xormEngine, err := newXORMEngineGroup() if err != nil { return fmt.Errorf("failed to connect to database: %w", err) } + // Try to cast to the concrete type to access diagnostic methods + if eng, ok := xormEngine.(engineGroupWrapper); ok { + eng.SetMapper(names.GonicMapper{}) + // WARNING: for serv command, MUST remove the output to os.Stdout, + // so use a log file instead of printing to stdout. + eng.SetLogger(NewXORMLogger(setting.Database.LogSQL)) + eng.ShowSQL(setting.Database.LogSQL) + eng.SetMaxOpenConns(setting.Database.MaxOpenConns) + eng.SetMaxIdleConns(setting.Database.MaxIdleConns) + eng.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) + eng.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) + eng.SetDefaultContext(ctx) - xormEngine.SetMapper(names.GonicMapper{}) - // WARNING: for serv command, MUST remove the output to os.stdout, - // so use log file to instead print to stdout. - xormEngine.SetLogger(NewXORMLogger(setting.Database.LogSQL)) - xormEngine.ShowSQL(setting.Database.LogSQL) - xormEngine.SetMaxOpenConns(setting.Database.MaxOpenConns) - xormEngine.SetMaxIdleConns(setting.Database.MaxIdleConns) - xormEngine.SetConnMaxLifetime(setting.Database.ConnMaxLifetime) - xormEngine.SetConnMaxIdleTime(setting.Database.ConnMaxIdleTime) - xormEngine.SetDefaultContext(ctx) + if setting.Database.SlowQueryThreshold > 0 { + eng.AddHook(&SlowQueryHook{ + Treshold: setting.Database.SlowQueryThreshold, + Logger: log.GetLogger("xorm"), + }) + } - if setting.Database.SlowQueryThreshold > 0 { - xormEngine.AddHook(&SlowQueryHook{ - Treshold: setting.Database.SlowQueryThreshold, - Logger: log.GetLogger("xorm"), + errorLogger := log.GetLogger("xorm") + if setting.IsInTesting { + errorLogger = log.GetLogger(log.DEFAULT) + } + + eng.AddHook(&ErrorQueryHook{ + Logger: errorLogger, }) + + eng.AddHook(&TracingHook{}) + + SetDefaultEngine(ctx, eng) + } else { + // Fallback: if type assertion fails, set default engine without extended diagnostics + SetDefaultEngine(ctx, xormEngine) } - - errorLogger := log.GetLogger("xorm") - if setting.IsInTesting { - errorLogger = log.GetLogger(log.DEFAULT) - } - - xormEngine.AddHook(&ErrorQueryHook{ - Logger: errorLogger, - }) - - xormEngine.AddHook(&TracingHook{}) - - SetDefaultEngine(ctx, xormEngine) return nil } -// SetDefaultEngine sets the default engine for db -func SetDefaultEngine(ctx context.Context, eng *xorm.Engine) { - x = eng +// SetDefaultEngine sets the default engine for db. +func SetDefaultEngine(ctx context.Context, eng Engine) { + masterEngine, err := GetMasterEngine(eng) + if err == nil { + x = masterEngine + } + DefaultContext = &Context{ Context: ctx, - e: x, + e: eng, } } @@ -191,12 +236,12 @@ func UnsetDefaultEngine() { DefaultContext = nil } -// InitEngineWithMigration initializes a new xorm.Engine and sets it as the db.DefaultContext +// InitEngineWithMigration initializes a new xorm EngineGroup, runs migrations, and sets it as db.DefaultContext // This function must never call .Sync() if the provided migration function fails. // When called from the "doctor" command, the migration function is a version check // that prevents the doctor from fixing anything in the database if the migration level // is different from the expected value. -func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err error) { +func InitEngineWithMigration(ctx context.Context, migrateFunc func(Engine) error) (err error) { if err = InitEngine(ctx); err != nil { return err } @@ -230,14 +275,14 @@ func InitEngineWithMigration(ctx context.Context, migrateFunc func(*xorm.Engine) return nil } -// NamesToBean return a list of beans or an error +// NamesToBean returns a list of beans given names func NamesToBean(names ...string) ([]any, error) { beans := []any{} if len(names) == 0 { beans = append(beans, tables...) return beans, nil } - // Need to map provided names to beans... + // Map provided names to beans beanMap := make(map[string]any) for _, bean := range tables { beanMap[strings.ToLower(reflect.Indirect(reflect.ValueOf(bean)).Type().Name())] = bean @@ -259,7 +304,7 @@ func NamesToBean(names ...string) ([]any, error) { return beans, nil } -// DumpDatabase dumps all data from database according the special database SQL syntax to file system. +// DumpDatabase dumps all data from database using special SQL syntax to the file system. func DumpDatabase(filePath, dbType string) error { var tbs []*schemas.Table for _, t := range tables { @@ -295,29 +340,33 @@ func MaxBatchInsertSize(bean any) int { return 999 / len(t.ColumnsSeq()) } -// IsTableNotEmpty returns true if table has at least one record +// IsTableNotEmpty returns true if the table has at least one record func IsTableNotEmpty(beanOrTableName any) (bool, error) { return x.Table(beanOrTableName).Exist() } -// DeleteAllRecords will delete all the records of this table +// DeleteAllRecords deletes all records in the given table. func DeleteAllRecords(tableName string) error { _, err := x.Exec(fmt.Sprintf("DELETE FROM %s", tableName)) return err } -// GetMaxID will return max id of the table +// GetMaxID returns the maximum id in the table func GetMaxID(beanOrTableName any) (maxID int64, err error) { _, err = x.Select("MAX(id)").Table(beanOrTableName).Get(&maxID) return maxID, err } func SetLogSQL(ctx context.Context, on bool) { - e := GetEngine(ctx) - if x, ok := e.(*xorm.Engine); ok { - x.ShowSQL(on) - } else if sess, ok := e.(*xorm.Session); ok { + ctxEngine := GetEngine(ctx) + + if sess, ok := ctxEngine.(*xorm.Session); ok { sess.Engine().ShowSQL(on) + } else if wrapper, ok := ctxEngine.(engineGroupWrapper); ok { + // Handle engineGroupWrapper directly + wrapper.ShowSQL(on) + } else if masterEngine, err := GetMasterEngine(ctxEngine); err == nil { + masterEngine.ShowSQL(on) } } @@ -374,3 +423,18 @@ func (h *ErrorQueryHook) AfterProcess(c *contexts.ContextHook) error { } return nil } + +// GetMasterEngine extracts the master xorm.Engine from the provided xorm.Engine. +// This handles both direct xorm.Engine cases and engines that implement a Master() method. +func GetMasterEngine(x Engine) (*xorm.Engine, error) { + if getter, ok := x.(interface{ Master() *xorm.Engine }); ok { + return getter.Master(), nil + } + + engine, ok := x.(*xorm.Engine) + if !ok { + return nil, fmt.Errorf("unsupported engine type: %T", x) + } + + return engine, nil +} diff --git a/models/db/index_test.go b/models/db/index_test.go index 929e514329..b64a816bd2 100644 --- a/models/db/index_test.go +++ b/models/db/index_test.go @@ -33,10 +33,11 @@ func getCurrentResourceIndex(ctx context.Context, tableName string, groupID int6 func TestSyncMaxResourceIndex(t *testing.T) { require.NoError(t, unittest.PrepareTestDatabase()) - xe := unittest.GetXORMEngine() + xe, err := unittest.GetXORMEngine() + require.NoError(t, err) require.NoError(t, xe.Sync(&TestIndex{})) - err := db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 51) + err = db.SyncMaxResourceIndex(db.DefaultContext, "test_index", 10, 51) require.NoError(t, err) // sync new max index @@ -88,7 +89,8 @@ func TestSyncMaxResourceIndex(t *testing.T) { func TestGetNextResourceIndex(t *testing.T) { require.NoError(t, unittest.PrepareTestDatabase()) - xe := unittest.GetXORMEngine() + xe, err := unittest.GetXORMEngine() + require.NoError(t, err) require.NoError(t, xe.Sync(&TestIndex{})) // create a new record diff --git a/models/db/iterate_test.go b/models/db/iterate_test.go index 47b6a956f4..bdeaa876d5 100644 --- a/models/db/iterate_test.go +++ b/models/db/iterate_test.go @@ -17,7 +17,8 @@ import ( func TestIterate(t *testing.T) { require.NoError(t, unittest.PrepareTestDatabase()) - xe := unittest.GetXORMEngine() + xe, err := unittest.GetXORMEngine() + require.NoError(t, err) require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) cnt, err := db.GetEngine(db.DefaultContext).Count(&repo_model.RepoUnit{}) diff --git a/models/db/list_test.go b/models/db/list_test.go index 7108b64ead..502372782d 100644 --- a/models/db/list_test.go +++ b/models/db/list_test.go @@ -29,11 +29,12 @@ func (opts mockListOptions) ToConds() builder.Cond { func TestFind(t *testing.T) { require.NoError(t, unittest.PrepareTestDatabase()) - xe := unittest.GetXORMEngine() + xe, err := unittest.GetXORMEngine() + require.NoError(t, err) require.NoError(t, xe.Sync(&repo_model.RepoUnit{})) var repoUnitCount int - _, err := db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount) + _, err = db.GetEngine(db.DefaultContext).SQL("SELECT COUNT(*) FROM repo_unit").Get(&repoUnitCount) require.NoError(t, err) assert.NotEmpty(t, repoUnitCount) diff --git a/models/migrations/migrations.go b/models/migrations/migrations.go index 11933014d7..aea9b593bd 100644 --- a/models/migrations/migrations.go +++ b/models/migrations/migrations.go @@ -8,6 +8,7 @@ import ( "context" "fmt" + "forgejo.org/models/db" "forgejo.org/models/forgejo_migrations" "forgejo.org/models/migrations/v1_10" "forgejo.org/models/migrations/v1_11" @@ -510,3 +511,12 @@ Please try upgrading to a lower version first (suggested v1.6.4), then upgrade t // Execute Forgejo specific migrations. return forgejo_migrations.Migrate(x) } + +// WrapperMigrate is a wrapper for Migrate to be called in diagnostics +func WrapperMigrate(e db.Engine) error { + engine, err := db.GetMasterEngine(e) + if err != nil { + return err + } + return Migrate(engine) +} diff --git a/models/migrations/test/tests.go b/models/migrations/test/tests.go index 07487cf58a..c1f0caf19b 100644 --- a/models/migrations/test/tests.go +++ b/models/migrations/test/tests.go @@ -175,7 +175,10 @@ func newXORMEngine() (*xorm.Engine, error) { if err := db.InitEngine(context.Background()); err != nil { return nil, err } - x := unittest.GetXORMEngine() + x, err := unittest.GetXORMEngine() + if err != nil { + return nil, err + } return x, nil } diff --git a/models/unittest/fixtures.go b/models/unittest/fixtures.go index 6402fd9466..940830d20f 100644 --- a/models/unittest/fixtures.go +++ b/models/unittest/fixtures.go @@ -22,11 +22,11 @@ import ( var fixturesLoader *testfixtures.Loader // GetXORMEngine gets the XORM engine -func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine) { +func GetXORMEngine(engine ...*xorm.Engine) (x *xorm.Engine, err error) { if len(engine) == 1 { - return engine[0] + return engine[0], nil } - return db.DefaultContext.(*db.Context).Engine().(*xorm.Engine) + return db.GetMasterEngine(db.DefaultContext.(*db.Context).Engine()) } func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() { @@ -41,7 +41,10 @@ func OverrideFixtures(opts FixturesOptions, engine ...*xorm.Engine) func() { // InitFixtures initialize test fixtures for a test database func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { - e := GetXORMEngine(engine...) + e, err := GetXORMEngine(engine...) + if err != nil { + return err + } var fixtureOptionFiles func(*testfixtures.Loader) error if opts.Dir != "" { fixtureOptionFiles = testfixtures.Directory(opts.Dir) @@ -93,10 +96,12 @@ func InitFixtures(opts FixturesOptions, engine ...*xorm.Engine) (err error) { // LoadFixtures load fixtures for a test database func LoadFixtures(engine ...*xorm.Engine) error { - e := GetXORMEngine(engine...) - var err error + e, err := GetXORMEngine(engine...) + if err != nil { + return err + } // (doubt) database transaction conflicts could occur and result in ROLLBACK? just try for a few times. - for i := 0; i < 5; i++ { + for range 5 { if err = fixturesLoader.Load(); err == nil { break } diff --git a/modules/setting/database.go b/modules/setting/database.go index 76fae27164..b5131d3782 100644 --- a/modules/setting/database.go +++ b/modules/setting/database.go @@ -10,8 +10,13 @@ import ( "net/url" "os" "path/filepath" + "strconv" "strings" "time" + + "forgejo.org/modules/log" + + "xorm.io/xorm" ) var ( @@ -24,35 +29,41 @@ var ( EnableSQLite3 bool // Database holds the database settings - Database = struct { - Type DatabaseType - Host string - Name string - User string - Passwd string - Schema string - SSLMode string - Path string - LogSQL bool - MysqlCharset string - CharsetCollation string - Timeout int // seconds - SQLiteJournalMode string - DBConnectRetries int - DBConnectBackoff time.Duration - MaxIdleConns int - MaxOpenConns int - ConnMaxIdleTime time.Duration - ConnMaxLifetime time.Duration - IterateBufferSize int - AutoMigration bool - SlowQueryThreshold time.Duration - }{ + Database = DatabaseSettings{ Timeout: 500, IterateBufferSize: 50, } ) +type DatabaseSettings struct { + Type DatabaseType + Host string + HostPrimary string + HostReplica string + LoadBalancePolicy string + LoadBalanceWeights string + Name string + User string + Passwd string + Schema string + SSLMode string + Path string + LogSQL bool + MysqlCharset string + CharsetCollation string + Timeout int // seconds + SQLiteJournalMode string + DBConnectRetries int + DBConnectBackoff time.Duration + MaxIdleConns int + MaxOpenConns int + ConnMaxIdleTime time.Duration + ConnMaxLifetime time.Duration + IterateBufferSize int + AutoMigration bool + SlowQueryThreshold time.Duration +} + // LoadDBSetting loads the database settings func LoadDBSetting() { loadDBSetting(CfgProvider) @@ -63,6 +74,10 @@ func loadDBSetting(rootCfg ConfigProvider) { Database.Type = DatabaseType(sec.Key("DB_TYPE").String()) Database.Host = sec.Key("HOST").String() + Database.HostPrimary = sec.Key("HOST_PRIMARY").String() + Database.HostReplica = sec.Key("HOST_REPLICA").String() + Database.LoadBalancePolicy = sec.Key("LOAD_BALANCE_POLICY").String() + Database.LoadBalanceWeights = sec.Key("LOAD_BALANCE_WEIGHTS").String() Database.Name = sec.Key("NAME").String() Database.User = sec.Key("USER").String() if len(Database.Passwd) == 0 { @@ -99,8 +114,93 @@ func loadDBSetting(rootCfg ConfigProvider) { } } -// DBConnStr returns database connection string -func DBConnStr() (string, error) { +// DBMasterConnStr returns the connection string for the master (primary) database. +// If a primary host is defined in the configuration, it is used; +// otherwise, it falls back to Database.Host. +// Returns an error if no master host is provided but a slave is defined. +func DBMasterConnStr() (string, error) { + var host string + if Database.HostPrimary != "" { + host = Database.HostPrimary + } else { + host = Database.Host + } + if host == "" && Database.HostReplica != "" { + return "", errors.New("master host is not defined while slave is defined; cannot proceed") + } + + // For SQLite, no host is needed + if host == "" && !Database.Type.IsSQLite3() { + return "", errors.New("no database host defined") + } + + return dbConnStrWithHost(host) +} + +// DBSlaveConnStrs returns one or more connection strings for the replica databases. +// If a replica host is defined (possibly as a comma-separated list) then those DSNs are returned. +// Otherwise, this function falls back to the master DSN (with a warning log). +func DBSlaveConnStrs() ([]string, error) { + var dsns []string + if Database.HostReplica != "" { + // support multiple replica hosts separated by commas + replicas := strings.SplitSeq(Database.HostReplica, ",") + for r := range replicas { + trimmed := strings.TrimSpace(r) + if trimmed == "" { + continue + } + dsn, err := dbConnStrWithHost(trimmed) + if err != nil { + return nil, err + } + dsns = append(dsns, dsn) + } + } + // Fall back to master if no slave DSN was provided. + if len(dsns) == 0 { + master, err := DBMasterConnStr() + if err != nil { + return nil, err + } + log.Debug("Database: No dedicated replica host defined; falling back to primary DSN for replica connections") + dsns = append(dsns, master) + } + return dsns, nil +} + +func BuildLoadBalancePolicy(settings *DatabaseSettings, slaveEngines []*xorm.Engine) xorm.GroupPolicy { + var policy xorm.GroupPolicy + switch settings.LoadBalancePolicy { // Use the settings parameter directly + case "WeightRandom": + var weights []int + if settings.LoadBalanceWeights != "" { // Use the settings parameter directly + for part := range strings.SplitSeq(settings.LoadBalanceWeights, ",") { + w, err := strconv.Atoi(strings.TrimSpace(part)) + if err != nil { + w = 1 // use a default weight if conversion fails + } + weights = append(weights, w) + } + } + // If no valid weights were provided, default each slave to weight 1 + if len(weights) == 0 { + weights = make([]int, len(slaveEngines)) + for i := range weights { + weights[i] = 1 + } + } + policy = xorm.WeightRandomPolicy(weights) + case "RoundRobin": + policy = xorm.RoundRobinPolicy() + default: + policy = xorm.RandomPolicy() + } + return policy +} + +// dbConnStrWithHost constructs the connection string, given a host value. +func dbConnStrWithHost(host string) (string, error) { var connStr string paramSep := "?" if strings.Contains(Database.Name, paramSep) { @@ -109,23 +209,25 @@ func DBConnStr() (string, error) { switch Database.Type { case "mysql": connType := "tcp" - if len(Database.Host) > 0 && Database.Host[0] == '/' { // looks like a unix socket + // if the host starts with '/' it is assumed to be a unix socket path + if len(host) > 0 && host[0] == '/' { connType = "unix" } tls := Database.SSLMode - if tls == "disable" { // allow (Postgres-inspired) default value to work in MySQL + // allow the "disable" value (borrowed from Postgres defaults) to behave as false + if tls == "disable" { tls = "false" } connStr = fmt.Sprintf("%s:%s@%s(%s)/%s%sparseTime=true&tls=%s", - Database.User, Database.Passwd, connType, Database.Host, Database.Name, paramSep, tls) + Database.User, Database.Passwd, connType, host, Database.Name, paramSep, tls) case "postgres": - connStr = getPostgreSQLConnectionString(Database.Host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) + connStr = getPostgreSQLConnectionString(host, Database.User, Database.Passwd, Database.Name, Database.SSLMode) case "sqlite3": if !EnableSQLite3 { return "", errors.New("this Gitea binary was not built with SQLite3 support") } if err := os.MkdirAll(filepath.Dir(Database.Path), os.ModePerm); err != nil { - return "", fmt.Errorf("Failed to create directories: %w", err) + return "", fmt.Errorf("failed to create directories: %w", err) } journalMode := "" if Database.SQLiteJournalMode != "" { @@ -136,7 +238,6 @@ func DBConnStr() (string, error) { default: return "", fmt.Errorf("unknown database type: %s", Database.Type) } - return connStr, nil } diff --git a/modules/setting/database_test.go b/modules/setting/database_test.go index a742d54f8c..ce816d53e8 100644 --- a/modules/setting/database_test.go +++ b/modules/setting/database_test.go @@ -4,6 +4,7 @@ package setting import ( + "strings" "testing" "github.com/stretchr/testify/assert" @@ -107,3 +108,104 @@ func Test_getPostgreSQLConnectionString(t *testing.T) { assert.Equal(t, test.Output, connStr) } } + +func getPostgreSQLEngineGroupConnectionStrings(primaryHost, replicaHosts, user, passwd, name, sslmode string) (string, []string) { + // Determine the primary connection string. + primary := primaryHost + if strings.TrimSpace(primary) == "" { + primary = "127.0.0.1:5432" + } + primaryConn := getPostgreSQLConnectionString(primary, user, passwd, name, sslmode) + + // Build the replica connection strings. + replicaConns := []string{} + if strings.TrimSpace(replicaHosts) != "" { + // Split comma-separated replica host values. + hosts := strings.Split(replicaHosts, ",") + for _, h := range hosts { + trimmed := strings.TrimSpace(h) + if trimmed != "" { + replicaConns = append(replicaConns, + getPostgreSQLConnectionString(trimmed, user, passwd, name, sslmode)) + } + } + } + + return primaryConn, replicaConns +} + +func Test_getPostgreSQLEngineGroupConnectionStrings(t *testing.T) { + tests := []struct { + primaryHost string // primary host setting (e.g. "localhost" or "[::1]:1234") + replicaHosts string // comma-separated replica hosts (e.g. "replica1,replica2:2345") + user string + passwd string + name string + sslmode string + outputPrimary string + outputReplicas []string + }{ + { + // No primary override (empty => default) and no replicas. + primaryHost: "", + replicaHosts: "", + user: "", + passwd: "", + name: "", + sslmode: "", + outputPrimary: "postgres://:@127.0.0.1:5432?sslmode=", + outputReplicas: []string{}, + }, + { + // Primary set and one replica. + primaryHost: "localhost", + replicaHosts: "replicahost", + user: "user", + passwd: "pass", + name: "gitea", + sslmode: "disable", + outputPrimary: "postgres://user:pass@localhost:5432/gitea?sslmode=disable", + outputReplicas: []string{"postgres://user:pass@replicahost:5432/gitea?sslmode=disable"}, + }, + { + // Primary with explicit port; multiple replicas (one without and one with an explicit port). + primaryHost: "localhost:5433", + replicaHosts: "replica1,replica2:5434", + user: "test", + passwd: "secret", + name: "db", + sslmode: "require", + outputPrimary: "postgres://test:secret@localhost:5433/db?sslmode=require", + outputReplicas: []string{ + "postgres://test:secret@replica1:5432/db?sslmode=require", + "postgres://test:secret@replica2:5434/db?sslmode=require", + }, + }, + { + // IPv6 addresses for primary and replica. + primaryHost: "[::1]:1234", + replicaHosts: "[::2]:2345", + user: "ipv6", + passwd: "ipv6pass", + name: "ipv6db", + sslmode: "disable", + outputPrimary: "postgres://ipv6:ipv6pass@[::1]:1234/ipv6db?sslmode=disable", + outputReplicas: []string{ + "postgres://ipv6:ipv6pass@[::2]:2345/ipv6db?sslmode=disable", + }, + }, + } + + for _, test := range tests { + primary, replicas := getPostgreSQLEngineGroupConnectionStrings( + test.primaryHost, + test.replicaHosts, + test.user, + test.passwd, + test.name, + test.sslmode, + ) + assert.Equal(t, test.outputPrimary, primary) + assert.Equal(t, test.outputReplicas, replicas) + } +} diff --git a/modules/testlogger/testlogger.go b/modules/testlogger/testlogger.go index 5567ea433e..772ae47e71 100644 --- a/modules/testlogger/testlogger.go +++ b/modules/testlogger/testlogger.go @@ -364,6 +364,9 @@ var ignoredErrorMessage = []string{ // TestDatabaseCollation `[E] [Error SQL Query] INSERT INTO test_collation_tbl (txt) VALUES ('main') []`, + // Test_CmdForgejo_Actions + `DB: No dedicated replica host defined; falling back to primary DSN for replica connections`, + // TestDevtestErrorpages `ErrorPage() [E] Example error: Example error`, } diff --git a/routers/common/db.go b/routers/common/db.go index 0646071264..0f78d8debc 100644 --- a/routers/common/db.go +++ b/routers/common/db.go @@ -28,7 +28,7 @@ func InitDBEngine(ctx context.Context) (err error) { default: } log.Info("ORM engine initialization attempt #%d/%d...", i+1, setting.Database.DBConnectRetries) - if err = db.InitEngineWithMigration(ctx, migrateWithSetting); err == nil { + if err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error { return migrateWithSetting(eng.(*xorm.Engine)) }); err == nil { break } else if i == setting.Database.DBConnectRetries-1 { return err diff --git a/routers/install/install.go b/routers/install/install.go index b9333a9e16..ace0b2c8ed 100644 --- a/routers/install/install.go +++ b/routers/install/install.go @@ -361,7 +361,8 @@ func SubmitInstall(ctx *context.Context) { } // Init the engine with migration - if err = db.InitEngineWithMigration(ctx, migrations.Migrate); err != nil { + // Wrap migrations.Migrate into a function of type func(db.Engine) error to fix diagnostics. + if err = db.InitEngineWithMigration(ctx, migrations.WrapperMigrate); err != nil { db.UnsetDefaultEngine() ctx.Data["Err_DbSetting"] = true ctx.RenderWithErr(ctx.Tr("install.invalid_db_setting", err), tplInstall, &form) @@ -587,7 +588,7 @@ func SubmitInstall(ctx *context.Context) { go func() { // Sleep for a while to make sure the user's browser has loaded the post-install page and its assets (images, css, js) - // What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Gitea in the future .... + // What if this duration is not long enough? That's impossible -- if the user can't load the simple page in time, how could they install or use Forgejo in the future .... time.Sleep(3 * time.Second) // Now get the http.Server from this request and shut it down diff --git a/services/doctor/dbconsistency.go b/services/doctor/dbconsistency.go index 6fcbd90940..6fe4c9c5e6 100644 --- a/services/doctor/dbconsistency.go +++ b/services/doctor/dbconsistency.go @@ -78,7 +78,14 @@ func genericOrphanCheck(name, subject, refobject, joincond string) consistencyCh func checkDBConsistency(ctx context.Context, logger log.Logger, autofix bool) error { // make sure DB version is up-to-date - if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { + ensureUpToDateWrapper := func(e db.Engine) error { + engine, err := db.GetMasterEngine(e) + if err != nil { + return err + } + return migrations.EnsureUpToDate(engine) + } + if err := db.InitEngineWithMigration(ctx, ensureUpToDateWrapper); err != nil { logger.Critical("Model version on the database does not match the current Gitea version. Model consistency will not be checked until the database is upgraded") return err } diff --git a/services/doctor/dbversion.go b/services/doctor/dbversion.go index 9c02c732e5..c0ff22915d 100644 --- a/services/doctor/dbversion.go +++ b/services/doctor/dbversion.go @@ -9,11 +9,15 @@ import ( "forgejo.org/models/db" "forgejo.org/models/migrations" "forgejo.org/modules/log" + + "xorm.io/xorm" ) func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error { logger.Info("Expected database version: %d", migrations.ExpectedDBVersion()) - if err := db.InitEngineWithMigration(ctx, migrations.EnsureUpToDate); err != nil { + if err := db.InitEngineWithMigration(ctx, func(eng db.Engine) error { + return migrations.EnsureUpToDate(eng.(*xorm.Engine)) + }); err != nil { if !autofix { logger.Critical("Error: %v during ensure up to date", err) return err @@ -21,7 +25,9 @@ func checkDBVersion(ctx context.Context, logger log.Logger, autofix bool) error logger.Warn("Got Error: %v during ensure up to date", err) logger.Warn("Attempting to migrate to the latest DB version to fix this.") - err = db.InitEngineWithMigration(ctx, migrations.Migrate) + err = db.InitEngineWithMigration(ctx, func(eng db.Engine) error { + return migrations.Migrate(eng.(*xorm.Engine)) + }) if err != nil { logger.Critical("Error: %v during migration", err) } diff --git a/tests/integration/db_collation_test.go b/tests/integration/db_collation_test.go index 6bfe656b9b..5b84dae823 100644 --- a/tests/integration/db_collation_test.go +++ b/tests/integration/db_collation_test.go @@ -16,7 +16,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "xorm.io/xorm" ) type TestCollationTbl struct { @@ -48,11 +47,13 @@ func TestDatabaseCollationSelfCheckUI(t *testing.T) { } func TestDatabaseCollation(t *testing.T) { - x := db.GetEngine(db.DefaultContext).(*xorm.Engine) + engine, err := db.GetMasterEngine(db.GetEngine(db.DefaultContext)) + require.NoError(t, err) + x := engine // all created tables should use case-sensitive collation by default _, _ = x.Exec("DROP TABLE IF EXISTS test_collation_tbl") - err := x.Sync(&TestCollationTbl{}) + err = x.Sync(&TestCollationTbl{}) require.NoError(t, err) _, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('main')") _, _ = x.Exec("INSERT INTO test_collation_tbl (txt) VALUES ('Main')") // case-sensitive, so it inserts a new row diff --git a/tests/integration/migration-test/migration_test.go b/tests/integration/migration-test/migration_test.go index 8076dfa452..798161a560 100644 --- a/tests/integration/migration-test/migration_test.go +++ b/tests/integration/migration-test/migration_test.go @@ -278,23 +278,36 @@ func doMigrationTest(t *testing.T, version string) { setting.InitSQLLoggersForCli(log.INFO) - err := db.InitEngineWithMigration(t.Context(), wrappedMigrate) + err := db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + engine, err := db.GetMasterEngine(e) + if err != nil { + return err + } + currentEngine = engine + return wrappedMigrate(engine) + }) require.NoError(t, err) currentEngine.Close() beans, _ := db.NamesToBean() - err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { - currentEngine = x - return migrate_base.RecreateTables(beans...)(x) + err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + currentEngine, err = db.GetMasterEngine(e) + if err != nil { + return err + } + return migrate_base.RecreateTables(beans...)(currentEngine) }) require.NoError(t, err) currentEngine.Close() // We do this a second time to ensure that there is not a problem with retained indices - err = db.InitEngineWithMigration(t.Context(), func(x *xorm.Engine) error { - currentEngine = x - return migrate_base.RecreateTables(beans...)(x) + err = db.InitEngineWithMigration(t.Context(), func(e db.Engine) error { + currentEngine, err = db.GetMasterEngine(e) + if err != nil { + return err + } + return migrate_base.RecreateTables(beans...)(currentEngine) }) require.NoError(t, err)