Neural Networks are the foundation of all things Deep Learning so it’s pretty natural to be wondering how exactly they work. While they may seem like a black box to most, we find that the box isn’t actually that deep upon opening. With some simple, intuitive diagrams, let’s examine how these fascinating mathematical models learn to play video games, classify animals, and drive cars.
The piece that all neural networks are built on is the neuron. It takes in some inputs and yields an output, very similar to how the neurons in our brain work.
Think about each circle as a neuron and each line as a connection between the input and the output. In a computer, each circle and line are simply represented by a number. The gray circle is considered a bias term, whose value is always 1. However, this doesn’t mean the line associated with it is too.
The way the output is calculated is by taking a weighted sum of all of the inputs and then passing the result into the activation function; The lines are considered the weights. For now, let's assume the activation function is the sign function; It outputs 1 if the input is positive and -1 if it's not.
It is through simple addition and multiplication that allows us to calculate the output for each neuron.
Below is a diagram of a fully connected neural network. It's simply a bunch of neuron arranged together in layers, with the dotted circles representing the bias terms.
We have four neurons in the input layer (red), two neurons in the hidden layer (blue), and four neurons in the output layer (yellow). The number of neurons per layer is dependent on the task at hand, and may require tuning by you, the computer scientist. Each layer, besides the output layer, also has its own bias unit.
The input numbers will be given by whatever dataset we're using, but the values of each neuron in subsequent layers is the weighted sum of all the neurons in the previous layer, as mentioned above. The process of training a neural network is mathematically calculates out what values these weights should be.
At a fundamental level, a neural network just performs a sequence of multiplications and additions.
After calculating the weighted sum for each neuron in a layer, we pass the sum through the activation function; These aren't represented in the network above. The activation function’s task is to constrain the neuron's output within a specific range to prevent certain neurons from significantly overpowering other neurons. They also introduce non-linearity to the network, which is important for complex tasks. Without activation functions, a neural network would just condense to an everyday linear regression. All neurons in a layer have the same activation function, but can vary between layers.
We explore activation functions more in-depth in the technical section of this post.
The value of the first blue neuron in the hidden layer is the sum of each of the four neurons in the input layer times the corresponding weight connecting each red neuron to the blue neuron, plus the weight represented by the dotted line between the red bias term and the blue neuron passed into the activation function.
The same process goes for the other neuron in the hidden layer and the rest of ones in the output layer. We do not calculate anything for the bias neurons; They're always one.
For our network to be useful, we need some data. Let’s say we’re tackling the hand-written digit classification task (the "Hello World" of neural nets) using the MNIST dataset. Inserted below are 64 sample inputs from the dataset, each an image of a hand-written digit.
The way each digit is represented in the computer is by a 28x28 grid of numbers. Black pixels are given a value of 0, while white ones are given a value of 1. Any number in between is a relative shade of gray.
We take each row of the image and stack it into a giant list of 784 (28 times 28) numbers; This will be the input to our neural network. For this case the input layer would need to have 784 neurons, one for each input. The output layer would need to have 10 neurons, one for each digit. We can choose to have however many hidden layers we'd like, and however many neurons in each.
We then multiply through the network until we have the 10 values in our output layer, known as the forward pass. The neuron with the highest value is our “prediction.” The first neuron represents the number 0. So, if the 8th neuron is the highest value, then the network’s prediction is the digit 7.
While the general mechanics of a neural network hopefully seem much less complicated, I would expect that your next question is: How do we train these things? How does the network know what the values of the weights need to be in order to make accurate predictions?
Saving the calculus for the technical section, let’s examine the backpropagation algorithm. As mentioned earlier, each data point has an associated label, 0 through 9. We convert each numeric value into a vector of nine zeroes, and a one in the appropriate position indicated by the original label. These are called "one-hot encoded" vectors. Again, the first index represents 0.
0=10000000007=0000000100
At the starting, all of the weights in the network are initialized randomly, meaning the predictions won’t be accurate at all. After we forward propagate through the network, we have 10 numbers in our output layer, all of them pretty close to being useless.
We then compare the output layer to the vector created from the data label and subtract them to get an error vector. It essentially tells the network how incorrect each neuron in the output layer was. Using this error, we then calculate (by using gradient descent) how much to change each weight by for the output to be a little bit more accurate. We then move to the previous layer, again calculating how inaccurate the neurons were and then tweaking the network parameters.
After we repeat this for each data point, we’ll have completed one epoch of training. Most tasks require multiple epochs in order to yield a respectable accuracy.
We can track if our neural network is actually learning by using a “cost function.” It is a function that takes in our predicted values and the actual values for our entire dataset and returns a value indicating how “incorrect” our predictions are relative to the ground truth. The lower the value, the better.
It's through this process of passing a data point forward, comparing it with the ground truth, calculating the error, backpropagating the error through the network in reverse, and updating the weights that the weights will eventually start yielding accurate predictions. It's the fundamental backbone to all of the systems today that use a neural network(s) under the hood.
In the technical section, we’ll see how the back propagation algorithm minimizes this function to obtain the weights that maximize our accuracy. There we'll dive into the linear algebra and calculus for neural networks.
With the help of some linear algebra and calculus, we can build upon our intuition of how neural networks work by understanding the math behind it.
Data
An ML model doesn't provide any value if it's unable to properly learn the trends in the data it's trained on. Models are commonly plagued with two issues, overfitting and underfitting.
Let's say we train our model on the above dataset. The first line is doing a terrible job at it; Its predictions across the whole dataset are very incorrect. The model is unable to account for the quadratic nature of the data.
The middle model is doing a very reasonable job, while the last model is doing an incredible job. We call this overfitting. The overfitting model doesn't just capture the quadratic nature, but additionally, captures some of the noise present as well.
We might be inclined to lean towards the final model, but we also need to account for the fact that our model might not generalize to new data points very well, which is the entire goal of training a model: To be robust to the infinitely possible data points the real world might throw at it.
The black star represents a new data point the model hasn't seen before. The x-value for it is 5 and the true y-value is 25. It was sampled from the same random process that generated the rest of the data points.
The orange dot represents the model's prediction for when x is equal to 5. The first model is wrong, but only by about 15. The second model is off by a much smaller amount, while the overfitting model's prediction is off by over 250!
If you look closely at the graph of the third model above, you'll see that it reaches a local maximum when x is 4 and starts to decline as x continues to increase. It correctly predicts the data points it has already seen, but cannot predict the new point correctly, even though it's a point that fits the trend of the dataset.
Think about a student that memorizes the study sheet for an upcoming test rendering them unable to answer questions that are conceptually similar, but slightly different from the questions they've seen previously. A more ideal approach might be to not just memorize the problems and their answers on the study sheet, but rather to understand the concepts required to solve them.
The overfitting model is doing exactly what the first student did, while the second model is equivalent to the student that doesn't memorize, but rather seeks to learn the underlying concepts.
Extending the analogy to the first model: One could say it's like a student that didn't bother studying at all!
To avoid the problem of overfitting, the dataset is split into two batches, a train dataset and a test dataset. Some common splits are 70-30 or 80-20. We train our model on the train dataset and check if the accuracy of our test dataset improves. If the train accuracy is high, but the test accuracy is low, we know our model is overfitting. If both accuracies are low, we're underfitting. If both are high, we're golden.
Overview of the Full Algorithm
Let's first get a brief outline of what all of the steps in the training process are. Keep in mind where the train data is used vs the test data.
Randomly initialize the weights of the neural network
Shuffle all of the training data
Perform a forward pass of the first training example
Perform the backpropagation algorithm
a. Starting at the last layer, calculate the error term for the weights
b. Proceed to calculate the error term for all hidden layers (again, going from right to left)
c. Update the weights with gradient descent
Repeat steps 3 and 4 until you've iterated through the entire train dataset
Pass through all test data and measure test accuracy to gauge if network is learning
Repeat steps 2 through 6 for each epoch of training
The testing phase is when we give the trained network the test data, data it hasn't seen before, to gauge how well it can generalize to new data. It performs a forward pass on each test example and compares the network's prediction to the ground truth for that test example.
After having gone through the entire test dataset, it calculates its accuracy. We want this metric to increase with each epoch to ensure our model isn't overfitting.
The weights of the network are not tuned when the test data is passed in! This means that the model isn't learning from the test data so it'll be unable to "memorize" any of the test examples.
Linear Algebra
Linear Algebra provides us with some minimalist notation and extra tools to properly utilize vectors and matrices. Take a look at the diagram of a neuron with two inputs again.
The output of the neuron is the weighted sum of the inputs plus the bias term, passed into the activation function. Let's call the activation function f(x) for now.
output=f(w0+w1x1+w2x2)
This notation will become quite cumbersome as the number of inputs and weights increases. Let's rewrite our inputs and weights as vectors.
x=[x1x2]w=[w1w2]
If we take the dot product of these two vectors, we get the sum of the element-wise products. We set b to equal w0 because it's the bias.
x⋅wboutput=w1x1+w2x2=w0=f(x⋅w+b)
The weighted sum can now be written with just three variables because of the convenience of the dot product. The beauty here is that our notation doesn't change as the number of inputs increases. This operation—to multiply by the weights and add the bias—is called a linear transformation.
Let's move on to a neural network.
Our input, x, will be a vector of size four because we have four neurons in the input layer. We never count the bias unit as a neuron; It's present in all layers except the output. The first bias will be a vector of size two because there are two neurons in the hidden layer. Our weights for each neuron in the hidden layer will be represented as rows in a matrix, W1. The superscript represents which neuron in the hidden layer we’re on. The subscript represents which number weight for that neuron.
We have eight solid, purple lines and two dotted, purple lines in the diagram above, so we'll need eight elements in our weight matrix and two in our bias vector.
Arranging the weights like so allows us to use a neat property of matrix multiplication. When we multiply W1 by x, we'll get a vector with a length of 2. Adding that to our bias vector, b1, grants us the values for our hidden layer. Again, f(x) is the activation function for the hidden layer.
Feel free to learn more about matrix multiplication here.
The bias term is necessary because it adds some flexibility to our network. Think about the equation for a simple line: y=mx+b. The b indicates the line's y-intercept. Without it, all lines would be required to cross the origin point, which is a severe restriction when trying to model real world scenarios. The bias vector here plays the same role.
The exact process is repeated to get the values for the output layer. The weight matrix for the output layer, W2 will be a 3 x 2 matrix. Each row in W2 represents the weights for each neuron in the output layer. Each column represents the weights from a neuron from the hidden layer. The output layer will have a bias vector of length 3. Let's call the values in the hidden layer h and the activation function for the output layer g(x) .
output layer=g(W2h+b2)
As a rule of thumb, the dimensions for any given weight matrix will be the number of neurons in the current layer by the number of neurons in the previous layer. The length of the bias vector will be the number of neurons in the current layer.
Activation Functions
Activation functions introduce non-linearity to the neural network. This gives the network the ability to solve problems beyond those that a basic linear regression could solve.
After the linear transformation for each layer, the result is then passed into that layer's activation function to get the layer's output.
Sigmoid
The Sigmoid function is one of the most common activation functions. At 0, it's equal to 0.5 and as it approaches infinity, it asymptotes towards 1. As it approaches negative infinity, it approaches 0.
Tanh
The Tanh function is very similar to the sigmoid, but its asymptote as it approaches negative infinity is -1.
ReLu
The ReLu function, or the Rectified Linear Unit, is a unique one because of its kink at 0, meaning it's not differentiable there. It's a piecewise function that returns 0 for all negative inputs and the input itself for all positive numbers. It essentially acts as a hard gate, letting all positive numbers through, but not negative ones.
GeLu
The GeLu, or the Gaussian Error Linear Unit, aims to mirror the shape of ReLu, but avoids having the kink, making it differentiable everywhere. It's equal to x times the Gaussian cumulative distribution function, which is represented by ϕ.
Certain activation functions, such as the sigmoid or tanh functions, constrain the output to a specified range. This helps prevent some neurons from overpowering the rest of the network. However, given that functions such as the ReLu and GeLu are still very commonly used, this is not a strict requirement.
Activation functions also need to be differentiable. This allows us to calculate the gradients of the cost function and minimize it. The ReLu is unique because it's not differentiable at 0, but we just manually set the derivative to be 0 there.
I won't calculate each of the derivatives by hand, but I've inserted them below.
As a side note, P(X=x) is the value of the PDF of the standard normal curve at x, and ϕ(x) is the value of the CDF of the standard normal curve at x.
One useful feature of the derivatives for the sigmoid and tanh functions is that the derivative is a function of the original output, which makes calculating the derivatives simple.
Softmax
Another activation function I want to mention is the softmax function. It doesn't have a graph, because its input is all of the layer's neurons arranged in a vector.
S(x)=∑k=1Nexkexi
The useful feature of the softmax activation function is that all of the values now add up to 1, allowing us to interpret the result as probabilities. This is useful for multiclass classification problems.
We won't dive into the derivative of this function here because it requires the Jacobian matrix as it's a function that takes in a vector and outputs another vector. The other activation functions took in a number and returned one as well.
There are many functions I haven't mentioned, but if you're interested, I encourage you to look into the Binary Step function, Leaky ReLu, ELU, SoftPlus, Swish, and many more not listed here.
There's not one activation function that rules above all else. It's dependent on the specific use case. I would recommend starting with the ReLu for the hidden layers and the softmax activation function for the output layer. Then, taking into account factors such as accuracy of the network and training speed, see if modifying the activation functions yield any improvements.
Forward Pass
Now that we got our basic building blocks, let’s put them all together for an example forward pass.
Let's use the network from above, but add the ReLu activation function, R(x) to the hidden layer and the sigmoid activation function, σ(x), to the output layer. We start with our input, x.
The key insight to gather here is the entire network is a bunch of nested functions. The result of the first linear transformation with W1 and b1 is the input for the ReLu function. The result of the ReLu function is the input for the second linear transformation, and so on.
The index value of the neuron with the highest value in the output layer is considered our prediction.
If we wanted to, we could write the whole network on one line. It's probably clear why we don't though. It'll get harder to read as the number of layers increase.
output layer=σ(W2(R(W1x+b1))+b2)
This insight will come in handy when we start to calculate the derivatives soon.
Cost functions
After a forward pass for one training example, we pass the values of the output layer (our predicted values) into a cost function, along with the true values for that data point. The cost function then returns a number indicating how "incorrect" the prediction is. The higher the output is, the more incorrect our predictions are relative to the ground truth.
Assume the first input to the loss function is our prediction and the second the ground truth.
We want our loss function to return a smaller value for the pair of inputs on the left, because the prediction vector is closer to the ground truth vector. The goal of training is to minimize our cost function across all training examples. This means we need a way to measure how different our predictions are from the ground truth and the function needs to be differentiable.
MSE
The most basic cost function is the Mean Squared Error loss.
MSE=2N1k=0∑N−1(yk−y^k)2
We're taking our prediction vector and subtracting it from our ground truth, squaring it, and summing across all training examples. Check out my post about linear regression here for a deeper intuition as to what exactly is happening.
For a singular training example, we'll be minimizing the following:
2(y−y^)2
The 2 in the denominator will cancel out when we take the derivative of the function, which is −(y−y^).
Binary Cross Entropy
Binary CE=−(yln(y^)+(1−y)ln(1−y^))
The Binary Cross Entropy cost function is used when you need the network to classify between two classes. It addresses one of the issues that are prominent with MSE combined with the logistic sigmoid: If the ground truth is 0 and the prediction is close to 1, and vice versa, then the error term evaluates to be near zero. This means the weights won't be updated when the network's predictions are extremely off, defeating the purpose of training it.
Categorical Cross Entropy
Categorical CE=−k=1∑nyilogy^i
The Categorical Cross Entropy is the extension of the Binary CE into tasks that classify more than two classes. It's used when the data's y-values are one-hot encoded, meaning that they're vectors of all zeroes except one one, whose index position represents the y-value.
Sparse Categorical Cross Entropy
The Sparse Categorical CE has the same equation as the Categorical CE but is used when the y-values are class indices, rather than one-hot encoded vectors.
Backward prop
This is the complicated part of a neural network, the training process. I’ve redrawn the neural network from above in a slightly different manner. This format will help us understand the calculus more easily.
The colors represent each layer and each rectangle represents an operation that’s performed in the network. The output of one function is the input for the next one. We pass x into the first linear transformation, z1(x,W1,b1), then into the ReLu activation function, R(x), then into the second linear transformation, z2(a1,W2,b2), then into the sigmoid activation function, σ(x), and finally into the loss function, L(a2,y), to calculate the error.
I hope this reinforces the insight of how a neural network is a bunch of nested functions.
Chain Rule Review
Imagine we're trying to find the derivative of h(x), which is comprised of two functions, one nested in the other, i.e. h(x)=f(g(x)), or the output of g(x) is the input of f(x).
Let's say f(x) and g(x) are equal to the equations below.
f(x)g(x)h(x)=x2=5x+3=(5x+3)2
The chain rule dictates how to take the derivative of nested functions; You derive the outer one while treating the inner function as a variable. Then, you multiply the result by the derivative of the inner function. So the derivative of h(x) is f′(g(x))g′(x). In other words, dxdh=dgdfdxdg.
f′(x)g′(x)h′(x)=2x=5=2(5x+3)∗5=10(5x+3)=50x+30
Partial Derivatives Review
Let's say we have the following equation, f(x,y)=x2y. Partial derivatives represent the change in the output of the equation when one of the input variables is changed by a small amount. The partial derivative of f(x,y) is how the output of the function will be affected when changing x by a small while keeping y constant. The partial derivative with respect to y is how the output of the function will be affected when changing y by a small while keeping x constant.
To calculate the partial derivatives, we take the derivative like normal but pretend one of the variables is just like any other constant.
If either the chain rule or the partial derivatives are confusing, I recommend watching this video about the chain rule by 3Blue1Brown or this video about partial derivatives by Khan Academy.
∂x∂f∂y∂f=2xy=x2
When we stack all of the partial derivatives into a vector, we call that the gradient vector of a function.
∇f=[2xyx2]
The premise of the backpropagation algorithm is that we need to take the partial derivative of the cost function with respect to each of the weight matrices and bias vectors. There are four partial derivatives in total we want to calculate: ∂W1∂L, ∂b1∂L, ∂W2∂L, and ∂b2∂L.
These are the four parameters of our network, the ones we want to optimize. We can now use the chain rule to write the equations for each of these.
If you look closely, you'll see a lot of repeated partials across all of the equations, such as the product ∂a2∂L∂z2∂a2. This is convenient for us because it means we can just compute them once and store them for later use.
Just as a reminder, ∂z1∂a1 and ∂z2∂a2 are the derivatives of the ReLu and sigmoid functions, respectively.
Calculating the Partials
Given that the loss function is comprised of several nested functions, we'll be needing to utilize the chain rule throughout our computations.
To calculate all of the partials, we'll need to find the derivatives of each of the cells in our diagram from above. Then, we'll need to multiply the appropriate results for each partial we're trying to find.
Let's start with our cost function, Mean Squared Error. Because a2, (written as y^ above), is a function of our weights and biases, we want to take the partial derivative with respect to that.
L(a2)∂y^∂L=2(y−a2)2=−(y−a2)
Moving on to the derivative of the sigmoid. We've already seen it, but here it is one more time.
σ′(z2)=σ(z2)(1−σ(z2))
Now, let's calculate the partial of the second linear transformation, W2a1+b2, with respect to W2 because we want to optimize our weight matrices. It's just a1. The partial with respect to our bias vector is 1. And our partial with respect to a1 is W2; We'll need this partial as we propagate further back for W1 and b1.
We have all of the pieces of our equation to write the complete partial derivatives for W2 and b2. I'm using ∗ to represent element-wise multiplication.
The transpose for a1 is necessary to get the dimensions to align. y, a2, and z2 are each going to be 3×1 vectors, while a1 will be a 2×1 vector. Transposing a1 and multiplying yields us a 3×2 matrix, which are the exact dimensions for W2.
The dimensions of the derivative of a vector/matrix will be the same as the dimensions of the original vector/matrix.
The partial with respect to our bias is very similar.
We call −(y−a2)∗σ′(z2) the error term for the output layer, also referred to as δ. The gradient for any weight is the layer's error term times the input from the previous layer. Note that the gradient of the bias weight is just the error term. This makes sense because the value of the bias term is 1.
The error term is useful to calculate separately because error terms in earlier layers are dependent on error terms in later layers. The backpropagation algorithm is named as such because we "propagate" the error backwards through the network to calculate the gradients.
The derivative of the ReLu function is a little bit unorthodox as we saw above.
∂z1∂R={01if z1≤0if z1>0
The derivatives of the first linear transformation are the same as the one previously computed. The derivative with respect to W1 is x and the derivative with respect to b1 is 1. We can use these pieces to calculate the partials with respect to W1 and b1.
Again, the error term of the hidden layer is W2T(−(y−a2)∗σ′(z2))∗R′(z1). It contains the error term from the output layer. If there were any more hidden layers before this one, we would've repeated this process to calculate the error terms for those layers.
Because we're in the realm of Linear Algebra, it helps to outline the expected dimensions of the values in our gradients, as a sanity check. Our expected dimensions are 2×4 for W1∂L and 2×1 for b1∂L.
A neat feature of the backpropagation algorithm is if we choose to change the activation functions for any of the layers, we don't need to calculate everything from scratch. We can just replace the derivative of the old activation function with the new one.
Updating the Weights
The final step after calculating all of the gradients is actually updating the weights.
The weights and biases are updated simultaneously after all the gradients have been calculated! Do not update the weights or biases as you calculate the gradients.
After all of the gradients are calculated for all the parameters of the network, the final update takes place. α is our learning rate, δ our error term, and a is the input from the previous layer. Remember, δ times a is our gradient for that given layer.
The := symbol means you're assigning the value on the right side as the new value of what is on the left.
Wnewbnew:=Wcurrent−αδa:=bcurrent−αδ
The learning rate, α, is a hyperparameter that controls how quickly the network learns. If it's too small, then the network will take a long time, thus costing more computational resources. If it's too large, the network will never converge, i.e., a set of optimal weights will never be found. As with all other hyperparameters, the computer scientist is tasked with experimentation to figure out the ideal learning rate for their given use case.
Don't forget that all of this was for just one data point. The entire process needs to be repeated for each data point in the train dataset to complete one epoch. Additionally, multiple epochs may be needed.
I hope the math that goes behind a neural network isn't as daunting for you now. If you have any questions or comments, please don't hesitate to reach out.