diff --git a/internal/compiler/analyze.go b/internal/compiler/analyze.go index 264dbef6f5..75cd403b4a 100644 --- a/internal/compiler/analyze.go +++ b/internal/compiler/analyze.go @@ -65,6 +65,36 @@ func convertColumn(c *analyzer.Column) *Column { } } +func mergeColumnOrigin(dst, src *Column) { + if dst == nil || src == nil { + return + } + + // Column overrides in the Go generator depend on the column's original + // table identity. For SQLite, the database analyzer is often the most + // accurate source for this information, especially for columns added by + // ALTER TABLE ... ADD COLUMN. + // + // Keep the catalog-inferred name/type/nullability unless the existing + // combine logic below decides to override the type. Only merge origin + // metadata here. + if src.OriginalName != "" { + dst.OriginalName = src.OriginalName + } + if src.Table != nil { + dst.Table = src.Table + } + if src.TableAlias != "" { + dst.TableAlias = src.TableAlias + } + if src.Scope != "" { + dst.Scope = src.Scope + } + if src.EmbedTable != nil { + dst.EmbedTable = src.EmbedTable + } +} + func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis { var cols []*Column for _, c := range a.Columns { @@ -79,6 +109,7 @@ func combineAnalysis(prev *analysis, a *analyzer.Analysis) *analysis { } if len(prev.Columns) == len(cols) { for i := range prev.Columns { + mergeColumnOrigin(prev.Columns[i], cols[i]) // Only override column types if the analyzer provides a specific type // (not "any"), since the catalog-based inference may have better info if cols[i].DataType != "any" { diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/db.go b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/db.go new file mode 100644 index 0000000000..f43598b1eb --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package db + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/models.go b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/models.go new file mode 100644 index 0000000000..69acd5d0b8 --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/models.go @@ -0,0 +1,16 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 + +package db + +import ( + "time" +) + +type ChatMessage struct { + ID string + Body string + CreatedAt time.Time + UpdatedAt time.Time +} diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/query.sql.go b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/query.sql.go new file mode 100644 index 0000000000..54fb64f633 --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/go/query.sql.go @@ -0,0 +1,21 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.31.1 +// source: query.sql + +package db + +import ( + "context" +) + +const test = `-- name: Test :one +SELECT 1 +` + +func (q *Queries) Test(ctx context.Context) (int64, error) { + row := q.db.QueryRowContext(ctx, test) + var column_1 int64 + err := row.Scan(&column_1) + return column_1, err +} diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/query.sql b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/query.sql new file mode 100644 index 0000000000..9da604b57e --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/query.sql @@ -0,0 +1,2 @@ +-- name: Test :one +SELECT 1; diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/schema.sql b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/schema.sql new file mode 100644 index 0000000000..a89bbf6a6a --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/schema.sql @@ -0,0 +1,7 @@ +CREATE TABLE chat_messages ( + id TEXT PRIMARY KEY NOT NULL, + body TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP +) STRICT; + +ALTER TABLE chat_messages ADD COLUMN "updated_at" TEXT NOT NULL DEFAULT ''; diff --git a/internal/endtoend/testdata/overrides_alter_add_column/sqlite/sqlc.json b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/sqlc.json new file mode 100644 index 0000000000..b30fe1d6d3 --- /dev/null +++ b/internal/endtoend/testdata/overrides_alter_add_column/sqlite/sqlc.json @@ -0,0 +1,22 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "name": "db", + "engine": "sqlite", + "schema": "schema.sql", + "queries": "query.sql", + "overrides": [ + { + "column": "chat_messages.created_at", + "go_type": "time.Time" + }, + { + "column": "chat_messages.updated_at", + "go_type": "time.Time" + } + ] + } + ] +} diff --git a/internal/engine/sqlite/convert.go b/internal/engine/sqlite/convert.go index e9868f5be6..c372def277 100644 --- a/internal/engine/sqlite/convert.go +++ b/internal/engine/sqlite/convert.go @@ -67,14 +67,18 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a Table: parseTableName(n), Cmds: &ast.List{}, } - name := def.Column_name().GetText() + name := identifier(def.Column_name().GetText()) + typeName := "any" + if def.Type_name() != nil { + typeName = def.Type_name().GetText() + } stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_AddColumn, Def: &ast.ColumnDef{ Colname: name, TypeName: &ast.TypeName{ - Name: def.Type_name().GetText(), + Name: typeName, }, IsNotNull: hasNotNullConstraint(def.AllColumn_constraint()), }, @@ -88,7 +92,7 @@ func (c *cc) convertAlter_table_stmtContext(n *parser.Alter_table_stmtContext) a Table: parseTableName(n), Cmds: &ast.List{}, } - name := n.Column_name(0).GetText() + name := identifier(n.Column_name(0).GetText()) stmt.Cmds.Items = append(stmt.Cmds.Items, &ast.AlterTableCmd{ Name: &name, Subtype: ast.AT_DropColumn, @@ -826,7 +830,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.MINUS() != nil { // Negative number: -expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "-"}}}, Rexpr: expr, } } @@ -837,7 +841,7 @@ func (c *cc) convertUnaryExpr(n *parser.Expr_unaryContext) ast.Node { if opCtx.TILDE() != nil { // Bitwise NOT: ~expr return &ast.A_Expr{ - Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, + Name: &ast.List{Items: []ast.Node{&ast.String{Str: "~"}}}, Rexpr: expr, } }