jrnl · home about list

# Advanced Torch C++ Tutorial (multiple objectives) - Code Snippets

This is a collection of notes regarding the C++ frontend of PyTorch. I'm a proponent of reading code to understand concepts in computer science, so this 'tutorial' is most likely not for everyone. Still, I recommend to just go ahead and read the goddamn code. In this post I'll just highlight a few specific topics to get one started. Documentation is in my opinion lacking, so anyone developing a project using the C++ frontend should basically download PyTorch from Github to read the code as needed. Let's get started.

Modules always need to be named XXImpl. For instance, your module could look like this in your header:

struct A2CNetImpl : public torch::nn::Cloneable<A2CNetImpl> {
  public:
    A2CNetImpl() {};
    ~A2CNetImpl() {};
    ...
}

TORCH_MODULE(A2CNet);

It is possible to train multiple different objectives at once. For instance, let's assume you want to both learn to predict the policy and the value of a given state, you can do just that by simply returning a std::pair in the forward method of your custom module:

std::pair<torch::Tensor, torch::Tensor>
A2CNetImpl::forward(torch::Tensor input) {
  auto x = input.view({input.size(0), -1});
  x = seq->forward(x);
  auto policy = F::softmax(action_head(x), F::SoftmaxFuncOptions(-1));
  auto value = value_head(x);
  return std::make_pair(policy, value);
}

It is possible to create a network architecture dynamically from a configuration file. Here I pass a std::vector net_architecture (e.g. {64, 64}) and then iteratively create a linear layer of size 64 (as passed) and a ReLU activation layer. At the end I create two heads, a policy and a value head.

seq = register_module("seq", torch::nn::Sequential());
int n_features_before = n_input_features;
int i = 0;
for (int layer_features : net_architecture) {
  auto linear = register_module(
      "l" + std::to_string(i),
      torch::nn::Linear(n_features_before, layer_features)
  );
  linear->reset_parameters();
  auto relu = register_module("r" + std::to_string(i + 1), torch::nn::ReLU());
  seq->push_back(linear);
  seq->push_back(relu);
  n_features_before = layer_features;
  i += 2;
}
action_head = register_module("a", torch::nn::Linear(n_features_before, 3));
value_head = register_module("v", torch::nn::Linear(n_features_before, 1));
action_head->reset_parameters();

The policy optimizers of course support polymorphism, so creating different ones based on configuration is possible too:

A2CNet policy_net = A2CNet();

std::shared_ptr<torch::optim::Optimizer> policy_optimizer;
std::string optimizer_class = params["optimizer_class"];
if (optimizer_class == "adam") {
  auto opt = torch::optim::AdamOptions(lr);
  if (params["use_weight_decay"])
    opt.weight_decay(params["weight_decay"]);
  policy_optimizer = std::make_shared<torch::optim::Adam>(policy_net->parameters(), opt);
} else if (optimizer_class == "sgd") {
  auto opt = torch::optim::SGDOptions(lr);
  opt.momentum(params["sgd_momentum"]);
  policy_optimizer = std::make_shared<torch::optim::SGD>(policy_net->parameters(), opt);
}

Actually training two objectives:

policy_net->train();

// Forward.
torch::Tensor action_probs;
torch::Tensor values;
std::tie(action_probs, values) = policy_net->forward(samples);
auto mcts_actions = attached_mcts_actions.detach_();

// Calculate losses.
torch::Tensor cross_entropy;
if (params["tough_ce"]) {
  auto err = -(torch::log(action_probs) * mcts_actions).sum({1});
  cross_entropy = (err).sum({0});
} else {
  auto argmax_mcts_actions = mcts_actions.argmax({1});
  cross_entropy = F::cross_entropy(
      action_probs,
      argmax_mcts_actions,
      F::CrossEntropyFuncOptions().reduction(torch::kSum));
}
cross_entropy /= mcts_actions.size(0);

torch::Tensor value_loss = F::smooth_l1_loss(
    values.reshape(-1),
    normalized_returns,
    torch::nn::SmoothL1LossOptions(torch::kSum)
);

policy_optimizer->zero_grad();
cross_entropy.backward({}, true, false);
value_loss.backward();
policy_optimizer->step();

More interesting code snippets and the whole code can be found in the following two repositories:

Find the files a2c.cpp and a2c.hpp in both.

Published on