blob: cedd4e24d228ed54f72cf05d9e16cb4253d8f980 [file] [log] [blame]
import functools
import os
from unittest import TestCase
import casbin
from graphql import (
GraphQLArgument,
GraphQLField,
GraphQLInt,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLSchema,
GraphQLString,
graphql_sync,
)
from authz.middleware import enforcer_middleware
def get_examples(path):
examples_path = os.path.split(os.path.realpath(__file__))[0] + "/../examples/"
return os.path.abspath(examples_path + path)
def sub_test(param_list):
"""Decorates a test case to run it as a set of subtests."""
def decorator(f):
@functools.wraps(f)
def wrapped(self):
for param in param_list:
with self.subTest(**param):
f(self, **param)
return wrapped
return decorator
class TestConfig(TestCase):
def given_a_graphql_schema(self) -> GraphQLSchema:
ticket_type = GraphQLObjectType(
name="TicketType",
fields={
"id": GraphQLField(GraphQLInt),
"message": GraphQLField(GraphQLString),
},
)
member_type = GraphQLObjectType(
name="MemberType",
fields={
"id": GraphQLField(GraphQLInt),
"name": GraphQLField(GraphQLString),
"tickets": GraphQLField(
GraphQLList(ticket_type),
resolve=lambda member, _info: [
{"id": 1, "message": f"Member {member['id']}, Ticket: 1"},
{"id": 2, "message": f"Member {member['id']}, Ticket: 2"},
{"id": 3, "message": f"Member {member['id']}, Ticket: 3"},
{"id": 4, "message": f"Member {member['id']}, Ticket: 4"},
],
),
},
)
project_type = GraphQLObjectType(
name="ProjectType",
fields={
"id": GraphQLNonNull(GraphQLInt),
"name": GraphQLField(GraphQLString),
"members": GraphQLField(
GraphQLList(member_type),
resolve=lambda project, _info: [
{"id": 1, "name": f"Project {project['id']}, Member: 1"},
{"id": 2, "name": f"Project {project['id']}, Member: 2"},
],
),
},
)
query_type = GraphQLObjectType(
name="Query",
fields={
"project": GraphQLField(
project_type,
args={"id": GraphQLArgument(GraphQLInt)},
resolve=lambda _source, _info, id: {
"id": id,
"name": f"Project {id}",
},
),
"projects": GraphQLField(
GraphQLList(project_type),
resolve=lambda _source, _info, id: [
{"id": 1, "name": "Project 1"},
{"id": 2, "name": "Project 2"},
],
),
},
)
return GraphQLSchema(query_type)
def given_an_enforcer(self):
return casbin.Enforcer(get_examples("model.conf"), get_examples("policy.csv"))
def test_graphql_middleware(self):
schema = self.given_a_graphql_schema()
enforcer = self.given_an_enforcer()
query = """{
project(id: 2) {
id name members {
id name tickets {
id message
}
}
}
}"""
casbin_middleware = enforcer_middleware(enforcer)
response = graphql_sync(
schema,
query,
middleware=[casbin_middleware],
context_value={"role": "user"},
)
self.assertEqual(
response.errors[0].formatted,
{
"message": "user can not query project.name",
"path": ["project", "name"],
"locations": [{"line": 3, "column": 20}],
},
)
self.assertEqual(
response.errors[1].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 0, "tickets", 0, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[2].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 0, "tickets", 1, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[3].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 0, "tickets", 2, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[4].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 0, "tickets", 3, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[5].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 1, "tickets", 0, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[6].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 1, "tickets", 1, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[7].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 1, "tickets", 2, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.errors[8].formatted,
{
"message": "user can not query project.members.tickets.message",
"path": ["project", "members", 1, "tickets", 3, "message"],
"locations": [{"line": 5, "column": 28}],
},
)
self.assertEqual(
response.data,
{
"project": {
"id": 2,
"name": None,
"members": [
{
"id": 1,
"name": "Project 2, Member: 1",
"tickets": [
{"id": 1, "message": None},
{"id": 2, "message": None},
{"id": 3, "message": None},
{"id": 4, "message": None},
],
},
{
"id": 2,
"name": "Project 2, Member: 2",
"tickets": [
{"id": 1, "message": None},
{"id": 2, "message": None},
{"id": 3, "message": None},
{"id": 4, "message": None},
],
},
],
}
},
)
@sub_test([dict(context={"role": "*"}), dict(context={})])
def test_graphql_middleware_as_anonymous(self, context):
schema = self.given_a_graphql_schema()
enforcer = self.given_an_enforcer()
query = """{
project(id: 2) {
id name
}
}"""
casbin_middleware = enforcer_middleware(enforcer)
response = graphql_sync(
schema, query, middleware=[casbin_middleware], context_value=context
)
self.assertEqual(
response.errors[0].formatted,
{
"message": "anonymous can not query project.name",
"path": ["project", "name"],
"locations": [{"line": 3, "column": 20}],
},
)
self.assertEqual(response.data, {"project": {"id": 2, "name": None}})
def test_graphql_middleware_unauthorized_querying_non_nullable_fields(self):
schema = self.given_a_graphql_schema()
enforcer = casbin.Enforcer(
get_examples("model.conf"),
get_examples("policy_with_project_id_restricted.csv"),
)
query = """{
project(id: 2) {
id name
}
}"""
casbin_middleware = enforcer_middleware(enforcer)
response = graphql_sync(
schema,
query,
middleware=[casbin_middleware],
context_value={"role": "unathorized_user"},
)
self.assertEqual(
response.errors[0].formatted,
{
"message": "unathorized_user can not query project.id",
"path": ["project", "id"],
"locations": [{"line": 3, "column": 17}],
},
)
self.assertEqual(response.data, {"project": None})