W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
TorchScript 是 PyTorch 的一種中間表示形式,允許開發(fā)者將模型及其執(zhí)行邏輯編譯為高效的序列化格式,便于部署和優(yōu)化。在許多場景下,開發(fā)者可能需要將自定義的 C++ 類集成到 TorchScript 中,以利用 C++ 的高性能特性或調(diào)用第三方庫。本文將詳細講解如何使用自定義 C++ 類擴展 TorchScript,并通過實例演示其在 Python 和 C++ 環(huán)境中的應用。
首先,我們需要定義一個繼承自 torch::jit::CustomClassHolder
的 C++ 類。這個基類確保了自定義類能夠與 PyTorch 的生命周期管理系統(tǒng)兼容。
#include <torch/script.h>
#include <torch/custom_class.h>
#include <string>
#include <vector>
template <class T>
struct Stack : torch::jit::CustomClassHolder {
std::vector<T> stack_;
Stack(std::vector<T> init) : stack_(init.begin(), init.end()) {}
void push(T x) {
stack_.push_back(x);
}
T pop() {
auto val = stack_.back();
stack_.pop_back();
return val;
}
c10::intrusive_ptr<Stack> clone() const {
return c10::make_intrusive<Stack>(stack_);
}
void merge(const c10::intrusive_ptr<Stack>& c) {
for (auto& elem : c->stack_) {
push(elem);
}
}
};
注意:c10::intrusive_ptr
是一個智能指針,用于管理對象的生命周期,類似于 std::shared_ptr
。
為了使自定義類在 TorchScript 和 Python 中可見,需要使用 torch::jit::class_
進行注冊。
static auto testStack = torch::jit::class_<Stack<std::string>>("Stack")
.def(torch::jit::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<Stack<std::string>>& self) {
return self->stack_.back();
})
.def("push", &Stack<std::string>::push)
.def("pop", &Stack<std::string>::pop)
.def("clone", &Stack<std::string>::clone)
.def("merge", &Stack<std::string>::merge);
將自定義類的實現(xiàn)和注冊代碼編譯為共享庫,以便在不同環(huán)境中加載和使用。
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_class)
find_package(Torch REQUIRED)
add_library(custom_class SHARED class.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(custom_class "${TORCH_LIBRARIES}")
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
在 Python 中使用 torch.classes.load_library
加載共享庫:
import torch
torch.classes.load_library("libcustom_class.so")
s = torch.classes.Stack(["foo", "bar"])
s.push("pushed")
assert s.pop() == "pushed"
s2 = s.clone()
s.merge(s2)
for expected in ["bar", "foo", "bar", "foo"]:
assert s.pop() == expected
Stack = torch.classes.Stack
@torch.jit.script
def do_stacks(s: Stack) -> (Stack, str):
s2 = torch.classes.Stack(["hi", "mom"])
s2.merge(s)
return s2.clone(), s2.pop()
stack, top = do_stacks(torch.classes.Stack(["wow"]))
assert top == "wow"
for expected in ["wow", "mom", "hi"]:
assert stack.pop() == expected
import torch
torch.classes.load_library('libcustom_class.so')
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, s: str) -> str:
stack = torch.classes.Stack(["hi", "mom"])
return stack.pop() + s
scripted_foo = torch.jit.script(Foo())
scripted_foo.save('foo.pt')
#include <torch/script.h>
#include <iostream>
int main(int argc, const char* argv[]) {
torch::jit::script::Module module;
try {
module = torch::jit::load("foo.pt");
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::vector<c10::IValue> inputs = {"foobarbaz"};
auto output = module.forward(inputs).toString();
std::cout << output->string() << std::endl;
}
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(infer)
find_package(Torch REQUIRED)
add_subdirectory(custom_class_project)
add_executable(infer infer.cpp)
set(CMAKE_CXX_STANDARD 14)
target_link_libraries(infer "${TORCH_LIBRARIES}")
target_link_libraries(infer -Wl,--no-as-needed custom_class)
在項目目錄中運行以下命令:
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
make
./infer
通過本文,您已掌握如何在 PyTorch 中使用自定義 C++ 類擴展 TorchScript。這一技能在需要高性能計算或調(diào)用第三方 C++ 庫時尤為有用。未來,您可以進一步探索如何將自定義類與深度學習模型結(jié)合,以實現(xiàn)更高效的訓練和推理流程。編程獅將持續(xù)為您提供更多深度學習模型開發(fā)和優(yōu)化的優(yōu)質(zhì)教程,助力您的技術成長。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: