blob: d5f95c8251e727397330bca22a2a26a07a1d289f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package bundlev1
import (
"context"
"log/slog"
"testing"
"github.com/stretchr/testify/suite"
"github.com/apache/airflow/go-sdk/pkg/logging"
"github.com/apache/airflow/go-sdk/sdk"
)
type TaskSuite struct {
suite.Suite
}
func TestTaskSuite(t *testing.T) {
suite.Run(t, &TaskSuite{})
}
func (s *TaskSuite) TestReturnValidation() {
cases := map[string]struct {
fn any
errContains string
}{
"no-ret-values": {
func() {},
`func\d+ has 0 return values, must be`,
},
"too-many-ret-values": {
func() (a, b, c int) { return },
`func\d+ has 3 return values, must be`,
},
"invalid-ret": {
func() (c chan int) { return },
`func\d+ last return value to return error but found chan`,
},
}
for name, tt := range cases {
s.Run(name, func() {
_, err := NewTaskFunction(tt.fn)
if s.Assert().Error(err) {
s.Assert().Regexp(tt.errContains, err.Error())
}
})
}
}
func (s *TaskSuite) TestArgumentBinding() {
cases := map[string]struct {
fn any
}{
"no-args": {
func() error { return nil },
},
"context": {
func(ctx context.Context) error {
s.Equal("def", ctx.Value("abc"))
return nil
},
},
"context-and-logger": {
func(ctx context.Context, logger *slog.Logger) error {
s.Equal("def", ctx.Value("abc"))
s.NotNil(logger)
return nil
},
},
"client": {
func(client sdk.Client) error {
s.NotNil(client)
return nil
},
},
"var-client": {
func(client sdk.VariableClient) error {
s.NotNil(client)
return nil
},
},
"conn-client": {
func(client sdk.ConnectionClient) error {
s.NotNil(client)
return nil
},
},
}
for name, tt := range cases {
s.Run(name, func() {
task, err := NewTaskFunction(tt.fn)
s.Require().NoError(err)
ctx := context.WithValue(context.Background(), "abc", "def")
logger := slog.New(logging.NewTeeLogger())
task.Execute(ctx, logger)
})
}
}