blob: 2106e558065708773b841e72ffc79641bcf23356 [file] [log] [blame]
/*
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/
package gremlingo
import (
"fmt"
gremlingo "github.com/apache/tinkerpop/gremlin-go/v3/driver"
"github.com/cucumber/godog"
"os"
"reflect"
"strconv"
)
type CucumberWorld struct {
scenario *godog.Scenario
g *gremlingo.GraphTraversalSource
graphName string
traversal *gremlingo.GraphTraversal
result []interface{}
graphDataMap map[string]*DataGraph
parameters map[string]interface{}
}
type DataGraph struct {
name string
connection *gremlingo.DriverRemoteConnection
vertices map[string]*gremlingo.Vertex
edges map[string]*gremlingo.Edge
}
func getEnvOrDefaultString(key string, defaultValue string) string {
// Missing value is returned as "".
value := os.Getenv(key)
if len(value) != 0 {
return value
}
return defaultValue
}
func getEnvOrDefaultInt(key string, defaultValue int) int {
value := getEnvOrDefaultString(key, "")
if len(value) != 0 {
intValue, err := strconv.Atoi(value)
if err == nil {
return intValue
}
}
return defaultValue
}
func scenarioUrl() string {
return getEnvOrDefaultString("GREMLIN_SERVER_URL", "ws://localhost:45940/gremlin")
}
func NewCucumberWorld() *CucumberWorld {
return &CucumberWorld{
scenario: nil,
g: nil,
graphName: "",
traversal: nil,
result: nil,
graphDataMap: make(map[string]*DataGraph),
parameters: make(map[string]interface{}),
}
}
var graphNames = []string{"modern", "classic", "crew", "grateful", "sink", "empty"}
func (t *CucumberWorld) getDataGraphFromMap(name string) *DataGraph {
if val, ok := t.graphDataMap[name]; ok {
return val
} else {
return nil
}
}
func (t *CucumberWorld) loadAllDataGraph() {
for _, name := range graphNames {
if name == "empty" {
t.loadEmptyDataGraph()
} else {
connection, err := gremlingo.NewDriverRemoteConnection(scenarioUrl(),
func(settings *gremlingo.DriverRemoteConnectionSettings) {
settings.TraversalSource = "g" + name
})
if err != nil {
panic(fmt.Sprintf("Failed to create connection '%v'", err))
}
g := gremlingo.Traversal_().WithRemote(connection)
t.graphDataMap[name] = &DataGraph{
name: name,
connection: connection,
vertices: getVertices(g),
edges: getEdges(g),
}
}
}
}
func (t *CucumberWorld) loadEmptyDataGraph() {
connection, _ := gremlingo.NewDriverRemoteConnection(scenarioUrl(), func(settings *gremlingo.DriverRemoteConnectionSettings) {
settings.TraversalSource = "ggraph"
})
t.graphDataMap["empty"] = &DataGraph{connection: connection}
}
func (t *CucumberWorld) reloadEmptyData() {
graphData := t.getDataGraphFromMap("empty")
g := gremlingo.Traversal_().WithRemote(graphData.connection)
graphData.vertices = getVertices(g)
graphData.edges = getEdges(g)
}
func (t *CucumberWorld) cleanEmptyDataGraph(g *gremlingo.GraphTraversalSource) error {
future := g.V().Drop().Iterate()
return <-future
}
func getVertices(g *gremlingo.GraphTraversalSource) map[string]*gremlingo.Vertex {
vertexMap := make(map[string]*gremlingo.Vertex)
res, err := g.V().Group().By("name").By(gremlingo.T__.Tail()).Next()
if res == nil {
return nil
}
if err != nil {
return nil
}
v := reflect.ValueOf(res.GetInterface())
if v.Kind() != reflect.Map {
fmt.Printf("Expecting to get a map as a result, got %v instead.", v.Kind())
return nil
}
keys := v.MapKeys()
for _, k := range keys {
convKey := k.Convert(v.Type().Key())
val := v.MapIndex(convKey)
vertexMap[k.Interface().(string)] = val.Interface().(*gremlingo.Vertex)
}
return vertexMap
}
func getEdges(g *gremlingo.GraphTraversalSource) map[string]*gremlingo.Edge {
edgeMap := make(map[string]*gremlingo.Edge)
resE, err := g.E().Group().By(gremlingo.T__.Project("o", "l", "i").
By(gremlingo.T__.OutV().Values("name")).By(gremlingo.T__.Label()).By(gremlingo.T__.InV().Values("name"))).
By(gremlingo.T__.Tail()).Next()
if err != nil {
return nil
}
valMap := reflect.ValueOf(resE.GetInterface())
if valMap.Kind() != reflect.Map {
fmt.Printf("Expecting to get a map as a result, got %v instead.", valMap.Kind())
return nil
}
keys := valMap.MapKeys()
for _, k := range keys {
convKey := k.Convert(valMap.Type().Key())
val := valMap.MapIndex(convKey)
keyMap := reflect.ValueOf(k.Interface()).Elem().Interface().(map[interface{}]interface{})
edgeMap[getEdgeKey(keyMap)] = val.Interface().(*gremlingo.Edge)
}
return edgeMap
}
func getEdgeKey(edgeKeyMap map[interface{}]interface{}) string {
return fmt.Sprint(edgeKeyMap["o"], "-", edgeKeyMap["l"], "->", edgeKeyMap["i"])
}
// This function is used to isolate connection problems to each scenario, and used in the Before context hook to prevent
// a failing test in one scenario closing the shared connection that leads to failing subsequent scenario tests.
// This function can be removed once all pending tests pass.
func (t *CucumberWorld) recreateAllDataGraphConnection() error {
var err error
for _, name := range graphNames {
if name == "empty" {
t.getDataGraphFromMap(name).connection, err = gremlingo.NewDriverRemoteConnection(scenarioUrl(), func(settings *gremlingo.DriverRemoteConnectionSettings) {
settings.TraversalSource = "ggraph"
})
} else {
t.getDataGraphFromMap(name).connection, err = gremlingo.NewDriverRemoteConnection(scenarioUrl(), func(settings *gremlingo.DriverRemoteConnectionSettings) {
settings.TraversalSource = "g" + name
})
}
}
return err
}
func (t *CucumberWorld) closeAllDataGraphConnection() error {
for _, name := range graphNames {
t.getDataGraphFromMap(name).connection.Close()
}
return nil
}
func strategyFactory(strategyName string, params map[string]interface{}) interface{} {
switch strategyName {
case "VertexProgramStrategy":
graphComputer, _ := params["graphComputer"].(string)
config := gremlingo.VertexProgramStrategyConfig{
GraphComputer: graphComputer,
Workers: 0,
Persist: "",
Result: "",
Vertices: nil,
Edges: nil,
Configuration: nil,
}
return gremlingo.VertexProgramStrategy(config)
case "ProductiveByStrategy":
productiveKeys, _ := params["productiveKeys"]
productiveKeysInterface := productiveKeys.([]interface{})
var productiveKeysStrings = make([]string, len(productiveKeysInterface))
for i := range productiveKeysInterface {
productiveKeysStrings[i] = productiveKeysInterface[i].(string)
}
config := gremlingo.ProductiveByStrategyConfig{
ProductiveKeys: productiveKeysStrings,
}
return gremlingo.ProductiveByStrategy(config)
case "ReadOnlyStrategy":
return gremlingo.ReadOnlyStrategy()
case "SubgraphStrategy":
edges, _ := params["edges"].(*gremlingo.GraphTraversal)
vertices, _ := params["vertices"].(*gremlingo.GraphTraversal)
vertexProperties, _ := params["vertexProperties"].(*gremlingo.GraphTraversal)
checkAdjacentVertices, _ := params["checkAdjacentVertices"]
config := gremlingo.SubgraphStrategyConfig{
Edges: edges,
Vertices: vertices,
VertexProperties: vertexProperties,
CheckAdjacentVertices: checkAdjacentVertices,
}
return gremlingo.SubgraphStrategy(config)
case "SeedStrategy":
seed, _ := params["seed"]
config := gremlingo.SeedStrategyConfig{
Seed: int64(seed.(int)),
}
return gremlingo.SeedStrategy(config)
}
return nil
}