blob: 6fa751547939cff98d260984dfb6c0d0782e19dd [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package sql
import (
proto ""
// DataFrame is a wrapper for data frame, representing a distributed collection of data row.
type DataFrame interface {
// Show prints out data frame data.
Show(numRows int, truncate bool) error
// Schema returns the schema for the current data frame.
Schema() (*StructType, error)
// Collect returns the data rows of the current data frame.
Collect() ([]Row, error)
// Write returns a data frame writer, which could be used to save data frame to supported storage.
Write() DataFrameWriter
// CreateTempView creates or replaces a temporary view.
CreateTempView(viewName string, replace bool, global bool) error
// Repartition re-partitions a data frame.
Repartition(numPartitions int, columns []string) (DataFrame, error)
// RepartitionByRange re-partitions a data frame by range partition.
RepartitionByRange(numPartitions int, columns []RangePartitionColumn) (DataFrame, error)
type RangePartitionColumn struct {
Name string
Descending bool
// dataFrameImpl is an implementation of DataFrame interface.
type dataFrameImpl struct {
sparkSession *sparkSessionImpl
relation *proto.Relation // TODO change to proto.Plan?
func (df *dataFrameImpl) Show(numRows int, truncate bool) error {
truncateValue := 0
if truncate {
truncateValue = 20
vertical := false
plan := &proto.Plan{
OpType: &proto.Plan_Root{
Root: &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
RelType: &proto.Relation_ShowString{
ShowString: &proto.ShowString{
Input: df.relation,
NumRows: int32(numRows),
Truncate: int32(truncateValue),
Vertical: vertical,
responseClient, err := df.sparkSession.executePlan(plan)
if err != nil {
return fmt.Errorf("failed to show dataframe: %w", err)
for {
response, err := responseClient.Recv()
if err != nil {
return fmt.Errorf("failed to receive show response: %w", err)
arrowBatch := response.GetArrowBatch()
if arrowBatch == nil {
err = showArrowBatch(arrowBatch)
if err != nil {
return err
return nil
return fmt.Errorf("did not get arrow batch in response")
func (df *dataFrameImpl) Schema() (*StructType, error) {
response, err := df.sparkSession.analyzePlan(df.createPlan())
if err != nil {
return nil, fmt.Errorf("failed to analyze plan: %w", err)
responseSchema := response.GetSchema().Schema
result := convertProtoDataTypeToStructType(responseSchema)
return result, nil
func (df *dataFrameImpl) Collect() ([]Row, error) {
responseClient, err := df.sparkSession.executePlan(df.createPlan())
if err != nil {
return nil, fmt.Errorf("failed to execute plan: %w", err)
var schema *StructType
var allRows []Row
for {
response, err := responseClient.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
return allRows, nil
} else {
return nil, fmt.Errorf("failed to receive plan execution response: %w", err)
dataType := response.GetSchema()
if dataType != nil {
schema = convertProtoDataTypeToStructType(dataType)
arrowBatch := response.GetArrowBatch()
if arrowBatch == nil {
rowBatch, err := readArrowBatchData(arrowBatch.Data, schema)
if err != nil {
return nil, err
if allRows == nil {
allRows = make([]Row, 0, len(rowBatch))
allRows = append(allRows, rowBatch...)
return allRows, nil
func (df *dataFrameImpl) Write() DataFrameWriter {
writer := dataFrameWriterImpl{
sparkSession: df.sparkSession,
relation: df.relation,
return &writer
func (df *dataFrameImpl) CreateTempView(viewName string, replace bool, global bool) error {
plan := &proto.Plan{
OpType: &proto.Plan_Command{
Command: &proto.Command{
CommandType: &proto.Command_CreateDataframeView{
CreateDataframeView: &proto.CreateDataFrameViewCommand{
Input: df.relation,
Name: viewName,
Replace: replace,
IsGlobal: global,
responseClient, err := df.sparkSession.executePlan(plan)
if err != nil {
return fmt.Errorf("failed to create temp view %s: %w", viewName, err)
return consumeExecutePlanClient(responseClient)
func (df *dataFrameImpl) Repartition(numPartitions int, columns []string) (DataFrame, error) {
var partitionExpressions []*proto.Expression
if columns != nil {
partitionExpressions = make([]*proto.Expression, 0, len(columns))
for _, c := range columns {
expr := &proto.Expression{
ExprType: &proto.Expression_UnresolvedAttribute_{
UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{
UnparsedIdentifier: c,
partitionExpressions = append(partitionExpressions, expr)
return df.repartitionByExpressions(numPartitions, partitionExpressions)
func (df *dataFrameImpl) RepartitionByRange(numPartitions int, columns []RangePartitionColumn) (DataFrame, error) {
var partitionExpressions []*proto.Expression
if columns != nil {
partitionExpressions = make([]*proto.Expression, 0, len(columns))
for _, c := range columns {
columnExpr := &proto.Expression{
ExprType: &proto.Expression_UnresolvedAttribute_{
UnresolvedAttribute: &proto.Expression_UnresolvedAttribute{
UnparsedIdentifier: c.Name,
direction := proto.Expression_SortOrder_SORT_DIRECTION_ASCENDING
if c.Descending {
direction = proto.Expression_SortOrder_SORT_DIRECTION_DESCENDING
sortExpr := &proto.Expression{
ExprType: &proto.Expression_SortOrder_{
SortOrder: &proto.Expression_SortOrder{
Child: columnExpr,
Direction: direction,
partitionExpressions = append(partitionExpressions, sortExpr)
return df.repartitionByExpressions(numPartitions, partitionExpressions)
func (df *dataFrameImpl) createPlan() *proto.Plan {
return &proto.Plan{
OpType: &proto.Plan_Root{
Root: &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
RelType: df.relation.RelType,
func (df *dataFrameImpl) repartitionByExpressions(numPartitions int, partitionExpressions []*proto.Expression) (DataFrame, error) {
var numPartitionsPointerValue *int32
if numPartitions != 0 {
int32Value := int32(numPartitions)
numPartitionsPointerValue = &int32Value
newRelation := &proto.Relation{
Common: &proto.RelationCommon{
PlanId: newPlanId(),
RelType: &proto.Relation_RepartitionByExpression{
RepartitionByExpression: &proto.RepartitionByExpression{
Input: df.relation,
NumPartitions: numPartitionsPointerValue,
PartitionExprs: partitionExpressions,
return &dataFrameImpl{
sparkSession: df.sparkSession,
relation: newRelation,
}, nil
func showArrowBatch(arrowBatch *proto.ExecutePlanResponse_ArrowBatch) error {
return showArrowBatchData(arrowBatch.Data)
func showArrowBatchData(data []byte) error {
rows, err := readArrowBatchData(data, nil)
if err != nil {
return err
for _, row := range rows {
values, err := row.Values()
if err != nil {
return fmt.Errorf("failed to get values in the row: %w", err)
return nil
func readArrowBatchData(data []byte, schema *StructType) ([]Row, error) {
reader := bytes.NewReader(data)
arrowReader, err := ipc.NewReader(reader)
if err != nil {
return nil, fmt.Errorf("failed to create arrow reader: %w", err)
defer arrowReader.Release()
var rows []Row
for {
record, err := arrowReader.Read()
if err != nil {
if errors.Is(err, io.EOF) {
return rows, nil
} else {
return nil, fmt.Errorf("failed to read arrow: %w", err)
values, err := readArrowRecord(record)
if err != nil {
return nil, err
numRows := int(record.NumRows())
if rows == nil {
rows = make([]Row, 0, numRows)
for _, v := range values {
row := &GenericRowWithSchema{
schema: schema,
values: v,
rows = append(rows, row)
hasNext := arrowReader.Next()
if !hasNext {
return rows, nil
// readArrowRecordColumn reads all values from arrow record and return [][]any
func readArrowRecord(record arrow.Record) ([][]any, error) {
numRows := record.NumRows()
numColumns := int(record.NumCols())
values := make([][]any, numRows)
for i := range values {
values[i] = make([]any, numColumns)
for columnIndex := 0; columnIndex < numColumns; columnIndex++ {
err := readArrowRecordColumn(record, columnIndex, values)
if err != nil {
return nil, err
return values, nil
// readArrowRecordColumn reads all values in a column and stores them in values
func readArrowRecordColumn(record arrow.Record, columnIndex int, values [][]any) error {
numRows := int(record.NumRows())
columnData := record.Column(columnIndex).Data()
dataTypeId := columnData.DataType().ID()
switch dataTypeId {
case arrow.BOOL:
vector := array.NewBooleanData(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.INT8:
vector := array.NewInt8Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.INT16:
vector := array.NewInt16Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.INT32:
vector := array.NewInt32Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.INT64:
vector := array.NewInt64Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.FLOAT16:
vector := array.NewFloat16Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.FLOAT32:
vector := array.NewFloat32Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.FLOAT64:
vector := array.NewFloat64Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.DECIMAL | arrow.DECIMAL128:
vector := array.NewDecimal128Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.DECIMAL256:
vector := array.NewDecimal256Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.STRING:
vector := array.NewStringData(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.BINARY:
vector := array.NewBinaryData(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.TIMESTAMP:
vector := array.NewTimestampData(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
case arrow.DATE64:
vector := array.NewDate64Data(columnData)
for rowIndex := 0; rowIndex < numRows; rowIndex++ {
values[rowIndex][columnIndex] = vector.Value(rowIndex)
return fmt.Errorf("unsupported arrow data type %s in column %d", dataTypeId.String(), columnIndex)
return nil
func convertProtoDataTypeToStructType(input *proto.DataType) *StructType {
dataTypeStruct := input.GetStruct()
if dataTypeStruct == nil {
panic("dataType.GetStruct() is nil")
return &StructType{
Fields: convertProtoStructFields(dataTypeStruct.Fields),
func convertProtoStructFields(input []*proto.DataType_StructField) []StructField {
result := make([]StructField, len(input))
for i, f := range input {
result[i] = convertProtoStructField(f)
return result
func convertProtoStructField(field *proto.DataType_StructField) StructField {
return StructField{
Name: field.Name,
DataType: convertProtoDataTypeToDataType(field.DataType),
// convertProtoDataTypeToDataType converts protobuf data type to Spark connect sql data type
func convertProtoDataTypeToDataType(input *proto.DataType) DataType {
switch v := input.GetKind().(type) {
case *proto.DataType_Boolean_:
return BooleanType{}
case *proto.DataType_Byte_:
return ByteType{}
case *proto.DataType_Short_:
return ShortType{}
case *proto.DataType_Integer_:
return IntegerType{}
case *proto.DataType_Long_:
return LongType{}
case *proto.DataType_Float_:
return FloatType{}
case *proto.DataType_Double_:
return DoubleType{}
case *proto.DataType_Decimal_:
return DecimalType{}
case *proto.DataType_String_:
return StringType{}
case *proto.DataType_Binary_:
return BinaryType{}
case *proto.DataType_Timestamp_:
return TimestampType{}
case *proto.DataType_TimestampNtz:
return TimestampNtzType{}
case *proto.DataType_Date_:
return DateType{}
return UnsupportedType{
TypeInfo: v,