Skip to content

Commit

Permalink
added macros to verify having and where
Browse files Browse the repository at this point in the history
  • Loading branch information
sviezypan committed Jan 6, 2023
1 parent 1340a4f commit 16091f9
Show file tree
Hide file tree
Showing 10 changed files with 269 additions and 190 deletions.
34 changes: 5 additions & 29 deletions core/jvm/src/main/scala/zio/sql/select.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package zio.sql

import zio.sql.Features._
import zio.sql.macros.GroupByLike
import zio.sql.macros._
import scala.language.implicitConversions

trait SelectModule { self: ExprModule with TableModule with UtilsModule =>
Expand Down Expand Up @@ -181,24 +181,6 @@ trait SelectModule { self: ExprModule with TableModule with UtilsModule =>
)
}

/**
* `HAVING` can only be called:
*
* 1. If its called with an aggregated function returning boolean like `Having Count(id) > 5`,
* while all the previously selected columns appeared in group by clause.
* 2. If its called with a normal expression returning boolean like `having customer_id = '636ae137-5b1a-4c8c-b11f-c47c624d9cdc``
* and all the previously selected columns appeared in group by clause.
*/
//TODO replace with macro
sealed trait HavingIsSound[F, GroupByF]

object HavingIsSound {
implicit def havingWasGroupedBy[F, GroupByF, Remainder](implicit
i: Features.IsPartiallyAggregated.WithRemainder[F, Remainder],
ev: GroupByF <:< Remainder
): HavingIsSound[F, GroupByF] = new HavingIsSound[F, GroupByF] {}
}

type Select[F, Repr, Source, Head, Tail <: SelectionSet[Source]] = Subselect[F, Repr, Source, Source, Head, Tail]

sealed case class Subselect[F, Repr, Source, Subsource, Head, Tail <: SelectionSet[Source]](
Expand All @@ -214,13 +196,9 @@ trait SelectModule { self: ExprModule with TableModule with UtilsModule =>

type GroupByF <: Any

// TODO FIX
// F2 should be not aggregated ->
// we cannot call where by partial aggregation on F
// however we can call where by full aggregation -> select Count(id) where x = ''
def where[F2](
whereExpr2: Expr[F2, Source, Boolean]
): Subselect.WithGroupByF[F, Repr, Source, Subsource, Head, Tail, self.GroupByF] =
)(implicit ev: WhereIsSound[F2, self.GroupByF]): Subselect.WithGroupByF[F, Repr, Source, Subsource, Head, Tail, self.GroupByF] =
new Subselect(selection, table, self.whereExpr && whereExpr2, groupByExprs, havingExpr, orderByExprs, offset, limit) {
override type GroupByF = self.GroupByF
}
Expand All @@ -243,12 +221,10 @@ trait SelectModule { self: ExprModule with TableModule with UtilsModule =>
override type GroupByF = self.GroupByF
}

def having[F2, Remainder](
def having[F2](
havingExpr2: Expr[F2, Source, Boolean]
)(implicit
i: Features.IsPartiallyAggregated.WithRemainder[F, Remainder],
ev: GroupByF <:< Remainder,
i2: HavingIsSound[F2, GroupByF]
ev: HavingIsSound[F, self.GroupByF, F2]
): Subselect.WithGroupByF[F, Repr, Source, Subsource, Head, Tail, self.GroupByF] =
new Subselect(selection, table, whereExpr, groupByExprs, self.havingExpr && havingExpr2, orderByExprs, offset, limit) {
override type GroupByF = self.GroupByF
Expand Down Expand Up @@ -285,7 +261,7 @@ trait SelectModule { self: ExprModule with TableModule with UtilsModule =>
override type GroupByF = self.GroupByF with F1 with F2 with F3 with F4 with F5 with F6
}

//TODO add arities up to 22 when needed
//TODO add arities up to 22 if needed
def groupBy[F1, F2, F3, F4, F5, F6, F7](expr1: Expr[F1, Source, Any], expr2: Expr[F2, Source, Any], expr3: Expr[F3, Source, Any], expr4: Expr[F4, Source, Any], expr5: Expr[F5, Source, Any], expr6: Expr[F6, Source, Any], expr7: Expr[F7, Source, Any])(implicit verify: GroupByLike[F, F1 with F2 with F3 with F4 with F5 with F6 with F7]): Subselect.WithGroupByF[F, Repr, Source, Subsource, Head, Tail, self.GroupByF with F1 with F2 with F3 with F4 with F5 with F6 with F7] =
new Subselect(selection, table, whereExpr, self.groupByExprs ++ expr1 ++ expr2 ++ expr3 ++ expr4 ++ expr5 ++ expr6 ++ expr7, havingExpr, orderByExprs, offset, limit) {
override type GroupByF = self.GroupByF with F1 with F2 with F3 with F4 with F5 with F6 with F7
Expand Down
47 changes: 33 additions & 14 deletions examples/src/main/scala/zio/sql/GroupByExamples.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,14 @@ object GroupByExamples extends App with PostgresJdbcModule {

val e = Sum(price) > 10

def testF[F, A, B](value: Expr[F, A, B])(implicit in: Features.IsFullyAggregated[F]) = ???

def test2[F, A, B](value: Expr[F, A, B])(implicit i: Features.IsPartiallyAggregated[F]): i.Unaggregated = ???

val orderValue = select(name, Sum(price))
.from(productTable)
.groupBy(name, price)
.having(Sum(price) > 10)

execute(orderValue)

//this
select(Sum(price))
.from(productTable)
.groupBy(name)
Expand All @@ -45,8 +42,8 @@ object GroupByExamples extends App with PostgresJdbcModule {
.from(productTable)
.groupBy(amount)
.having(amount > 10)
.where(amount > 10)

// this
select(Sum(price))
.from(productTable)
.groupBy(name)
Expand Down Expand Up @@ -75,16 +72,38 @@ object GroupByExamples extends App with PostgresJdbcModule {

//execute(select(name, Sum(price)).from(productTable))


// TODO better error message by having + remove isPartialAggregation
// select(price)
// .from(productTable)
// .having(Count(price) > 10)
select(price)
.from(productTable)
.groupBy(price)
.having(Count(price) > 10)

//TODO make the following not to compile
select(amount)
select(Sum(price))
.from(productTable)
.having(Sum(price) > 10)

select(price)
.from(productTable)
.groupBy(price, amount)
.having(amount > 200)

select(amount)
.from(productTable)
.groupBy(amount)
.having(amount > 10)
.where(amount > 10)
.having(Sum(price) > 200)


// select(price)
// .from(productTable)
// .groupBy(price)
// .having(amount > 200)

// select(amount)
// .from(productTable)
// .having(Sum(price) > 200)

// select(amount)
// .from(productTable)
// .groupBy(amount)
// .having(amount > 10)
// .where(amount > 10)
}
98 changes: 46 additions & 52 deletions macros/src/main/scala-2/zio/sql/macros/groupbylike.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,26 @@ package zio.sql.macros
import scala.reflect.macros.blackbox
import scala.language.experimental.macros

sealed trait GroupByLike[All, Grouped]
/**
* select Count(id)
from orders
group by customer_id
select customer_id
from orders
select Count(id)
from orders
select customer_id
from orders
group by customer_id
select customer_id, Count(id)
from orders
group by customer_id
*/
sealed trait GroupByLike[All, Grouped]

object GroupByLike {

Expand All @@ -16,82 +35,52 @@ object GroupByLike {
): c.Expr[GroupByLike[All, Grouped]] = {
import c.universe._

val allType = weakTypeOf[All]
val allType = weakTypeOf[All]
val groupedType = weakTypeOf[Grouped]

def splitIntersection(t: Type): List[Type] =
t.dealias match {
case t: RefinedType =>
t.parents.flatMap(s => splitIntersection(s))
case TypeRef(_, sym, _) if sym.info.isInstanceOf[RefinedTypeApi] =>
splitIntersection(sym.info)
case t: TypeRef => {
case t: TypeRef =>
t.args.headOption match {
case Some(value) => List(value.dealias)
case None => Nil
}
}
case _ => Nil
}

def isThereAggregation(t: Type): Boolean =
def isThereAggregation(t: Type): Boolean =
t.dealias match {
case TypeRef(_, typeSymbol, args) if typeSymbol == symbolOf[zio.sql.Features.Union[_, _]] =>
case TypeRef(_, typeSymbol, args) if typeSymbol == symbolOf[zio.sql.Features.Union[_, _]] =>
args.find(t => isThereAggregation(t)) match {
case None => false
case None => false
case Some(_) => true
}
case TypeRef(_, typeSymbol, _) if typeSymbol == symbolOf[zio.sql.Features.Aggregated[_]] =>
true
case _ => false
case TypeRef(_, typeSymbol, _) if typeSymbol == symbolOf[zio.sql.Features.Aggregated[_]] =>
true
case _ => false
}



def extractFromFeatures(f: Type): List[Type] =
f.dealias match {
case TypeRef(_, typeSymbol, args) if typeSymbol == symbolOf[zio.sql.Features.Source[_, _]] =>
List(args.head.dealias)
case TypeRef(_, typeSymbol, args) if typeSymbol == symbolOf[zio.sql.Features.Union[_, _]] =>
args.flatMap(f => extractFromFeatures(f))
case _ =>
case _ =>
Nil
}

// c.info(c.enclosingPosition,
// s"F -> ${allType} \nGrouped -> ${groupedType.dealias} \nNot aggregated F -> ${notAggregatedF} \nGroupedByF -> ${groupedByF}",
// true)
// EXAMPLE
// select(name, Sum(price))
// .from(productTable)
// .groupBy(name, price)

/**
* TO SUPPORT
*
select Count(id)
from orders
group by customer_id
select customer_id
from orders
select Count(id)
from orders
select customer_id
from orders
group by customer_id
select customer_id, Count(id)
from orders
group by customer_id
EXAMPLE
select(name, Sum(price))
.from(productTable)
.groupBy(name, price)
*/

// name & price
val groupedByF = splitIntersection(groupedType)
// name & price
val groupedByF = splitIntersection(groupedType)

// name
val notAggregatedF = extractFromFeatures(allType)
Expand All @@ -106,26 +95,31 @@ object GroupByLike {
val partialAggregation = aggregateFunctionExists && !notAggregatedF.isEmpty

// price
//val _ = groupedByF diff notAggregatedF
// val _ = groupedByF diff notAggregatedF

// nil
// Nil
val missing = notAggregatedF diff groupedByF

// group by not called
if (groupedByF.isEmpty) {
if (partialAggregation) {
c.abort(c.enclosingPosition, s"Column(s) ${missing.distinct.mkString(" and ")} must appear in the GROUP BY clause or be used in an aggregate function")
c.abort(
c.enclosingPosition,
s"Column(s) ${missing.distinct.mkString(" and ")} must appear in the GROUP BY clause or be used in an aggregate function"
)
} else {
result
}
// group by called
} else {
if (!missing.isEmpty) {
c.abort(c.enclosingPosition, s"Column(s) ${missing.distinct.mkString(" and ")} must appear in the GROUP BY clause or be used in an aggregate function")
c.abort(
c.enclosingPosition,
s"Column(s) ${missing.distinct.mkString(" and ")} must appear in the GROUP BY clause or be used in an aggregate function"
)
} else {
result
}
}
}

}
Loading

0 comments on commit 16091f9

Please sign in to comment.