Skip to content

Commit

Permalink
Date type with user location (#923)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkaflik committed Mar 3, 2023
1 parent 31d148f commit 08b2788
Show file tree
Hide file tree
Showing 15 changed files with 217 additions and 31 deletions.
11 changes: 9 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ func (c *connect) sendData(block *proto.Block, name string) error {
return nil
}

func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error) {
func (c *connect) readData(ctx context.Context, packet byte, compressible bool) (*proto.Block, error) {
if _, err := c.reader.Str(); err != nil {
c.debugf("[read data] str error: %v", err)
return nil, err
Expand All @@ -260,7 +260,14 @@ func (c *connect) readData(packet byte, compressible bool) (*proto.Block, error)
c.reader.EnableCompression()
defer c.reader.DisableCompression()
}
block := proto.Block{Timezone: c.server.Timezone}

opts := queryOptions(ctx)
location := c.server.Timezone
if opts.userLocation != nil {
location = opts.userLocation
}

block := proto.Block{Timezone: location}
if err := block.Decode(c.reader, c.revision); err != nil {
c.debugf("[read data] decode error: %v", err)
return nil, err
Expand Down
10 changes: 8 additions & 2 deletions conn_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,14 @@ func (h *httpConnect) writeData(block *proto.Block) error {
return nil
}

func (h *httpConnect) readData(reader *chproto.Reader) (*proto.Block, error) {
block := proto.Block{Timezone: h.location}
func (h *httpConnect) readData(ctx context.Context, reader *chproto.Reader) (*proto.Block, error) {
opts := queryOptions(ctx)
location := h.location
if opts.userLocation != nil {
location = opts.userLocation
}

block := proto.Block{Timezone: location}
if h.compression == CompressionLZ4 || h.compression == CompressionZSTD {
reader.EnableCompression()
defer reader.DisableCompression()
Expand Down
4 changes: 2 additions & 2 deletions conn_http_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ func (h *httpConnect) query(ctx context.Context, release func(*connect, error),
}
h.compressionPool.Put(rw)
reader := chproto.NewReader(bytes.NewReader(body))
block, err := h.readData(reader)
block, err := h.readData(ctx, reader)
if err != nil {
return nil, err
}

go func() {
for {
block, err := h.readData(reader)
block, err := h.readData(ctx, reader)
if err != nil {
// ch-go wraps EOF errors
if !errors.Is(err, io.EOF) {
Expand Down
5 changes: 3 additions & 2 deletions conn_logs.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package clickhouse

import (
"context"
"time"

"github.com/ClickHouse/clickhouse-go/v2/lib/proto"
Expand All @@ -34,8 +35,8 @@ type Log struct {
Text string
}

func (c *connect) logs() ([]Log, error) {
block, err := c.readData(proto.ServerLog, false)
func (c *connect) logs(ctx context.Context) ([]Log, error) {
block, err := c.readData(ctx, proto.ServerLog, false)
if err != nil {
return nil, err
}
Expand Down
14 changes: 7 additions & 7 deletions conn_process.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ func (c *connect) firstBlock(ctx context.Context, on *onProcess) (*proto.Block,
}
switch packet {
case proto.ServerData:
return c.readData(packet, true)
return c.readData(ctx, packet, true)
case proto.ServerEndOfStream:
c.debugf("[end of stream]")
return nil, io.EOF
default:
if err := c.handle(packet, on); err != nil {
if err := c.handle(ctx, packet, on); err != nil {
return nil, err
}
}
Expand All @@ -75,16 +75,16 @@ func (c *connect) process(ctx context.Context, on *onProcess) error {
c.debugf("[end of stream]")
return nil
}
if err := c.handle(packet, on); err != nil {
if err := c.handle(ctx, packet, on); err != nil {
return err
}
}
}

func (c *connect) handle(packet byte, on *onProcess) error {
func (c *connect) handle(ctx context.Context, packet byte, on *onProcess) error {
switch packet {
case proto.ServerData, proto.ServerTotals, proto.ServerExtremes:
block, err := c.readData(packet, true)
block, err := c.readData(ctx, packet, true)
if err != nil {
return err
}
Expand All @@ -107,13 +107,13 @@ func (c *connect) handle(packet byte, on *onProcess) error {
}
c.debugf("[table columns]")
case proto.ServerProfileEvents:
events, err := c.profileEvents()
events, err := c.profileEvents(ctx)
if err != nil {
return err
}
on.profileEvents(events)
case proto.ServerLog:
logs, err := c.logs()
logs, err := c.logs(ctx)
if err != nil {
return err
}
Expand Down
5 changes: 3 additions & 2 deletions conn_profile_events.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package clickhouse

import (
"context"
"reflect"
"time"

Expand All @@ -33,8 +34,8 @@ type ProfileEvent struct {
Value int64
}

func (c *connect) profileEvents() ([]ProfileEvent, error) {
block, err := c.readData(proto.ServerProfileEvents, false)
func (c *connect) profileEvents(ctx context.Context) ([]ProfileEvent, error) {
block, err := c.readData(ctx, proto.ServerProfileEvents, false)
if err != nil {
return nil, err
}
Expand Down
8 changes: 8 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ type (
parameters Parameters
external []*ext.Table
blockBufferSize uint8
userLocation *time.Location
}
)

Expand Down Expand Up @@ -140,6 +141,13 @@ func WithStdAsync(wait bool) QueryOption {
}
}

func WithUserLocation(location *time.Location) QueryOption {
return func(o *QueryOptions) error {
o.userLocation = location
return nil
}
}

func Context(parent context.Context, options ...QueryOption) context.Context {
opt := queryOptions(parent)
for _, f := range options {
Expand Down
4 changes: 2 additions & 2 deletions lib/column/codegen/column.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ func (t Type) Column(name string, tz *time.Location) (Interface, error) {
case "Bool", "Boolean":
return &Bool{name: name}, nil
case "Date":
return &Date{name: name}, nil
return &Date{name: name, location: tz}, nil
case "Date32":
return &Date32{name: name}, nil
return &Date32{name: name, location: tz}, nil
case "UUID":
return &UUID{name: name}, nil
case "Nothing":
Expand Down
4 changes: 2 additions & 2 deletions lib/column/column_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

29 changes: 23 additions & 6 deletions lib/column/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ const (
)

type Date struct {
col proto.ColDate
name string
col proto.ColDate
name string
location *time.Location
}

func (col *Date) parse(t Type, tz *time.Location) (_ *Date, err error) {
col.location = tz
return col, nil
}

func (col *Date) Reset() {
Expand Down Expand Up @@ -222,7 +228,11 @@ func (col *Date) AppendRow(v interface{}) error {
return nil
}

func parseDate(value string, minDate time.Time, maxDate time.Time) (tv time.Time, err error) {
func parseDate(value string, minDate time.Time, maxDate time.Time, location *time.Location) (tv time.Time, err error) {
if location == nil {
location = time.Local
}

defer func() {
if err == nil {
err = dateOverflow(minDate, maxDate, tv, defaultDateFormatNoZone)
Expand All @@ -233,14 +243,14 @@ func parseDate(value string, minDate time.Time, maxDate time.Time) (tv time.Time
}
if tv, err = time.Parse(defaultDateFormatNoZone, value); err == nil {
return time.Date(
tv.Year(), tv.Month(), tv.Day(), tv.Hour(), tv.Minute(), tv.Second(), tv.Nanosecond(), time.Local,
tv.Year(), tv.Month(), tv.Day(), tv.Hour(), tv.Minute(), tv.Second(), tv.Nanosecond(), location,
), nil
}
return time.Time{}, err
}

func (col *Date) parseDate(value string) (tv time.Time, err error) {
return parseDate(value, minDate, maxDate)
return parseDate(value, minDate, maxDate, col.location)
}

func (col *Date) Decode(reader *proto.Reader, rows int) error {
Expand All @@ -252,7 +262,14 @@ func (col *Date) Encode(buffer *proto.Buffer) {
}

func (col *Date) row(i int) time.Time {
return col.col.Row(i)
t := col.col.Row(i)

if col.location != nil {
// proto.Date is normalized as time.Time with UTC timezone.
// We make sure Date return from ClickHouse matches server timezone or user defined location.
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), col.location)
}
return t
}

var _ Interface = (*Date)(nil)
16 changes: 12 additions & 4 deletions lib/column/date32.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,9 @@ var (
)

type Date32 struct {
col proto.ColDate32
name string
col proto.ColDate32
name string
location *time.Location
}

func (col *Date32) Reset() {
Expand Down Expand Up @@ -218,7 +219,7 @@ func (col *Date32) AppendRow(v interface{}) error {
}

func (col *Date32) parseDate(value string) (datetime time.Time, err error) {
return parseDate(value, minDate32, maxDate32)
return parseDate(value, minDate32, maxDate32, col.location)
}

func (col *Date32) Decode(reader *proto.Reader, rows int) error {
Expand All @@ -230,7 +231,14 @@ func (col *Date32) Encode(buffer *proto.Buffer) {
}

func (col *Date32) row(i int) time.Time {
return col.col.Row(i)
t := col.col.Row(i)

if col.location != nil {
// proto.Date is normalized as time.Time with UTC timezone.
// We make sure Date return from ClickHouse matches server timezone or user defined location.
t = time.Date(t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), col.location)
}
return t
}

var _ Interface = (*Date32)(nil)
31 changes: 31 additions & 0 deletions tests/date32_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,34 @@ func TestCustomDateTime32(t *testing.T) {
require.NoError(t, row.Scan(&col1))
require.Equal(t, now, time.Time(col1))
}

func TestDate32WithUserLocation(t *testing.T) {
t.Skip("Date32 decode is broken in this scenario. row.Scan returns '1977-07-01 00:00:00 +0000' instead of '2022-07-01 00:00:00 +0000'. Needs further investigation.")

ctx := context.Background()

conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
})
require.NoError(t, err)

require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS date_with_user_location"))
require.NoError(t, conn.Exec(ctx, `
CREATE TABLE date_with_user_location (
Col1 Date32
) Engine MergeTree() ORDER BY tuple()
`))
require.NoError(t, conn.Exec(ctx, "INSERT INTO date_with_user_location SELECT toDate32(toStartOfMonth(toDate('2022-07-12')))"))

userLocation, _ := time.LoadLocation("Pacific/Pago_Pago")
queryCtx := clickhouse.Context(ctx, clickhouse.WithUserLocation(userLocation))

var col1 time.Time
row := conn.QueryRow(queryCtx, "SELECT * FROM date_with_user_location")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&col1))

const dateTimeNoZoneFormat = "2006-01-02T15:04:05"
assert.Equal(t, "2022-07-01T00:00:00", col1.Format(dateTimeNoZoneFormat))
assert.Equal(t, userLocation.String(), col1.Location().String())
}
29 changes: 29 additions & 0 deletions tests/date_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,32 @@ func TestCustomDate(t *testing.T) {
require.NoError(t, row.Scan(&col1))
require.Equal(t, now, time.Time(col1))
}

func TestDateWithUserLocation(t *testing.T) {
ctx := context.Background()

conn, err := GetNativeConnection(nil, nil, &clickhouse.Compression{
Method: clickhouse.CompressionLZ4,
})
require.NoError(t, err)

require.NoError(t, conn.Exec(ctx, "DROP TABLE IF EXISTS date_with_user_location"))
require.NoError(t, conn.Exec(ctx, `
CREATE TABLE date_with_user_location (
Col1 Date
) Engine MergeTree() ORDER BY tuple()
`))
require.NoError(t, conn.Exec(ctx, "INSERT INTO date_with_user_location SELECT toStartOfMonth(toDate('2022-07-12'))"))

userLocation, _ := time.LoadLocation("Pacific/Pago_Pago")
queryCtx := clickhouse.Context(ctx, clickhouse.WithUserLocation(userLocation))

var col1 time.Time
row := conn.QueryRow(queryCtx, "SELECT * FROM date_with_user_location")
require.NoError(t, row.Err())
require.NoError(t, row.Scan(&col1))

const dateTimeNoZoneFormat = "2006-01-02T15:04:05"
assert.Equal(t, "2022-07-01T00:00:00", col1.Format(dateTimeNoZoneFormat))
assert.Equal(t, userLocation.String(), col1.Location().String())
}
Loading

0 comments on commit 08b2788

Please sign in to comment.