| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sys |
| import os |
| import mxnet as mx |
| import numpy as np |
| import unittest |
| import ctypes |
| from common import with_seed |
| import pytest |
| |
| def test_float64_fallback(): |
| sym = mx.sym.FullyConnected( |
| mx.sym.Variable('in'), |
| mx.sym.Variable('w'), |
| mx.sym.Variable('b'), |
| num_hidden=2) |
| |
| dtype = 'float64' |
| args = {'in': mx.nd.array([[2, 3, 4]], dtype=dtype), |
| 'w': mx.nd.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), |
| 'b': mx.nd.array([7, 8], dtype=dtype)} |
| ex = sym._bind(mx.cpu(), args, args_grad=None, grad_req='write') |
| ex.forward() |
| ex.outputs[0].wait_to_read() |