Welcome to torch-submod’s documentation!


A library implementing layers that solve the min-norm problem for submodular functions. The computation of the Jacobian (i.e., backpropagation) is done using the methods from [DK17]. At the moment only graph-cuts on two dimensional grids are implemented, in which case the min-norm problem is also known as a total variation problem.

At the moment only one- and two-dimensional graph cut functions have been implemented, so that this package provides differentiable (with respect to the input signal and the weights) total variation solvers.


Once you install PyTorch (following these instructions) , you can install the package as:

python setup.py install


For example, let us try to learn row- and column- weights that will denoise a simple image. Let us create an image that is zero everywhere, except its left-right corner that is filled with ones. Then, we will corrupt it with normal noise, and try to recover it using a total-variation solver with learned weights.

Note that an extended version of the example below, together with visualization is provided in the repository as a jupyter notebook).

>>> from __future__ import division, print_function
>>> import torch
>>> from torch.autograd import Variable
>>> from torch_submod.graph_cuts import TotalVariation2dWeighted as tv2d
>>> torch.manual_seed(0)
>>> m, n = 50, 100  # The image dimensions.
>>> std = 1e-1  # The standard deviation of noise.
>>> x = torch.zeros((m, n))
>>> x[:m//2, :n//2] += 1
>>> x_noisy = x + torch.normal(torch.zeros(x.size()))
>>> x = Variable(x, requires_grad=False)
>>> x_noisy = Variable(x_noisy, requires_grad=False)
>>> # The learnable parameters.
>>> log_w_row = Variable(- 3 * torch.ones(1), requires_grad=True)
>>> log_w_col = Variable(- 3 * torch.ones(1), requires_grad=True)
>>> scale = Variable(torch.ones(1), requires_grad=True)
>>> optimizer = torch.optim.SGD([log_w_row, log_w_col, scale], lr=.5)
>>> losses = []
>>> for iter_no in range(1000):
>>>     w_row = torch.exp(log_w_row)
>>>     w_col = torch.exp(log_w_col)
>>>     y = tv2d()(scale * x_noisy,
>>>                w_row.expand((m, n-1)), w_col.expand((m - 1, n)))
>>>     optimizer.zero_grad()
>>>     loss = torch.mean((y - x)**2)
>>>     loss.backward()
>>>     if iter_no % 100 == 0:
>>>         losses.append(loss.data[0])
>>>     optimizer.step()
>>> print('\n'.join(map(str, losses)))

Function classes

Graph cuts

To solve the total-variation problem we are using the prox_tv library. Please refer to the documentation accompanying that package to find out more about the set of available methods. Namely, each function accepts a tv_args dictionary argument, which is passed onto the solver. The idea to average within the connected components, enabled when average_connected=True, first appeared for the one-dimensional case in [NB17].

Note: At the moment the total variation problems can be solved only on the CPU, so please make sure that all variables are placed on the CPU.

class torch_submod.graph_cuts.TotalVariation2dWeighted(refine=True, average_connected=True, tv_args=None)

A two dimensional total variation function.

Specifically, given as input the unaries x, positive row weights \(\mathbf{r}\) and positive column weights \(\mathbf{c}\), the output is computed as

\[\textrm{argmin}_{\mathbf z} \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + \sum_{i, j} r_{i,j} |z_{i, j} - z_{i, j + 1}| + \sum_{i, j} c_{i,j} |z_{i, j} - z_{i + 1, j}|.\]
  • refine (bool) – If set the solution will be refined with isotonic regression.
  • avearge_2d (bool) –

    How to compute the approximate derivative.

    If True, will average within each connected component. If False, it will average within each block of equal values. Typically, you want this set to true.

  • tv_args (dict) – The dictionary of arguments passed to the total variation solver.
forward(x, weights_row, weights_col)

Solve the total variation problem and return the solution.

  • x (torch.Tensor) – A tensor with shape (m, n) holding the input signal.
  • weights_row (torch.Tensor) –

    The horizontal edge weights.

    Tensor of shape (m, n - 1), or (1,) if all weights are equal.

  • weights_col (torch.Tensor) –

    The vertical edge weights.

    Tensor of shape (m - 1, n), or (1,) if all weights are equal.


The solution to the total variation problem, of shape (m, n).

Return type:


class torch_submod.graph_cuts.TotalVariation2d(refine=True, average_connected=True, tv_args=None)

A two dimensional total variation function with tied edge weights.

Specifically, given as input the unaries x and edge weight w, the returned value is given by:

\[\textrm{argmin}_{\mathbf z} \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + \sum_{i, j} w |z_{i, j} - z_{i, j + 1}| + \sum_{i, j} w |z_{i, j} - z_{i + 1, j}|.\]
  • refine (bool) – If set the solution will be refined with isotonic regression.
  • avearge_2d (bool) –

    How to compute the approximate derivative.

    If True, will average within each connected component. If False, it will average within each block of equal values. Typically, you want this set to true.

  • tv_args (dict) – The dictionary of arguments passed to the total variation solver.
forward(x, w)

Solve the total variation problem and return the solution.

  • x (torch.Tensor) – A tensor with shape (m, n) holding the input signal.
  • weights_row (torch.Tensor) –

    The horizontal edge weights.

    Tensor of shape (m, n - 1), or (1,) if all weights are equal.

  • weights_col (torch.Tensor) –

    The vertical edge weights.

    Tensor of shape (m - 1, n), or (1,) if all weights are equal.


The solution to the total variation problem, of shape (m, n).

Return type:


class torch_submod.graph_cuts.TotalVariation1d(average_connected=True, tv_args=None)

A one dimensional total variation function.

Specifically, given as input the signal x and weights \(\mathbf{w}\), the output is computed as

\[\textrm{argmin}_{\mathbf z} \frac{1}{2} \|\mathbf{x}-\mathbf{z}\|^2 + \sum_{i=1}^{n-1} w_i |z_i - z_{i+1}|.\]
  • average_connected (bool) –

    How to compute the approximate derivative.

    If True, will average within each connected component. If False, it will average within each block of equal values. Typically, you want this set to true.

  • tv_args (dict) – The dictionary of arguments passed to the total variation solver.
forward(x, weights)

Solve the total variation problem and return the solution.

  • x (torch.Tensor) – A tensor with shape (n,) holding the input signal.
  • weights (torch.Tensor) –

    The edge weights.

    Shape (n-1,), or (1,) if all weights are equal.


The solution to the total variation problem, of shape (m, n).

Return type:



[DK17]Josip Djolonga and Andreas Krause. Differentiable learning of submodular models. In Neural Information Processing Systems (NIPS). 2017.
[NB17]Vlad Niculae and Mathieu Blondel. A regularized framework for sparse and structured neural attention. arXiv preprint arXiv:1705.07704, 2017.

Indices and tables