blob: c054cc2f1701c7bee50995a3620225cafcec5c37 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator, List, Optional, Union
from unittest.mock import create_autospec, MagicMock
from pyspark.sql import DataFrame
class MockItem:
"""Mock STAC Item"""
def __init__(self, item_id: str = "test_item"):
self.id = item_id
self.properties = {
"datetime": "2006-12-26T18:03:22Z",
"collection": "aster-l1t",
}
self.geometry = {
"type": "Polygon",
"coordinates": [[[90, -73], [105, -73], [105, -69], [90, -69], [90, -73]]],
}
self.bbox = [90, -73, 105, -69]
self.assets = {}
class MockIterator:
"""Mock iterator for STAC items"""
def __init__(self, items: List[MockItem] = None):
self.items = items or [MockItem(f"item_{i}") for i in range(5)]
self.index = 0
def __iter__(self):
return self
def __next__(self):
if self.index < len(self.items):
item = self.items[self.index]
self.index += 1
return item
raise StopIteration
def create_mock_dataframe(data=None):
"""Create a mock DataFrame that passes isinstance checks"""
mock_df = create_autospec(DataFrame, instance=True)
mock_df.data = data or []
mock_df._count = len(mock_df.data) if data else 5
mock_df.count.return_value = mock_df._count
mock_df.collect.return_value = mock_df.data
mock_df.show.return_value = None
return mock_df
class MockCollectionClient:
"""Mock CollectionClient"""
def __init__(self, url: str, collection_id: str):
self.url = url
self.collection_id = collection_id
def __str__(self):
return f"<CollectionClient id={self.collection_id}>"
def get_items(self, *ids, **kwargs) -> Iterator:
"""Return mock iterator of items"""
if ids:
items = [MockItem(item_id) for item_id in ids if isinstance(item_id, str)]
else:
items = [MockItem(f"item_{i}") for i in range(5)]
return MockIterator(items)
def get_dataframe(self, **kwargs) -> DataFrame:
"""Return mock DataFrame"""
# Return a mock DataFrame instead of creating a real Spark DataFrame
return create_mock_dataframe()
def save_to_geoparquet(self, output_path: str, **kwargs):
"""Mock save to geoparquet - just create an empty file"""
import os
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as f:
f.write("")
class MockClient:
"""Mock STAC Client"""
def __init__(self, url: str):
self.url = url
@classmethod
def open(cls, url: str):
"""Create mock client instance"""
return cls(url)
def get_collection(self, collection_id: str):
"""Return mock collection client"""
return MockCollectionClient(self.url, collection_id)
def search(
self,
*ids,
collection_id: Optional[str] = None,
bbox: Optional[list] = None,
geometry: Optional[Union[str, object, List]] = None,
datetime: Optional[Union[str, object, list]] = None,
max_items: Optional[int] = None,
return_dataframe: bool = True,
) -> Union[Iterator, DataFrame]:
"""Mock search method"""
if return_dataframe:
# Return mock DataFrame instead of creating a real Spark DataFrame
return create_mock_dataframe()
else:
# Return mock iterator
if ids and len(ids) > 0:
if isinstance(ids[0], str):
items = [MockItem(ids[0])]
else:
items = [
MockItem(item_id)
for item_id in ids[0]
if isinstance(item_id, str)
]
else:
num_items = min(max_items, 5) if max_items else 5
items = [MockItem(f"item_{i}") for i in range(num_items)]
return MockIterator(items)
def create_mock_client(url: str) -> MockClient:
"""Factory function to create a mock client"""
return MockClient(url)
def mock_client_open(monkeypatch):
"""Pytest fixture to mock Client.open"""
def _mock_open(url: str):
return MockClient(url)
return _mock_open