diff --git a/_examples/sql/main.go b/_examples/sql/main.go new file mode 100644 index 00000000..289d1fcc --- /dev/null +++ b/_examples/sql/main.go @@ -0,0 +1,197 @@ +package main + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go/sentrysql" + "github.com/lib/pq" +) + +func init() { + // Registering a custom database driver that's wrapped by sentrysql. + // Later, we can call `sql.Open("sentrysql-postgres", databaseDSN)` to use it. + sql.Register("sentrysql-postgres", sentrysql.NewSentrySQL(&pq.Driver{}, sentrysql.WithDatabaseSystem(sentrysql.PostgreSQL), sentrysql.WithDatabaseName("postgres"), sentrysql.WithServerAddress("write.postgres.internal", "5432"))) +} + +func main() { + err := sentry.Init(sentry.ClientOptions{ + // Either set your DSN here or set the SENTRY_DSN environment variable. + Dsn: "", + // Enable printing of SDK debug messages. + // Useful when getting started or trying to figure something out. + Debug: true, + // EnableTracing must be set to true if you want the SQL queries to be traced. + EnableTracing: true, + TracesSampleRate: 1.0, + }) + if err != nil { + fmt.Printf("failed to initialize sentry: %s\n", err.Error()) + return + } + + // We are going to emulate a scenario where an application requires a read database and a write database. + // This is also to show how to use each `sentrysql.NewSentrySQLConnector` and `sentrysql.NewSentrySQL`. + + // Create a database connection for read database. + connector, err := pq.NewConnector("postgres://postgres:password@read.postgres.internal:5432/postgres") + if err != nil { + fmt.Printf("failed to create a postgres connector: %s\n", err.Error()) + return + } + + sentryWrappedConnector := sentrysql.NewSentrySQLConnector( + connector, + sentrysql.WithDatabaseSystem(sentrysql.PostgreSQL), // required if you want to see the queries on the Queries Insights page + sentrysql.WithDatabaseName("postgres"), + sentrysql.WithServerAddress("read.postgres.internal", "5432"), + ) + + readDatabase := sql.OpenDB(sentryWrappedConnector) + defer func() { + err := readDatabase.Close() + if err != nil { + sentry.CaptureException(err) + } + }() + + // Create a database connection for write database. + writeDatabase, err := sql.Open("sentrysql-postgres", "postgres://postgres:password@write.postgres.internal:5432/postgres") + if err != nil { + fmt.Printf("failed to open write postgres database: %s\n", err.Error()) + return + } + defer func() { + err := writeDatabase.Close() + if err != nil { + sentry.CaptureException(err) + } + }() + + ctx, cancel := context.WithTimeout( + sentry.SetHubOnContext(context.Background(), sentry.CurrentHub().Clone()), + time.Minute, + ) + defer cancel() + + err = ScaffoldDatabase(ctx, writeDatabase) + if err != nil { + fmt.Printf("failed to scaffold database: %s\n", err.Error()) + return + } + + users, err := GetAllUsers(ctx, readDatabase) + if err != nil { + fmt.Printf("failed to get users: %s\n", err.Error()) + return + } + + for _, user := range users { + fmt.Printf("User: %+v\n", user) + } +} + +// ScaffoldDatabase prepares the database to have the users table. +func ScaffoldDatabase(ctx context.Context, db *sql.DB) error { + // A parent span is required to have the queries to be traced. + // Make sure to override the `context.Context` with the parent span's context. + span := sentry.StartSpan(ctx, "ScaffoldDatabase") + ctx = span.Context() + defer span.Finish() + + conn, err := db.Conn(ctx) + if err != nil { + return fmt.Errorf("acquiring connection from pool: %w", err) + } + defer func() { + err := conn.Close() + if err != nil && !errors.Is(err, sql.ErrConnDone) { + if hub := sentry.GetHubFromContext(ctx); hub != nil { + hub.CaptureException(err) + } + } + }() + + tx, err := conn.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable, ReadOnly: false}) + if err != nil { + return fmt.Errorf("beginning transaction: %w", err) + } + defer func() { + err := tx.Rollback() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + if hub := sentry.GetHubFromContext(ctx); hub != nil { + hub.CaptureException(err) + } + } + }() + + _, err = tx.ExecContext(ctx, "CREATE TABLE users (id INTEGER GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, name VARCHAR(255), email VARCHAR(255), active BOOLEAN)") + if err != nil { + return fmt.Errorf("creating users table: %w", err) + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil +} + +// User represents a user in the database. +type User struct { + ID int + Name string + Email string +} + +// GetAllUsers returns all the users from the database. +func GetAllUsers(ctx context.Context, db *sql.DB) ([]User, error) { + // A parent span is required to have the queries to be traced. + // Make sure to override the `context.Context` with the parent span's context. + span := sentry.StartSpan(ctx, "GetAllUsers") + ctx = span.Context() + defer span.Finish() + + conn, err := db.Conn(ctx) + if err != nil { + return nil, fmt.Errorf("acquiring connection from pool: %w", err) + } + defer func() { + err := conn.Close() + if err != nil && !errors.Is(err, sql.ErrConnDone) { + if hub := sentry.GetHubFromContext(ctx); hub != nil { + hub.CaptureException(err) + } + } + }() + + rows, err := conn.QueryContext(ctx, "SELECT id, name, email FROM users WHERE active = $1", true) + if err != nil { + return nil, fmt.Errorf("querying users: %w", err) + } + defer func() { + err := rows.Close() + if err != nil { + if hub := sentry.GetHubFromContext(ctx); hub != nil { + hub.CaptureException(err) + } + } + }() + + var users []User + for rows.Next() { + var user User + err := rows.Scan(&user.ID, &user.Name, &user.Email) + if err != nil { + return nil, fmt.Errorf("scanning user: %w", err) + } + users = append(users, user) + } + + return users, nil +} diff --git a/go.mod b/go.mod index c06650d8..d7e1869a 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,14 @@ go 1.21 require ( github.com/gin-gonic/gin v1.8.1 + github.com/glebarez/go-sqlite v1.21.1 github.com/go-errors/errors v1.4.2 + github.com/go-sql-driver/mysql v1.8.1 github.com/gofiber/fiber/v2 v2.52.2 github.com/google/go-cmp v0.5.9 github.com/kataras/iris/v12 v12.2.0 github.com/labstack/echo/v4 v4.10.0 + github.com/lib/pq v1.10.9 github.com/pingcap/errors v0.11.4 github.com/pkg/errors v0.9.1 github.com/sirupsen/logrus v1.9.0 @@ -20,6 +23,7 @@ require ( ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/BurntSushi/toml v1.2.1 // indirect github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 // indirect github.com/CloudyKit/jet/v6 v6.2.0 // indirect @@ -29,6 +33,7 @@ require ( github.com/andybalholm/brotli v1.1.0 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 // indirect github.com/fatih/structs v1.1.0 // indirect github.com/flosch/pongo2/v4 v4.0.2 // indirect @@ -67,6 +72,7 @@ require ( github.com/nxadm/tail v1.4.11 // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sanity-io/litter v1.5.5 // indirect @@ -94,5 +100,9 @@ require ( gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + modernc.org/libc v1.22.3 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.5.0 // indirect + modernc.org/sqlite v1.21.1 // indirect moul.io/http2curl/v2 v2.3.0 // indirect ) diff --git a/go.sum b/go.sum index a01095fe..56d04606 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.2.1 h1:9F2/+DoOYIOksmaJFPw1tGFy1eDnIJXg+UHjuD8lTak= github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= github.com/CloudyKit/fastprinter v0.0.0-20200109182630-33d98a066a53 h1:sR+/8Yb4slttB4vD+b9btVEnWgL3Q00OBTzVT8B9C0c= @@ -24,6 +26,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/djherbis/atime v1.1.0/go.mod h1:28OF6Y8s3NQWwacXc5eZTsEsiMzp7LF8MbXE+XJPdBE= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385 h1:clC1lXBpe2kTj2VHdaIu9ajZQe4kcEY9j0NsnDDBZ3o= github.com/eknkc/amber v0.0.0-20171010120322-cdade1c07385/go.mod h1:0vRUJqYpeSZifjYj7uP3BG/gKcuzL9xWVV/Y+cK33KM= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= @@ -37,6 +41,8 @@ github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.8.1 h1:4+fr/el88TOO3ewCmQr8cx/CtZ/umlIRIs5M4NTNjf8= github.com/gin-gonic/gin v1.8.1/go.mod h1:ji8BvRH1azfM+SYow9zQ6SZMvR8qOMZHmsCuWR9tTTk= +github.com/glebarez/go-sqlite v1.21.1 h1:7MZyUPh2XTrHS7xNEHQbrhfMZuPSzhkm2A1qgg0y5NY= +github.com/glebarez/go-sqlite v1.21.1/go.mod h1:ISs8MF6yk5cL4n/43rSOmVMGJJjHYr7L2MbZZ5Q4E2E= github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxIA= github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= @@ -47,6 +53,8 @@ github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/j github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gofiber/fiber/v2 v2.52.2 h1:b0rYH6b06Df+4NyrbdptQL8ifuxw/Tf2DgfkZkDaxEo= @@ -59,6 +67,7 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= @@ -103,6 +112,8 @@ github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8 github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailgun/raymond/v2 v2.0.48 h1:5dmlB680ZkFG2RN/0lvTAghrSxIESeu9/2aeDqACtjw= github.com/mailgun/raymond/v2 v2.0.48/go.mod h1:lsgvL50kgt1ylcFJYZiULi5fjPBkkhNfj4KA0W54Z18= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= @@ -144,6 +155,9 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v0.0.0-20151028094244-d8ed2627bdf0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= @@ -292,5 +306,13 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/libc v1.22.3 h1:D/g6O5ftAfavceqlLOFwaZuA5KYafKwmr30A6iSqoyY= +modernc.org/libc v1.22.3/go.mod h1:MQrloYP209xa2zHome2a8HLiLm6k0UT8CoHpV74tOFw= +modernc.org/mathutil v1.5.0 h1:rV0Ko/6SfM+8G+yKiyI830l3Wuz1zRutdslNoQ0kfiQ= +modernc.org/mathutil v1.5.0/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= +modernc.org/memory v1.5.0 h1:N+/8c5rE6EqugZwHii4IFsaJ7MUhoWX07J5tC/iI5Ds= +modernc.org/memory v1.5.0/go.mod h1:PkUhL0Mugw21sHPeskwZW4D6VscE/GQJOnIpCnW6pSU= +modernc.org/sqlite v1.21.1 h1:GyDFqNnESLOhwwDRaHGdp2jKLDzpyT/rNLglX3ZkMSU= +modernc.org/sqlite v1.21.1/go.mod h1:XwQ0wZPIh1iKb5mkvCJ3szzbhk+tykC8ZWqTRTgYRwI= moul.io/http2curl/v2 v2.3.0 h1:9r3JfDzWPcbIklMOs2TnIFzDYvfAZvjeavG6EzP7jYs= moul.io/http2curl/v2 v2.3.0/go.mod h1:RW4hyBjTWSYDOxapodpNEtX0g5Eb16sxklBqmd2RHcE= diff --git a/sentrysql/conn.go b/sentrysql/conn.go new file mode 100644 index 00000000..e04cef3a --- /dev/null +++ b/sentrysql/conn.go @@ -0,0 +1,249 @@ +package sentrysql + +import ( + "context" + "database/sql/driver" + + "github.com/getsentry/sentry-go" +) + +// sentryConn wraps the original driver.Conn. +// As per the driver's documentation: +// - All Conn implementations should implement the following interfaces: +// Pinger, SessionResetter, and Validator. +// - If named parameters or context are supported, the driver's Conn should +// implement: ExecerContext, QueryerContext, ConnPrepareContext, +// and ConnBeginTx. +// +// On this specific Sentry wrapper, we are not going to implement the Validator +// interface because it does not support ErrSkip, since returning ErrSkip +// is only possible when it's explicitly stated on the driver documentation. +type sentryConn struct { + originalConn driver.Conn + ctx context.Context + config *sentrySQLConfig +} + +// Make sure that sentryConn implements the driver.Conn interface. +var _ driver.Conn = (*sentryConn)(nil) +var _ driver.Pinger = (*sentryConn)(nil) +var _ driver.SessionResetter = (*sentryConn)(nil) +var _ driver.Validator = (*sentryConn)(nil) +var _ driver.ExecerContext = (*sentryConn)(nil) +var _ driver.QueryerContext = (*sentryConn)(nil) +var _ driver.ConnPrepareContext = (*sentryConn)(nil) +var _ driver.ConnBeginTx = (*sentryConn)(nil) +var _ driver.NamedValueChecker = (*sentryConn)(nil) + +func (s *sentryConn) Prepare(query string) (driver.Stmt, error) { + stmt, err := s.originalConn.Prepare(query) + if err != nil { + return nil, err + } + + return &sentryStmt{ + originalStmt: stmt, + query: query, + ctx: s.ctx, + config: s.config, + }, nil +} + +func (s *sentryConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + // should only be executed if the original driver implements ConnPrepareContext + connPrepareContext, ok := s.originalConn.(driver.ConnPrepareContext) + if !ok { + // We can't return driver.ErrSkip here. We should fall back to Prepare without context. + return s.Prepare(query) + } + + stmt, err := connPrepareContext.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + + return &sentryStmt{ + originalStmt: stmt, + query: query, + ctx: ctx, + config: s.config, + }, nil +} + +func (s *sentryConn) Close() error { + return s.originalConn.Close() +} + +func (s *sentryConn) Begin() (driver.Tx, error) { + tx, err := s.originalConn.Begin() //nolint:staticcheck // We must support legacy clients + if err != nil { + return nil, err + } + + return &sentryTx{originalTx: tx, ctx: s.ctx, config: s.config}, nil +} + +func (s *sentryConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + // should only be executed if the original driver implements ConnBeginTx + connBeginTx, ok := s.originalConn.(driver.ConnBeginTx) + if !ok { + // We can't return driver.ErrSkip here. We should fall back to Begin without context. + return s.Begin() + } + + tx, err := connBeginTx.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &sentryTx{originalTx: tx, ctx: s.ctx, config: s.config}, nil +} + +//nolint:dupl +func (s *sentryConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // should only be executed if the original driver implements Queryer + queryer, ok := s.originalConn.(driver.Queryer) //nolint:staticcheck // We must support legacy clients + if !ok { + return nil, driver.ErrSkip + } + + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return queryer.Query(query, args) + } + + span := parentSpan.StartChild("db.sql.query", sentry.WithDescription(query)) + s.config.SetData(span, query) + defer span.Finish() + + rows, err := queryer.Query(query, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +//nolint:dupl +func (s *sentryConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // should only be executed if the original driver implements QueryerContext + queryerContext, ok := s.originalConn.(driver.QueryerContext) + if !ok { + return nil, driver.ErrSkip + } + + parentSpan := sentry.SpanFromContext(ctx) + if parentSpan == nil { + return queryerContext.QueryContext(ctx, query, args) + } + + span := parentSpan.StartChild("db.sql.query", sentry.WithDescription(query)) + s.config.SetData(span, query) + defer span.Finish() + + rows, err := queryerContext.QueryContext(ctx, query, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +//nolint:dupl +func (s *sentryConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // should only be executed if the original driver implements Execer + execer, ok := s.originalConn.(driver.Execer) //nolint:staticcheck // We must support legacy clients + if !ok { + return nil, driver.ErrSkip + } + + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return execer.Exec(query, args) + } + + span := parentSpan.StartChild("db.sql.exec", sentry.WithDescription(query)) + s.config.SetData(span, query) + defer span.Finish() + + rows, err := execer.Exec(query, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +//nolint:dupl +func (s *sentryConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // should only be executed if the original driver implements ExecerContext { + execerContext, ok := s.originalConn.(driver.ExecerContext) + if !ok { + // ExecContext may return ErrSkip. + return nil, driver.ErrSkip + } + + parentSpan := sentry.SpanFromContext(ctx) + if parentSpan == nil { + return execerContext.ExecContext(ctx, query, args) + } + + span := parentSpan.StartChild("db.sql.exec", sentry.WithDescription(query)) + s.config.SetData(span, query) + defer span.Finish() + + rows, err := execerContext.ExecContext(ctx, query, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +func (s *sentryConn) Ping(ctx context.Context) error { + pinger, ok := s.originalConn.(driver.Pinger) + if !ok { + // We may not return ErrSkip. We should return nil. + return nil + } + + return pinger.Ping(ctx) +} + +func (s *sentryConn) ResetSession(ctx context.Context) error { + sessionResetter, ok := s.originalConn.(driver.SessionResetter) + if !ok { + // We may not return ErrSkip. We should return nil. + return nil + } + + return sessionResetter.ResetSession(ctx) +} + +func (s *sentryConn) CheckNamedValue(namedValue *driver.NamedValue) error { + namedValueChecker, ok := s.originalConn.(driver.NamedValueChecker) + if !ok { + // We may return ErrSkip. + return driver.ErrSkip + } + + return namedValueChecker.CheckNamedValue(namedValue) +} + +// IsValid implements driver.Validator. +func (s *sentryConn) IsValid() bool { + validator, ok := s.originalConn.(driver.Validator) + if !ok { + return true + } + + return validator.IsValid() +} diff --git a/sentrysql/driver.go b/sentrysql/driver.go new file mode 100644 index 00000000..6808d2bb --- /dev/null +++ b/sentrysql/driver.go @@ -0,0 +1,96 @@ +package sentrysql + +import ( + "context" + "database/sql/driver" + "io" +) + +// sentrySQLDriver wraps the original driver.Driver. +// As per the driver's documentation: +// Drivers should implement driver.Connector and driver.DriverContext interfaces. +type sentrySQLDriver struct { + originalDriver driver.Driver + config *sentrySQLConfig +} + +// Make sure that sentrySQLDriver implements the driver.Driver interface. +var _ driver.Driver = (*sentrySQLDriver)(nil) +var _ driver.DriverContext = (*sentrySQLDriver)(nil) + +func (s *sentrySQLDriver) OpenConnector(name string) (driver.Connector, error) { + driverContext, ok := s.originalDriver.(driver.DriverContext) + if !ok { + return &sentrySQLConnector{ + originalConnector: dsnConnector{dsn: name, driver: s.originalDriver, config: s.config}, + config: s.config, + }, nil + } + + connector, err := driverContext.OpenConnector(name) + if err != nil { + return nil, err + } + + return &sentrySQLConnector{originalConnector: connector, config: s.config}, nil +} + +func (s *sentrySQLDriver) Open(name string) (driver.Conn, error) { + conn, err := s.originalDriver.Open(name) + if err != nil { + return nil, err + } + + return &sentryConn{originalConn: conn, config: s.config}, nil +} + +type sentrySQLConnector struct { + originalConnector driver.Connector + config *sentrySQLConfig +} + +// Make sure that sentrySQLConnector implements the driver.Connector interface. +var _ driver.Connector = (*sentrySQLConnector)(nil) +var _ io.Closer = (*sentrySQLConnector)(nil) + +func (s *sentrySQLConnector) Connect(ctx context.Context) (driver.Conn, error) { + conn, err := s.originalConnector.Connect(ctx) + if err != nil { + return nil, err + } + + return &sentryConn{originalConn: conn, ctx: ctx, config: s.config}, nil +} + +func (s *sentrySQLConnector) Driver() driver.Driver { + return s.originalConnector.Driver() +} + +func (s *sentrySQLConnector) Close() error { + // driver.Connector should optionally implements io.Closer + closer, ok := s.originalConnector.(io.Closer) + if !ok { + return nil + } + + return closer.Close() +} + +// dsnConnector is copied from +// https://cs.opensource.google/go/go/+/refs/tags/go1.23.2:src/database/sql/sql.go;l=795-806 +type dsnConnector struct { + dsn string + driver driver.Driver + config *sentrySQLConfig +} + +// Make sure dsnConnector implements driver.Connector. +var _ driver.Connector = (*dsnConnector)(nil) + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} diff --git a/sentrysql/example_test.go b/sentrysql/example_test.go new file mode 100644 index 00000000..4a0907af --- /dev/null +++ b/sentrysql/example_test.go @@ -0,0 +1,110 @@ +package sentrysql_test + +import ( + "database/sql" + "fmt" + "net" + + "github.com/getsentry/sentry-go/sentrysql" + sqlite "github.com/glebarez/go-sqlite" + "github.com/go-sql-driver/mysql" + "github.com/lib/pq" +) + +func ExampleNewSentrySQL() { + sql.Register("sentrysql-sqlite", sentrysql.NewSentrySQL( + &sqlite.Driver{}, + sentrysql.WithDatabaseName(":memory:"), + sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("sqlite")), + )) + + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + panic(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE test (id INT)") + if err != nil { + panic(err) + } + + _, err = db.Exec("INSERT INTO test (id) VALUES (1)") + if err != nil { + panic(err) + } + + rows, err := db.Query("SELECT * FROM test") + if err != nil { + panic(err) + } + defer rows.Close() + + for rows.Next() { + var id int + err = rows.Scan(&id) + if err != nil { + panic(err) + } + + fmt.Println(id) + } +} + +func ExampleNewSentrySQLConnector_postgres() { + // Create a new PostgreSQL connector that utilizes the `github.com/lib/pq` package. + pqConnector, err := pq.NewConnector("postgres://user:password@localhost:5432/db") + if err != nil { + fmt.Println("creating postgres connector:", err.Error()) + return + } + + // `db` here is an instance of *sql.DB. + db := sql.OpenDB(sentrysql.NewSentrySQLConnector( + pqConnector, + sentrysql.WithDatabaseName("db"), + sentrysql.WithDatabaseSystem(sentrysql.PostgreSQL), + sentrysql.WithServerAddress("localhost", "5432"), + )) + defer func() { + err := db.Close() + if err != nil { + fmt.Println("closing postgres connection:", err.Error()) + } + }() + + // Use the db connection as usual. +} + +func ExampleNewSentrySQLConnector_mysql() { + // Create a new MySQL connector that utilizes the `github.com/go-sql-driver/mysql` package. + config, err := mysql.ParseDSN("user:password@tcp(localhost:3306)/test?parseTime=true") + if err != nil { + fmt.Println("parsing mysql dsn:", err.Error()) + return + } + + mysqlHost, mysqlPort, _ := net.SplitHostPort(config.Addr) + + connector, err := mysql.NewConnector(config) + if err != nil { + fmt.Println("creating mysql connector:", err.Error()) + return + } + + // `db` here is an instance of *sql.DB. + db := sql.OpenDB(sentrysql.NewSentrySQLConnector( + connector, + sentrysql.WithDatabaseName(config.DBName), + sentrysql.WithDatabaseSystem(sentrysql.MySQL), + sentrysql.WithServerAddress(mysqlHost, mysqlPort), + )) + defer func() { + err := db.Close() + if err != nil { + fmt.Println("closing mysql connection:", err.Error()) + } + }() + + // Use the db connection as usual. +} diff --git a/sentrysql/fakedb_test.go b/sentrysql/fakedb_test.go new file mode 100644 index 00000000..c431e2bc --- /dev/null +++ b/sentrysql/fakedb_test.go @@ -0,0 +1,1286 @@ +//nolint:all +package sentrysql_test + +// This file is a fork of +// https://cs.opensource.google/go/go/+/refs/tags/go1.23.2:src/database/sql/fakedb_test.go +// +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Copyright (c) 2009 The Go Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "sync" + "testing" + "time" +) // fakeDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntactically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE||=,=,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT||col=val,col2=val2,col3=? +// SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? +// SELECT||projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 +// +// Any of these can be preceded by PANIC||, to cause the +// named method on fakeStmt to panic. +// +// Any of these can be proceeded by WAIT||, to cause the +// named method on fakeStmt to sleep for the specified duration. +// +// Multiple of these can be combined when separated with a semicolon. +// +// When opening a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type fakeDriver struct { + mu sync.Mutex // guards 3 following fields + openCount int // conn opens + closeCount int // conn closes + waitCh chan struct{} + waitingCh chan struct{} + dbs map[string]*fakeDB +} + +type fakeConnector struct { + name string + + waiter func(context.Context) + closed bool +} + +func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { + conn, err := fdriver.Open(c.name) + conn.(*fakeConn).waiter = c.waiter + return conn, err +} + +func (c *fakeConnector) Driver() driver.Driver { + return fdriver +} + +func (c *fakeConnector) Close() error { + if c.closed { + return errors.New("fakedb: connector is closed") + } + c.closed = true + return nil +} + +type fakeDriverCtx struct { + fakeDriver +} + +var _ driver.DriverContext = &fakeDriverCtx{} + +func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { + return &fakeConnector{name: name}, nil +} + +type fakeDB struct { + name string + + mu sync.Mutex + tables map[string]*table + badConn bool + allowAny bool +} + +type fakeError struct { + Message string + Wrapped error +} + +func (err fakeError) Error() string { + return err.Message +} + +func (err fakeError) Unwrap() error { + return err.Wrapped +} + +type table struct { + mu sync.Mutex + colname []string + coltype []string + rows []*row +} + +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type row struct { + cols []any // must be same size as its table colname + coltype +} + +type memToucher interface { + // touchMem reads & writes some memory, to help find data races. + touchMem() +} + +type fakeConn struct { + db *fakeDB // where to return ourselves to + + currTx *fakeTx + + // Every operation writes to line to enable the race detector + // check for data races. + line int64 + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int + numPrepare int + + // bad connection tests; see isBad() + bad bool + stickyBad bool + + skipDirtySession bool // tests that use Conn should set this to true. + + // dirtySession tests ResetSession, true if a query has executed + // until ResetSession is called. + dirtySession bool + + // The waiter is called before each query. May be used in place of the "WAIT" + // directive. + waiter func(context.Context) +} + +func (c *fakeConn) touchMem() { + c.line++ +} + +func (c *fakeConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() +} + +type fakeTx struct { + c *fakeConn +} + +type boundCol struct { + Column string + Placeholder string + Ordinal int +} + +type fakeStmt struct { + memToucher + c *fakeConn + q string // just for debugging + + cmd string + table string + panic string + wait time.Duration + + next *fakeStmt // used for returning multiple results. + + closed bool + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []any // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []boundCol // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var fdriver driver.Driver = &fakeDriver{} + +func contains(list []string, y string) bool { + for _, x := range list { + if x == y { + return true + } + } + return false +} + +type Dummy struct { + driver.Driver +} + +// hook to simulate connection failures +var hookOpenErr struct { + sync.Mutex + fn func() error +} + +func setHookOpenErr(fn func() error) { + hookOpenErr.Lock() + defer hookOpenErr.Unlock() + hookOpenErr.fn = fn +} + +// Supports dsn forms: +// +// +// ; (only currently supported option is `badConn`, +// which causes driver.ErrBadConn to be returned on +// every other conn.Begin()) +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + hookOpenErr.Lock() + fn := hookOpenErr.fn + hookOpenErr.Unlock() + if fn != nil { + if err := fn(); err != nil { + return nil, err + } + } + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, errors.New("fakedb: no database name") + } + name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + conn := &fakeConn{db: db} + + if len(parts) >= 2 && parts[1] == "badConn" { + conn.bad = true + } + if d.waitCh != nil { + d.waitingCh <- struct{}{} + <-d.waitCh + d.waitCh = nil + d.waitingCh = nil + } + return conn, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + db, ok := d.dbs[name] + if !ok { + db = &fakeDB{name: name} + d.dbs[name] = db + } + return db +} + +func (db *fakeDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.tables = nil +} + +func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { + db.mu.Lock() + defer db.mu.Unlock() + if db.tables == nil { + db.tables = make(map[string]*table) + } + if _, exist := db.tables[name]; exist { + return fmt.Errorf("fakedb: table %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", + name, len(columnNames), len(columnTypes)) + } + db.tables[name] = &table{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *fakeDB) table(table string) (*table, bool) { + if db.tables == nil { + return nil, false + } + t, ok := db.tables[table] + return t, ok +} + +func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.table(table) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *fakeConn) isBad() bool { + if c.stickyBad { + return true + } else if c.bad { + if c.db == nil { + return false + } + // alternate between bad conn and not bad conn + c.db.badConn = !c.db.badConn + return c.db.badConn + } else { + return false + } +} + +func (c *fakeConn) isDirtyAndMark() bool { + if c.skipDirtySession { + return false + } + if c.currTx != nil { + c.dirtySession = true + return false + } + if c.dirtySession { + return true + } + c.dirtySession = true + return false +} + +func (c *fakeConn) Begin() (driver.Tx, error) { + if c.isBad() { + return nil, fakeError{Wrapped: driver.ErrBadConn} + } + if c.currTx != nil { + return nil, errors.New("fakedb: already in a transaction") + } + c.touchMem() + c.currTx = &fakeTx{c: c} + return c.currTx, nil +} + +var hookPostCloseConn struct { + sync.Mutex + fn func(*fakeConn, error) +} + +func setHookpostCloseConn(fn func(*fakeConn, error)) { + hookPostCloseConn.Lock() + defer hookPostCloseConn.Unlock() + hookPostCloseConn.fn = fn +} + +var testStrictClose *testing.T + +// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close +// fails to close. If nil, the check is disabled. +func setStrictFakeConnClose(t *testing.T) { + testStrictClose = t +} + +func (c *fakeConn) ResetSession(ctx context.Context) error { + c.dirtySession = false + c.currTx = nil + if c.isBad() { + return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} + } + return nil +} + +var _ driver.Validator = (*fakeConn)(nil) + +func (c *fakeConn) IsValid() bool { + return !c.isBad() +} + +func (c *fakeConn) Close() (err error) { + drv := fdriver.(*fakeDriver) + defer func() { + if err != nil && testStrictClose != nil { + testStrictClose.Errorf("failed to close a test fakeConn: %v", err) + } + hookPostCloseConn.Lock() + fn := hookPostCloseConn.fn + hookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + if err == nil { + drv.mu.Lock() + drv.closeCount++ + drv.mu.Unlock() + } + }() + c.touchMem() + if c.currTx != nil { + return errors.New("fakedb: can't close fakeConn; in a Transaction") + } + if c.db == nil { + return errors.New("fakedb: can't close fakeConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("fakedb: can't close; dangling statement(s)") + } + c.db = nil + return nil +} + +func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { + for _, arg := range args { + switch arg.Value.(type) { + case int64, float64, bool, nil, []byte, string, time.Time: + default: + if !allowAny { + return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + } + } + } + return nil +} + +func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // Ensure that ExecContext is called if available. + panic("ExecContext was not called.") +} + +func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(c.db.allowAny, args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // Ensure that ExecContext is called if available. + panic("QueryContext was not called.") +} + +func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(c.db.allowAny, args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func errf(msg string, args ...any) error { + return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) +} + +// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where columns must always contain ? marks, +// +// just a limitation for fakedb) +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 3 { + stmt.Close() + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.table = parts[0] + + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) + } + if !strings.HasPrefix(value, "?") { + stmt.Close() + return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", + stmt.table, column) + } + stmt.placeholders++ + stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) + } + return stmt, nil +} + +// parts are table|col=type,col2=type2 +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + stmt.Close() + return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are table|col=?,col2=val +func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) + } + stmt.colName = append(stmt.colName, column) + + if !strings.HasPrefix(value, "?") { + var subsetVal any + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + stmt.Close() + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + case "table": // For testing cursor reads. + c.skipDirtySession = true + vparts := strings.Split(value, "!") + + substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) + if err != nil { + return nil, err + } + cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) + substmt.Close() + if err != nil { + return nil, err + } + subsetVal = cursor + default: + stmt.Close() + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, value) + } + } + return stmt, nil +} + +// hook to simulate broken connections +var hookPrepareBadConn func() bool + +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + panic("use PrepareContext") +} + +func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + c.numPrepare++ + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + + if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { + return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn} + } + + c.touchMem() + var firstStmt, prev *fakeStmt + for _, query := range strings.Split(query, ";") { + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + stmt := &fakeStmt{q: query, c: c, memToucher: c} + if firstStmt == nil { + firstStmt = stmt + } + if len(parts) >= 3 { + switch parts[0] { + case "PANIC": + stmt.panic = parts[1] + parts = parts[2:] + case "WAIT": + wait, err := time.ParseDuration(parts[1]) + if err != nil { + return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) + } + parts = parts[2:] + stmt.wait = wait + } + } + cmd := parts[0] + stmt.cmd = cmd + parts = parts[1:] + + if c.waiter != nil { + c.waiter(ctx) + if err := ctx.Err(); err != nil { + return nil, err + } + } + + if stmt.wait > 0 { + wait := time.NewTimer(stmt.wait) + select { + case <-wait.C: + case <-ctx.Done(): + wait.Stop() + return nil, ctx.Err() + } + } + + c.incrStat(&c.stmtsMade) + var err error + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + stmt, err = c.prepareSelect(stmt, parts) + case "CREATE": + stmt, err = c.prepareCreate(stmt, parts) + case "INSERT": + stmt, err = c.prepareInsert(ctx, stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + stmt, err = c.prepareInsert(ctx, stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + if err != nil { + return nil, err + } + if prev != nil { + prev.next = stmt + } + prev = stmt + } + return firstStmt, nil +} + +func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + if s.panic == "ColumnConverter" { + panic(s.panic) + } + if len(s.placeholderConverter) == 0 { + return driver.DefaultParameterConverter + } + return s.placeholderConverter[idx] +} + +func (s *fakeStmt) Close() error { + if s.panic == "Close" { + panic(s.panic) + } + if s.c == nil { + panic("nil conn in fakeStmt.Close") + } + if s.c.db == nil { + panic("in fakeStmt.Close, conn's db is nil (already closed)") + } + s.touchMem() + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } + if s.next != nil { + s.next.Close() + } + return nil +} + +var errClosed = errors.New("fakedb: statement has been closed") + +// hook to simulate broken connections +var hookExecBadConn func() bool + +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + panic("Using ExecContext") +} + +var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") + +func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if s.panic == "Exec" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { + return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} + } + if s.c.isDirtyAndMark() { + return nil, errFakeConnSessionDirty + } + + err := checkSubsetTypes(s.c.db.allowAny, args) + if err != nil { + return nil, err + } + s.touchMem() + + if s.wait > 0 { + time.Sleep(s.wait) + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.ResultNoRows, nil + case "CREATE": + if err := db.createTable(s.table, s.colName, s.colType); err != nil { + return nil, err + } + return driver.ResultNoRows, nil + case "INSERT": + return s.execInsert(args, true) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return s.execInsert(args, false) + } + return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) +} + +// When doInsert is true, add the row to the table. +// When doInsert is false do prep-work and error checking, but don't +// actually add the row to the table. +func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + t.mu.Lock() + defer t.mu.Unlock() + + var cols []any + if doInsert { + cols = make([]any, len(t.colname)) + } + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val any + if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { + if strvalue == "?" { + val = args[argPos].Value + } else { + // Assign value from argument placeholder name. + for _, a := range args { + if a.Name == strvalue[1:] { + val = a.Value + break + } + } + } + argPos++ + } else { + val = s.colValue[n] + } + if doInsert { + cols[colidx] = val + } + } + + if doInsert { + t.rows = append(t.rows, &row{cols: cols}) + } + return driver.RowsAffected(1), nil +} + +// hook to simulate broken connections +var hookQueryBadConn func() bool + +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + panic("Use QueryContext") +} + +func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if s.panic == "Query" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { + return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} + } + if s.c.isDirtyAndMark() { + return nil, errFakeConnSessionDirty + } + + err := checkSubsetTypes(s.c.db.allowAny, args) + if err != nil { + return nil, err + } + + s.touchMem() + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + setMRows := make([][]*row, 0, 1) + setColumns := make([][]string, 0, 1) + setColType := make([][]string, 0, 1) + + for { + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + if s.table == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { + if args[0].Value == "sleep" { + time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) + } + } + } + if s.table == "tx_status" && s.colName[0] == "tx_status" { + txStatus := "autocommit" + if s.c.currTx != nil { + txStatus = "transaction" + } + cursor := &rowsCursor{ + parentMem: s.c, + posRow: -1, + rows: [][]*row{ + { + { + cols: []any{ + txStatus, + }, + }, + }, + }, + cols: [][]string{ + { + "tx_status", + }, + }, + colType: [][]string{ + { + "string", + }, + }, + errPos: -1, + } + return cursor, nil + } + + t.mu.Lock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + t.mu.Unlock() + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mrows := []*row{} + rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for _, wcol := range s.whereCol { + idx := t.columnIndex(wcol.Column) + if idx == -1 { + t.mu.Unlock() + return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + var argValue any + if wcol.Placeholder == "?" { + argValue = args[wcol.Ordinal-1].Value + } else { + // Assign arg value from placeholder name. + for _, a := range args { + if a.Name == wcol.Placeholder[1:] { + argValue = a.Value + break + } + } + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { + continue rows + } + } + mrow := &row{cols: make([]any, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] + } + mrows = append(mrows, mrow) + } + + var colType []string + for _, column := range s.colName { + colType = append(colType, t.coltype[t.columnIndex(column)]) + } + + t.mu.Unlock() + + setMRows = append(setMRows, mrows) + setColumns = append(setColumns, s.colName) + setColType = append(setColType, colType) + + if s.next == nil { + break + } + s = s.next + } + + cursor := &rowsCursor{ + parentMem: s.c, + posRow: -1, + rows: setMRows, + cols: setColumns, + colType: setColType, + errPos: -1, + } + return cursor, nil +} + +func (s *fakeStmt) NumInput() int { + if s.panic == "NumInput" { + panic(s.panic) + } + return s.placeholders +} + +// hook to simulate broken connections +var hookCommitBadConn func() bool + +func (tx *fakeTx) Commit() error { + tx.c.currTx = nil + if hookCommitBadConn != nil && hookCommitBadConn() { + return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} + } + tx.c.touchMem() + return nil +} + +// hook to simulate broken connections +var hookRollbackBadConn func() bool + +func (tx *fakeTx) Rollback() error { + tx.c.currTx = nil + if hookRollbackBadConn != nil && hookRollbackBadConn() { + return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} + } + tx.c.touchMem() + return nil +} + +type rowsCursor struct { + parentMem memToucher + cols [][]string + colType [][]string + posSet int + posRow int + rows [][]*row + closed bool + + // errPos and err are for making Next return early with error. + errPos int + err error + + // a clone of slices to give out to clients, indexed by the + // original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte + + // Every operation writes to line to enable the race detector + // check for data races. + // This is separate from the fakeConn.line to allow for drivers that + // can start multiple queries on the same transaction at the same time. + line int64 +} + +func (rc *rowsCursor) touchMem() { + rc.parentMem.touchMem() + rc.line++ +} + +func (rc *rowsCursor) Close() error { + rc.touchMem() + rc.parentMem.touchMem() + rc.closed = true + return nil +} + +func (rc *rowsCursor) Columns() []string { + return rc.cols[rc.posSet] +} + +func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { + return colTypeToReflectType(rc.colType[rc.posSet][index]) +} + +var rowsCursorNextHook func(dest []driver.Value) error + +func (rc *rowsCursor) Next(dest []driver.Value) error { + if rowsCursorNextHook != nil { + return rowsCursorNextHook(dest) + } + + if rc.closed { + return errors.New("fakedb: cursor is closed") + } + rc.touchMem() + rc.posRow++ + if rc.posRow == rc.errPos { + return rc.err + } + if rc.posRow >= len(rc.rows[rc.posSet]) { + return io.EOF // per interface spec + } + for i, v := range rc.rows[rc.posSet][rc.posRow].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the sql package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } + } + return nil +} + +func (rc *rowsCursor) HasNextResultSet() bool { + rc.touchMem() + return rc.posSet < len(rc.rows)-1 +} + +func (rc *rowsCursor) NextResultSet() error { + rc.touchMem() + if rc.HasNextResultSet() { + rc.posSet++ + rc.posRow = -1 + return nil + } + return io.EOF // Per interface spec. +} + +// fakeDriverString is like driver.String, but indirects pointers like +// DefaultValueConverter. +// +// This could be surprising behavior to retroactively apply to +// driver.String now that Go1 is out, but this is convenient for +// our TestPointerParamsAndScans. +type fakeDriverString struct{} + +func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { + switch c := v.(type) { + case string, []byte: + return v, nil + case *string: + if c == nil { + return nil, nil + } + return *c, nil + } + return fmt.Sprintf("%v", v), nil +} + +type anyTypeConverter struct{} + +func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { + return v, nil +} + +func converterForType(typ string) driver.ValueConverter { + switch typ { + case "bool": + return driver.Bool + case "nullbool": + return driver.Null{Converter: driver.Bool} + case "byte", "int16": + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "int32": + return driver.Int32 + case "nullbyte", "nullint32", "nullint16": + return driver.Null{Converter: driver.DefaultParameterConverter} + case "string": + return driver.NotNull{Converter: fakeDriverString{}} + case "nullstring": + return driver.Null{Converter: fakeDriverString{}} + case "int64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullint64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "float64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullfloat64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "datetime": + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nulldatetime": + return driver.Null{Converter: driver.DefaultParameterConverter} + case "any": + return anyTypeConverter{} + } + panic("invalid fakedb column type of " + typ) +} + +func colTypeToReflectType(typ string) reflect.Type { + switch typ { + case "bool": + return reflect.TypeOf(false) + case "nullbool": + return reflect.TypeOf(sql.NullBool{}) + case "int16": + return reflect.TypeOf(int16(0)) + case "nullint16": + return reflect.TypeOf(sql.NullInt16{}) + case "int32": + return reflect.TypeOf(int32(0)) + case "nullint32": + return reflect.TypeOf(sql.NullInt32{}) + case "string": + return reflect.TypeOf("") + case "nullstring": + return reflect.TypeOf(sql.NullString{}) + case "int64": + return reflect.TypeOf(int64(0)) + case "nullint64": + return reflect.TypeOf(sql.NullInt64{}) + case "float64": + return reflect.TypeOf(float64(0)) + case "nullfloat64": + return reflect.TypeOf(sql.NullFloat64{}) + case "datetime": + return reflect.TypeOf(time.Time{}) + case "any": + return reflect.TypeOf(new(any)).Elem() + } + panic("invalid fakedb column type of " + typ) +} diff --git a/sentrysql/legacydb_test.go b/sentrysql/legacydb_test.go new file mode 100644 index 00000000..96e7b898 --- /dev/null +++ b/sentrysql/legacydb_test.go @@ -0,0 +1,835 @@ +//nolint:all +package sentrysql_test + +// This file is a fork of +// https://cs.opensource.google/go/go/+/refs/tags/go1.7.6:src/database/sql/fakedb_test.go +// +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Copyright (c) 2009 The Go Authors. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// - Redistributions of source code must retain the above copyright +// +// notice, this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above +// +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// - Neither the name of Google Inc. nor the names of its +// +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import ( + "database/sql/driver" + "errors" + "fmt" + "io" + "log" + "strconv" + "strings" + "sync" + "time" +) + +var _ = log.Printf + +// legacyDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntactically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE||=,=,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT||col=val,col2=val2,col3=? +// SELECT||projectcol1,projectcol2|filtercol=?,filtercol2=? +// +// Any of these can be preceded by PANIC||, to cause the +// named method on fakeStmt to panic. +// +// When opening a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type legacyDriver struct { + mu sync.Mutex // guards 3 following fields + openCount int // conn opens + closeCount int // conn closes + waitCh chan struct{} + waitingCh chan struct{} + dbs map[string]*legacyDB +} + +type legacyDB struct { + name string + + mu sync.Mutex + legacyTables map[string]*legacyTable + badConn bool +} + +type legacyTable struct { + mu sync.Mutex + colname []string + coltype []string + legacyRows []*legacyRow +} + +func (t *legacyTable) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type legacyRow struct { + cols []interface{} // must be same size as its legacyTable colname + coltype +} + +type legacyConn struct { + db *legacyDB // where to return ourselves to + + currTx *legacyTx + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int + numPrepare int + + // bad connection tests; see isBad() + bad bool + stickyBad bool +} + +func (c *legacyConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() +} + +type legacyTx struct { + c *legacyConn +} + +type legacyStmt struct { + c *legacyConn + q string // just for debugging + + cmd string + legacyTable string + panic string + + closed bool + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []interface{} // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []string // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var ldriver driver.Driver = &legacyDriver{} + +// hook to simulate connection failures +var legacyHookOpenErr struct { + sync.Mutex + fn func() error +} + +// Supports dsn forms: +// +// +// ; (only currently supported option is `badConn`, +// which causes driver.ErrBadConn to be returned on +// every other conn.Begin()) +func (d *legacyDriver) Open(dsn string) (driver.Conn, error) { + legacyHookOpenErr.Lock() + fn := legacyHookOpenErr.fn + legacyHookOpenErr.Unlock() + if fn != nil { + if err := fn(); err != nil { + return nil, err + } + } + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, errors.New("fakedb: no database name") + } + name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + conn := &legacyConn{db: db} + + if len(parts) >= 2 && parts[1] == "badConn" { + conn.bad = true + } + if d.waitCh != nil { + d.waitingCh <- struct{}{} + <-d.waitCh + d.waitCh = nil + d.waitingCh = nil + } + return conn, nil +} + +func (d *legacyDriver) getDB(name string) *legacyDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*legacyDB) + } + db, ok := d.dbs[name] + if !ok { + db = &legacyDB{name: name} + d.dbs[name] = db + } + return db +} + +func (db *legacyDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.legacyTables = nil +} + +func (db *legacyDB) createTable(name string, columnNames, columnTypes []string) error { + db.mu.Lock() + defer db.mu.Unlock() + if db.legacyTables == nil { + db.legacyTables = make(map[string]*legacyTable) + } + if _, exist := db.legacyTables[name]; exist { + return fmt.Errorf("legacyTable %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("create legacyTable of %q len(names) != len(types): %d vs %d", + name, len(columnNames), len(columnTypes)) + } + db.legacyTables[name] = &legacyTable{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *legacyDB) legacyTable(legacyTable string) (*legacyTable, bool) { + if db.legacyTables == nil { + return nil, false + } + t, ok := db.legacyTables[legacyTable] + return t, ok +} + +func (db *legacyDB) columnType(legacyTable, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.legacyTable(legacyTable) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *legacyConn) isBad() bool { + if c.stickyBad { + return true + } else if c.bad { + // alternate between bad conn and not bad conn + c.db.badConn = !c.db.badConn + return c.db.badConn + } else { + return false + } +} + +func (c *legacyConn) Begin() (driver.Tx, error) { + if c.isBad() { + return nil, driver.ErrBadConn + } + if c.currTx != nil { + return nil, errors.New("already in a transaction") + } + c.currTx = &legacyTx{c: c} + return c.currTx, nil +} + +var legacyHookPostCloseConn struct { + sync.Mutex + fn func(*legacyConn, error) +} + +func (c *legacyConn) Close() (err error) { + drv := ldriver.(*legacyDriver) + defer func() { + if err != nil && testStrictClose != nil { + testStrictClose.Errorf("failed to close a test legacyConn: %v", err) + } + legacyHookPostCloseConn.Lock() + fn := legacyHookPostCloseConn.fn + legacyHookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + if err == nil { + drv.mu.Lock() + drv.closeCount++ + drv.mu.Unlock() + } + }() + if c.currTx != nil { + return errors.New("can't close legacyConn; in a Transaction") + } + if c.db == nil { + return errors.New("can't close legacyConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("can't close; dangling statement(s)") + } + c.db = nil + return nil +} + +func legacyCheckSubsetTypes(args []driver.Value) error { + for n, arg := range args { + switch arg.(type) { + case int64, float64, bool, nil, []byte, string, time.Time: + default: + return fmt.Errorf("fakedb_test: invalid argument #%d: %v, type %T", n+1, arg, arg) + } + } + return nil +} + +func (c *legacyConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := legacyCheckSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func (c *legacyConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := legacyCheckSubsetTypes(args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +// parts are legacyTable|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where columns must always contain ? marks, +// +// just a limitation for fakedb) +func (c *legacyConn) prepareSelect(stmt *legacyStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 3 { + stmt.Close() + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.legacyTable = parts[0] + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("SELECT on legacyTable %q has invalid column spec of %q (index %d)", stmt.legacyTable, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.legacyTable, column) + if !ok { + stmt.Close() + return nil, errf("SELECT on legacyTable %q references non-existent column %q", stmt.legacyTable, column) + } + if value != "?" { + stmt.Close() + return nil, errf("SELECT on legacyTable %q has pre-bound value for where column %q; need a question mark", + stmt.legacyTable, column) + } + stmt.whereCol = append(stmt.whereCol, column) + stmt.placeholders++ + } + return stmt, nil +} + +// parts are legacyTable|col=type,col2=type2 +func (c *legacyConn) prepareCreate(stmt *legacyStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.legacyTable = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + stmt.Close() + return nil, errf("CREATE legacyTable %q has invalid column spec of %q (index %d)", stmt.legacyTable, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are legacyTable|col=?,col2=val +func (c *legacyConn) prepareInsert(stmt *legacyStmt, parts []string) (driver.Stmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.legacyTable = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("INSERT legacyTable %q has invalid column spec of %q (index %d)", stmt.legacyTable, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.legacyTable, column) + if !ok { + stmt.Close() + return nil, errf("INSERT legacyTable %q references non-existent column %q", stmt.legacyTable, column) + } + stmt.colName = append(stmt.colName, column) + + if value != "?" { + var subsetVal interface{} + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + stmt.Close() + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + default: + stmt.Close() + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, "?") + } + } + return stmt, nil +} + +// hook to simulate broken connections +var legacyHookPrepareBadConn func() bool + +func (c *legacyConn) Prepare(query string) (driver.Stmt, error) { + c.numPrepare++ + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + + if c.stickyBad || (legacyHookPrepareBadConn != nil && legacyHookPrepareBadConn()) { + return nil, driver.ErrBadConn + } + + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + stmt := &legacyStmt{q: query, c: c} + if len(parts) >= 3 && parts[0] == "PANIC" { + stmt.panic = parts[1] + parts = parts[2:] + } + cmd := parts[0] + stmt.cmd = cmd + parts = parts[1:] + + c.incrStat(&c.stmtsMade) + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + return c.prepareSelect(stmt, parts) + case "CREATE": + return c.prepareCreate(stmt, parts) + case "INSERT": + return c.prepareInsert(stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the legacyRow. + // Used for some of the concurrent tests. + return c.prepareInsert(stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + return stmt, nil +} + +func (s *legacyStmt) ColumnConverter(idx int) driver.ValueConverter { + if s.panic == "ColumnConverter" { + panic(s.panic) + } + if len(s.placeholderConverter) == 0 { + return driver.DefaultParameterConverter + } + return s.placeholderConverter[idx] +} + +func (s *legacyStmt) Close() error { + if s.panic == "Close" { + panic(s.panic) + } + if s.c == nil { + panic("nil conn in legacyStmt.Close") + } + if s.c.db == nil { + panic("in legacyStmt.Close, conn's db is nil (already closed)") + } + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } + return nil +} + +// hook to simulate broken connections +var legacyHookExecBadConn func() bool + +func (s *legacyStmt) Exec(args []driver.Value) (driver.Result, error) { + if s.panic == "Exec" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (legacyHookExecBadConn != nil && legacyHookExecBadConn()) { + return nil, driver.ErrBadConn + } + + err := legacyCheckSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.ResultNoRows, nil + case "CREATE": + if err := db.createTable(s.legacyTable, s.colName, s.colType); err != nil { + return nil, err + } + return driver.ResultNoRows, nil + case "INSERT": + return s.execInsert(args, true) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the legacyRow. + // Used for some of the concurrent tests. + return s.execInsert(args, false) + } + fmt.Printf("EXEC statement, cmd=%q: %#v\n", s.cmd, s) + return nil, fmt.Errorf("unimplemented statement Exec command type of %q", s.cmd) +} + +// When doInsert is true, add the legacyRow to the legacyTable. +// When doInsert is false do prep-work and error checking, but don't +// actually add the legacyRow to the legacyTable. +func (s *legacyStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result, error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.legacyTable(s.legacyTable) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: legacyTable %q doesn't exist", s.legacyTable) + } + + t.mu.Lock() + defer t.mu.Unlock() + + var cols []interface{} + if doInsert { + cols = make([]interface{}, len(t.colname)) + } + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val interface{} + if strvalue, ok := s.colValue[n].(string); ok && strvalue == "?" { + val = args[argPos] + argPos++ + } else { + val = s.colValue[n] + } + if doInsert { + cols[colidx] = val + } + } + + if doInsert { + t.legacyRows = append(t.legacyRows, &legacyRow{cols: cols}) + } + return driver.RowsAffected(1), nil +} + +// hook to simulate broken connections +var legacyHookQueryBadConn func() bool + +func (s *legacyStmt) Query(args []driver.Value) (driver.Rows, error) { + if s.panic == "Query" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (legacyHookQueryBadConn != nil && legacyHookQueryBadConn()) { + return nil, driver.ErrBadConn + } + + err := legacyCheckSubsetTypes(args) + if err != nil { + return nil, err + } + + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + db.mu.Lock() + t, ok := db.legacyTable(s.legacyTable) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: legacyTable %q doesn't exist", s.legacyTable) + } + + if s.legacyTable == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { + if args[0] == "sleep" { + time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + } + } + } + + t.mu.Lock() + defer t.mu.Unlock() + + colIdx := make(map[string]int) // select column name -> column index in legacyTable + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mlegacyRows := []*legacyRow{} +legacyRows: + for _, tlegacyRow := range t.legacyRows { + // Process the where clause, skipping non-match legacyRows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for widx, wcol := range s.whereCol { + idx := t.columnIndex(wcol) + if idx == -1 { + return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + } + tcol := tlegacyRow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + continue legacyRows + } + } + mlegacyRow := &legacyRow{cols: make([]interface{}, len(s.colName))} + for seli, name := range s.colName { + mlegacyRow.cols[seli] = tlegacyRow.cols[colIdx[name]] + } + mlegacyRows = append(mlegacyRows, mlegacyRow) + } + + cursor := &legacyRowsCursor{ + pos: -1, + legacyRows: mlegacyRows, + cols: s.colName, + errPos: -1, + } + return cursor, nil +} + +func (s *legacyStmt) NumInput() int { + if s.panic == "NumInput" { + panic(s.panic) + } + return s.placeholders +} + +// hook to simulate broken connections +var legacyHookCommitBadConn func() bool + +func (tx *legacyTx) Commit() error { + tx.c.currTx = nil + if legacyHookCommitBadConn != nil && legacyHookCommitBadConn() { + return driver.ErrBadConn + } + return nil +} + +// hook to simulate broken connections +var legacyHookRollbackBadConn func() bool + +func (tx *legacyTx) Rollback() error { + tx.c.currTx = nil + if legacyHookRollbackBadConn != nil && legacyHookRollbackBadConn() { + return driver.ErrBadConn + } + return nil +} + +type legacyRowsCursor struct { + cols []string + pos int + legacyRows []*legacyRow + closed bool + + // errPos and err are for making Next return early with error. + errPos int + err error + + // a clone of slices to give out to clients, indexed by the + // the original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte +} + +func (rc *legacyRowsCursor) Close() error { + if !rc.closed { + for _, bs := range rc.bytesClone { + bs[0] = 255 // first byte corrupted + } + } + rc.closed = true + return nil +} + +func (rc *legacyRowsCursor) Columns() []string { + return rc.cols +} + +var legacyRowsCursorNextHook func(dest []driver.Value) error + +func (rc *legacyRowsCursor) Next(dest []driver.Value) error { + if legacyRowsCursorNextHook != nil { + return legacyRowsCursorNextHook(dest) + } + + if rc.closed { + return errors.New("fakedb: cursor is closed") + } + rc.pos++ + if rc.pos == rc.errPos { + return rc.err + } + if rc.pos >= len(rc.legacyRows) { + return io.EOF // per interface spec + } + for i, v := range rc.legacyRows[rc.pos].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the sql package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } + } + return nil +} + +// legacyDriverString is like driver.String, but indirects pointers like +// DefaultValueConverter. +// +// This could be surprising behavior to retroactively apply to +// driver.String now that Go1 is out, but this is convenient for +// our TestPointerParamsAndScans. +type legacyDriverString struct{} + +func (legacyDriverString) ConvertValue(v interface{}) (driver.Value, error) { + switch c := v.(type) { + case string, []byte: + return v, nil + case *string: + if c == nil { + return nil, nil + } + return *c, nil + } + return fmt.Sprintf("%v", v), nil +} diff --git a/sentrysql/operation.go b/sentrysql/operation.go new file mode 100644 index 00000000..4b9591de --- /dev/null +++ b/sentrysql/operation.go @@ -0,0 +1,25 @@ +package sentrysql + +import "strings" + +var knownDatabaseOperations = map[string]struct{}{ + "SELECT": {}, + "INSERT": {}, + "DELETE": {}, + "UPDATE": {}, +} + +func parseDatabaseOperation(query string) string { + // The operation is the first word of the query. + operation := query + if i := strings.Index(query, " "); i >= 0 { + operation = strings.ToUpper(query[:i]) + } + + // Only returns known words. + if _, ok := knownDatabaseOperations[operation]; !ok { + return "" + } + + return operation +} diff --git a/sentrysql/operation_test.go b/sentrysql/operation_test.go new file mode 100644 index 00000000..7aae625f --- /dev/null +++ b/sentrysql/operation_test.go @@ -0,0 +1,50 @@ +package sentrysql + +import "testing" + +func TestParseDatabaseOperation(t *testing.T) { + tests := []struct { + name string + query string + want string + }{ + { + name: "SELECT", + query: "SELECT * FROM users", + want: "SELECT", + }, + { + name: "INSERT", + query: "INSERT INTO users (id, name) VALUES (1, 'John')", + want: "INSERT", + }, + { + name: "DELETE", + query: "DELETE FROM users WHERE id = 1", + want: "DELETE", + }, + { + name: "UPDATE", + query: "UPDATE users SET name = 'John' WHERE id = 1", + want: "UPDATE", + }, + { + name: "findById", + query: "findById", + want: "", + }, + { + name: "Empty", + query: "", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseDatabaseOperation(tt.query); got != tt.want { + t.Errorf("parseDatabaseOperation() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sentrysql/options.go b/sentrysql/options.go new file mode 100644 index 00000000..7772f109 --- /dev/null +++ b/sentrysql/options.go @@ -0,0 +1,25 @@ +package sentrysql + +type Option func(*sentrySQLConfig) + +// WithDatabaseSystem specifies the current database system. +func WithDatabaseSystem(system DatabaseSystem) Option { + return func(config *sentrySQLConfig) { + config.databaseSystem = system + } +} + +// WithDatabaseName specifies the name of the current database. +func WithDatabaseName(name string) Option { + return func(config *sentrySQLConfig) { + config.databaseName = name + } +} + +// WithServerAddress specifies the address and port of the current database server. +func WithServerAddress(address string, port string) Option { + return func(config *sentrySQLConfig) { + config.serverAddress = address + config.serverPort = port + } +} diff --git a/sentrysql/sentrysql.go b/sentrysql/sentrysql.go new file mode 100644 index 00000000..24363451 --- /dev/null +++ b/sentrysql/sentrysql.go @@ -0,0 +1,81 @@ +package sentrysql + +import ( + "database/sql/driver" + + "github.com/getsentry/sentry-go" +) + +// DatabaseSystem points to the list of accepted OpenTelemetry database system. +// The ones defined here are not exhaustive, but are the ones that are supported by Sentry. +// Although you can override the value by creating your own, it will still be sent to Sentry, +// but it most likely will not appear on the Queries Insights page. +type DatabaseSystem string + +const ( + // PostgreSQL specifies the PostgreSQL database system. + PostgreSQL DatabaseSystem = "postgresql" + // MySQL specifies the MySQL database system. + MySQL DatabaseSystem = "mysql" + // SQLite specifies the SQLite database system. + SQLite DatabaseSystem = "sqlite" + // Oracle specifies the Oracle database system. + Oracle DatabaseSystem = "oracle" + // MSSQL specifies the Microsoft SQL Server database system. + MSSQL DatabaseSystem = "mssql" +) + +type sentrySQLConfig struct { + databaseSystem DatabaseSystem + databaseName string + serverAddress string + serverPort string +} + +func (s *sentrySQLConfig) SetData(span *sentry.Span, query string) { + if span == nil { + return + } + + if s.databaseSystem != "" { + span.SetData("db.system", s.databaseSystem) + } + if s.databaseName != "" { + span.SetData("db.name", s.databaseName) + } + if s.serverAddress != "" { + span.SetData("server.address", s.serverAddress) + } + if s.serverPort != "" { + span.SetData("server.port", s.serverPort) + } + + if query != "" { + databaseOperation := parseDatabaseOperation(query) + if databaseOperation != "" { + span.SetData("db.operation", databaseOperation) + } + } +} + +// NewSentrySQL is a wrapper for driver.Driver that provides tracing for SQL queries. +// The span will only be created if the parent span is available. +func NewSentrySQL(driver driver.Driver, options ...Option) driver.Driver { + var config sentrySQLConfig + for _, option := range options { + option(&config) + } + + return &sentrySQLDriver{originalDriver: driver, config: &config} +} + +// NewSentrySQLConnector is a wrapper for driver.Connector that provides tracing for SQL queries. +// The span will only be created if the parent span is available. +func NewSentrySQLConnector(connector driver.Connector, options ...Option) driver.Connector { + var config sentrySQLConfig + for _, option := range options { + option(&config) + } + + return &sentrySQLConnector{originalConnector: connector, config: &config} +} diff --git a/sentrysql/sentrysql_connector_test.go b/sentrysql/sentrysql_connector_test.go new file mode 100644 index 00000000..9fbe0a19 --- /dev/null +++ b/sentrysql/sentrysql_connector_test.go @@ -0,0 +1,1289 @@ +package sentrysql_test + +import ( + "context" + "database/sql" + "strings" + "testing" + "time" + + "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go/internal/testutils" + "github.com/getsentry/sentry-go/sentrysql" + "github.com/google/go-cmp/cmp" +) + +//nolint:dupl +func TestNewSentrySQLConnector_Integration(t *testing.T) { + db := sql.OpenDB(sentrysql.NewSentrySQLConnector(&fakeConnector{}, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("fakedb")), sentrysql.WithDatabaseName("fake"))) + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on fakedb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id|id=?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|id|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{1, "John"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "CREATE|temporary_test|id=int32,name=string", + WantError: false, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "CREATE|temporary_test|id=int32,name=string", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Ping", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.Ping() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("PingContext", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.PingContext(context.Background()) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("Driver", func(t *testing.T) { + // Just checking if this works and doesn't panic + driver := db.Driver() + if driver == nil { + t.Fatal("driver is nil") + } + }) +} + +//nolint:dupl +func TestNewSentrySQLConnector_Conn(t *testing.T) { + db := sql.OpenDB(sentrysql.NewSentrySQLConnector(&fakeConnector{}, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("fakedb")), sentrysql.WithDatabaseName("fake"))) + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on fakedb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id|id=?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|id|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + Parameters: []interface{}{1}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + rows, err := conn.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{2, "Peter"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = conn.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) +} + +//nolint:dupl,gocyclo +func TestNewSentrySQLConnector_BeginTx(t *testing.T) { + t.Skip("fakedb does not implement transactions") + + db := sql.OpenDB(sentrysql.NewSentrySQLConnector(&fakeConnector{}, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("fakedb")), sentrysql.WithDatabaseName("fake"))) + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on fakedb: %v", err) + } + } + + t.Run("Singles", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{2, "Peter"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Commit() + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = tx.Rollback() + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Multiple Queries", func(t *testing.T) { + spansCh := make(chan []*sentry.Span, 2) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + defer cancel() + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = tx.Rollback() + }() + + var name string + err = tx.QueryRowContext(ctx, "SELECT|query_test|name|id=?", 1).Scan(&name) + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, "INSERT|exec_test|id=?,name=?", 5, "Catherine") + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Commit() + if err != nil { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + + cancel() + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got []*sentry.Span + for e := range spansCh { + got = append(got, e...) + } + + want := []*sentry.Span{ + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|name|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + } + + if diff := cmp.Diff(want, got, optstrans); diff != "" { + t.Errorf("Span mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("Rollback", func(t *testing.T) { + spansCh := make(chan []*sentry.Span, 2) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + defer cancel() + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = tx.Rollback() + }() + + var name string + err = tx.QueryRowContext(ctx, "SELECT|query_test|name|id=?", 1).Scan(&name) + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, "INSERT|exec_test|id=?,name=?", 5, "Catherine") + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Rollback() + if err != nil { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + + cancel() + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got []*sentry.Span + for e := range spansCh { + got = append(got, e...) + } + + want := []*sentry.Span{ + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|name|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + } + + if diff := cmp.Diff(want, got, optstrans); diff != "" { + t.Errorf("Span mismatch (-want +got):\n%s", diff) + } + }) +} + +//nolint:dupl +func TestNewSentrySQLConnector_PrepareContext(t *testing.T) { + db := sql.OpenDB(sentrysql.NewSentrySQLConnector(&fakeConnector{}, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("fakedb")), sentrysql.WithDatabaseName("fake"))) + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on fakedb: %v", err) + } + } + + t.Run("Exec", func(t *testing.T) { + t.Skip("fakedb does not implement Exec") + + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{3, "Sarah"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT exec_test (id, name) VALUES (?, ?, ?, ?)", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + stmt, err := db.PrepareContext(ctx, tt.Query) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + _, err = stmt.Exec(tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Query", func(t *testing.T) { + t.Skip("fakedb does not implement Query") + + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id,name,age|id=?", + Parameters: []interface{}{2}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|id,name,age|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT * FROM query_test WHERE id =", + Parameters: []interface{}{1}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("fakedb"), + "db.name": "fake", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT * FROM query_test WHERE id =", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + stmt, err := db.PrepareContext(ctx, tt.Query) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + rows, err := stmt.Query(tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) +} + +//nolint:dupl +func TestNewSentrySQLConnector_NoParentSpan(t *testing.T) { + db := sql.OpenDB(sentrysql.NewSentrySQLConnector(&fakeConnector{}, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("fakedb")), sentrysql.WithDatabaseName("fake"))) + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on fakedb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id,name,age|id=?", + Parameters: []interface{}{1}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{1, "John"}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) +} diff --git a/sentrysql/sentrysql_legacy_test.go b/sentrysql/sentrysql_legacy_test.go new file mode 100644 index 00000000..e0cb1ae9 --- /dev/null +++ b/sentrysql/sentrysql_legacy_test.go @@ -0,0 +1,697 @@ +package sentrysql_test + +import ( + "context" + "database/sql" + "strings" + "testing" + "time" + + "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go/internal/testutils" + "github.com/getsentry/sentry-go/sentrysql" + "github.com/google/go-cmp/cmp" +) + +//nolint:dupl,gocyclo +func TestNewSentrySQLLegacy_Integration(t *testing.T) { + db, err := sql.Open("sentrysql-legacy", "fake") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on legacydb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id|id=?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|id|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if len(diffs) == 0 && !foundMatch { + t.Logf("No span was found for query: %s", tt.Query) + return + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=1,name=John", + Parameters: nil, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=1,name=John", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "CREATE|temporary_test|id=int32,name=string", + WantError: false, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "CREATE|temporary_test|id=int32,name=string", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if len(diffs) == 0 && !foundMatch { + t.Logf("No span was found for query: %s", tt.Query) + return + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Ping", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.Ping() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("PingContext", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.PingContext(context.Background()) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("Driver", func(t *testing.T) { + // Just checking if this works and doesn't panic + driver := db.Driver() + if driver == nil { + t.Fatal("driver is nil") + } + }) +} + +//nolint:dupl,gocyclo +func TestNewSentrySQLLegacy_Conn(t *testing.T) { + db, err := sql.Open("sentrysql-legacy", "fake") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on legacydb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id|id=?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "SELECT|query_test|id|id=?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + Parameters: []interface{}{1}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + rows, err := conn.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if len(diffs) == 0 && !foundMatch { + t.Logf("No span was found for query: %s", tt.Query) + return + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{2, "Peter"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("legacydb"), + "db.name": "fake", + }, + Description: "INSERT|exec_test|id=?,name=?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = conn.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if len(diffs) == 0 && !foundMatch { + t.Logf("No span was found for query: %s", tt.Query) + return + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) +} + +//nolint:dupl +func TestNewSentrySQLLegacy_NoParentSpan(t *testing.T) { + db, err := sql.Open("sentrysql-legacy", "fake") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + _, _ = db.Exec("WIPE") + _ = db.Close() + }) + + setupQueries := []string{ + "CREATE|exec_test|id=int32,name=string", + "CREATE|query_test|id=int32,name=string,age=int32,created_at=string", + "INSERT|query_test|id=1,name=John,age=30,created_at=2023-01-01", + "INSERT|query_test|id=2,name=Jane,age=25,created_at=2023-01-02", + "INSERT|query_test|id=3,name=Bob,age=35,created_at=2023-01-03", + } + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err := db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on legacydb: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT|query_test|id,name,age|id=?", + Parameters: []interface{}{1}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT|exec_test|id=?,name=?", + Parameters: []interface{}{1, "John"}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) +} diff --git a/sentrysql/sentrysql_test.go b/sentrysql/sentrysql_test.go new file mode 100644 index 00000000..bef2f2d9 --- /dev/null +++ b/sentrysql/sentrysql_test.go @@ -0,0 +1,1418 @@ +package sentrysql_test + +import ( + "context" + "database/sql" + "os" + "strings" + "testing" + "time" + + "github.com/getsentry/sentry-go" + "github.com/getsentry/sentry-go/internal/testutils" + "github.com/getsentry/sentry-go/sentrysql" + sqlite "github.com/glebarez/go-sqlite" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" +) + +var optstrans = cmp.Options{ + cmpopts.IgnoreFields( + sentry.Span{}, + "TraceID", "SpanID", "ParentSpanID", "StartTime", "EndTime", + "mu", "parent", "sampleRate", "ctx", "dynamicSamplingContext", "recorder", "finishOnce", "collectProfile", "contexts", + ), +} + +func TestMain(m *testing.M) { + sql.Register("sentrysql-sqlite", sentrysql.NewSentrySQL(&sqlite.Driver{}, sentrysql.WithDatabaseName("memory"), sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("sqlite")), sentrysql.WithServerAddress("localhost", "5432"))) + // sentrysql-legacy is used by `sentrysql_legacy_test.go` + sql.Register("sentrysql-legacy", sentrysql.NewSentrySQL(ldriver, sentrysql.WithDatabaseSystem(sentrysql.DatabaseSystem("legacydb")), sentrysql.WithDatabaseName("fake"))) + + os.Exit(m.Run()) +} + +//nolint:dupl +func TestNewSentrySQL_Integration(t *testing.T) { + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + t.Fatalf("opening sqlite: %v", err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + setupQueries := []string{ + "CREATE TABLE exec_test (id INT, name TEXT)", + "CREATE TABLE query_test (id INT, name TEXT, age INT, created_at TEXT)", + "INSERT INTO query_test (id, name, age, created_at) VALUES (1, 'John', 30, '2023-01-01')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (2, 'Jane', 25, '2023-01-02')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (3, 'Bob', 35, '2023-01-03')", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err = db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on sqlite: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT * FROM query_test WHERE id = ?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT * FROM query_test WHERE id = ?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Parameters: []interface{}{1, "John"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "UPDATE exec_test SET name = ? WHERE id = ?", + Parameters: []interface{}{"Bob", 1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "UPDATE", + }, + Description: "UPDATE exec_test SET name = ? WHERE id = ?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "DELETE FROM exec_test WHERE name = ?", + Parameters: []interface{}{"Nolan"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "DELETE", + }, + Description: "DELETE FROM exec_test WHERE name = ?", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Parameters: []interface{}{ + 1, "John", "Doe", 1, + }, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + { + Query: "CREATE TABLE temporary_test (id INT, name TEXT)", + WantError: false, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + }, + Description: "CREATE TABLE temporary_test (id INT, name TEXT)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Ping", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.Ping() + if err != nil { + t.Fatal(err) + } + }) + + t.Run("PingContext", func(t *testing.T) { + // Just checking if this works and doesn't panic + err := db.PingContext(context.Background()) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("Driver", func(t *testing.T) { + // Just checking if this works and doesn't panic + driver := db.Driver() + if driver == nil { + t.Fatal("driver is nil") + } + }) +} + +//nolint:dupl +func TestNewSentrySQL_Conn(t *testing.T) { + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + t.Fatalf("opening sqlite: %v", err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + setupQueries := []string{ + "CREATE TABLE exec_test (id INT, name TEXT)", + "CREATE TABLE query_test (id INT, name TEXT, age INT, created_at TEXT)", + "INSERT INTO query_test (id, name, age, created_at) VALUES (1, 'John', 30, '2023-01-01')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (2, 'Jane', 25, '2023-01-02')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (3, 'Bob', 35, '2023-01-03')", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err = db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on sqlite: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT * FROM query_test WHERE id = ?", + Parameters: []interface{}{1}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT * FROM query_test WHERE id = ?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT FROM query_test", + Parameters: []interface{}{1}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT FROM query_test", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + rows, err := conn.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Parameters: []interface{}{2, "Peter"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = conn.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) +} + +//nolint:dupl,gocyclo +func TestNewSentrySQL_BeginTx(t *testing.T) { + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + t.Fatalf("opening sqlite: %v", err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + setupQueries := []string{ + "CREATE TABLE exec_test (id INT, name TEXT)", + "CREATE TABLE query_test (id INT, name TEXT, age INT, created_at TEXT)", + "INSERT INTO query_test (id, name, age, created_at) VALUES (1, 'John', 30, '2023-01-01')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (2, 'Jane', 25, '2023-01-02')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (3, 'Bob', 35, '2023-01-03')", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err = db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on sqlite: %v", err) + } + } + + t.Run("Singles", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Parameters: []interface{}{2, "Peter"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + cancel() + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Commit() + if err != nil && !tt.WantError { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = tx.Rollback() + + _ = conn.Close() + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Multiple Queries", func(t *testing.T) { + spansCh := make(chan []*sentry.Span, 2) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + defer cancel() + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = tx.Rollback() + }() + + var name string + err = tx.QueryRowContext(ctx, "SELECT name FROM query_test WHERE id = ?", 1).Scan(&name) + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, "INSERT INTO exec_test (id, name) VALUES (?, ?)", 5, "Catherine") + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Commit() + if err != nil { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + + cancel() + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got []*sentry.Span + for e := range spansCh { + got = append(got, e...) + } + + want := []*sentry.Span{ + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT name FROM query_test WHERE id = ?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + } + + if diff := cmp.Diff(want, got, optstrans); diff != "" { + t.Errorf("Span mismatch (-want +got):\n%s", diff) + } + }) + + t.Run("Rollback", func(t *testing.T) { + spansCh := make(chan []*sentry.Span, 2) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + defer cancel() + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + defer func() { + _ = tx.Rollback() + }() + + var name string + err = tx.QueryRowContext(ctx, "SELECT name FROM query_test WHERE id = ?", 1).Scan(&name) + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _, err = tx.ExecContext(ctx, "INSERT INTO exec_test (id, name) VALUES (?, ?)", 5, "Catherine") + if err != nil { + _ = tx.Rollback() + _ = conn.Close() + cancel() + t.Fatal(err) + } + + err = tx.Rollback() + if err != nil { + _ = conn.Close() + cancel() + t.Fatal(err) + } + + _ = conn.Close() + + span.Finish() + + cancel() + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got []*sentry.Span + for e := range spansCh { + got = append(got, e...) + } + + want := []*sentry.Span{ + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT name FROM query_test WHERE id = ?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + { + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + } + + if diff := cmp.Diff(want, got, optstrans); diff != "" { + t.Errorf("Span mismatch (-want +got):\n%s", diff) + } + }) +} + +//nolint:dupl +func TestNewSentrySQL_PrepareContext(t *testing.T) { + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + t.Fatalf("opening sqlite: %v", err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + setupQueries := []string{ + "CREATE TABLE exec_test (id INT, name TEXT)", + "CREATE TABLE query_test (id INT, name TEXT, age INT, created_at TEXT)", + "INSERT INTO query_test (id, name, age, created_at) VALUES (1, 'John', 30, '2023-01-01')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (2, 'Jane', 25, '2023-01-02')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (3, 'Bob', 35, '2023-01-03')", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err = db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on sqlite: %v", err) + } + } + + t.Run("Exec", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Parameters: []interface{}{3, "Sarah"}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Parameters: []interface{}{4, "John", "Doe", "John Doe"}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "INSERT", + }, + Description: "INSERT INTO exec_test (id, name) VALUES (?, ?, ?, ?)", + Op: "db.sql.exec", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + stmt, err := db.PrepareContext(ctx, tt.Query) + if err != nil { + cancel() + t.Fatal(err) + } + + _, err = stmt.Exec(tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) + + t.Run("Query", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT * FROM query_test WHERE id = ?", + Parameters: []interface{}{2}, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT * FROM query_test WHERE id = ?", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusOK, + }, + }, + { + Query: "SELECT * FROM query_test WHERE id =", + Parameters: []interface{}{1}, + WantError: true, + WantSpan: &sentry.Span{ + Data: map[string]interface{}{ + "db.system": sentrysql.DatabaseSystem("sqlite"), + "db.name": "memory", + "server.address": "localhost", + "server.port": "5432", + "db.operation": "SELECT", + }, + Description: "SELECT * FROM query_test WHERE id =", + Op: "db.sql.query", + Tags: nil, + Origin: "manual", + Sampled: sentry.SampledTrue, + Status: sentry.SpanStatusInternalError, + }, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + span := sentry.StartSpan(ctx, "fake_parent", sentry.WithTransactionName("Fake Parent")) + ctx = span.Context() + + stmt, err := db.PrepareContext(ctx, tt.Query) + if err != nil { + cancel() + t.Fatal(err) + } + + rows, err := stmt.Query(tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + span.Finish() + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + for i, tt := range tests { + var foundMatch = false + gotSpans := got[i] + + var diffs []string + for _, gotSpan := range gotSpans { + if diff := cmp.Diff(tt.WantSpan, gotSpan, optstrans); diff != "" { + diffs = append(diffs, diff) + } else { + foundMatch = true + break + } + } + + if !foundMatch { + t.Errorf("Span mismatch (-want +got):\n%s", strings.Join(diffs, "\n")) + } + } + }) +} + +//nolint:dupl +func TestNewSentrySQL_NoParentSpan(t *testing.T) { + db, err := sql.Open("sentrysql-sqlite", ":memory:") + if err != nil { + t.Fatalf("opening sqlite: %v", err) + } + db.SetMaxOpenConns(1) + defer db.Close() + + setupQueries := []string{ + "CREATE TABLE exec_test (id INT, name TEXT)", + "CREATE TABLE query_test (id INT, name TEXT, age INT, created_at TEXT)", + "INSERT INTO query_test (id, name, age, created_at) VALUES (1, 'John', 30, '2023-01-01')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (2, 'Jane', 25, '2023-01-02')", + "INSERT INTO query_test (id, name, age, created_at) VALUES (3, 'Bob', 35, '2023-01-03')", + } + + setupCtx, cancelCtx := context.WithTimeout(context.Background(), 30*time.Second) + defer cancelCtx() + + for _, query := range setupQueries { + _, err = db.ExecContext(setupCtx, query) + if err != nil { + t.Fatalf("initializing table on sqlite: %v", err) + } + } + + t.Run("QueryContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "SELECT * FROM query_test WHERE id = ?", + Parameters: []interface{}{1}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + rows, err := db.QueryContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + if rows != nil { + _ = rows.Close() + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) + + t.Run("ExecContext", func(t *testing.T) { + tests := []struct { + Query string + Parameters []interface{} + WantSpan *sentry.Span + WantError bool + }{ + { + Query: "INSERT INTO exec_test (id, name) VALUES (?, ?)", + Parameters: []interface{}{1, "John"}, + WantSpan: nil, + }, + } + + spansCh := make(chan []*sentry.Span, len(tests)) + + sentryClient, err := sentry.NewClient(sentry.ClientOptions{ + EnableTracing: true, + TracesSampleRate: 1.0, + BeforeSendTransaction: func(event *sentry.Event, hint *sentry.EventHint) *sentry.Event { + spansCh <- event.Spans + return event + }, + }) + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + hub := sentry.NewHub(sentryClient, sentry.NewScope()) + ctx, cancel := context.WithTimeout(sentry.SetHubOnContext(context.Background(), hub), 10*time.Second) + + _, err := db.ExecContext(ctx, tt.Query, tt.Parameters...) + if err != nil && !tt.WantError { + cancel() + t.Fatal(err) + } + + cancel() + } + + if ok := sentryClient.Flush(testutils.FlushTimeout()); !ok { + t.Fatal("sentry.Flush timed out") + } + close(spansCh) + + var got [][]*sentry.Span + for e := range spansCh { + got = append(got, e) + } + + // `got` should be empty + if len(got) != 0 { + t.Errorf("got %d spans, want 0", len(got)) + } + }) +} diff --git a/sentrysql/stmt.go b/sentrysql/stmt.go new file mode 100644 index 00000000..2b986b1c --- /dev/null +++ b/sentrysql/stmt.go @@ -0,0 +1,162 @@ +package sentrysql + +import ( + "context" + "database/sql/driver" + "errors" + + "github.com/getsentry/sentry-go" +) + +type sentryStmt struct { + originalStmt driver.Stmt + query string + ctx context.Context + config *sentrySQLConfig +} + +// Make sure sentryStmt implements driver.Stmt interface. +var _ driver.Stmt = (*sentryStmt)(nil) +var _ driver.StmtExecContext = (*sentryStmt)(nil) +var _ driver.StmtQueryContext = (*sentryStmt)(nil) +var _ driver.NamedValueChecker = (*sentryStmt)(nil) + +func (s *sentryStmt) Close() error { + return s.originalStmt.Close() +} + +func (s *sentryStmt) NumInput() int { + return s.originalStmt.NumInput() +} + +//nolint:dupl +func (s *sentryStmt) Exec(args []driver.Value) (driver.Result, error) { + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return s.originalStmt.Exec(args) //nolint:staticcheck // We must support legacy clients + } + + span := parentSpan.StartChild("db.sql.exec", sentry.WithDescription(s.query)) + s.config.SetData(span, s.query) + defer span.Finish() + + result, err := s.originalStmt.Exec(args) //nolint:staticcheck // We must support legacy clients + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + + return result, nil +} + +//nolint:dupl +func (s *sentryStmt) Query(args []driver.Value) (driver.Rows, error) { + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return s.originalStmt.Query(args) //nolint:staticcheck // We must support legacy clients + } + + span := parentSpan.StartChild("db.sql.query", sentry.WithDescription(s.query)) + s.config.SetData(span, s.query) + defer span.Finish() + + rows, err := s.originalStmt.Query(args) //nolint:staticcheck // We must support legacy clients + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +func (s *sentryStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + // should only be executed if the original driver implements StmtExecContext + stmtExecContext, ok := s.originalStmt.(driver.StmtExecContext) + if !ok { + // We may not return driver.ErrSkip. We should fallback to Exec without context. + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + return s.Exec(values) + } + + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return stmtExecContext.ExecContext(ctx, args) + } + + span := parentSpan.StartChild("db.sql.exec", sentry.WithDescription(s.query)) + s.config.SetData(span, s.query) + defer span.Finish() + + result, err := stmtExecContext.ExecContext(ctx, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + + return result, nil +} + +func (s *sentryStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + // should only be executed if the original driver implements StmtQueryContext + stmtQueryContext, ok := s.originalStmt.(driver.StmtQueryContext) + if !ok { + // We may not return driver.ErrSkip. We should fallback to Exec without context. + values, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + return s.Query(values) + } + + parentSpan := sentry.SpanFromContext(s.ctx) + if parentSpan == nil { + return stmtQueryContext.QueryContext(ctx, args) + } + + span := parentSpan.StartChild("db.sql.query", sentry.WithDescription(s.query)) + s.config.SetData(span, s.query) + defer span.Finish() + + rows, err := stmtQueryContext.QueryContext(ctx, args) + if err != nil { + span.Status = sentry.SpanStatusInternalError + return nil, err + } + + span.Status = sentry.SpanStatusOK + return rows, nil +} + +func (s *sentryStmt) CheckNamedValue(namedValue *driver.NamedValue) error { + // It is allowed to return driver.ErrSkip if the original driver does not + // implement driver.NamedValueChecker. + namedValueChecker, ok := s.originalStmt.(driver.NamedValueChecker) + if !ok { + return driver.ErrSkip + } + + return namedValueChecker.CheckNamedValue(namedValue) +} + +// namedValueToValue is an exact copy of +// https://cs.opensource.google/go/go/+/refs/tags/go1.23.2:src/database/sql/ctxutil.go;l=137-146 +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/sentrysql/tx.go b/sentrysql/tx.go new file mode 100644 index 00000000..51478669 --- /dev/null +++ b/sentrysql/tx.go @@ -0,0 +1,22 @@ +package sentrysql + +import ( + "context" + "database/sql/driver" +) + +type sentryTx struct { + originalTx driver.Tx + ctx context.Context + config *sentrySQLConfig +} + +// Commit implements driver.Tx. +func (s *sentryTx) Commit() error { + return s.originalTx.Commit() +} + +// Rollback implements driver.Tx. +func (s *sentryTx) Rollback() error { + return s.originalTx.Rollback() +}