国产gaysexchina男同gay,japanrcep老熟妇乱子伦视频,吃奶呻吟打开双腿做受动态图,成人色网站,国产av一区二区三区最新精品

PyTorch 使用自定義 C ++類擴展 TorchScript

2025-06-23 10:24 更新

TorchScript 是 PyTorch 的一種中間表示形式,允許開發(fā)者將模型及其執(zhí)行邏輯編譯為高效的序列化格式,便于部署和優(yōu)化。在許多場景下,開發(fā)者可能需要將自定義的 C++ 類集成到 TorchScript 中,以利用 C++ 的高性能特性或調(diào)用第三方庫。本文將詳細講解如何使用自定義 C++ 類擴展 TorchScript,并通過實例演示其在 Python 和 C++ 環(huán)境中的應用。

一、定義自定義 C++ 類

(一)類的基本結(jié)構(gòu)

首先,我們需要定義一個繼承自 torch::jit::CustomClassHolder 的 C++ 類。這個基類確保了自定義類能夠與 PyTorch 的生命周期管理系統(tǒng)兼容。

#include <torch/script.h>
#include <torch/custom_class.h>
#include <string>
#include <vector>


template <class T>
struct Stack : torch::jit::CustomClassHolder {
    std::vector<T> stack_;


    Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}


    void push(T x) {
        stack_.push_back(x);
    }


    T pop() {
        auto val = stack_.back();
        stack_.pop_back();
        return val;
    }


    c10::intrusive_ptr<Stack> clone() const {
        return c10::make_intrusive<Stack>(stack_);
    }


    void merge(const c10::intrusive_ptr<Stack>& c) {
        for (auto& elem : c->stack_) {
            push(elem);
        }
    }
};

注意:c10::intrusive_ptr 是一個智能指針,用于管理對象的生命周期,類似于 std::shared_ptr。

(二)注冊自定義類

為了使自定義類在 TorchScript 和 Python 中可見,需要使用 torch::jit::class_ 進行注冊。

static auto testStack = torch::jit::class_<Stack<std::string>>("Stack")
    .def(torch::jit::init<std::vector<std::string>>())
    .def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
        return self->stack_.back();
    })
    .def("push", &Stack<std::string>::push)
    .def("pop", &Stack<std::string>::pop)
    .def("clone", &Stack<std::string>::clone)
    .def("merge", &Stack<std::string>::merge);

二、構(gòu)建共享庫

將自定義類的實現(xiàn)和注冊代碼編譯為共享庫,以便在不同環(huán)境中加載和使用。

(一)創(chuàng)建 CMakeLists.txt

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_class)


find_package(Torch REQUIRED)


add_library(custom_class SHARED class.cpp)
set(CMAKE_CXX_STANDARD 14)


target_link_libraries(custom_class "${TORCH_LIBRARIES}")

(二)編譯共享庫

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make

三、在 Python 和 TorchScript 中使用自定義類

(一)加載共享庫

在 Python 中使用 torch.classes.load_library 加載共享庫:

import torch
torch.classes.load_library("libcustom_class.so")

(二)實例化和使用自定義類

s = torch.classes.Stack(["foo", "bar"])
s.push("pushed")
assert s.pop() == "pushed"


s2 = s.clone()
s.merge(s2)


for expected in ["bar", "foo", "bar", "foo"]:
    assert s.pop() == expected

(三)在 TorchScript 中使用自定義類

Stack = torch.classes.Stack


@torch.jit.script
def do_stacks(s: Stack) -> (Stack, str):
    s2 = torch.classes.Stack(["hi", "mom"])
    s2.merge(s)
    return s2.clone(), s2.pop()


stack, top = do_stacks(torch.classes.Stack(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
    assert stack.pop() == expected

四、在 C++ 中加載和運行 TorchScript 模型

(一)定義模型并保存

import torch
torch.classes.load_library('libcustom_class.so')


class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()


    def forward(self, s: str) -> str:
        stack = torch.classes.Stack(["hi", "mom"])
        return stack.pop() + s


scripted_foo = torch.jit.script(Foo())
scripted_foo.save('foo.pt')

(二)加載模型并在 C++ 中運行

#include <torch/script.h>
#include <iostream>


int main(int argc, const char* argv[]) {
    torch::jit::script::Module module;
    try {
        module = torch::jit::load("foo.pt");
    } catch (const c10::Error& e) {
        std::cerr << "error loading the model\n";
        return -1;
    }


    std::vector<c10::IValue> inputs = {"foobarbaz"};
    auto output = module.forward(inputs).toString();
    std::cout << output->string() << std::endl;
}

(三)構(gòu)建和運行 C++ 項目

cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(infer)


find_package(Torch REQUIRED)
add_subdirectory(custom_class_project)
add_executable(infer infer.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(infer "${TORCH_LIBRARIES}")
target_link_libraries(infer -Wl,--no-as-needed custom_class)

在項目目錄中運行以下命令:

mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
./infer

五、總結(jié)與拓展

通過本文,您已掌握如何在 PyTorch 中使用自定義 C++ 類擴展 TorchScript。這一技能在需要高性能計算或調(diào)用第三方 C++ 庫時尤為有用。未來,您可以進一步探索如何將自定義類與深度學習模型結(jié)合,以實現(xiàn)更高效的訓練和推理流程。編程獅將持續(xù)為您提供更多深度學習模型開發(fā)和優(yōu)化的優(yōu)質(zhì)教程,助力您的技術成長。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號