OpFromGraph

This page describes theano.OpFromGraph, an Op that allows to encapsulate a Theano graph in an op.

This can be used to encapsulate some functionality in one block. It is useful to scale Theano compilation for regular bigger graphs when we reuse that encapsulated functionality with different inputs many times. Due to this encapsulation, it can make Theano compilation phase faster for graphs with many nodes.

Using this for small graphs is not recommended as it disables optimizations between what is inside the encapsulation and outside of it.

class theano.compile.builders.OpFromGraph(inputs, outputs, inline=False, lop_overrides='default', grad_overrides='default', rop_overrides='default', connection_pattern=None, name=None, **kwargs)[source]

This creates an Op from inputs and outputs lists of variables. The signature is similar to theano.function and the resulting Op’s perform will do the same operation as:

orig_function(inputs, outputs, **kwargs)

Currently does not support updates or givens argument.

Parameters
  • inputs (list of Variable) –

  • outputs (list of Variable) –

  • inline (bool, optional) –

    Defaults to False

    True : Cause the Op’s original graph being used during compilation, the Op will not be visible in the compiled graph but rather its internal graph.

    False : will use a pre-compiled function inside.

  • grad_overrides (single or list of {'default', OpFromGraph, callable, Variable with special type}, optional) –

    Defaults to 'default'. This argument is mutually exclusive with lop_overrides.

    'default' : Do not override, use default grad() result

    OpFromGraph instance : Override with another OpFromGraph, should accept inputs as the same order and types of inputs and output_grads arguments as one would specify in grad() method.

    callable : Should take two args: inputs and output_grads. Each argument is expected to be a list of Variable. Must return list of Variable.

    Variable :

    NullType() instance : Treat as non-differentiable DisconnectedType() instance : Treat as disconnected gradient, numerically gives zero

    list: Each OpFromGraph/callable must return a single Variable. Each list element corresponds to gradient of a specific input, length of list must be equal to number of inputs.

  • lop_overrides (single or list of {'default', OpFromGraph, callable, Variable with special type}, optional) –

    Defaults to 'default'. This argument is mutually exclusive with grad_overrides.

    'default' : Do not override, use default L_op() result

    OpFromGraph instance : Override with another OpFromGraph, should accept inputs as the same order and types of inputs, outputs and output_grads arguments as one would specify in grad() method.

    callable : Should take three args: inputs, outputs and output_grads. Each argument is expected to be a list of Variable. Must return list of Variable.

    Variable :

    NullType() instance : Treat as non-differentiable DisconnectedType() instance : Treat as disconnected gradient, numerically gives zero

    list: Each OpFromGraph/callable must return a single Variable. Each list element corresponds to gradient of a specific input, length of list must be equal to number of inputs.

  • rop_overrides (single or list of {'default', OpFromGraph, callable, Variable with special type}, optional) –

    Defaults to default.

    'default' : Do not override, use default R_op() result

    OpFromGraph instance : Override with another OpFromGraph, should accept inputs as the same order and types of inputs and eval_points arguments as one would specify in R_op() method.

    callable : Should take two args: inputs and eval_points. Each argument is expected to be a list of Variable. Must return list of Variable.

    Variable :

    NullType() instance : Treat as non-differentiable DisconnectedType() instance : Treat as zero since DisconnectedType is not yet supported in R_op

    list: Each OpFromGraph/callable must return a single Variable. Each list element corresponds to a specific output of R_op, length of list must be equal to number of outputs.

  • connection_pattern (list of list) – If not None, this will be used as the connection_pattern for this op.

  • name (string, optional) – A name for debugging purposes

  • **kwargs (optional) – Check orig_function for more arguments, only works when not inline.

Notes

  • We support shared variables in the inner graph. This is automatic and invisible to the user. They can be as input to the node or in the inner graph.

  • We support unused inputs. This is needed for the grad.

  • We support nested OpFromGraph.

  • inline=True will cause better runtime optimization at the cost of compilation time. Currently only works with fast_compile or fast_run mode.

  • For overriding, it’s recommended to provide pure functions (no side effects like setting global variable) as callable(s). The callable(s) supplied for overriding gradient/rop will be called only once at the first call to grad/R_op, and will be converted to OpFromGraph instances.

Examples

Example 1:

from theano import function, OpFromGraph, tensor
x, y, z = tensor.scalars('xyz')
e = x + y * z
op = OpFromGraph([x, y, z], [e])
# op behaves like a normal theano op
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])

Example 2 with shared variable:

import numpy as np
import theano
from theano import config, function, OpFromGraph, tensor
x, y, z = tensor.scalars('xyz')
s = theano.shared(np.random.rand(2, 2).astype(config.floatX))
e = x + y * z + s
op = OpFromGraph([x, y, z], [e])
# op behaves like a normal theano op
e2 = op(x, y, z) + op(z, y, x)
fn = function([x, y, z], [e2])

Example 3 override gradient

from theano import function, OpFromGraph, tensor, grad
x, y, z = tensor.scalars('xyz')
e = x + y * z
def rescale_dy(inps, grads):
    x, y, z = inps
    g, = grads
    return z*2
op = OpFromGraph(
    [x, y, z], [e], grad_overrides=['default', rescale_dy, 'default']
e2 = op(x, y, z)
dx, dy, dz = grad(e2, [x, y, z])
fn = function([x, y, z], [dx, dy, dz])
# the gradient wrt y is now doubled
fn(2., 3., 4.) # [1., 8., 3.]