blob: 0ef785dd3dc7c7cea44c6f4a7fcad6b19dd23f5c [file] [log] [blame]
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
"bytes"
"context"
"io"
"testing"
"github.com/apache/arrow/go/v17/arrow"
"github.com/apache/arrow/go/v17/arrow/array"
"github.com/apache/arrow/go/v17/arrow/ipc"
"github.com/apache/arrow/go/v17/arrow/memory"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
proto "github.com/apache/spark-connect-go/v35/internal/generated"
"github.com/apache/spark-connect-go/v35/spark/client"
"github.com/apache/spark-connect-go/v35/spark/client/testutils"
"github.com/apache/spark-connect-go/v35/spark/mocks"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
)
func TestSparkSessionTable(t *testing.T) {
resetPlanIdForTesting()
plan := newReadTableRelation("table")
resetPlanIdForTesting()
s := testutils.NewConnectServiceClientMock(nil, nil, nil, t)
c := client.NewSparkExecutorFromClient(s, nil, "")
session := &sparkSessionImpl{client: c}
df, err := session.Table("table")
df_plan := df.(*dataFrameImpl).relation
assert.Equal(t, plan, df_plan)
assert.NoError(t, err)
}
func TestSQLCallsExecutePlanWithSQLOnClient(t *testing.T) {
ctx := context.Background()
query := "select * from bla"
// Create the responses:
responses := []*mocks.MockResponse{
{
Resp: &proto.ExecutePlanResponse{
ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{
SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{},
},
},
Err: nil,
},
{
Resp: &proto.ExecutePlanResponse{
ResponseType: &proto.ExecutePlanResponse_ResultComplete_{
ResultComplete: &proto.ExecutePlanResponse_ResultComplete{},
},
},
Err: nil,
},
{
Err: io.EOF,
},
}
s := testutils.NewConnectServiceClientMock(&mocks.ProtoClient{
RecvResponse: responses,
}, nil, nil, t)
c := client.NewSparkExecutorFromClient(s, nil, "")
session := &sparkSessionImpl{
client: c,
}
resp, err := session.Sql(ctx, query)
assert.NoError(t, err)
assert.NotNil(t, resp)
}
func TestNewSessionBuilderCreatesASession(t *testing.T) {
ctx := context.Background()
spark, err := NewSessionBuilder().Remote("sc://connection").Build(ctx)
assert.NoError(t, err)
assert.NotNil(t, spark)
}
func TestNewSessionBuilderFailsIfConnectionStringIsInvalid(t *testing.T) {
ctx := context.Background()
spark, err := NewSessionBuilder().Remote("invalid").Build(ctx)
assert.Error(t, err)
assert.ErrorIs(t, err, sparkerrors.InvalidInputError)
assert.Nil(t, spark)
}
func TestWriteResultStreamsArrowResultToCollector(t *testing.T) {
ctx := context.Background()
arrowFields := []arrow.Field{
{
Name: "show_string",
Type: &arrow.StringType{},
},
}
arrowSchema := arrow.NewSchema(arrowFields, nil)
var buf bytes.Buffer
arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema))
defer arrowWriter.Close()
alloc := memory.NewGoAllocator()
recordBuilder := array.NewRecordBuilder(alloc, arrowSchema)
defer recordBuilder.Release()
recordBuilder.Field(0).(*array.StringBuilder).Append("str1a\nstr1b")
recordBuilder.Field(0).(*array.StringBuilder).Append("str2")
record := recordBuilder.NewRecord()
defer record.Release()
err := arrowWriter.Write(record)
require.Nil(t, err)
query := "select * from bla"
// Create the responses:
responses := []*mocks.MockResponse{
// The first stream of response is necessary for the SQL command.
{
Resp: &proto.ExecutePlanResponse{
ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{
SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{},
},
},
Err: nil,
},
{
Resp: &proto.ExecutePlanResponse{
ResponseType: &proto.ExecutePlanResponse_ResultComplete_{
ResultComplete: &proto.ExecutePlanResponse_ResultComplete{},
},
},
Err: nil,
},
{
Err: io.EOF,
},
// The second stream of responses is for the actual execution
{
Resp: &proto.ExecutePlanResponse{
ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{
ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{
RowCount: 2,
Data: buf.Bytes(),
},
},
},
},
{
Err: io.EOF,
},
}
s := testutils.NewConnectServiceClientMock(&mocks.ProtoClient{
RecvResponse: responses,
}, nil, nil, t)
c := client.NewSparkExecutorFromClient(s, nil, "")
session := &sparkSessionImpl{
client: c,
}
resp, err := session.Sql(ctx, query)
assert.NoError(t, err)
assert.NotNil(t, resp)
df, err := resp.Repartition(1, []string{"1"})
assert.NoError(t, err)
rows, err := df.Collect(ctx)
assert.NoError(t, err)
vals, err := rows[1].Values()
assert.NoError(t, err)
assert.Equal(t, []any{"str2"}, vals)
}