Add dataframe PrintSchema and schema.TreeString

Added the skeleton for dataframe.PrintSchema() and schema.TreeString() that can be extended with functionality pertaining to nested dataTypes once they become available.

The feature has been tested through writing a simple job.


Closes #80 from magpierre/Add-PrintSchema-method-to-dataframe.

Lead-authored-by: Magnus Pierre <magnus.pierre@icloud.com>
Co-authored-by: Martin Grund <martin.grund@databricks.com>
Signed-off-by: Martin Grund <martin.grund@databricks.com>
diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go
index 503b859..d383ca1 100644
--- a/internal/tests/integration/dataframe_test.go
+++ b/internal/tests/integration/dataframe_test.go
@@ -1204,3 +1204,23 @@
 		assert.Error(t, err)
 	}
 }
+
+func TestDataFrame_PrintSchema(t *testing.T) {
+	ctx, spark := connect()
+	df, err := spark.Sql(ctx, "select * from range(10)")
+	assert.NoError(t, err)
+	err = df.PrintSchema(ctx)
+	assert.NoError(t, err)
+}
+
+func TestDataFrame_SchemaTreeString(t *testing.T) {
+	ctx, spark := connect()
+	df, err := spark.Sql(ctx, "select map('a', 1) as first, array(1,2,3) as second, map('a', map('b', 2)) as third")
+	assert.NoError(t, err)
+	schema, err := df.Schema(ctx)
+	assert.NoError(t, err)
+	ts := schema.TreeString()
+	assert.Contains(t, ts, "|-- first: map")
+	assert.Contains(t, ts, "|-- second: array")
+	assert.Contains(t, ts, "|-- third: map")
+}
diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go
index 02eb10d..4c7a563 100644
--- a/spark/sql/dataframe.go
+++ b/spark/sql/dataframe.go
@@ -156,6 +156,7 @@
 	Offset(ctx context.Context, offset int32) DataFrame
 	// OrderBy is an alias for Sort
 	OrderBy(ctx context.Context, columns ...column.Convertible) (DataFrame, error)
+	PrintSchema(ctx context.Context) error
 	Persist(ctx context.Context, storageLevel utils.StorageLevel) error
 	RandomSplit(ctx context.Context, weights []float64) ([]DataFrame, error)
 	// Repartition re-partitions a data frame.
@@ -1716,3 +1717,12 @@
 		}
 	}
 }
+
+func (df *dataFrameImpl) PrintSchema(ctx context.Context) error {
+	schema, err := df.Schema(ctx)
+	if err != nil {
+		return err
+	}
+	fmt.Print(schema.TreeString())
+	return nil
+}
diff --git a/spark/sql/types/structtype.go b/spark/sql/types/structtype.go
index d55d02c..17becd5 100644
--- a/spark/sql/types/structtype.go
+++ b/spark/sql/types/structtype.go
@@ -16,7 +16,12 @@
 
 package types
 
-import "github.com/apache/arrow-go/v18/arrow"
+import (
+	"fmt"
+	"strings"
+
+	"github.com/apache/arrow-go/v18/arrow"
+)
 
 // StructField represents a field in a StructType.
 type StructField struct {
@@ -34,6 +39,36 @@
 	}
 }
 
+func (t *StructField) buildFormattedString(prefix string, target *string) {
+	if target == nil {
+		return
+	}
+
+	switch t.DataType.(type) {
+	case ArrayType:
+		*target += fmt.Sprintf("%s-- %s: array (nullable = %t)\n",
+			prefix, t.Name, t.Nullable)
+		*target += fmt.Sprintf("%s    |-- element: %s (valueContainsNull = %t)\n",
+			prefix, t.DataType.(ArrayType).ElementType.TypeName(), t.Nullable)
+	case MapType:
+		*target += fmt.Sprintf("%s-- %s: map (nullable = %t)\n",
+			prefix, t.Name, t.Nullable)
+		*target += fmt.Sprintf("%s    |-- key: %s\n",
+			prefix, t.DataType.(MapType).KeyType.TypeName())
+		*target += fmt.Sprintf("%s    |-- value: %s (valueContainsNull = %t)\n",
+			prefix, t.DataType.(MapType).ValueType.TypeName(), t.Nullable)
+	case StructType:
+		*target += fmt.Sprintf("%s-- %s: structtype (nullable = %t)\n",
+			prefix, t.Name, t.Nullable)
+		for _, field := range t.DataType.(StructType).Fields {
+			field.buildFormattedString(prefix+"    |", target)
+		}
+	default:
+		*target += fmt.Sprintf("%s-- %s: %s (nullable = %t)\n", prefix, t.Name,
+			strings.ToLower(t.DataType.TypeName()), t.Nullable)
+	}
+}
+
 // StructType represents a struct type.
 type StructType struct {
 	Fields []StructField
@@ -55,6 +90,15 @@
 	return arrow.StructOf(fields...)
 }
 
+func (t *StructType) TreeString() string {
+	tree := string("root\n")
+	prefix := " |"
+	for _, f := range t.Fields {
+		f.buildFormattedString(prefix, &tree)
+	}
+	return tree + "\n"
+}
+
 func StructOf(fields ...StructField) *StructType {
 	return &StructType{Fields: fields}
 }
diff --git a/spark/sql/types/structtype_test.go b/spark/sql/types/structtype_test.go
index ccb3a98..482300c 100644
--- a/spark/sql/types/structtype_test.go
+++ b/spark/sql/types/structtype_test.go
@@ -17,6 +17,7 @@
 package types
 
 import (
+	"strings"
 	"testing"
 
 	"github.com/apache/arrow-go/v18/arrow"
@@ -125,3 +126,81 @@
 	assert.True(t, ok)
 	assert.Equal(t, 1, concreteType.NumFields())
 }
+
+func TestTreeString(t *testing.T) {
+	c := NewStructField("col1", STRING)
+	c.Nullable = false
+	s := StructOf(
+		c,
+		NewStructField("col2", INTEGER),
+		NewStructField("col3", DATE),
+	)
+	assert.Len(t, s.Fields, 3)
+	ts := s.TreeString()
+	assert.Contains(t, ts, "|-- col1: string (nullable = false")
+	assert.Contains(t, ts, "|-- col2: integer (nullable = true)")
+	assert.Contains(t, ts, "|-- col3: date (nullable = true)")
+}
+
+func TestTreeString_ComplexNestedTypes(t *testing.T) {
+	// Create a complex nested structure with maps, arrays, and nested structs
+	nestedStruct := StructOf(
+		NewStructField("nested_id", INTEGER),
+		NewStructField("nested_name", STRING),
+	)
+
+	arrayOfStrings := ArrayType{
+		ElementType:  STRING,
+		ContainsNull: true,
+	}
+
+	mapOfIntToString := MapType{
+		KeyType:           INTEGER,
+		ValueType:         STRING,
+		ValueContainsNull: true,
+	}
+
+	arrayOfMaps := ArrayType{
+		ElementType: MapType{
+			KeyType:           STRING,
+			ValueType:         DOUBLE,
+			ValueContainsNull: false,
+		},
+		ContainsNull: true,
+	}
+
+	complexStruct := StructOf(
+		NewStructField("id", INTEGER),
+		NewStructField("name", STRING),
+		NewStructField("tags", arrayOfStrings),
+		NewStructField("metadata", mapOfIntToString),
+		NewStructField("scores", arrayOfMaps),
+		NewStructField("profile", *nestedStruct),
+		NewStructField("active", BOOLEAN),
+	)
+
+	ts := complexStruct.TreeString()
+
+	// Verify the tree string contains all expected elements
+	assert.Contains(t, ts, "root")
+	assert.Contains(t, ts, "|-- id: integer (nullable = true)")
+	assert.Contains(t, ts, "|-- name: string (nullable = true)")
+	assert.Contains(t, ts, "|-- tags: array (nullable = true)")
+	assert.Contains(t, ts, "|-- metadata: map (nullable = true)")
+	assert.Contains(t, ts, "|-- scores: array (nullable = true)")
+	assert.Contains(t, ts, "|-- profile: structtype (nullable = true)")
+	assert.Contains(t, ts, "|-- active: boolean (nullable = true)")
+
+	// Verify the structure starts with "root" and ends with newlines
+	assert.True(t, strings.HasPrefix(ts, "root\n"))
+	assert.True(t, strings.HasSuffix(ts, "\n"))
+
+	// Verify the correct number of field lines (excluding root and trailing newline)
+	lines := strings.Split(strings.TrimSpace(ts), "\n")
+	assert.Equal(t, len(complexStruct.Fields)+7, len(lines)) // root + number of fields
+
+	// Verify the prefix format for all field lines
+	for i := 1; i < len(lines); i++ {
+		assert.True(t, strings.HasPrefix(lines[i], " |"))
+	}
+}