Skip to content

Commit

Permalink
fixed tests for mssql server and added test for correlated subquery i…
Browse files Browse the repository at this point in the history
…n where clause
  • Loading branch information
sviezypan committed Nov 14, 2021
1 parent 709be62 commit ea19a1b
Show file tree
Hide file tree
Showing 12 changed files with 223 additions and 125 deletions.
10 changes: 6 additions & 4 deletions core/jvm/src/main/scala/zio/sql/expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,14 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule {
): Selection[F, A, SelectionSet.Cons[A, B, SelectionSet.Empty]] =
Selection.computedOption(expr, exprName(expr))

//TODO needed by suqueries in where clauses
implicit def selectionToExpr[F, A, B](subselect: Read.Subselect[F, _, A, A, B, SelectionSet.Empty]): Expr[F, A, B] =
implicit def selectionToExpr[F, Repr, Source, Subsource, Head](
subselect: Read.Subselect[F, Repr, _ <: Source, Subsource, Head, SelectionSet.Empty]
): Expr[F, Source, Head] =
Expr.Subselect(subselect)

sealed case class Subselect[F, A, B](subselect: Read.Subselect[F, _, A, A, B, SelectionSet.Empty])
extends Expr[F, A, B]
sealed case class Subselect[F, Repr, Source, Subsource, Head](
subselect: Read.Subselect[F, Repr, _ <: Source, Subsource, Head, SelectionSet.Empty]
) extends Expr[F, Source, Head]

sealed case class Source[A, B] private[sql] (tableName: TableName, column: Column[B])
extends InvariantExpr[Features.Source, A, B] {
Expand Down
2 changes: 1 addition & 1 deletion core/jvm/src/main/scala/zio/sql/newtypes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ trait NewtypesModule {

type ColumnName = String
//TODO we could use zio-prelude new types
type TableName = String
type TableName = String

sealed case class FunctionName(name: String)
}
2 changes: 0 additions & 2 deletions core/jvm/src/main/scala/zio/sql/select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ trait SelectModule { self: ExprModule with TableModule =>
*/
sealed trait Read[+Out] { self =>
type ResultType
type DerivedTableType

val mapper: ResultType => Out

Expand Down Expand Up @@ -327,7 +326,6 @@ trait SelectModule { self: ExprModule with TableModule =>
limit: Option[Long] = None
) extends Read[Repr] { self =>

//TODO check if I need copy(whereExpr = self.whereExpr && whereExpr2)
def where(whereExpr2: Expr[_, Source, Boolean]): Subselect[F, Repr, Source, Subsource, Head, Tail] =
copy(whereExpr = whereExpr2)

Expand Down
2 changes: 1 addition & 1 deletion core/jvm/src/main/scala/zio/sql/table.scala
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ trait TableModule { self: ExprModule with SelectModule =>

sealed case class DerivedTable[+R <: Read[_]](read: R, name: TableName) extends Table { self =>

override type TableType = read.DerivedTableType
//override type TableType = read.DerivedTableType

override type ColumnHead = read.ColumnHead
override type ColumnTail = read.ColumnTail
Expand Down
4 changes: 2 additions & 2 deletions jdbc/src/main/scala/zio/sql/JdbcInternalModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ trait JdbcInternalModule { self: Jdbc =>
case t @ ColumnSelection.Constant(_, _) => t.typeTag
case t @ ColumnSelection.Computed(_, _) => t.typeTag
}
case Read.Union(left, _, _) => getColumns(left)
case v @ Read.Literal(_) => scala.collection.immutable.Vector(v.typeTag)
case Read.Union(left, _, _) => getColumns(left)
case v @ Read.Literal(_) => scala.collection.immutable.Vector(v.typeTag)
}

private[sql] def unsafeExtractRow[A](
Expand Down
18 changes: 9 additions & 9 deletions mysql/src/test/scala/zio/sql/mysql/FunctionDefSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,19 @@ object FunctionDefSpec extends MysqlRunnableSpec with ShopSchema {

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
testM("bit_length") {
val query = select(BitLength("hello"))
testM("bit_length") {
val query = select(BitLength("hello"))

val expected = 40
val expected = 40

val testResult = execute(query.to[Int, Int](identity))
val testResult = execute(query.to[Int, Int](identity))

val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))
val assertion = for {
r <- testResult.runCollect
} yield assert(r.head)(equalTo(expected))

assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
assertion.mapErrorCause(cause => Cause.stackless(cause.untraced))
},
testM("pi") {
val query = select(Pi) from customers

Expand Down
31 changes: 15 additions & 16 deletions oracle/src/main/scala/zio/sql/oracle/OracleModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ trait OracleModule extends Jdbc { self =>
}

def buildExpr[A, B](expr: self.Expr[_, A, B], builder: StringBuilder): Unit = expr match {
case Expr.Subselect(subselect) =>
case Expr.Source(table, column) => {
(table, column) match {
case (tableName: TableName, Column.Named(columnName)) =>
val _ = builder.append(tableName).append(".").append(columnName)
case _ => ()
}
case Expr.Subselect(subselect) =>
case Expr.Source(table, column) =>
(table, column) match {
case (tableName: TableName, Column.Named(columnName)) =>
val _ = builder.append(tableName).append(".").append(columnName)
case _ => ()
}
case Expr.Unary(base, op) =>
val _ = builder.append(" ").append(op.symbol)
Expand Down Expand Up @@ -132,12 +131,12 @@ trait OracleModule extends Jdbc { self =>

case read0 @ Read.Subselect(_, _, _, _, _, _, _, _) =>
object Dummy {
type F
type Repr
type Source
type Head
type Tail <: SelectionSet[Source]
}
type F
type Repr
type Source
type Head
type Tail <: SelectionSet[Source]
}
val read = read0.asInstanceOf[Read.Select[Dummy.F, Dummy.Repr, Dummy.Source, Dummy.Head, Dummy.Tail]]
import read._

Expand Down Expand Up @@ -269,10 +268,10 @@ trait OracleModule extends Jdbc { self =>
table match {
case Table.DialectSpecificTable(tableExtension) => ???
//The outer reference in this type test cannot be checked at run time?!
case sourceTable: self.Table.Source =>
case sourceTable: self.Table.Source =>
val _ = builder.append(sourceTable.name)
case Table.DerivedTable(read, name) => ???
case Table.Joined(joinType, left, right, on) =>
case Table.DerivedTable(read, name) => ???
case Table.Joined(joinType, left, right, on) =>
buildTable(left, builder)
builder.append(joinType match {
case JoinType.Inner => " INNER JOIN "
Expand Down
120 changes: 60 additions & 60 deletions sqlserver/src/main/scala/zio/sql/sqlserver/SqlServerModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ trait SqlServerModule extends Jdbc { self =>
override type ColumnTail =
left.columnSet.tail.Append[ColumnSet.Cons[right.ColumnHead, right.ColumnTail]]

override val columnSet: ColumnSet.Cons[ColumnHead, ColumnTail] =
override val columnSet: ColumnSet.Cons[ColumnHead, ColumnTail] =
left.columnSet ++ right.columnSet

override val columnToExpr: ColumnToExpr[A with B] = new ColumnToExpr[A with B] {
def toExpr[C](column: Column[C]): Expr[Features.Source, A with B, C] =
def toExpr[C](column: Column[C]): Expr[Features.Source, A with B, C] =
if (left.columnSet.contains(column))
left.columnToExpr.toExpr(column)
else
Expand Down Expand Up @@ -92,7 +92,8 @@ trait SqlServerModule extends Jdbc { self =>

val orderId :*: fkCustomerId :*: orderDate :*: _ = orders.columns

val orderDetails = (uuid("orderId") ++ uuid("product_id") ++ int("quantity") ++ double("unit_price")).table("order_details")
val orderDetails =
(uuid("orderId") ++ uuid("product_id") ++ int("quantity") ++ double("unit_price")).table("order_details")

val orderDetailsId :*: productId :*: quantity :*: unitPrice :*: _ = orderDetails.columns

Expand All @@ -101,7 +102,7 @@ trait SqlServerModule extends Jdbc { self =>
val derived =
select(customerId ++ fName).from(customers).asTable("derived")

val derivedId :*: derivedName :*: _ = derived.columns
val derivedId :*: derivedName :*: _ = derived.columns

//AS TABLE example
val e = select(fName ++ lName).from(customers).asTable("derived")
Expand All @@ -124,44 +125,36 @@ trait SqlServerModule extends Jdbc { self =>

val orderDateColumn :*: _ = ordersTable.columns

/*
select customers.id, customers.first_name, customers.last_name, ooo.order_date, ooo.id
from
customers
cross apply (
select order_date
from orders
where orders.customer_id = customers.id
) ooo
*/

// Cross Apply example
import SqlServerTable._

import SqlServerTable._

/*TODO
1. which one do we need ?
* table.subquery(select)
* subselect[TableType](select)
* subselectFrom(parentTable)(query)
* table.subquery(select)
* subselect[TableType](select)
* subselectFrom(parentTable)(query)
2. rename those subqueries to correlated subqueries, add suppost for normal subqueries (ones which does not access parent table)
3. add support for correlated subquery in where clause when accessing the same table
4. correlated subqueries in selections / where clauses
5. translate DerivedTable also for postgres, Oracle, Mysql
6. add test for outer apply and real cross apply
*/
*/

select(customerId ++ fName ++ lName)
.from(customers
.crossApply(
subselect[customers.TableType](orderDate).from(orders).where(customerId === fkCustomerId).asTable("derived")
)
.from(
customers
.crossApply(
subselect[customers.TableType](orderDate)
.from(orders)
.where(customerId === fkCustomerId)
.asTable("derived")
)
)

val newOrdersTable = subselect[customers.TableType](orderDate).from(orders).where(customerId === fkCustomerId).asTable("derived")
val newOrdersTable =
subselect[customers.TableType](orderDate).from(orders).where(customerId === fkCustomerId).asTable("derived")

val localdate :*: _ = newOrdersTable.columns

val crossApplyExample2 = select(customerId ++ fName ++ lName ++ localdate)
.from(
customers
Expand All @@ -186,7 +179,7 @@ trait SqlServerModule extends Jdbc { self =>
)
)

val newtable =
val newtable =
customers
.subselect(customerId ++ fName ++ lName ++ orderDate)
.from(orders)
Expand All @@ -198,7 +191,6 @@ trait SqlServerModule extends Jdbc { self =>
val crossApplyExample4 = select(dCustomerId ++ dfName ++ dflName ++ dOrderDate)
.from(newtable)


// //JOIN example
val joinQuery = select(fName ++ lName ++ orderDate).from(customers.join(orders).on(customerId === fkCustomerId))

Expand All @@ -212,20 +204,6 @@ trait SqlServerModule extends Jdbc { self =>
val qqqqq =
subselect[customers.TableType](orderDate).from(orders).where(customerId === fkCustomerId).asTable("ooo")


//TODO
// // ========= CORRELATED SUBQUERY

// // select order_id, product_id, unit_price from order_details od
// // where unit_price > (select avg(unit_price) from order_details where od.product_id = product_id)

// val customUUID = java.util.UUID.fromString("48ce2e6e-7258-413f-942f-01eb21acd979")

// val re = select(AggregationDef.Avg(unitPrice)).from(orderDetails).where(productId === customUUID)

// select(orderDetailsId ++ productId ++ unitPrice)
// .from(orderDetails)
// .where(unitPrice > select(AggregationDef.Avg(unitPrice)).from(orderDetails).where(productId === ???))
}
}

Expand All @@ -237,7 +215,10 @@ trait SqlServerModule extends Jdbc { self =>
val builder = new StringBuilder

def buildExpr[A, B](expr: self.Expr[_, A, B]): Unit = expr match {
case Expr.Subselect(subselect) => ???
case Expr.Subselect(subselect) =>
builder.append(" (")
builder.append(renderRead(subselect))
val _ = builder.append(") ")
case Expr.Source(table, column) =>
(table, column) match {
case (tableName: TableName, Column.Named(columnName)) =>
Expand All @@ -249,7 +230,13 @@ trait SqlServerModule extends Jdbc { self =>
buildExpr(base)
case Expr.Property(base, op) =>
buildExpr(base)
val _ = builder.append(" ").append(op.symbol)
val symbol = op match {
case PropertyOp.IsNull => "is null"
case PropertyOp.IsNotNull => "is not null"
case PropertyOp.IsTrue => "= 1"
case PropertyOp.IsNotTrue => "= 0"
}
val _ = builder.append(" ").append(symbol)
case Expr.Binary(left, right, op) =>
buildExpr(left)
builder.append(" ").append(op.symbol).append(" ")
Expand All @@ -261,8 +248,23 @@ trait SqlServerModule extends Jdbc { self =>
case Expr.In(value, set) =>
buildExpr(value)
buildReadString(set)
case Expr.Literal(value) =>
val _ = builder.append(value.toString) //todo fix escaping
case literal @ Expr.Literal(value) =>
val lit = literal.typeTag match {
case TypeTag.TLocalDateTime =>
value
.asInstanceOf[java.time.LocalDateTime]
.format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss"))
case TypeTag.TZonedDateTime =>
value
.asInstanceOf[java.time.ZonedDateTime]
.format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss"))
case TypeTag.TOffsetDateTime =>
value
.asInstanceOf[java.time.OffsetDateTime]
.format(java.time.format.DateTimeFormatter.ofPattern("YYYY-MM-dd HH:mm:ss"))
case _ => value.toString
}
val _ = builder.append(s"'${lit}'")
case Expr.AggregationCall(param, aggregation) =>
builder.append(aggregation.name.name)
builder.append("(")
Expand Down Expand Up @@ -489,28 +491,26 @@ trait SqlServerModule extends Jdbc { self =>
def buildTable(table: Table): Unit =
table match {

case Table.DerivedTable(read, name) => {
builder.append(renderRead(read))
case Table.DerivedTable(read, name) =>
builder.append(" ( ")
builder.append(renderRead(read))
builder.append(" ) ")
val _ = builder.append(name)

builder.append(" ) ")
val _ = builder.append(name)
}

case sourceTable: self.Table.Source =>
case sourceTable: self.Table.Source =>
val _ = builder.append(sourceTable.name)

case Table.DialectSpecificTable(tableExtension) =>
tableExtension match {
case SqlServerSpecific.SqlServerTable.CrossOuterApplyTable(crossType, left, derivedTable) => {
case SqlServerSpecific.SqlServerTable.CrossOuterApplyTable(crossType, left, derivedTable) =>
buildTable(left)

crossType match {
case SqlServerSpecific.SqlServerTable.CrossType.CrossApply => builder.append(" cross apply ( ")
case SqlServerSpecific.SqlServerTable.CrossType.OuterApply => builder.append(" outer apply ( ")
case SqlServerSpecific.SqlServerTable.CrossType.CrossApply => builder.append(" cross apply ")
case SqlServerSpecific.SqlServerTable.CrossType.OuterApply => builder.append(" outer apply ")
}

val _ = buildTable(derivedTable)
}
}

case Table.Joined(joinType, left, right, on) =>
Expand Down
10 changes: 5 additions & 5 deletions sqlserver/src/test/resources/db_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,11 @@ create table order_details
insert into customers
(id, first_name, last_name, verified, dob)
values
('60b01fc9-c902-4468-8d49-3c0f989def37', 'Ronald', 'Russell', 0, '1983-01-05'),
('f76c9ace-be07-4bf3-bd4c-4a9c62882e64', 'Terrence', 'Noel', 0, '1999-11-02'),
('784426a5-b90a-4759-afbb-571b7a0ba35e', 'Mila', 'Paterso', 0, '1990-11-16'),
('df8215a2-d5fd-4c6c-9984-801a1b3a2a0b', 'Alana', 'Murray', 0, '1995-11-12'),
('636ae137-5b1a-4c8c-b11f-c47c624d9cdc', 'Jose', 'Wiggins', 1, '1987-03-23');
('60b01fc9-c902-4468-8d49-3c0f989def37', 'Ronald', 'Russell', 1, '1983-01-05'),
('f76c9ace-be07-4bf3-bd4c-4a9c62882e64', 'Terrence', 'Noel', 1, '1999-11-02'),
('784426a5-b90a-4759-afbb-571b7a0ba35e', 'Mila', 'Paterso', 1, '1990-11-16'),
('df8215a2-d5fd-4c6c-9984-801a1b3a2a0b', 'Alana', 'Murray', 1, '1995-11-12'),
('636ae137-5b1a-4c8c-b11f-c47c624d9cdc', 'Jose', 'Wiggins', 0, '1987-03-23');

insert into orders
(id, customer_id, order_date)
Expand Down
4 changes: 3 additions & 1 deletion sqlserver/src/test/scala/zio/sql/TestContainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ object TestContainer {
}
}(container => effectBlocking(container.stop()).orDie).toLayer

def postgres(imageName: String = "mcr.microsoft.com/mssql/server:2017-latest"): ZLayer[Blocking, Throwable, Has[MSSQLServerContainer]] =
def postgres(
imageName: String = "mcr.microsoft.com/mssql/server:2017-latest"
): ZLayer[Blocking, Throwable, Has[MSSQLServerContainer]] =
ZManaged.make {
effectBlocking {
val c = new MSSQLServerContainer(
Expand Down
Loading

0 comments on commit ea19a1b

Please sign in to comment.