Tensorflow C++使用ops::BatchMatMul实现特征批量乘法

2022-09-22 13:56:12

本例主要测试Tensorflow C++ API中的ops::BatchMatMul算子。
整体来说这个算子比较简单。但是难在官网没有例子。Tensorflow的单测也写得不到位。
话不多说,上代码。
代码结构如下,
image.png

conanfile.txt

 [requires]
 gtest/1.10.0
 glog/0.4.0
 protobuf/3.9.1
 eigen/3.4.0
 dataframe/1.20.0
 opencv/3.4.17
 boost/1.76.0
 abseil/20210324.0
 xtensor/0.23.10

 [generators]
 cmake

CMakeLists.txt

cmake_minimum_required(VERSION 3.3)


project(test_math_ops)

set(ENV{PKG_CONFIG_PATH} "$ENV{PKG_CONFIG_PATH}:/usr/local/lib/pkgconfig/")

set(CMAKE_CXX_STANDARD 17)
add_definitions(-g)

include(${CMAKE_BINARY_DIR}/conanbuildinfo.cmake)
conan_basic_setup()

find_package(TensorflowCC REQUIRED)
find_package(PkgConfig REQUIRED)
pkg_search_module(PKG_PARQUET REQUIRED IMPORTED_TARGET parquet)
pkg_search_module(PKG_ARROW REQUIRED IMPORTED_TARGET arrow)
pkg_search_module(PKG_ARROW_COMPUTE REQUIRED IMPORTED_TARGET arrow-compute)
pkg_search_module(PKG_ARROW_CSV REQUIRED IMPORTED_TARGET arrow-csv)
pkg_search_module(PKG_ARROW_DATASET REQUIRED IMPORTED_TARGET arrow-dataset)
pkg_search_module(PKG_ARROW_FS REQUIRED IMPORTED_TARGET arrow-filesystem)
pkg_search_module(PKG_ARROW_JSON REQUIRED IMPORTED_TARGET arrow-json)

set(ARROW_INCLUDE_DIRS ${PKG_PARQUET_INCLUDE_DIRS} ${PKG_ARROW_INCLUDE_DIRS} ${PKG_ARROW_COMPUTE_INCLUDE_DIRS} ${PKG_ARROW_CSV_INCLUDE_DIRS} ${PKG_ARROW_DATASET_INCLUDE_DIRS} ${PKG_ARROW_FS_INCLUDE_DIRS} ${PKG_ARROW_JSON_INCLUDE_DIRS})

set(INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${ARROW_INCLUDE_DIRS})

set(ARROW_LIBS PkgConfig::PKG_PARQUET PkgConfig::PKG_ARROW PkgConfig::PKG_ARROW_COMPUTE PkgConfig::PKG_ARROW_CSV PkgConfig::PKG_ARROW_DATASET PkgConfig::PKG_ARROW_FS PkgConfig::PKG_ARROW_JSON)

include_directories(${INCLUDE_DIRS})


file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/tensor_testutil.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/queue_runner.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/coordinator.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/tf_/impl/status.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../include/death_handler/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../include/img_util/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES})
target_link_libraries(${PROJECT_NAME}_lib PUBLIC ${CONAN_LIBS} TensorflowCC::TensorflowCC ${ARROW_LIBS})

foreach( test_file ${test_file_list} )
    file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
    string(REPLACE ".cpp" "" file ${filename})
    add_executable(${file}  ${test_file})
    target_link_libraries(${file} PUBLIC ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

tf_math2_test.cpp

#include <string>
#include <vector>
#include <glog/logging.h>
#include "death_handler/death_handler.h"
#include "tf_/tensor_testutil.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/client/client_session.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/cc/training/coordinator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/queue_runner.pb.h"
#include "tensorflow/core/public/session.h"


using namespace tensorflow;

int main(int argc, char** argv) {
    FLAGS_log_dir = "./";
    FLAGS_alsologtostderr = true;
    // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
    FLAGS_minloglevel = 0;

    Debug::DeathHandler dh;

    google::InitGoogleLogging("./logs.log");
    ::testing::InitGoogleTest(&argc, argv);
    int ret = RUN_ALL_TESTS();
    return ret;
}


TEST(TfArthimaticTests, BatchMatMul) {
    // BatchMatMul  测试
    // Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
    
    // 2 * 1 * 2
    // 2 * 2 * 3
    // = 
    // 2 * 1 * 3
    Scope root = Scope::NewRootScope();
    auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 1, 2});
    /**
     * @brief Left param
     * {{1, 2},
     *  {3, 4}}
     */
    auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 2, 3});
    /**
     * @brief Right param
     *  {{{1, 2, 3}, {4, 1, 2}},
     *   {{3, 4, 5}, {6, 7, 8}}}
     */

    /**
     * @brief Result
     * {{9, 4, 7},
     *  {33, 40, 47}}
     */
    auto batch_op = ops::BatchMatMul(root, left_, right_);

    ClientSession session(root);
    std::vector<Tensor> outputs;
    session.Run({batch_op.output}, &outputs);

    test::PrintTensorValue<int>(std::cout, outputs[0]);
    test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({9, 4, 7, 33, 40, 47}, {2, 1, 3}));
}

TEST(TfArthimaticTests, BatchMatMulAdjXY) {
    // BatchMatMul  测试
    // Refers to: https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/batch-mat-mul
    
    // 2 * 1 * 2
    // 2 * 2 * 3
    // = 
    // 2 * 1 * 3
    Scope root = Scope::NewRootScope();
    auto left_ = test::AsTensor<int>({1, 2, 3, 4}, {2, 2, 1});
    /**
     * @brief Left param
     * {{{1}, 
     *   {2}},
     *  {{3},
     *   {4}}}
     */
    auto right_ = test::AsTensor<int>({1, 2, 3, 4, 1, 2, 3, 4, 5, 6, 7, 8}, {2, 3, 2});
    /**
     * @brief Right param
     *  {{{1, 2}, 
     *   {3, 4}, 
     *   {1, 2}}, 
     *
     *   {{3, 4}, 
     *   {5, 6}, 
     *   {7, 8}}  
     * }
     */
    
   
    /**
     * @brief Result
     * {{5, 11, 5},
     *  {25, 39, 53}}
     */

    auto attrs = ops::BatchMatMul::AdjX(true).AdjY(true);
    auto batch_op = ops::BatchMatMul(root, left_, right_, attrs);

    ClientSession session(root);
    std::vector<Tensor> outputs;
    session.Run({batch_op.output}, &outputs);

    test::PrintTensorValue<int>(std::cout, outputs[0]);
    test::ExpectTensorEqual<int>(outputs[0], test::AsTensor<int>({5, 11, 5, 25, 39, 53}, {2, 1, 3}));
}

程序输出如下,代表两个算子均测试通过。
image.png

  • 作者:zhuge19870104
  • 原文链接:https://blog.csdn.net/zhuge19870104/article/details/124391762
    更新时间:2022-09-22 13:56:12