Machine learning: Neural networks introduction

Week four of my Coursera machine learning course was a breezy introduction to neural networks. The lecture videos were very high level but did a good job introducing the concept. The part I hadn’t understood before was how regression techniques are really best suited for linear prediction models, that building Nth order polynomials out of M features leads to O(N*M) work and badness. I also hadn’t really understood that neural networks are just a series of logistic regressions. The input variables are mapped through a logistic model to an intermediate hidden layer (of some chosen number of features), then the hidden layer is mapped again through a second logistic model to yield output variables. However the lecture stopped before we got to backpropagation, so for this week the method of training a neural network is still a mystery.

Logistic regression applied to OCR

The homework is a bit behind and out of sync with the lecture notes. The bulk of the work in the homework was still doing logistic regression, last week’s lecture concept. The hardest part was figuring out how to vectorize the naive loop implementation of the regularized logistic regression cost function we did last week. But I’d already vectorized it so I could just copy my solution from last week, gold star!

The more fun part was actually applying one of these learned models to do something useful with real data; OCR classification of handwritten numbers. The input was 5000 images, 20×20 greyscale pixel arrays, along with their classification (“this squiggle is the number 7”). Our job was to build a multiclass classifier to do the OCR, to predict a digit. So we took the regularized logistic regression cost function we just implemented and used fmincg() to search for the best parameters to match the data. The resulting output vector (theta) is our prediction model. Then we applied that learned model to classify input data. So I’ve now built a linear regression OCR system for handwritten numbers! The final system predicted the input set with 95% accuracy. The final model is quite large; 4010 separate integers. 401 weights for predicting each digit from 0–9, or one weight per pixel plus a constant term. Not exactly parsimony.

One neat thing about multiclass models is they don’t just output a predicted clas (“the number 7”), they also output a vector of probabilities for each possible value: “probability this image is the number 1, probability it is the number 2, …”. We crush those probabilities down to a single “this input is probably an image of the number 7”. But something to remember for later; machine learning models not only can return a prediction, but a confidence in that prediction. Or some ambiguity, I believe the math works such that a single image might have a 90% probability of being the number 7 and an 80% probability of being the number 9 (for a particularly ambiguous squiggle.)

Neural network forward propagation

The last part of the homework was implementing a basic neural network. Or rather the application of one, the forward propagation that maps the input data through the layers and gives outputs. We were handed parameters that had already been trained, so really this was just an exercise in “can you code up forward propagation?” Useful to do that myself though. In particular I had to puzzle out that the hidden layer consists of 25 nodes. So the final classifier is basically two steps. Logistic regression to map 400 pixels to 25 hidden nodes. And then a second logistic regression to map 25 hidden nodes to 10 probabilities. The central mystery of neural networks is what those “hidden nodes” really mean. And we have Deep Dream to thank for a lovely visualized expression of hidden states in a different kind of machine learning image processing system.


I almost gave up this week. Ran into a bunch of weird technical problems. By far the biggest one was me putting the line “print size(X)” in my code as a debugging aid and then forgetting about it. And suddenly my octave program is complaining about fig2dev missing and I’m down a rabbit hole of Homebrew installs trying to figure out what the hell is wrong. Turns out “print” means print, as in paper, and I needed “printf”. Derp.

I am also doing a lot of stumbling around and shallow learning. Like I know I need to combine the matrix X with the vector theta somehow but forgot which way. Rather than puzzle that out from mathematical principles, I just inspect and see X is 5000×400 and theta is 1×400, and I need to multiply them somehow, so the only sensible math is X*theta’ or theta*X’. So try one out and see if the homework oracle tells me I got the right answer, and call it a day. (Writing this out I realized I picked the wrong one of the two, which is why I keep having to transpose everything. oops.) Anyway this doesn’t feel like learning so much as just bashing about with the only thing that will pass a type checker. I keep telling myself I’m absorbing something though, and my real goal here is just to understand enough about what’s going on that I can use other people’s machine learning systems later.

I do wish the homework assignments had more interactive help. In particular a bit more hand-holding about how to run the code you write, understand what it’s doing, and test it against known good results. In desperation I started looking at the submit.m internals just to see what test cases the homework oracle was using. I don’t get to see the right outputs that way (they’re hidden on the server), but at least I have some reasonable test inputs I can look at.

The only actual course lecture notes are a wiki page. The wiki content is pretty good. But it requires a login to even read the pages! And Coursera’s wiki is broken somehow so you have to log in about once an hour, makes me very cranky.