| |

| /*interface file for swig */ |

| |

| %module model_loss |

| %include "std_string.i" |

| %{ |

| #include "singa/model/loss.h" |

| using singa::Tensor; |

| %} |

| |

| namespace singa { |

| class Loss { |

| public: |

| Loss() = default; |

| virtual ~Loss() {} |

| |

| virtual Tensor Forward(int flag, const Tensor &prediction, |

| const Tensor &target) = 0; |

| |

| float Evaluate(int flag, const Tensor &prediction, const Tensor &target); |

| |

| /// Compute the gradients of the loss values w.r.t. the prediction. |

| virtual Tensor Backward() = 0; |

| }; |

| |

| class MSE : public Loss { |

| public: |

| Tensor Forward(int flag, const Tensor &prediction, const Tensor &target) |

| override; |

| |

| Tensor Backward() override; |

| }; |

| |

| class SoftmaxCrossEntropy : public Loss { |

| public: |

| Tensor Forward(int flag, const Tensor &prediction, const Tensor &target) |

| override; |

| |

| Tensor Backward() override; |

| }; |

| |

| } |