Hide code cell source
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
import seaborn as sns
sns.set_context("paper")
sns.set_style("ticks")

Multi-class Logistic Regression#

We demonstrate multi-class logistic regression.

Handwritten Digits#

We will demonstrate multi-class logistic regression using a handwritten digits dataset. The data are in scikit-learn, and our example follows very closely this example.

First, let’s load the dataset.

from sklearn import datasets

digits = datasets.load_digits()

print(digits.DESCR)
.. _digits_dataset:

Optical recognition of handwritten digits dataset
--------------------------------------------------

**Data Set Characteristics:**

    :Number of Instances: 1797
    :Number of Attributes: 64
    :Attribute Information: 8x8 image of integer pixels in the range 0..16.
    :Missing Attribute Values: None
    :Creator: E. Alpaydin (alpaydin '@' boun.edu.tr)
    :Date: July; 1998

This is a copy of the test set of the UCI ML hand-written digits datasets
https://archive.ics.uci.edu/ml/datasets/Optical+Recognition+of+Handwritten+Digits

The data set contains images of hand-written digits: 10 classes where
each class refers to a digit.

Preprocessing programs made available by NIST were used to extract
normalized bitmaps of handwritten digits from a preprinted form. From a
total of 43 people, 30 contributed to the training set and different 13
to the test set. 32x32 bitmaps are divided into nonoverlapping blocks of
4x4 and the number of on pixels are counted in each block. This generates
an input matrix of 8x8 where each element is an integer in the range
0..16. This reduces dimensionality and gives invariance to small
distortions.

For info on NIST preprocessing routines, see M. D. Garris, J. L. Blue, G.
T. Candela, D. L. Dimmick, J. Geist, P. J. Grother, S. A. Janet, and C.
L. Wilson, NIST Form-Based Handprint Recognition System, NISTIR 5469,
1994.

.. topic:: References

  - C. Kaynak (1995) Methods of Combining Multiple Classifiers and Their
    Applications to Handwritten Digit Recognition, MSc Thesis, Institute of
    Graduate Studies in Science and Engineering, Bogazici University.
  - E. Alpaydin, C. Kaynak (1998) Cascading Classifiers, Kybernetika.
  - Ken Tang and Ponnuthurai N. Suganthan and Xi Yao and A. Kai Qin.
    Linear dimensionalityreduction using relevance weighted LDA. School of
    Electrical and Electronic Engineering Nanyang Technological University.
    2005.
  - Claudio Gentile. A New Approximate Maximal Margin Classification
    Algorithm. NIPS. 2000.

The images are in a 3D array:

print(digits.images.shape)
(1797, 8, 8)

Each row of this array is an 8x8 image (which is just a matrix). Here is the first image as just numbers:

print(digits.images[0])
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]

These numbers correspond to the darkness of each pixel. The greater the value the darker the pixel. Here is how we can visualie the first image:

fig, ax = plt.subplots()
ax.imshow(
    digits.images[0],
    cmap=plt.cm.gray_r,
    interpolation='nearest'
);
../_images/40ecc4d11de01a727475973c3ca4ba867fda4b6af329c3d9a82572f792e7145c.svg

That’s clearly a 0. Now each one of the images comes we predetermined labels that we can use to train models. Here is where you can find the labels:

print(digits.target)
[0 1 2 ... 8 9 8]

and notice that the first label is a 0, which is great. Let’s now plot several images just to gain some intuition about them:

fig, axes = plt.subplots(4, 4)
images_and_labels = list(zip(digits.images, digits.target))
for ax, (image, label) in zip(
    axes.flatten(),
    images_and_labels[:16]
):
    ax.set_axis_off()
    ax.imshow(
        image,
        cmap=plt.cm.gray_r,
        interpolation='nearest'
    )
    ax.set_title(f'Training: {label}')
plt.tight_layout()
sns.despine(trim=True);
../_images/83cb4da3a582f702d234153282f4255516faac66f2d5cd30ee8d8c736de9215e.svg

We will apply the multi-class logistic regression classifier with 64 linear features, one per pixel. First, we will vectorize the images. We turn them from \(8\times 8\) matrices to \(64\)-dimensional arrays. Here is how we can do this:

n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))
print(data.shape)
(1797, 64)

Let’s split the dataset into training and validation sets. We will use the functionality of scikit-learn for this:

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    data,
    digits.target,
    test_size=0.5,
    shuffle=True
)

The model we are going to fit is:

\[ p(y=k|\mathbf{x}, \mathbf{W}) = \operatorname{softmax}_k\left(\mathbf{w}_1^T\mathbf{x},\dots,\mathbf{w}_K^T\mathbf{x}\right), \]

where \(\mathbf{x}\) is the vectorized version of the image. Let’s do it:

from sklearn.linear_model import LogisticRegression

model = LogisticRegression(
    max_iter=2000,
    penalty=None,
    fit_intercept=True
)
model.fit(X_train, y_train);

Here is how you can get the matrix of all weights \(\mathbf{W}\):

print(model.coef_.shape)
(10, 64)

Here are point predictions for (picking the label with the highest probability):

predicted = model.predict(X_test)
print('#\tTruth\tPrediction')
print('-' * 26)
for i, (yt, yp) in enumerate(zip(y_test, predicted)):
    print(f'{i}\t{yt}\t{yp}')
Hide code cell output
#	Truth	Prediction
--------------------------
0	6	6
1	6	6
2	7	7
3	2	2
4	1	1
5	8	9
6	0	0
7	7	7
8	5	5
9	1	1
10	7	7
11	5	5
12	8	8
13	6	6
14	5	6
15	1	1
16	7	7
17	0	0
18	0	0
19	7	7
20	2	2
21	0	0
22	1	1
23	7	7
24	1	1
25	8	6
26	2	2
27	6	6
28	4	4
29	4	4
30	4	4
31	7	7
32	5	5
33	8	8
34	6	6
35	5	5
36	8	8
37	4	4
38	0	0
39	6	6
40	5	5
41	3	3
42	1	4
43	4	4
44	7	7
45	0	0
46	9	3
47	6	6
48	5	5
49	7	7
50	2	2
51	1	1
52	1	1
53	8	8
54	8	8
55	6	6
56	0	0
57	5	5
58	9	9
59	7	7
60	5	5
61	8	1
62	6	6
63	1	1
64	1	1
65	7	7
66	3	3
67	8	8
68	2	2
69	3	3
70	0	0
71	9	9
72	5	5
73	4	4
74	3	3
75	5	5
76	0	0
77	6	6
78	5	5
79	1	1
80	3	3
81	4	4
82	1	1
83	1	1
84	2	2
85	5	5
86	8	8
87	5	5
88	2	2
89	1	1
90	8	1
91	8	8
92	6	6
93	0	0
94	1	1
95	8	8
96	5	5
97	1	1
98	6	6
99	1	1
100	7	7
101	2	2
102	7	7
103	9	9
104	2	2
105	4	4
106	7	7
107	1	1
108	5	5
109	6	6
110	5	5
111	3	3
112	1	1
113	6	6
114	1	1
115	2	1
116	4	4
117	6	6
118	0	0
119	4	4
120	8	8
121	2	2
122	4	4
123	2	2
124	7	7
125	1	1
126	7	7
127	3	3
128	8	8
129	1	1
130	5	5
131	0	0
132	1	1
133	3	3
134	5	5
135	7	7
136	0	0
137	0	0
138	2	2
139	1	1
140	9	9
141	4	4
142	0	0
143	1	1
144	0	0
145	5	5
146	4	4
147	8	1
148	2	3
149	6	6
150	0	0
151	5	5
152	0	0
153	4	4
154	3	3
155	1	1
156	7	7
157	8	8
158	6	6
159	8	8
160	7	7
161	7	7
162	1	1
163	3	3
164	6	6
165	4	4
166	8	8
167	1	1
168	6	6
169	7	7
170	7	7
171	2	2
172	9	9
173	3	3
174	4	4
175	2	2
176	5	5
177	6	6
178	9	9
179	6	6
180	0	0
181	6	6
182	4	4
183	3	3
184	8	8
185	6	6
186	6	6
187	4	4
188	4	4
189	7	7
190	3	3
191	7	7
192	2	2
193	1	1
194	6	6
195	6	6
196	8	8
197	7	7
198	5	5
199	7	7
200	1	1
201	7	7
202	9	9
203	9	9
204	2	2
205	5	5
206	6	6
207	1	1
208	9	9
209	6	6
210	6	6
211	2	2
212	3	3
213	5	5
214	7	7
215	2	2
216	9	9
217	2	2
218	3	3
219	3	3
220	7	7
221	1	1
222	0	0
223	0	0
224	4	4
225	3	3
226	8	8
227	3	3
228	4	4
229	5	5
230	7	7
231	1	1
232	5	5
233	1	1
234	9	9
235	8	8
236	2	1
237	6	6
238	2	2
239	8	1
240	2	2
241	1	1
242	1	1
243	8	8
244	5	5
245	7	7
246	7	4
247	0	0
248	9	9
249	1	1
250	7	7
251	0	0
252	2	2
253	3	3
254	1	1
255	8	6
256	0	0
257	9	9
258	3	3
259	7	7
260	4	7
261	5	5
262	2	2
263	9	9
264	4	4
265	7	7
266	2	2
267	5	5
268	9	9
269	2	2
270	1	1
271	2	2
272	8	8
273	3	3
274	7	7
275	3	3
276	4	4
277	7	7
278	0	0
279	8	8
280	0	0
281	2	2
282	1	1
283	1	1
284	0	0
285	1	1
286	2	2
287	1	1
288	9	9
289	0	0
290	2	2
291	2	2
292	8	8
293	5	5
294	1	9
295	6	6
296	4	4
297	4	4
298	0	0
299	9	9
300	5	3
301	6	6
302	7	7
303	8	2
304	3	3
305	7	1
306	8	8
307	7	7
308	5	5
309	0	0
310	9	9
311	5	5
312	0	0
313	3	3
314	6	6
315	3	3
316	0	0
317	8	8
318	3	3
319	3	3
320	4	4
321	9	9
322	7	7
323	2	2
324	4	4
325	4	4
326	0	0
327	8	8
328	2	2
329	9	9
330	3	3
331	6	6
332	9	5
333	6	6
334	5	5
335	7	7
336	8	8
337	5	5
338	5	5
339	7	7
340	9	9
341	0	0
342	5	5
343	8	6
344	3	3
345	0	0
346	5	9
347	9	9
348	0	0
349	8	8
350	8	6
351	0	0
352	0	0
353	3	3
354	5	5
355	1	1
356	9	9
357	6	6
358	9	9
359	4	4
360	0	0
361	4	4
362	5	5
363	4	4
364	9	9
365	4	4
366	3	3
367	4	4
368	9	9
369	9	9
370	6	6
371	9	9
372	5	5
373	9	9
374	0	0
375	8	1
376	9	9
377	6	6
378	7	7
379	0	0
380	9	9
381	7	7
382	3	3
383	9	9
384	5	5
385	3	3
386	5	5
387	6	6
388	9	9
389	2	2
390	3	3
391	7	7
392	1	1
393	7	7
394	0	0
395	7	7
396	0	0
397	3	3
398	3	3
399	0	0
400	0	0
401	1	1
402	0	0
403	4	4
404	5	5
405	1	1
406	6	6
407	8	8
408	3	3
409	0	0
410	8	8
411	4	4
412	8	8
413	4	4
414	1	1
415	2	1
416	0	0
417	4	4
418	3	3
419	3	3
420	3	3
421	6	6
422	1	1
423	6	6
424	2	2
425	5	5
426	2	2
427	5	5
428	2	2
429	1	1
430	9	9
431	7	7
432	7	7
433	0	0
434	2	2
435	6	6
436	8	8
437	9	9
438	5	5
439	3	3
440	9	9
441	6	6
442	2	2
443	6	6
444	6	6
445	6	6
446	9	9
447	2	1
448	6	6
449	9	9
450	3	3
451	7	7
452	7	7
453	8	6
454	4	4
455	3	3
456	8	8
457	4	4
458	1	1
459	8	8
460	6	6
461	9	9
462	3	3
463	1	1
464	5	5
465	8	8
466	0	0
467	3	3
468	3	3
469	9	9
470	8	8
471	7	7
472	3	3
473	7	7
474	8	1
475	0	0
476	6	6
477	0	0
478	0	0
479	8	8
480	3	3
481	8	8
482	6	6
483	3	3
484	4	4
485	0	0
486	5	5
487	9	9
488	9	9
489	1	1
490	6	6
491	8	8
492	1	1
493	5	5
494	8	8
495	8	8
496	1	1
497	1	1
498	0	0
499	6	6
500	7	7
501	4	4
502	8	8
503	0	0
504	0	0
505	0	0
506	8	6
507	2	2
508	4	4
509	8	8
510	4	4
511	0	0
512	0	0
513	7	7
514	8	8
515	9	8
516	5	5
517	8	8
518	6	6
519	5	5
520	3	3
521	9	9
522	8	8
523	5	5
524	7	7
525	0	0
526	2	2
527	5	5
528	5	5
529	3	8
530	9	9
531	5	5
532	8	1
533	5	5
534	2	2
535	1	1
536	6	6
537	6	6
538	1	1
539	1	1
540	6	6
541	7	7
542	9	9
543	0	0
544	1	1
545	5	5
546	7	7
547	3	3
548	2	2
549	2	2
550	7	7
551	9	9
552	6	6
553	7	7
554	4	4
555	7	7
556	2	2
557	6	6
558	9	9
559	5	5
560	5	5
561	3	3
562	9	9
563	9	9
564	0	0
565	6	6
566	3	3
567	5	5
568	6	6
569	7	7
570	3	3
571	8	8
572	3	3
573	8	8
574	3	3
575	9	9
576	5	5
577	2	2
578	8	8
579	2	2
580	1	1
581	1	1
582	6	6
583	8	8
584	2	2
585	6	6
586	8	8
587	7	7
588	9	9
589	7	7
590	9	9
591	6	6
592	0	0
593	8	8
594	4	4
595	8	8
596	5	5
597	8	8
598	2	2
599	7	7
600	3	3
601	7	7
602	3	3
603	5	5
604	1	1
605	3	3
606	2	2
607	0	0
608	7	7
609	6	6
610	1	1
611	5	5
612	3	3
613	0	0
614	8	8
615	3	3
616	1	1
617	4	4
618	0	0
619	3	3
620	3	3
621	3	3
622	8	8
623	7	7
624	2	2
625	6	6
626	9	9
627	1	1
628	9	9
629	2	2
630	8	8
631	2	2
632	7	7
633	6	6
634	9	9
635	6	6
636	7	7
637	2	2
638	0	0
639	0	0
640	2	2
641	6	6
642	8	8
643	3	3
644	8	8
645	1	1
646	3	3
647	9	9
648	4	4
649	2	2
650	9	8
651	9	8
652	3	3
653	9	9
654	1	1
655	2	2
656	5	7
657	4	4
658	8	8
659	0	0
660	3	3
661	3	3
662	0	0
663	1	1
664	0	0
665	4	4
666	4	4
667	0	0
668	4	4
669	7	7
670	0	0
671	9	9
672	4	4
673	6	6
674	1	1
675	5	5
676	5	9
677	6	6
678	6	6
679	8	8
680	3	3
681	9	9
682	4	4
683	3	3
684	9	9
685	1	1
686	3	3
687	3	3
688	8	8
689	6	6
690	7	7
691	3	8
692	2	1
693	5	5
694	2	2
695	2	1
696	6	6
697	6	6
698	7	7
699	3	3
700	0	0
701	8	8
702	8	8
703	0	0
704	7	7
705	0	0
706	6	6
707	4	4
708	0	0
709	3	3
710	1	1
711	1	1
712	7	7
713	9	9
714	5	5
715	4	4
716	4	4
717	2	2
718	4	4
719	8	1
720	1	1
721	8	8
722	7	3
723	5	5
724	5	5
725	2	2
726	3	3
727	6	6
728	3	3
729	0	0
730	7	7
731	5	5
732	6	6
733	7	4
734	0	0
735	2	2
736	0	0
737	5	5
738	5	5
739	5	5
740	2	2
741	6	6
742	1	1
743	7	7
744	1	1
745	7	7
746	9	9
747	2	2
748	4	4
749	8	6
750	5	5
751	4	4
752	4	4
753	9	9
754	0	0
755	5	5
756	5	5
757	1	1
758	6	6
759	8	8
760	9	9
761	7	7
762	6	6
763	8	8
764	4	4
765	6	6
766	1	1
767	4	4
768	5	5
769	9	9
770	9	9
771	4	4
772	4	4
773	0	0
774	3	3
775	8	8
776	8	8
777	2	2
778	4	4
779	7	7
780	9	9
781	9	9
782	8	8
783	8	8
784	5	5
785	9	9
786	9	9
787	4	4
788	3	3
789	2	3
790	2	2
791	9	9
792	9	9
793	8	5
794	9	6
795	3	3
796	5	5
797	2	2
798	9	9
799	4	4
800	2	2
801	8	8
802	8	8
803	4	4
804	6	6
805	5	5
806	8	1
807	6	6
808	6	6
809	3	3
810	9	9
811	2	2
812	7	7
813	0	0
814	4	4
815	7	7
816	9	9
817	5	5
818	2	2
819	8	8
820	3	3
821	2	2
822	3	3
823	7	7
824	1	1
825	2	2
826	0	0
827	0	0
828	4	4
829	7	7
830	1	1
831	6	6
832	5	5
833	4	4
834	3	3
835	4	4
836	4	4
837	4	4
838	3	3
839	8	8
840	3	3
841	9	9
842	3	3
843	5	5
844	0	0
845	8	8
846	5	5
847	4	4
848	9	9
849	2	2
850	0	0
851	6	6
852	9	9
853	8	8
854	4	4
855	9	9
856	0	0
857	0	0
858	4	4
859	1	1
860	1	1
861	2	2
862	4	4
863	1	1
864	0	0
865	9	9
866	9	6
867	8	8
868	0	0
869	7	7
870	2	2
871	0	0
872	7	7
873	1	1
874	6	6
875	9	9
876	1	1
877	4	4
878	8	8
879	4	4
880	8	8
881	6	6
882	6	6
883	6	6
884	4	4
885	2	2
886	8	8
887	7	7
888	6	6
889	7	7
890	6	6
891	0	0
892	5	5
893	1	1
894	5	5
895	0	0
896	7	7
897	4	4
898	2	2

But we can also make probabilistic predictions:

Hide code cell source
prob_predict = model.predict_proba(X_test)

fig, axes = plt.subplots(10, 2)
for i in range(10):
    axes[i, 0].imshow(
        X_test[i].reshape((8, 8)),
        cmap=plt.cm.gray_r,
        interpolation='nearest'
    )
    axes[i, 0].set_yticks([])
    axes[i, 0].set_xticks([])
    axes[i, 1].set_xticks([])
    axes[i, 1].bar(
        np.arange(10),
        prob_predict[i, :]
    )
    axes[i, 1].set_yticks([])
axes[-1, 1].set_xticks(np.arange(10))
axes[-1, 1].set_xticklabels(model.classes_)
sns.despine(trim=True);
../_images/75eb981810ddb48122413afe312ec83cdaf2bce72c108ae9aa58417d290ea03b.svg

Scikit-learn can run many accuracy metrics at once for you. Here is everything, including the confusion matrix:

from sklearn import metrics
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

print(f"Classification report for model {model}")
print(
    metrics.classification_report(y_test, predicted)
)

cm = confusion_matrix(
    y_test,
    predicted,
    labels=model.classes_
)

disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=model.classes_
)
disp.plot();
Classification report for model LogisticRegression(max_iter=2000, penalty=None)
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        95
           1       0.84      0.98      0.90        87
           2       0.99      0.90      0.94        83
           3       0.95      0.98      0.96        89
           4       0.96      0.99      0.98        83
           5       0.98      0.94      0.96        90
           6       0.90      1.00      0.95        93
           7       0.98      0.96      0.97        92
           8       0.94      0.81      0.87        99
           9       0.95      0.92      0.94        88

    accuracy                           0.95       899
   macro avg       0.95      0.95      0.95       899
weighted avg       0.95      0.95      0.95       899
../_images/72869ed52b1f11b53922d17c60c61895d6ce6ce58b1ced0f34d6a4f3233ad4fb.svg

Questions#

  • Look at the precision matrix carefully and identify the digits for which the most mistakes are made. Why does this happen? Write code to visualize some of the wrong predictions.