From 646e8cdd2630aed13c72765e7aed49c189dcea3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20Sza=C5=82omski?= Date: Sat, 28 May 2022 23:32:32 +0200 Subject: [PATCH] #154: Implement `renderUpdate` for Oracle --- .../main/scala/zio/sql/driver/Renderer.scala | 3 +- .../zio/sql/oracle/OracleRenderModule.scala | 29 ++++++++++--------- .../zio/sql/oracle/OracleSqlModuleSpec.scala | 23 +++++++++++++-- .../scala/zio/sql/oracle/ShopSchema.scala | 8 +++-- ...Spec.scala => PostgresSqlModuleSpec.scala} | 2 +- .../sql/sqlserver/SqlServerRenderModule.scala | 3 +- 6 files changed, 46 insertions(+), 22 deletions(-) rename postgres/src/test/scala/zio/sql/postgresql/{PostgresModuleSpec.scala => PostgresSqlModuleSpec.scala} (99%) diff --git a/driver/src/main/scala/zio/sql/driver/Renderer.scala b/driver/src/main/scala/zio/sql/driver/Renderer.scala index f616001b8..a0e04821a 100644 --- a/driver/src/main/scala/zio/sql/driver/Renderer.scala +++ b/driver/src/main/scala/zio/sql/driver/Renderer.scala @@ -22,6 +22,7 @@ private[sql] object Renderer { def apply(): Renderer = new Renderer(new StringBuilder) implicit class Extensions(val value: String) { - def quoted: String = s"\"$value\"" + def doubleQuoted: String = s"\"$value\"" + def singleQuoted: String = s"'$value'" } } diff --git a/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala b/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala index 9137a1e35..37a64dfbc 100644 --- a/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala +++ b/oracle/src/main/scala/zio/sql/oracle/OracleRenderModule.scala @@ -4,6 +4,8 @@ import zio.schema.Schema import zio.sql.driver.Renderer import zio.sql.driver.Renderer.Extensions +import scala.collection.mutable + trait OracleRenderModule extends OracleSqlModule { self => override def renderDelete(delete: self.Delete[_]): String = { @@ -56,16 +58,20 @@ trait OracleRenderModule extends OracleSqlModule { self => case Expr.Relational(left, right, op) => buildExpr(left, builder) builder.append(" ").append(op.symbol).append(" ") - buildExpr(right, builder) + right.asInstanceOf[Expr[_, A, B]] match { + case Expr.Literal(true) => val _ = builder.append("1") + case Expr.Literal(false) => val _ = builder.append("0") + case otherValue => buildExpr(otherValue, builder) + } case Expr.In(value, set) => buildExpr(value, builder) buildReadString(set, builder) case Expr.Literal(true) => - val _ = builder.append("1") + val _ = builder.append("1 = 1") case Expr.Literal(false) => - val _ = builder.append("0") + val _ = builder.append("0 = 1") case Expr.Literal(value) => - val _ = builder.append(value.toString) // todo fix escaping + val _ = builder.append(value.toString.singleQuoted) case Expr.AggregationCall(param, aggregation) => builder.append(aggregation.name.name) builder.append("(") @@ -320,7 +326,7 @@ trait OracleRenderModule extends OracleSqlModule { self => val _ = builder.append(" ") } - private def buildDeleteString(delete: Delete[_], builder: StringBuilder) = { + private def buildDeleteString(delete: Delete[_], builder: mutable.StringBuilder): Unit = { builder.append("DELETE FROM ") buildTable(delete.table, builder) delete.whereExpr match { @@ -331,7 +337,6 @@ trait OracleRenderModule extends OracleSqlModule { self => } } - private[oracle] object OracleRender { def renderUpdateImpl(update: Update[_])(implicit render: Renderer): Unit = @@ -362,14 +367,12 @@ trait OracleRenderModule extends OracleSqlModule { self => private[zio] def renderSetLhs[A, B](expr: self.Expr[_, A, B])(implicit render: Renderer): Unit = expr match { - // TODO: to check if Oracle allows for `tableName.columnName = value` format in update statement, - // or it requires `columnName = value` instead? - case Expr.Source(_, column) => - column.name match { - case Some(columnName) => render(columnName.quoted) - case _ => () + case Expr.Source(table, column) => + (table, column.name) match { + case (tableName, Some(columnName)) => val _ = render(tableName, ".", columnName) + case _ => () } - case _ => () + case _ => () } } } diff --git a/oracle/src/test/scala/zio/sql/oracle/OracleSqlModuleSpec.scala b/oracle/src/test/scala/zio/sql/oracle/OracleSqlModuleSpec.scala index 5d84fa831..89c4da84a 100644 --- a/oracle/src/test/scala/zio/sql/oracle/OracleSqlModuleSpec.scala +++ b/oracle/src/test/scala/zio/sql/oracle/OracleSqlModuleSpec.scala @@ -10,10 +10,25 @@ object OracleSqlModuleSpec extends OracleRunnableSpec with ShopSchema { import Customers._ - override def specLayered = suite("Postgres module")( + override def specLayered: Spec[SqlDriver, Exception] = suite("Oracle module")( + test("Can update rows") { + /** + * UPDATE customers SET customers.first_name = 'Jaroslav' + * WHERE 1 = 1 and customers.verified = 0 and customers.verified <> 1 + */ + val query = + update(customers) + .set(fName, "Jaroslav") + .where(verified isNotTrue) + .where(verified <> true) // we intentionally verify two syntax variants + + assertZIO(execute(query))(equalTo(1)) + }, test("Can delete from single table with a condition") { + /** + * DELETE FROM customers WHERE customers.verified = 0 + */ val query = deleteFrom(customers) where (verified isNotTrue) - println(renderDelete(query)) val expected = 1 val result = execute(query) @@ -21,8 +36,10 @@ object OracleSqlModuleSpec extends OracleRunnableSpec with ShopSchema { assertZIO(result)(equalTo(expected)) }, test("Can delete all from a single table") { + /** + * DELETE FROM customers + */ val query = deleteFrom(customers) - println(renderDelete(query)) val expected = 4 val result = execute(query) diff --git a/oracle/src/test/scala/zio/sql/oracle/ShopSchema.scala b/oracle/src/test/scala/zio/sql/oracle/ShopSchema.scala index 52076f321..177831047 100644 --- a/oracle/src/test/scala/zio/sql/oracle/ShopSchema.scala +++ b/oracle/src/test/scala/zio/sql/oracle/ShopSchema.scala @@ -5,12 +5,14 @@ import zio.sql.Jdbc trait ShopSchema extends Jdbc { self => import self.ColumnSet._ - object Customers { + object Customers { + val customers = - (uuid("id") ++ localDate("dob") ++ string("first_name") ++ string("last_name") ++ boolean("verified")) + (uuid("id") ++ localDate("dob") ++ string("first_name") ++ string("last_name") ++ + boolean("verified") ++ zonedDateTime("Created_timestamp")) .table("customers") - val (customerId, dob, fName, lName, verified) = customers.columns + val (customerId, dob, fName, lName, verified, createdTimestamp) = customers.columns } object Orders { val orders = (uuid("id") ++ uuid("customer_id") ++ localDate("order_date")).table("orders") diff --git a/postgres/src/test/scala/zio/sql/postgresql/PostgresModuleSpec.scala b/postgres/src/test/scala/zio/sql/postgresql/PostgresSqlModuleSpec.scala similarity index 99% rename from postgres/src/test/scala/zio/sql/postgresql/PostgresModuleSpec.scala rename to postgres/src/test/scala/zio/sql/postgresql/PostgresSqlModuleSpec.scala index 2e6e46842..eb78ed3f1 100644 --- a/postgres/src/test/scala/zio/sql/postgresql/PostgresModuleSpec.scala +++ b/postgres/src/test/scala/zio/sql/postgresql/PostgresSqlModuleSpec.scala @@ -11,7 +11,7 @@ import scala.language.postfixOps import zio.schema.Schema import java.time.format.DateTimeFormatter -object PostgresModuleSpec extends PostgresRunnableSpec with DbSchema { +object PostgresSqlModuleSpec extends PostgresRunnableSpec with DbSchema { import AggregationDef._ import Customers._ diff --git a/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala b/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala index 30cccdd67..29cf49d5c 100644 --- a/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala +++ b/sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerRenderModule.scala @@ -2,6 +2,7 @@ package zio.sql.sqlserver import zio.schema.Schema import zio.sql.driver.Renderer +import zio.sql.driver.Renderer.Extensions trait SqlServerRenderModule extends SqlServerSqlModule { self => @@ -153,7 +154,7 @@ trait SqlServerRenderModule extends SqlServerSqlModule { self => .asInstanceOf[java.time.OffsetDateTime] .format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss")) s"'$x'" - case _ => s"'${value.toString}'" + case _ => value.toString.singleQuoted } render(lit) case Expr.AggregationCall(param, aggregation) =>