blob: b3c613b9a905b5c5df6d3dc4fcf59f91286a05b3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
import os
import netCDF4
import numpy as np
from ocw.dataset import Dataset
from ocw.dataset_loader import DatasetLoader
class TestDatasetLoader(unittest.TestCase):
def setUp(self):
# Read netCDF file
self.file_path = create_netcdf_object()
self.netCDF_file = netCDF4.Dataset(self.file_path, 'r')
self.latitudes = self.netCDF_file.variables['latitude'][:]
self.longitudes = self.netCDF_file.variables['longitude'][:]
self.times = self.netCDF_file.variables['time'][:]
self.alt_lats = self.netCDF_file.variables['alt_lat'][:]
self.alt_lons = self.netCDF_file.variables['alt_lon'][:]
self.values = self.netCDF_file.variables['value'][:]
self.values2 = self.values + 1
# Set up config
self.config = {'file_path': self.file_path, 'variable_name': 'value'}
self.new_data_source_config = {'loader_name': 'foo',
'lats': self.latitudes,
'lons': self.longitudes,
'times': self.times,
'values': self.values2,
'variable': 'value'}
def tearDown(self):
os.remove(self.file_path)
def testNewDataSource(self):
'''
Ensures that custom data source loaders can be added
'''
self.loader = DatasetLoader(self.new_data_source_config)
# Here the data_source "foo" represents the Dataset constructor
self.loader.add_source_loader('foo', build_dataset)
self.loader.load_datasets()
self.assertEqual(self.loader.datasets[0].origin['source'], 'foo')
np.testing.assert_array_equal(self.loader.datasets[0].values,
self.values2)
def testExistingDataSource(self):
'''
Ensures that existing data source loaders can be added
'''
self.loader = DatasetLoader(self.config)
self.loader.load_datasets()
self.assertEqual(self.loader.datasets[0].origin['source'], 'local')
np.testing.assert_array_equal(self.loader.datasets[0].values,
self.values)
def testMultipleDataSources(self):
'''
Test for when multiple dataset configs are specified
'''
self.loader = DatasetLoader(self.config, self.new_data_source_config)
# Here the data_source "foo" represents the Dataset constructor
self.loader.add_source_loader('foo', build_dataset)
self.loader.load_datasets()
self.assertEqual(self.loader.datasets[0].origin['source'],
'local')
self.assertEqual(self.loader.datasets[1].origin['source'],
'foo')
np.testing.assert_array_equal(self.loader.datasets[0].values,
self.values)
np.testing.assert_array_equal(self.loader.datasets[1].values,
self.values2)
def build_dataset(*args, **kwargs):
'''
Wrapper to Dataset constructor from fictitious 'foo' data_source.
'''
origin = {'source': 'foo'}
return Dataset(*args, origin=origin, **kwargs)
def create_netcdf_object():
# To create the temporary netCDF file
file_path = '/tmp/temporaryNetcdf.nc'
netCDF_file = netCDF4.Dataset(file_path, 'w', format='NETCDF4')
# To create dimensions
netCDF_file.createDimension('lat_dim', 5)
netCDF_file.createDimension('lon_dim', 5)
netCDF_file.createDimension('time_dim', 3)
# To create variables
latitudes = netCDF_file.createVariable('latitude', 'd', ('lat_dim',))
longitudes = netCDF_file.createVariable('longitude', 'd', ('lon_dim',))
times = netCDF_file.createVariable('time', 'd', ('time_dim',))
# unusual variable names to test optional arguments for Dataset constructor
alt_lats = netCDF_file.createVariable('alt_lat', 'd', ('lat_dim',))
alt_lons = netCDF_file.createVariable('alt_lon', 'd', ('lon_dim',))
alt_times = netCDF_file.createVariable('alt_time', 'd', ('time_dim',))
values = netCDF_file.createVariable('value', 'd',
('time_dim',
'lat_dim',
'lon_dim')
)
# To latitudes and longitudes for five values
latitudes_data = np.arange(5.)
longitudes_data = np.arange(150., 155.)
# Three months of data.
times_data = np.arange(3)
# Create 150 values
values_data = np.array([i for i in range(75)])
# Reshape values to 4D array (level, time, lats, lons)
values_data = values_data.reshape(len(times_data), len(latitudes_data),
len(longitudes_data))
# Ingest values to netCDF file
latitudes[:] = latitudes_data
longitudes[:] = longitudes_data
times[:] = times_data
alt_lats[:] = latitudes_data + 10
alt_lons[:] = longitudes_data - 10
alt_times[:] = times_data
values[:] = values_data
# Assign time info to time variable
netCDF_file.variables['time'].units = 'months since 2001-01-01 00:00:00'
netCDF_file.variables['alt_time'].units = 'months since 2001-04-01 00:00:00'
netCDF_file.variables['value'].units = 'foo_units'
netCDF_file.close()
return file_path
if __name__ == '__main__':
unittest.main()