Consider a neural network layer with N input and M output units. The forward computation is y = h(W x + b) where W and b are weights and biases, respectively, h is the activation function, and x and y are the layer’s inputs and outputs. If we choose ReLU, or Rectified Linear Unit/Ramp Function as h, there will be no multiplications in computing the activation function, thus all multiplications reside in the matrix product W x. For each input vector x, N M floating point multiplications are needed.
Binary connect eliminates these multiplications by stochastically sampling weights to be −1 or 1. Full resolution weights w ̄ are kept in memory as reference, and each time when y is needed, we sample a stochastic weight matrix W according to w ̄. For each element of the sampled matrix W, the probability of getting a 1 is proportional to how “close” its corresponding entry in w ̄ is to 1. i.e.,
P(Wij = 1) = (w ̄ij+ 1)/2;
P(Wij = −1) = 1 − P(Wij = 1)
It is necessary to add some edge constraints to w ̄. To ensure that P(Wij = 1) lies in a reasonable range, values in w ̄ are forced to be a real value in the interval [-1, 1]. If during the updates any of its value grows beyond that interval, we set it to be its corresponding edge values −1 or 1. That way floating point multiplications become sign changes.
A remaining question concerns the use of multiplications in the random number generator involved in the sampling process. Sampling an integer has to be faster than multiplication for the algorithm to be worth it.
Moving on from binary to ternary connect, whereas in the former weights are allowed to be −1 or 1, in a trained neural network, it is common to observe that many learned weights are zero or close to zero. Although the stochastic sampling process would allow the mean value of sampled weights to be zero, this suggests that it may be beneficial to explicitly allow weights to be zero.
To allow weights to be zero, split the interval of [-1, 1], within which the full resolution weight value w ̄ lies, into two sub-intervals: [−1, 0] and (0, 1]. If a weight value w ̄ij drops into one of them, we sample w ̄ij to be the two edge values of that interval,
according to their distance from w ̄ij , i.e., if w ̄ij > 0:
P(Wij =1)= w ̄ij; P(Wij = 0) = 1−w ̄ij
and if
w ̄ij <=0:
P(Wij = −1) = −w ̄ij; P(Wij = 0) = 1 + w ̄ij
Like binary connect, ternary connect also eliminates all multiplications in the forward pass.
We move from the forward to the backward pass. Suppose the i-th layer of the network has N input and M output units, and consider an error signal δ propagating downward from its output. The updates for weights and biases would be the outer product of the layer’s input and the error signal:
∆W = ηδ◦h′ (W x + b) xT
∆b = ηδ◦h (W x + b)
where η is the learning rate, and x the input to the layer. While propagating through the layers, the error signal δ needs to be updated, too. Its update taking into account the next layer below takes the form:
δ = WTδ◦h′ (W x + b)
Three terms appear repeatedly in the above three equations, viz. δ, h (W x + b) and x. The latter two terms introduce matrix outer products. To eliminate multiplications, one can quantize one of them to be an integer power of 2, so that multiplications involving that term become binary shifts. The expression h′ (W x + b) contains down flowing gradients, which are largely determined by the cost function and network parameters, thus it is hard to bound its values. However, bounding the values is essential for quantization because we need to supply a fixed number of bits for each sampled value, and if that value varies too much, we will need too many bits for the exponent. This, in turn, will result in the need for more bits to store the sampled value and unnecessarily increase the required amount of computation.
While h′ (W x + b) is not a good choice for quantization, x is a better choice, because it is the hidden representation at each layer, and we know roughly the distribution of each layer’s activation.
The approach is therefore to eliminate multiplications in
∆W = ηδ◦h′ (W x + b) xT
by quantizing each entry in x to an integer power of 2. That way the outer product in
∆W = ηδ◦h′ (W x + b) xT becomes a series of bit shifts. Experimentally, it is discovered that allowing a maximum of 3 to 4 bits of shift is sufficient to make the network work well. This means that 3 bits are already enough to quantize x. As the float 32 format has 24 bits of mantissa, shifting (to the left or right) by 3 to 4 bits is completely tolerable. This approach is referred to as “quantized back propagation”.
If we choose ReLU as the activation function and use binary (ternary) connect to sample W, computing the term h’ (W x + b) involves no multiplications at all. In addition, quantized back propagation eliminates the multiplications in the outer product in
∆W = ηδ◦h′ (W x + b) xT.
The only place where multiplications remain is the element-wise product. From
∆W = ηδ◦h′ (W x + b) xT, ∆b = ηδ◦h (W x + b), and δ = WTδ◦h′ (W x + b), one can see that 6 × M multiplications are needed for all computations. Like in the forward pass, most of the multiplications are used in the weight updates. Compared with standard back propagation, which would need 2MN + 6M multiplications, the amount of multiplications left is negligible in quantized back propagation.