{
"cells": [
{
"cell_type": "markdown",
"id": "b8163ca1-e156-485b-b60c-023914d65d56",
"metadata": {},
"source": [
"# Credal Classification\n",
"\n",
"First, we need to introduce a small tolerance when comparing floating point values, to account for numerical approximations in the code. Here you can set the global value:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e004caa3-9abb-40f9-8c2e-93523d1a8d07",
"metadata": {},
"outputs": [],
"source": [
"TOL = 1e-6"
]
},
{
"cell_type": "markdown",
"id": "7327acc6-81cd-44e3-9878-827b054065e2",
"metadata": {
"jp-MarkdownHeadingCollapsed": true
},
"source": [
"## Data Set\n",
"\n",
"To demonstrate credal classification, we will use the following [breast cancer dataset](http://archive.ics.uci.edu/ml/datasets/mammographic+mass). Note: if you can, hide the next cell for easier navigation."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "17b17f15-ce94-42dd-a775-5f5da860771b",
"metadata": {},
"outputs": [],
"source": [
"csv_file = \"\"\"5,67,3,5,3,1\n",
"4,43,1,1,?,1\n",
"5,58,4,5,3,1\n",
"4,28,1,1,3,0\n",
"5,74,1,5,?,1\n",
"4,65,1,?,3,0\n",
"4,70,?,?,3,0\n",
"5,42,1,?,3,0\n",
"5,57,1,5,3,1\n",
"5,60,?,5,1,1\n",
"5,76,1,4,3,1\n",
"3,42,2,1,3,1\n",
"4,64,1,?,3,0\n",
"4,36,3,1,2,0\n",
"4,60,2,1,2,0\n",
"4,54,1,1,3,0\n",
"3,52,3,4,3,0\n",
"4,59,2,1,3,1\n",
"4,54,1,1,3,1\n",
"4,40,1,?,?,0\n",
"?,66,?,?,1,1\n",
"5,56,4,3,1,1\n",
"4,43,1,?,?,0\n",
"5,42,4,4,3,1\n",
"4,59,2,4,3,1\n",
"5,75,4,5,3,1\n",
"2,66,1,1,?,0\n",
"5,63,3,?,3,0\n",
"5,45,4,5,3,1\n",
"5,55,4,4,3,0\n",
"4,46,1,5,2,0\n",
"5,54,4,4,3,1\n",
"5,57,4,4,3,1\n",
"4,39,1,1,2,0\n",
"4,81,1,1,3,0\n",
"4,77,3,?,?,0\n",
"4,60,2,1,3,0\n",
"5,67,3,4,2,1\n",
"4,48,4,5,?,1\n",
"4,55,3,4,2,0\n",
"4,59,2,1,?,0\n",
"4,78,1,1,1,0\n",
"4,50,1,1,3,0\n",
"4,61,2,1,?,0\n",
"5,62,3,5,2,1\n",
"5,44,2,4,?,1\n",
"5,64,4,5,3,1\n",
"4,23,1,1,?,0\n",
"2,42,?,?,4,0\n",
"5,67,4,5,3,1\n",
"4,74,2,1,2,0\n",
"5,80,3,5,3,1\n",
"4,23,1,1,?,0\n",
"4,63,2,1,?,0\n",
"4,53,?,5,3,1\n",
"4,43,3,4,?,0\n",
"4,49,2,1,1,0\n",
"5,51,2,4,?,0\n",
"4,45,2,1,?,0\n",
"5,59,2,?,?,1\n",
"5,52,4,3,3,1\n",
"5,60,4,3,3,1\n",
"4,57,2,5,3,0\n",
"3,57,2,1,?,0\n",
"5,74,4,4,3,1\n",
"4,25,2,1,?,0\n",
"4,49,1,1,3,0\n",
"5,72,4,3,?,1\n",
"4,45,2,1,3,0\n",
"4,64,2,1,3,0\n",
"4,73,2,1,2,0\n",
"5,68,4,3,3,1\n",
"5,52,4,5,3,0\n",
"5,66,4,4,3,1\n",
"5,70,?,4,?,1\n",
"4,25,1,1,3,0\n",
"5,74,1,1,2,1\n",
"4,64,1,1,3,0\n",
"5,60,4,3,2,1\n",
"5,67,2,4,1,0\n",
"4,67,4,5,3,0\n",
"5,44,4,4,2,1\n",
"3,68,1,1,3,1\n",
"4,57,?,4,1,0\n",
"5,51,4,?,?,1\n",
"4,33,1,?,?,0\n",
"5,58,4,4,3,1\n",
"5,36,1,?,?,0\n",
"4,63,1,1,?,0\n",
"5,62,1,5,3,1\n",
"4,73,3,4,3,1\n",
"4,80,4,4,3,1\n",
"4,67,1,1,?,0\n",
"5,59,2,1,3,1\n",
"5,60,1,?,3,0\n",
"5,54,4,4,3,1\n",
"4,40,1,1,?,0\n",
"4,47,2,1,?,0\n",
"5,62,4,4,3,0\n",
"4,33,2,1,3,0\n",
"5,59,2,?,?,0\n",
"4,65,2,?,?,0\n",
"4,58,4,4,?,0\n",
"4,29,2,?,?,0\n",
"4,58,1,1,?,0\n",
"4,54,1,1,?,0\n",
"4,44,1,1,?,1\n",
"3,34,2,1,?,0\n",
"4,57,1,1,3,0\n",
"5,33,4,4,?,1\n",
"4,45,4,4,3,0\n",
"5,71,4,4,3,1\n",
"5,59,4,4,2,0\n",
"4,56,2,1,?,0\n",
"4,40,3,4,?,0\n",
"4,56,1,1,3,0\n",
"4,45,2,1,?,0\n",
"4,57,2,1,2,0\n",
"5,55,3,4,3,1\n",
"5,84,4,5,3,0\n",
"5,51,4,4,3,1\n",
"4,43,1,1,?,0\n",
"4,24,2,1,2,0\n",
"4,66,1,1,3,0\n",
"5,33,4,4,3,0\n",
"4,59,4,3,2,0\n",
"4,76,2,3,?,0\n",
"4,40,1,1,?,0\n",
"4,52,?,4,?,0\n",
"5,40,4,5,3,1\n",
"5,67,4,4,3,1\n",
"5,75,4,3,3,1\n",
"5,86,4,4,3,0\n",
"4,60,2,?,?,0\n",
"5,66,4,4,3,1\n",
"5,46,4,5,3,1\n",
"4,59,4,4,3,1\n",
"5,65,4,4,3,1\n",
"4,53,1,1,3,0\n",
"5,67,3,5,3,1\n",
"5,80,4,5,3,1\n",
"4,55,2,1,3,0\n",
"4,48,1,1,?,0\n",
"4,47,1,1,2,0\n",
"4,50,2,1,?,0\n",
"5,62,4,5,3,1\n",
"5,63,4,4,3,1\n",
"4,63,4,?,3,1\n",
"4,71,4,4,3,1\n",
"4,41,1,1,3,0\n",
"5,57,4,4,4,1\n",
"5,71,4,4,4,1\n",
"4,66,1,1,3,0\n",
"4,47,2,4,2,0\n",
"3,34,4,4,3,0\n",
"4,59,3,4,3,0\n",
"5,55,2,?,?,1\n",
"4,51,?,?,3,0\n",
"4,62,2,1,?,0\n",
"4,58,4,?,3,1\n",
"5,67,4,4,3,1\n",
"4,41,2,1,3,0\n",
"4,23,3,1,3,0\n",
"4,53,?,4,3,0\n",
"4,42,2,1,3,0\n",
"5,87,4,5,3,1\n",
"4,68,1,1,3,1\n",
"4,64,1,1,3,0\n",
"5,54,3,5,3,1\n",
"5,86,4,5,3,1\n",
"4,21,2,1,3,0\n",
"4,39,1,1,?,0\n",
"4,53,4,4,3,0\n",
"4,44,4,4,3,0\n",
"4,54,1,1,3,0\n",
"5,63,4,5,3,1\n",
"4,62,2,1,?,0\n",
"4,45,2,1,2,0\n",
"5,71,4,5,3,0\n",
"5,49,4,4,3,1\n",
"4,49,4,4,3,0\n",
"5,66,4,4,4,0\n",
"4,19,1,1,3,0\n",
"4,35,1,1,2,0\n",
"4,71,3,3,?,1\n",
"5,74,4,5,3,1\n",
"5,37,4,4,3,1\n",
"4,67,1,?,3,0\n",
"5,81,3,4,3,1\n",
"5,59,4,4,3,1\n",
"4,34,1,1,3,0\n",
"5,79,4,3,3,1\n",
"5,60,3,1,3,0\n",
"4,41,1,1,3,1\n",
"4,50,1,1,3,0\n",
"5,85,4,4,3,1\n",
"4,46,1,1,3,0\n",
"5,66,4,4,3,1\n",
"4,73,3,1,2,0\n",
"4,55,1,1,3,0\n",
"4,49,2,1,3,0\n",
"3,49,4,4,3,0\n",
"4,51,4,5,3,1\n",
"2,48,4,4,3,0\n",
"4,58,4,5,3,0\n",
"5,72,4,5,3,1\n",
"4,46,2,3,3,0\n",
"4,43,4,3,3,1\n",
"?,52,4,4,3,0\n",
"4,66,2,1,?,0\n",
"4,46,1,1,1,0\n",
"4,69,3,1,3,0\n",
"2,59,1,1,?,1\n",
"5,43,2,1,3,1\n",
"5,76,4,5,3,1\n",
"4,46,1,1,3,0\n",
"4,59,2,4,3,0\n",
"4,57,1,1,3,0\n",
"5,43,4,5,?,0\n",
"3,45,2,1,3,0\n",
"3,43,2,1,3,0\n",
"4,45,2,1,3,0\n",
"5,57,4,5,3,1\n",
"5,79,4,4,3,1\n",
"5,54,2,1,3,1\n",
"4,40,3,4,3,0\n",
"5,63,4,4,3,1\n",
"2,55,1,?,1,0\n",
"4,52,2,1,3,0\n",
"4,38,1,1,3,0\n",
"3,72,4,3,3,0\n",
"5,80,4,3,3,1\n",
"5,76,4,3,3,1\n",
"4,62,3,1,3,0\n",
"5,64,4,5,3,1\n",
"5,42,4,5,3,0\n",
"3,60,?,3,1,0\n",
"4,64,4,5,3,0\n",
"4,63,4,4,3,1\n",
"4,24,2,1,2,0\n",
"5,72,4,4,3,1\n",
"4,63,2,1,3,0\n",
"4,46,1,1,3,0\n",
"3,33,1,1,3,0\n",
"5,76,4,4,3,1\n",
"4,36,2,3,3,0\n",
"4,40,2,1,3,0\n",
"5,58,1,5,3,1\n",
"4,43,2,1,3,0\n",
"3,42,1,1,3,0\n",
"4,32,1,1,3,0\n",
"5,57,4,4,2,1\n",
"4,37,1,1,3,0\n",
"4,70,4,4,3,1\n",
"5,56,4,2,3,1\n",
"3,76,?,3,2,0\n",
"5,73,4,4,3,1\n",
"5,77,4,5,3,1\n",
"5,67,4,4,1,1\n",
"5,71,4,3,3,1\n",
"5,65,4,4,3,1\n",
"4,43,1,1,3,0\n",
"4,40,2,1,?,0\n",
"4,49,2,1,3,0\n",
"5,76,4,2,3,1\n",
"4,55,4,4,3,0\n",
"5,72,4,5,3,1\n",
"3,53,4,3,3,0\n",
"5,75,4,4,3,1\n",
"5,61,4,5,3,1\n",
"5,67,4,4,3,1\n",
"5,55,4,2,3,1\n",
"5,66,4,4,3,1\n",
"2,76,1,1,2,0\n",
"4,57,4,4,3,1\n",
"5,71,3,1,3,0\n",
"5,70,4,5,3,1\n",
"4,35,4,2,?,0\n",
"5,79,1,?,3,1\n",
"4,63,2,1,3,0\n",
"5,40,1,4,3,1\n",
"4,41,1,1,3,0\n",
"4,47,2,1,2,0\n",
"4,68,1,1,3,1\n",
"4,64,4,3,3,1\n",
"4,65,4,4,?,1\n",
"4,73,4,3,3,0\n",
"4,39,4,3,3,0\n",
"5,55,4,5,4,1\n",
"5,53,3,4,4,0\n",
"5,66,4,4,3,1\n",
"4,43,3,1,2,0\n",
"5,44,4,5,3,1\n",
"4,77,4,4,3,1\n",
"4,62,2,4,3,0\n",
"5,80,4,4,3,1\n",
"4,33,4,4,3,0\n",
"4,50,4,5,3,1\n",
"4,71,1,?,3,0\n",
"5,46,4,4,3,1\n",
"5,49,4,5,3,1\n",
"4,53,1,1,3,0\n",
"3,46,2,1,2,0\n",
"4,57,1,1,3,0\n",
"4,54,3,1,3,0\n",
"4,54,1,?,?,0\n",
"2,49,2,1,2,0\n",
"4,47,3,1,3,0\n",
"4,40,1,1,3,0\n",
"4,45,1,1,3,0\n",
"4,50,4,5,3,1\n",
"5,54,4,4,3,1\n",
"4,67,4,1,3,1\n",
"4,77,4,4,3,1\n",
"4,66,4,3,3,0\n",
"4,71,2,?,3,1\n",
"4,36,2,3,3,0\n",
"4,69,4,4,3,0\n",
"4,48,1,1,3,0\n",
"4,64,4,4,3,1\n",
"4,71,4,2,3,1\n",
"5,60,4,3,3,1\n",
"4,24,1,1,3,0\n",
"5,34,4,5,2,1\n",
"4,79,1,1,2,0\n",
"4,45,1,1,3,0\n",
"4,37,2,1,2,0\n",
"4,42,1,1,2,0\n",
"4,72,4,4,3,1\n",
"5,60,4,5,3,1\n",
"5,85,3,5,3,1\n",
"4,51,1,1,3,0\n",
"5,54,4,5,3,1\n",
"5,55,4,3,3,1\n",
"4,64,4,4,3,0\n",
"5,67,4,5,3,1\n",
"5,75,4,3,3,1\n",
"5,87,4,4,3,1\n",
"4,46,4,4,3,1\n",
"4,59,2,1,?,0\n",
"55,46,4,3,3,1\n",
"5,61,1,1,3,1\n",
"4,44,1,4,3,0\n",
"4,32,1,1,3,0\n",
"4,62,1,1,3,0\n",
"5,59,4,5,3,1\n",
"4,61,4,1,3,0\n",
"5,78,4,4,3,1\n",
"5,42,4,5,3,0\n",
"4,45,1,2,3,0\n",
"5,34,2,1,3,1\n",
"5,39,4,3,?,1\n",
"4,27,3,1,3,0\n",
"4,43,1,1,3,0\n",
"5,83,4,4,3,1\n",
"4,36,2,1,3,0\n",
"4,37,2,1,3,0\n",
"4,56,3,1,3,1\n",
"5,55,4,4,3,1\n",
"5,46,3,?,3,0\n",
"4,88,4,4,3,1\n",
"5,71,4,4,3,1\n",
"4,41,2,1,3,0\n",
"5,49,4,4,3,1\n",
"3,51,1,1,4,0\n",
"4,39,1,3,3,0\n",
"4,46,2,1,3,0\n",
"5,52,4,4,3,1\n",
"5,58,4,4,3,1\n",
"4,67,4,5,3,1\n",
"5,80,4,4,3,1\n",
"3,46,1,?,?,0\n",
"3,43,1,?,?,0\n",
"4,45,1,1,3,0\n",
"5,68,4,4,3,1\n",
"4,54,4,4,?,1\n",
"4,44,2,3,3,0\n",
"5,74,4,3,3,1\n",
"5,55,4,5,3,0\n",
"4,49,4,4,3,1\n",
"4,49,1,1,3,0\n",
"5,50,4,3,3,1\n",
"5,52,3,5,3,1\n",
"4,45,1,1,3,0\n",
"4,66,1,1,3,0\n",
"4,68,4,4,3,1\n",
"4,72,2,1,3,0\n",
"5,64,?,?,3,0\n",
"2,49,?,3,3,0\n",
"3,44,?,4,3,0\n",
"5,74,4,4,3,1\n",
"5,58,4,4,3,1\n",
"4,77,2,3,3,0\n",
"4,49,3,1,3,0\n",
"4,34,?,?,4,0\n",
"5,60,4,3,3,1\n",
"5,69,4,3,3,1\n",
"4,53,2,1,3,0\n",
"3,46,3,4,3,0\n",
"5,74,4,4,3,1\n",
"4,58,1,1,3,0\n",
"5,68,4,4,3,1\n",
"5,46,4,3,3,0\n",
"5,61,2,4,3,1\n",
"5,70,4,3,3,1\n",
"5,37,4,4,3,1\n",
"3,65,4,5,3,1\n",
"4,67,4,4,3,0\n",
"5,69,3,4,3,0\n",
"5,76,4,4,3,1\n",
"4,65,4,3,3,0\n",
"5,72,4,2,3,1\n",
"4,62,4,2,3,0\n",
"5,42,4,4,3,1\n",
"5,66,4,3,3,1\n",
"5,48,4,4,3,1\n",
"4,35,1,1,3,0\n",
"5,60,4,4,3,1\n",
"5,67,4,2,3,1\n",
"5,78,4,4,3,1\n",
"4,66,1,1,3,1\n",
"4,26,1,1,?,0\n",
"4,48,1,1,3,0\n",
"4,31,1,1,3,0\n",
"5,43,4,3,3,1\n",
"5,72,2,4,3,0\n",
"5,66,1,1,3,1\n",
"4,56,4,4,3,0\n",
"5,58,4,5,3,1\n",
"5,33,2,4,3,1\n",
"4,37,1,1,3,0\n",
"5,36,4,3,3,1\n",
"4,39,2,3,3,0\n",
"4,39,4,4,3,1\n",
"5,83,4,4,3,1\n",
"4,68,4,5,3,1\n",
"5,63,3,4,3,1\n",
"5,78,4,4,3,1\n",
"4,38,2,3,3,0\n",
"5,46,4,3,3,1\n",
"5,60,4,4,3,1\n",
"5,56,2,3,3,1\n",
"4,33,1,1,3,0\n",
"4,?,4,5,3,1\n",
"4,69,1,5,3,1\n",
"5,66,1,4,3,1\n",
"4,72,1,3,3,0\n",
"4,29,1,1,3,0\n",
"5,54,4,5,3,1\n",
"5,80,4,4,3,1\n",
"5,68,4,3,3,1\n",
"4,35,2,1,3,0\n",
"4,57,3,?,3,0\n",
"5,?,4,4,3,1\n",
"4,50,1,1,3,0\n",
"4,32,4,3,3,0\n",
"0,69,4,5,3,1\n",
"4,71,4,5,3,1\n",
"5,87,4,5,3,1\n",
"3,40,2,?,3,0\n",
"4,31,1,1,?,0\n",
"4,64,1,1,3,0\n",
"5,55,4,5,3,1\n",
"4,18,1,1,3,0\n",
"3,50,2,1,?,0\n",
"4,53,1,1,3,0\n",
"5,84,4,5,3,1\n",
"5,80,4,3,3,1\n",
"4,32,1,1,3,0\n",
"5,77,3,4,3,1\n",
"4,38,1,1,3,0\n",
"5,54,4,5,3,1\n",
"4,63,1,1,3,0\n",
"4,61,1,1,3,0\n",
"4,52,1,1,3,0\n",
"4,36,1,1,3,0\n",
"4,41,?,?,3,0\n",
"4,59,1,1,3,0\n",
"5,51,4,4,2,1\n",
"4,36,1,1,3,0\n",
"5,40,4,3,3,1\n",
"4,49,1,1,3,0\n",
"4,37,2,3,3,0\n",
"4,46,1,1,3,0\n",
"4,63,1,1,3,0\n",
"4,28,2,1,3,0\n",
"4,47,2,1,3,0\n",
"4,42,2,1,3,1\n",
"5,44,4,5,3,1\n",
"4,49,4,4,3,0\n",
"5,47,4,5,3,1\n",
"5,52,4,5,3,1\n",
"4,53,1,1,3,1\n",
"5,83,3,3,3,1\n",
"4,50,4,4,?,1\n",
"5,63,4,4,3,1\n",
"4,82,?,5,3,1\n",
"4,54,1,1,3,0\n",
"4,50,4,4,3,0\n",
"5,80,4,5,3,1\n",
"5,45,2,4,3,0\n",
"5,59,4,4,?,1\n",
"4,28,2,1,3,0\n",
"4,31,1,1,3,0\n",
"4,41,2,1,3,0\n",
"4,21,3,1,3,0\n",
"5,44,3,4,3,1\n",
"5,49,4,4,3,1\n",
"5,71,4,5,3,1\n",
"5,75,4,5,3,1\n",
"4,38,2,1,3,0\n",
"4,60,1,3,3,0\n",
"5,87,4,5,3,1\n",
"4,70,4,4,3,1\n",
"5,55,4,5,3,1\n",
"3,21,1,1,3,0\n",
"4,50,1,1,3,0\n",
"5,76,4,5,3,1\n",
"4,23,1,1,3,0\n",
"3,68,?,?,3,0\n",
"4,62,4,?,3,1\n",
"5,65,1,?,3,1\n",
"5,73,4,5,3,1\n",
"4,38,2,3,3,0\n",
"2,57,1,1,3,0\n",
"5,65,4,5,3,1\n",
"5,67,2,4,3,1\n",
"5,61,2,4,3,1\n",
"5,56,4,4,3,0\n",
"5,71,2,4,3,1\n",
"4,49,2,2,3,0\n",
"4,55,?,?,3,0\n",
"4,44,2,1,3,0\n",
"0,58,4,4,3,0\n",
"4,27,2,1,3,0\n",
"5,73,4,5,3,1\n",
"4,34,2,1,3,0\n",
"5,63,?,4,3,1\n",
"4,50,2,1,3,1\n",
"4,62,2,1,3,0\n",
"3,21,3,1,3,0\n",
"4,49,2,?,3,0\n",
"4,36,3,1,3,0\n",
"4,45,2,1,3,1\n",
"5,67,4,5,3,1\n",
"4,21,1,1,3,0\n",
"4,57,2,1,3,0\n",
"5,66,4,5,3,1\n",
"4,71,4,4,3,1\n",
"5,69,3,4,3,1\n",
"6,80,4,5,3,1\n",
"3,27,2,1,3,0\n",
"4,38,2,1,3,0\n",
"4,23,2,1,3,0\n",
"5,70,?,5,3,1\n",
"4,46,4,3,3,0\n",
"4,61,2,3,3,0\n",
"5,65,4,5,3,1\n",
"4,60,4,3,3,0\n",
"5,83,4,5,3,1\n",
"5,40,4,4,3,1\n",
"2,59,?,4,3,0\n",
"4,53,3,4,3,0\n",
"4,76,4,4,3,0\n",
"5,79,1,4,3,1\n",
"5,38,2,4,3,1\n",
"4,61,3,4,3,0\n",
"4,56,2,1,3,0\n",
"4,44,2,1,3,0\n",
"4,64,3,4,?,1\n",
"4,66,3,3,3,0\n",
"4,50,3,3,3,0\n",
"4,46,1,1,3,0\n",
"4,39,1,1,3,0\n",
"4,60,3,?,?,0\n",
"5,55,4,5,3,1\n",
"4,40,2,1,3,0\n",
"4,26,1,1,3,0\n",
"5,84,3,2,3,1\n",
"4,41,2,2,3,0\n",
"4,63,1,1,3,0\n",
"2,65,?,1,2,0\n",
"4,49,1,1,3,0\n",
"4,56,2,2,3,1\n",
"5,65,4,4,3,0\n",
"4,54,1,1,3,0\n",
"4,36,1,1,3,0\n",
"5,49,4,4,3,0\n",
"4,59,4,4,3,1\n",
"5,75,4,4,3,1\n",
"5,59,4,2,3,0\n",
"5,59,4,4,3,1\n",
"4,28,4,4,3,1\n",
"5,53,4,5,3,0\n",
"5,57,4,4,3,0\n",
"5,77,4,3,4,0\n",
"5,85,4,3,3,1\n",
"4,59,4,4,3,0\n",
"5,59,1,5,3,1\n",
"4,65,3,3,3,1\n",
"4,54,2,1,3,0\n",
"5,46,4,5,3,1\n",
"4,63,4,4,3,1\n",
"4,53,1,1,3,1\n",
"4,56,1,1,3,0\n",
"5,66,4,4,3,1\n",
"5,66,4,5,3,1\n",
"4,55,1,1,3,0\n",
"4,44,1,1,3,0\n",
"5,86,3,4,3,1\n",
"5,47,4,5,3,1\n",
"5,59,4,5,3,1\n",
"5,66,4,5,3,0\n",
"5,61,4,3,3,1\n",
"3,46,?,5,?,1\n",
"4,69,1,1,3,0\n",
"5,93,1,5,3,1\n",
"4,39,1,3,3,0\n",
"5,44,4,5,3,1\n",
"4,45,2,2,3,0\n",
"4,51,3,4,3,0\n",
"4,56,2,4,3,0\n",
"4,66,4,4,3,0\n",
"5,61,4,5,3,1\n",
"4,64,3,3,3,1\n",
"5,57,2,4,3,0\n",
"5,79,4,4,3,1\n",
"4,57,2,1,?,0\n",
"4,44,4,1,1,0\n",
"4,31,2,1,3,0\n",
"4,63,4,4,3,0\n",
"4,64,1,1,3,0\n",
"5,47,4,5,3,0\n",
"5,68,4,5,3,1\n",
"4,30,1,1,3,0\n",
"5,43,4,5,3,1\n",
"4,56,1,1,3,0\n",
"4,46,2,1,3,0\n",
"4,67,2,1,3,0\n",
"5,52,4,5,3,1\n",
"4,67,4,4,3,1\n",
"4,47,2,1,3,0\n",
"5,58,4,5,3,1\n",
"4,28,2,1,3,0\n",
"4,43,1,1,3,0\n",
"4,57,2,4,3,0\n",
"5,68,4,5,3,1\n",
"4,64,2,4,3,0\n",
"4,64,2,4,3,0\n",
"5,62,4,4,3,1\n",
"4,38,4,1,3,0\n",
"5,68,4,4,3,1\n",
"4,41,2,1,3,0\n",
"4,35,2,1,3,1\n",
"4,68,2,1,3,0\n",
"5,55,4,4,3,1\n",
"5,67,4,4,3,1\n",
"4,51,4,3,3,0\n",
"2,40,1,1,3,0\n",
"5,73,4,4,3,1\n",
"4,58,?,4,3,1\n",
"4,51,?,4,3,0\n",
"3,50,?,?,3,1\n",
"5,59,4,3,3,1\n",
"6,60,3,5,3,1\n",
"4,27,2,1,?,0\n",
"5,54,4,3,3,0\n",
"4,56,1,1,3,0\n",
"5,53,4,5,3,1\n",
"4,54,2,4,3,0\n",
"5,79,1,4,3,1\n",
"5,67,4,3,3,1\n",
"5,64,3,3,3,1\n",
"4,70,1,2,3,1\n",
"5,55,4,3,3,1\n",
"5,65,3,3,3,1\n",
"5,45,4,2,3,1\n",
"4,57,4,4,?,1\n",
"5,49,1,1,3,1\n",
"4,24,2,1,3,0\n",
"4,52,1,1,3,0\n",
"4,50,2,1,3,0\n",
"4,35,1,1,3,0\n",
"5,?,3,3,3,1\n",
"5,64,4,3,3,1\n",
"5,40,4,1,1,1\n",
"5,66,4,4,3,1\n",
"4,64,4,4,3,1\n",
"5,52,4,3,3,1\n",
"5,43,1,4,3,1\n",
"4,56,4,4,3,0\n",
"4,72,3,?,3,0\n",
"6,51,4,4,3,1\n",
"4,79,4,4,3,1\n",
"4,22,2,1,3,0\n",
"4,73,2,1,3,0\n",
"4,53,3,4,3,0\n",
"4,59,2,1,3,1\n",
"4,46,4,4,2,0\n",
"5,66,4,4,3,1\n",
"4,50,4,3,3,1\n",
"4,58,1,1,3,1\n",
"4,55,1,1,3,0\n",
"4,62,2,4,3,1\n",
"4,60,1,1,3,0\n",
"5,57,4,3,3,1\n",
"4,57,1,1,3,0\n",
"6,41,2,1,3,0\n",
"4,71,2,1,3,1\n",
"4,32,2,1,3,0\n",
"4,57,2,1,3,0\n",
"4,19,1,1,3,0\n",
"4,62,2,4,3,1\n",
"5,67,4,5,3,1\n",
"4,50,4,5,3,0\n",
"4,65,2,3,2,0\n",
"4,40,2,4,2,0\n",
"6,71,4,4,3,1\n",
"6,68,4,3,3,1\n",
"4,68,1,1,3,0\n",
"4,29,1,1,3,0\n",
"4,53,2,1,3,0\n",
"5,66,4,4,3,1\n",
"4,60,3,?,4,0\n",
"5,76,4,4,3,1\n",
"4,58,2,1,2,0\n",
"5,96,3,4,3,1\n",
"5,70,4,4,3,1\n",
"4,34,2,1,3,0\n",
"4,59,2,1,3,0\n",
"4,45,3,1,3,1\n",
"5,65,4,4,3,1\n",
"4,59,1,1,3,0\n",
"4,21,2,1,3,0\n",
"3,43,2,1,3,0\n",
"4,53,1,1,3,0\n",
"4,65,2,1,3,0\n",
"4,64,2,4,3,1\n",
"4,53,4,4,3,0\n",
"4,51,1,1,3,0\n",
"4,59,2,4,3,0\n",
"4,56,2,1,3,0\n",
"4,60,2,1,3,0\n",
"4,22,1,1,3,0\n",
"4,25,2,1,3,0\n",
"6,76,3,?,3,0\n",
"5,69,4,4,3,1\n",
"4,58,2,1,3,0\n",
"5,62,4,3,3,1\n",
"4,56,4,4,3,0\n",
"4,64,1,1,3,0\n",
"4,32,2,1,3,0\n",
"5,48,?,4,?,1\n",
"5,59,4,4,2,1\n",
"4,52,1,1,3,0\n",
"4,63,4,4,3,0\n",
"5,67,4,4,3,1\n",
"5,61,4,4,3,1\n",
"5,59,4,5,3,1\n",
"5,52,4,3,3,1\n",
"4,35,4,4,3,0\n",
"5,77,3,3,3,1\n",
"5,71,4,3,3,1\n",
"5,63,4,3,3,1\n",
"4,38,2,1,2,0\n",
"5,72,4,3,3,1\n",
"4,76,4,3,3,1\n",
"4,53,3,3,3,0\n",
"4,67,4,5,3,0\n",
"5,69,2,4,3,1\n",
"4,54,1,1,3,0\n",
"2,35,2,1,2,0\n",
"5,68,4,3,3,1\n",
"4,68,4,4,3,0\n",
"4,67,2,4,3,1\n",
"3,39,1,1,3,0\n",
"4,44,2,1,3,0\n",
"4,33,1,1,3,0\n",
"4,60,?,4,3,0\n",
"4,58,1,1,3,0\n",
"4,31,1,1,3,0\n",
"3,23,1,1,3,0\n",
"5,56,4,5,3,1\n",
"4,69,2,1,3,1\n",
"6,63,1,1,3,0\n",
"4,65,1,1,3,1\n",
"4,44,2,1,2,0\n",
"4,62,3,3,3,1\n",
"4,67,4,4,3,1\n",
"4,56,2,1,3,0\n",
"4,52,3,4,3,0\n",
"4,43,1,1,3,1\n",
"4,41,4,3,2,1\n",
"4,42,3,4,2,0\n",
"3,46,1,1,3,0\n",
"5,55,4,4,3,1\n",
"5,58,4,4,2,1\n",
"5,87,4,4,3,1\n",
"4,66,2,1,3,0\n",
"0,72,4,3,3,1\n",
"5,60,4,3,3,1\n",
"5,83,4,4,2,1\n",
"4,31,2,1,3,0\n",
"4,53,2,1,3,0\n",
"4,64,2,3,3,0\n",
"5,31,4,4,2,1\n",
"5,62,4,4,2,1\n",
"4,56,2,1,3,0\n",
"5,58,4,4,3,1\n",
"4,67,1,4,3,0\n",
"5,75,4,5,3,1\n",
"5,65,3,4,3,1\n",
"5,74,3,2,3,1\n",
"4,59,2,1,3,0\n",
"4,57,4,4,4,1\n",
"4,76,3,2,3,0\n",
"4,63,1,4,3,0\n",
"4,44,1,1,3,0\n",
"4,42,3,1,2,0\n",
"4,35,3,?,2,0\n",
"5,65,4,3,3,1\n",
"4,70,2,1,3,0\n",
"4,48,1,1,3,0\n",
"4,74,1,1,1,1\n",
"6,40,?,3,4,1\n",
"4,63,1,1,3,0\n",
"5,60,4,4,3,1\n",
"5,86,4,3,3,1\n",
"4,27,1,1,3,0\n",
"4,71,4,5,2,1\n",
"5,85,4,4,3,1\n",
"4,51,3,3,3,0\n",
"6,72,4,3,3,1\n",
"5,52,4,4,3,1\n",
"4,66,2,1,3,0\n",
"5,71,4,5,3,1\n",
"4,42,2,1,3,0\n",
"4,64,4,4,2,1\n",
"4,41,2,2,3,0\n",
"4,50,2,1,3,0\n",
"4,30,1,1,3,0\n",
"4,67,1,1,3,0\n",
"5,62,4,4,3,1\n",
"4,46,2,1,2,0\n",
"4,35,1,1,3,0\n",
"4,53,1,1,2,0\n",
"4,59,2,1,3,0\n",
"4,19,3,1,3,0\n",
"5,86,2,1,3,1\n",
"4,72,2,1,3,0\n",
"4,37,2,1,2,0\n",
"4,46,3,1,3,1\n",
"4,45,1,1,3,0\n",
"4,48,4,5,3,0\n",
"4,58,4,4,3,1\n",
"4,42,1,1,3,0\n",
"4,56,2,4,3,1\n",
"4,47,2,1,3,0\n",
"4,49,4,4,3,1\n",
"5,76,2,5,3,1\n",
"5,62,4,5,3,1\n",
"5,64,4,4,3,1\n",
"5,53,4,3,3,1\n",
"4,70,4,2,2,1\n",
"5,55,4,4,3,1\n",
"4,34,4,4,3,0\n",
"5,76,4,4,3,1\n",
"4,39,1,1,3,0\n",
"2,23,1,1,3,0\n",
"4,19,1,1,3,0\n",
"5,65,4,5,3,1\n",
"4,57,2,1,3,0\n",
"5,41,4,4,3,1\n",
"4,36,4,5,3,1\n",
"4,62,3,3,3,0\n",
"4,69,2,1,3,0\n",
"4,41,3,1,3,0\n",
"3,51,2,4,3,0\n",
"5,50,3,2,3,1\n",
"4,47,4,4,3,0\n",
"4,54,4,5,3,1\n",
"5,52,4,4,3,1\n",
"4,30,1,1,3,0\n",
"3,48,4,4,3,1\n",
"5,?,4,4,3,1\n",
"4,65,2,4,3,1\n",
"4,50,1,1,3,0\n",
"5,65,4,5,3,1\n",
"5,66,4,3,3,1\n",
"6,41,3,3,2,1\n",
"5,72,3,2,3,1\n",
"4,42,1,1,1,1\n",
"4,80,4,4,3,1\n",
"0,45,2,4,3,0\n",
"4,41,1,1,3,0\n",
"4,72,3,3,3,1\n",
"4,60,4,5,3,0\n",
"5,67,4,3,3,1\n",
"4,55,2,1,3,0\n",
"4,61,3,4,3,1\n",
"4,55,3,4,3,1\n",
"4,52,4,4,3,1\n",
"4,42,1,1,3,0\n",
"5,63,4,4,3,1\n",
"4,62,4,5,3,1\n",
"4,46,1,1,3,0\n",
"4,65,2,1,3,0\n",
"4,57,3,3,3,1\n",
"4,66,4,5,3,1\n",
"4,45,1,1,3,0\n",
"4,77,4,5,3,1\n",
"4,35,1,1,3,0\n",
"4,50,4,5,3,1\n",
"4,57,4,4,3,0\n",
"4,74,3,1,3,1\n",
"4,59,4,5,3,0\n",
"4,51,1,1,3,0\n",
"4,42,3,4,3,1\n",
"4,35,2,4,3,0\n",
"4,42,1,1,3,0\n",
"4,43,2,1,3,0\n",
"4,62,4,4,3,1\n",
"4,27,2,1,3,0\n",
"5,?,4,3,3,1\n",
"4,57,4,4,3,1\n",
"4,59,2,1,3,0\n",
"5,40,3,2,3,1\n",
"4,20,1,1,3,0\n",
"5,74,4,3,3,1\n",
"4,22,1,1,3,0\n",
"4,57,4,3,3,0\n",
"4,57,4,3,3,1\n",
"4,55,2,1,2,0\n",
"4,62,2,1,3,0\n",
"4,54,1,1,3,0\n",
"4,71,1,1,3,1\n",
"4,65,3,3,3,0\n",
"4,68,4,4,3,0\n",
"4,64,1,1,3,0\n",
"4,54,2,4,3,0\n",
"4,48,4,4,3,1\n",
"4,58,4,3,3,0\n",
"5,58,3,4,3,1\n",
"4,70,1,1,1,0\n",
"5,70,1,4,3,1\n",
"4,59,2,1,3,0\n",
"4,57,2,4,3,0\n",
"4,53,4,5,3,0\n",
"4,54,4,4,3,1\n",
"4,53,2,1,3,0\n",
"0,71,4,4,3,1\n",
"5,67,4,5,3,1\n",
"4,68,4,4,3,1\n",
"4,56,2,4,3,0\n",
"4,35,2,1,3,0\n",
"4,52,4,4,3,1\n",
"4,47,2,1,3,0\n",
"4,56,4,5,3,1\n",
"4,64,4,5,3,0\n",
"5,66,4,5,3,1\n",
"4,62,3,3,3,0\"\"\""
]
},
{
"cell_type": "markdown",
"id": "12dfb986-864d-402b-ba1a-c3eada639b32",
"metadata": {},
"source": [
"## Preprocessing\n",
"\n",
"The following code simply loads this data into a list:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "72512b87-5dff-4d53-ae06-d70f130b21ca",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data loaded (830 rows)\n"
]
}
],
"source": [
"from collections.abc import Sequence\n",
"import csv\n",
"\n",
"COL_BIRADS = 0\n",
"COL_AGE = 1\n",
"COL_SHAPE = 2\n",
"COL_MARGIN = 3\n",
"COL_DENSITY = 4\n",
"COL_SEVERITY = 5\n",
"\n",
"\n",
"# possible values for each column in the data\n",
"cancer_domains: Sequence[Sequence[int]] = [\n",
" range(1, 7), # BI-RADS\n",
" [0, 45, 55, 75], # age\n",
" range(1, 5), # shape\n",
" range(1, 6), # margin\n",
" range(1, 5), # density\n",
" range(2), # severity\n",
"]\n",
"\n",
"\n",
"def process_row(vals: Sequence[str]) -> Sequence[int] | None:\n",
" # omit rows that have missing data\n",
" if \"?\" in vals:\n",
" return None\n",
" birads, age, shape, margin, density, severity = map(int, vals)\n",
" # discretize age\n",
" if age >= 75:\n",
" age = 75\n",
" elif age >= 55:\n",
" age = 55\n",
" elif age >= 45:\n",
" age = 45\n",
" else:\n",
" age = 0\n",
" # fix typos in birads column\n",
" if birads == 0:\n",
" birads = 1\n",
" elif birads == 55:\n",
" birads = 5\n",
" return birads, age, shape, margin, density, severity\n",
"\n",
"\n",
"cancer_data = [\n",
" row2\n",
" for row in csv.reader(csv_file.split(\"\\n\"))\n",
" if (row2 := process_row(row)) is not None\n",
"]\n",
"\n",
"print(\"data loaded (%i rows)\" % len(cancer_data))"
]
},
{
"cell_type": "markdown",
"id": "a8e08a6f-399d-41c3-8edb-39a1b4e054d3",
"metadata": {},
"source": [
"## Counts\n",
"\n",
"We first process the data by calculating the counts $n(c)$ and $n(a_i,c)$ for every class value $c$ and attribute value $a_i$. It is fine if you do not fully understand this code."
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "74c97d50-1ac0-4784-91f4-992a86c87d53",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"from collections import Counter\n",
"from collections.abc import Mapping\n",
"from dataclasses import dataclass\n",
"\n",
"\n",
"@dataclass\n",
"class Model:\n",
" domains: Sequence[Sequence[int]] # possible values\n",
" c_column: int # class column index\n",
" a_columns: Sequence[int] # attribute column indices\n",
" n: int # N, total number of observations\n",
" nc: Mapping[int, int] # n(c) as nc[c]\n",
" nac: Mapping[int, Mapping[tuple[int, int], int]] # n(a_i,c) as nac[i][a_i,c]\n",
" s: float # smoothing constant\n",
"\n",
"\n",
"def train_model(\n",
" domains: Sequence[Sequence[int]],\n",
" data: Sequence[Sequence[int]],\n",
" c_column: int,\n",
" a_columns: Sequence[int],\n",
" s: float = 2.0,\n",
") -> Model:\n",
" assert all(all(val in vals for val, vals in zip(row, domains)) for row in data)\n",
" nc = Counter(row[c_column] for row in data)\n",
" nac = {\n",
" a_column: Counter((row[a_column], row[c_column]) for row in data)\n",
" for a_column in a_columns\n",
" }\n",
" return Model(\n",
" domains=domains,\n",
" c_column=c_column,\n",
" a_columns=a_columns,\n",
" n=len(data),\n",
" nc=nc,\n",
" nac=nac,\n",
" s=s,\n",
" )\n",
"\n",
"\n",
"cancer_model = train_model(\n",
" domains=cancer_domains,\n",
" data=cancer_data,\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
")"
]
},
{
"cell_type": "markdown",
"id": "3af29686-65d0-4a66-9a5d-2ee07d71f5c2",
"metadata": {},
"source": [
"We can now retrieve the counts very easily, as follows:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "916697a3-c660-40cf-b6e9-42521a5f6f22",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"427"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# number of patients without cancer (i.e. severity 0)\n",
"cancer_model.nc[0]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0324303b-8817-46cb-b5bf-4cbb471a8a42",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"403"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# number of patients with cancer (i.e. severity 1)\n",
"cancer_model.nc[1]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "04026d42-6f28-44ba-9dfb-eb1aa667f8a0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"286"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# number of severity 1 patients with BI-RADS assessment of 5\n",
"cancer_model.nac[COL_BIRADS][5, 1]"
]
},
{
"cell_type": "markdown",
"id": "4a69605e-f788-4650-9c08-6112aa2c4329",
"metadata": {},
"source": [
"**Exercise** Find the number of patients in the dataset, aged over 75, that had no cancer."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "3522a376-7b1c-4970-a608-a34b3b9b139f",
"metadata": {},
"outputs": [],
"source": [
"# write your solution here"
]
},
{
"cell_type": "markdown",
"id": "644ab916-35e2-45ff-ad6a-a4457002e424",
"metadata": {},
"source": [
"## Naive Bayes Classifier\n",
"\n",
"To do the classification, we must make a decision based on the probability values, which we can derive from the counts $n(c)$ and $n(a_i,c)$. Let us first implement the usual naive Bayes classifier, and then move to the naive credal classifier. Recall that, for a given vector of attributes $a=(a_1,\\dots,a_k)$,\n",
"we want to find the value for $c$ that maximizes\n",
"$$p(a,c)=p(c)\\prod_{i=1}^k p(a_i|c)$$\n",
"In case of the naive Bayes classifier, we use the maximum likelihood estimates for the probabilities,\n",
"which happen to be given by the relative frequencies in the data:\n",
"$$p(c)=n(c)/N\\qquad p(a_i|c)=n(a_i,c)/n(c)$$\n",
"Let's implement this:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "204a2ced-5cdb-4230-9a42-70bf23dc3109",
"metadata": {},
"outputs": [],
"source": [
"from math import prod\n",
"\n",
"\n",
"def naive_bayes_prob_1(model: Model, test_row: Sequence[int], c: int) -> float:\n",
" n = model.n\n",
" nc = model.nc[c]\n",
" nacs = [model.nac[a_column][test_row[a_column], c] for a_column in model.a_columns]\n",
" pc = nc / n\n",
" pacs = [nac / nc for nac in nacs]\n",
" return pc * prod(pacs)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2599d814-05d8-44fe-9bd1-6a239800f4b4",
"metadata": {},
"source": [
"There is however one technical problem, which arises when some counts are zero.\n",
"Namely, if $N=0$ then the maximum likelihood estimate $p(c)=n(c)/N$ is undefined.\n",
"Similarly, if $n(c)=0$, then $p(a_i|c)=n(a_i,c)/n(c)$ is undefined.\n",
"This may result in a ``ZeroDivisionError`` in the code.\n",
"\n",
"To handle this, we can instead use the Bayesian estimates of $p(c)$ and $p(a_i|c)$ under a Dirichlet prior,\n",
"which we saw in the lectures:\n",
"$$p(c)=\\frac{n(c)+st(c)}{N+s}\\qquad p(a_i|c)=\\frac{n(a_i,c)+st(a_i,c)}{n(c)+st(c)}$$\n",
"where we must fix a value for $s>0$,\n",
"and the values for $t(c)$ and $t(a_i,c)$, bearing in mind the constraints\n",
"$$\\sum_{c}t(c)=1\\qquad \\sum_{a_i}t(a_i,c)=t(c)$$\n",
"For $s$, $s=2$ is a sensible and common default.\n",
"The usual choice for the $t$ parameters is to pick these values symmetrically:\n",
"$$t(c)=1/|\\mathcal{C}|\\qquad t(a_i,c)=t(c)/|\\mathcal{A}_i|$$\n",
"where $|\\mathcal{C}|$ denotes the number of possible classes,\n",
"and $|\\mathcal{A}_i|$ denotes the number of possible values of the $i$th attribute.\n",
"\n",
"Let us implement this:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "3cadd53d-a550-4154-9127-9391a221d0c4",
"metadata": {},
"outputs": [],
"source": [
"def naive_bayes_prob_2(model: Model, test_row: Sequence[int], c: int) -> float:\n",
" tc: float = 1 / len(model.domains[model.c_column])\n",
" tacs: Sequence[float] = [\n",
" tc / len(model.domains[a_column]) for a_column in model.a_columns\n",
" ]\n",
" n = model.n + model.s\n",
" nc = model.nc[c] + model.s * tc\n",
" nacs = [\n",
" model.nac[a_column][test_row[a_column], c] + model.s * tac\n",
" for a_column, tac in zip(model.a_columns, tacs)\n",
" ]\n",
" # p(c)=(n(c)+s*t(c))/(N+s)\n",
" pc = nc / n\n",
" # p(a|c)=(n(a_i,c)+s*t(a_i,c))/(n(c)+s*t(c))\n",
" pacs = [nac / nc for nac in nacs]\n",
" return pc * prod(pacs)"
]
},
{
"cell_type": "markdown",
"id": "cb916543-9d44-46b1-bfbd-5c7792209bcf",
"metadata": {},
"source": [
"We now have everything in place to implement the naive Bayes classifier:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8bb9b97c-57a0-41fe-905a-feb30a4d15e2",
"metadata": {},
"outputs": [],
"source": [
"def naive_bayes_outcome(\n",
" model: Model, test_row: Sequence[int]\n",
") -> Sequence[float | None]:\n",
" c_domain = model.domains[model.c_column]\n",
" probs = {c: naive_bayes_prob_2(model, test_row, c) for c in c_domain}\n",
" max_prob = max(probs.values())\n",
" c_test = test_row[model.c_column]\n",
" return [1 if probs[c_test] + TOL >= max_prob else 0]"
]
},
{
"cell_type": "markdown",
"id": "81fe21a3-eb72-43be-a2cf-1804ff6db4f5",
"metadata": {},
"source": [
"The test returns a sequence containing a single number, either ``1`` if the naive Bayes classifier is correct, or ``0`` if the naive Bayes classifier is wrong. Returning this number inside a sequence makes little sense now, but when we will consider more complex measures, using a sequence will be very handy to report multiple measures at once.\n",
"\n",
"Let us test each row of the original data set. To report the accuracy of the classifier, we simply average all the values. Again, the implementation here is slightly more complex than need be at this point: we will also exclude all ``None`` outcomes. This will be very handy when we consider more complex measures later in the context of credal classification."
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3eecad57-63e7-474e-b942-e47030f5a8f5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.8385542168674699]"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from statistics import mean\n",
"\n",
"\n",
"def mean_outcome(outcomes: Sequence[Sequence[float | None]]) -> Sequence[float | None]:\n",
" def _mean(xs: Sequence[float | None]) -> float | None:\n",
" xs2 = [x for x in xs if x is not None]\n",
" return mean(xs2) if xs2 else None\n",
"\n",
" return list(map(_mean, zip(*outcomes)))\n",
"\n",
"\n",
"mean_outcome([naive_bayes_outcome(cancer_model, row) for row in cancer_data])"
]
},
{
"cell_type": "markdown",
"id": "d6267fa7-5efb-48a9-a405-27882c8f338f",
"metadata": {},
"source": [
"As we can see, the classifier has an accuracy of about 84%."
]
},
{
"cell_type": "markdown",
"id": "02118dd8-6085-45ab-b627-9a13e0dcf2fd",
"metadata": {},
"source": [
"## k-Fold Cross Validation\n",
"\n",
"We should not use the same data used for training also for testing. Instead, we should split the data, train on one part, and test on the other. The next function abstracts this idea. It is ok if you do not fully understand the code."
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "cedce15a-551a-4e6f-9f89-aa0d87d5d532",
"metadata": {},
"outputs": [],
"source": [
"from collections.abc import Callable\n",
"\n",
"\n",
"def kfcv_outcomes(\n",
" # test(model, test_row) -> sequence of accuracy measures\n",
" test: Callable[[Model, Sequence[int]], Sequence[float | None]],\n",
" folds: int,\n",
" domains: Sequence[Sequence[int]],\n",
" data: Sequence[Sequence[int]],\n",
" c_column: int,\n",
" a_columns: Sequence[int],\n",
" s: float = 2.0,\n",
") -> Sequence[Sequence[float | None]]:\n",
" outcomes = []\n",
" for fold in range(folds):\n",
" test_data = data[fold::folds]\n",
" test_indices = range(fold, len(data), folds)\n",
" train_data = [row for i, row in enumerate(data) if i not in test_indices]\n",
" model = train_model(domains, train_data, c_column, a_columns, s)\n",
" outcomes += [test(model, row) for row in test_data]\n",
" return outcomes"
]
},
{
"cell_type": "markdown",
"id": "27c027da-8b94-4489-9949-128f1b4f57e7",
"metadata": {},
"source": [
"Let's test it on the cancer data set."
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "60c2f510-d6c4-4ac0-8cde-2bc5f03cf8ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.8337349397590361]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mean_outcome(\n",
" kfcv_outcomes(\n",
" test=naive_bayes_outcome,\n",
" folds=10,\n",
" domains=cancer_domains,\n",
" data=cancer_data,\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "d08d1a41-1d8f-4730-9216-b4758eff387d",
"metadata": {},
"source": [
"We now have a full naive Bayes classifier running properly. Let's now move to the fun part: credal classification."
]
},
{
"cell_type": "markdown",
"id": "9f9be026-d87d-42cc-a7da-ecc4024c7622",
"metadata": {},
"source": [
"## Naive Credal Classifier\n",
"\n",
"To implement our naive credal classifier, all we need to do is modify the ``naive_bayes_outcome`` function a little bit so we check for interval maximality. However, for convenience, we use a conservative approximation for the interval which is very quick to evaluate. (In the project, you will derive the exact bounds and investigate the impact of this approximation.)\n",
"Specifically, we will use the following expressions that we derived in the lectures:\n",
"$$\n",
" \\underline{p}(c,a)\n",
" \\ge\n",
" \\underbrace{\\frac{n(c)}{N+s}}_{\\underline{p}(c)}\n",
" \\prod_{i=1}^k\n",
" \\underbrace{\\frac{n(a_i,c)}{n(c) + s}}_{\\underline{p}(a_i|c)}\n",
"\\qquad\n",
" \\overline{p}(c,a)\n",
" \\le\n",
" \\underbrace{\\frac{n(c)+s}{N+s}}_{\\overline{p}(c)}\n",
" \\prod_{i=1}^k\n",
" \\underbrace{\\frac{n(a_i,c) + s}{n(c) + s}}_{\\overline{p}(a_i|c)}\n",
"$$"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cf045e26-1877-4706-b11a-5a1cf76ab52a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.8409638554216867, 0.8384332925336597, 1, 2, 0.9843373493975903]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def naive_credal_prob(\n",
" model: Model, test_row: Sequence[int], c: int\n",
") -> tuple[float, float]:\n",
" def interval(a: float, b: float) -> tuple[float, float]:\n",
" return a / (b + model.s), (a + model.s) / (b + model.s)\n",
"\n",
" pc = interval(model.nc[c], model.n)\n",
" pacs = [\n",
" interval(model.nac[a_column][test_row[a_column], c], model.nc[c])\n",
" for a_column in model.a_columns\n",
" ]\n",
" return pc[0] * prod(pac[0] for pac in pacs), pc[1] * prod(pac[1] for pac in pacs)\n",
"\n",
"\n",
"def naive_credal_outcome(\n",
" model: Model, test_row: Sequence[int]\n",
") -> Sequence[float | None]:\n",
" c_domain = model.domains[model.c_column]\n",
" probs = {c: naive_credal_prob(model, test_row, c) for c in c_domain}\n",
" max_lowprob = max(low for low, upp in probs.values())\n",
" set_size = sum(1 if probs[c][1] + TOL >= max_lowprob else 0 for c in c_domain)\n",
" c_test = test_row[model.c_column]\n",
" correct = probs[c_test][1] + TOL >= max_lowprob\n",
" return [\n",
" 1 if correct else 0, # accuracy\n",
" (1 if correct else 0) if set_size == 1 else None, # single accuracy\n",
" (1 if correct else 0) if set_size != 1 else None, # set accuracy\n",
" set_size if set_size != 1 else None, # indeterminate set size\n",
" 1 if set_size == 1 else 0, # determinacy\n",
" ]\n",
"\n",
"\n",
"mean_outcome(\n",
" kfcv_outcomes(\n",
" test=naive_credal_outcome,\n",
" folds=10,\n",
" domains=cancer_domains,\n",
" data=cancer_data,\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "009eef51-4e80-4e3b-95fd-76b2535ce9af",
"metadata": {},
"source": [
"**Exercise** Why is the set accuracy equal to 100%, and why is the indeterminate set size equal to exactly 2?"
]
},
{
"cell_type": "markdown",
"id": "f7f19d95-a394-4e25-9887-402c9d9e6571",
"metadata": {},
"source": [
"*Write your answer here.*"
]
},
{
"cell_type": "markdown",
"id": "4781dca4-f372-4cf1-befe-105e7681289c",
"metadata": {},
"source": [
"**Exercise** Is credal classification useful for this data? If yes, why, if no, why not?"
]
},
{
"cell_type": "markdown",
"id": "ece80b33-2b88-40b8-a89a-ebb455079ca4",
"metadata": {},
"source": [
"*Write your answer here.*"
]
},
{
"cell_type": "markdown",
"id": "25e9acc4-f614-4308-bd3e-2d8a1302a14b",
"metadata": {},
"source": [
"## Additional Exercises"
]
},
{
"cell_type": "markdown",
"id": "16618ee7-3d9a-4f99-a508-c284bbacbd93",
"metadata": {},
"source": [
"**Exercise** Discuss the impact of the sample size on both the naive Bayes and on the credal classifier. The code below may be helpful."
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "767a5d64-3bad-4c83-a2e9-4582af3f6e0c",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15 [0.6] [1, 1, 1, 2, 0.06666666666666667]\n",
"30 [0.8] [0.9333333333333333, 0.875, 1, 2, 0.5333333333333333]\n",
"60 [0.8] [0.8833333333333333, 0.8541666666666666, 1, 2, 0.8]\n",
"120 [0.8083333333333333] [0.85, 0.8301886792452831, 1, 2, 0.8833333333333333]\n",
"240 [0.8375] [0.8541666666666666, 0.8444444444444444, 1, 2, 0.9375]\n",
"480 [0.8458333333333333] [0.8541666666666666, 0.8494623655913979, 1, 2, 0.96875]\n",
"830 [0.8337349397590361] [0.8409638554216867, 0.8384332925336597, 1, 2, 0.9843373493975903]\n"
]
}
],
"source": [
"for sample_size in [15, 30, 60, 120, 240, 480, 830]:\n",
" print(\n",
" sample_size,\n",
" mean_outcome(\n",
" kfcv_outcomes(\n",
" test=naive_bayes_outcome,\n",
" folds=10,\n",
" domains=cancer_domains,\n",
" data=cancer_data[:sample_size],\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
" )\n",
" ),\n",
" mean_outcome(\n",
" kfcv_outcomes(\n",
" test=naive_credal_outcome,\n",
" folds=10,\n",
" domains=cancer_domains,\n",
" data=cancer_data[:sample_size],\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
" )\n",
" ),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "f439e14a-6003-4075-81e6-054ecabfe2a6",
"metadata": {},
"source": [
"**Exercise** Discuss the impact of $s$ on the credal classifier. What value of $s$ seems most appropriate to you? The code below may be helpful."
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "194b4731-db79-4ecf-be3d-df71aeefcbc2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.01 [0.8337349397590361, 0.8331318016928658, 1, 2, 0.9963855421686747]\n",
"1 [0.8385542168674699, 0.8363858363858364, 1, 2, 0.9867469879518073]\n",
"2 [0.8409638554216867, 0.8384332925336597, 1, 2, 0.9843373493975903]\n",
"10 [0.8626506024096385, 0.8507853403141361, 1, 2, 0.9204819277108434]\n",
"100 [0.9771084337349397, 0.9351535836177475, 1, 2, 0.3530120481927711]\n",
"1000 [1, None, 1, 2, 0]\n"
]
}
],
"source": [
"for s in [0.01, 1, 2, 10, 100, 1000]:\n",
" print(\n",
" s,\n",
" mean_outcome(\n",
" kfcv_outcomes(\n",
" test=naive_credal_outcome,\n",
" folds=10,\n",
" domains=cancer_domains,\n",
" data=cancer_data,\n",
" c_column=COL_SEVERITY,\n",
" a_columns=[COL_BIRADS, COL_AGE, COL_SHAPE, COL_MARGIN, COL_DENSITY],\n",
" s=s,\n",
" )\n",
" ),\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "f611afb8-89aa-443c-8bd9-8921ff4dd388",
"metadata": {},
"source": [
"**Exercise** There is a specific value for $s$ under which the naive Bayes classifier (using maximum likelihood estimates for the probabilities) obtains as a special case of the credal classifier.\n",
"\n",
"1. Identify this value.\n",
"\n",
"2. Prove the claim using the formulae provided in the lectures.\n",
"\n",
"3. Interpret this claim in the context of Wald's theorem which links frequentist inference and Bayesian inference."
]
},
{
"cell_type": "markdown",
"id": "204a11f9-aebd-44d1-8568-6a8ba7c00441",
"metadata": {},
"source": [
"*Write your answer here.*"
]
},
{
"cell_type": "markdown",
"id": "c6f37412-f5aa-41a3-b4c2-baaac3f5de76",
"metadata": {},
"source": [
"**Exercise** It is customary, in classification, to simply learn the possible class values from the data.\n",
"However, in the code above, we explicitly state the possible values explicitly, through the ``domains`` parameter,\n",
"instead of deriving these values from the ``data`` parameter.\n",
"Explain why this is critical for the naive credal classifier, whilst it is less critical for the naive Bayes classifier.\n",
"\n",
"Hint: What happens if a class value does not appear in the training data? The code below might be helpful."
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "1a10657e-281b-4fb9-b873-f713c6780a88",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"bayes 1: [1]\n",
"bayes 2: [1]\n",
"credal 1: [1, None, 1, 2, 0]\n",
"credal 2: [1, None, 1, 3, 0]\n"
]
}
],
"source": [
"model_1 = train_model(\n",
" domains=[[0, 1], [0, 1]], # c ∈ {0,1}, a ∈ {0,1}\n",
" data=[[0, 0]], # single obervation (c=0,a=0)\n",
" c_column=0,\n",
" a_columns=[1],\n",
" s=2,\n",
")\n",
"model_2 = train_model(\n",
" domains=[[0, 1, 2], [0, 1]], # c ∈ {0,1,2}, a ∈ {0,1}\n",
" data=[[0, 0]], # same single observation\n",
" c_column=0,\n",
" a_columns=[1],\n",
" s=2,\n",
")\n",
"# predict class for a=0\n",
"print(\"bayes 1:\", naive_bayes_outcome(model=model_1, test_row=[0, 0]))\n",
"print(\"bayes 2:\", naive_bayes_outcome(model=model_2, test_row=[0, 0]))\n",
"print(\"credal 1:\", naive_credal_outcome(model=model_1, test_row=[0, 0]))\n",
"print(\"credal 2:\", naive_credal_outcome(model=model_2, test_row=[0, 0]))"
]
},
{
"cell_type": "markdown",
"id": "d6000ef7-3308-485e-aaee-7980fe34e2a3",
"metadata": {},
"source": [
"*Write your answer here.*"
]
},
{
"cell_type": "markdown",
"id": "31d54887-44b3-4e04-83be-e3168a3d039e",
"metadata": {},
"source": [
"# Project\n",
"\n",
"The aim of the project is\n",
"for you to learn more about machine learning with bounded probability.\n",
"It consists of 3 tasks:\n",
"\n",
"1. Further explore the breast cancer dataset that we introduced in the lectures.\n",
"\n",
"2. Derive some theoretical results to improve the probability bounds that we used in the lectures.\n",
" Use this theoretical result to improve the credal classifier from the lectures.\n",
"\n",
"3. Derive some theoretical results concerning\n",
" robust Bayes maximality and robust Bayes admissibility for the credal classifier.\n",
" Use this theoretical result to further improve the credal classifier from the lectures.\n",
"\n",
"Each task is subdivided in very specific subtasks, to guide you along.\n",
"Most subtasks require some coding in Python.\n",
"However, the 2nd and 3rd task also have subtasks\n",
"that concern purely theoretical questions to be solved on pen and paper.\n",
"\n",
"You may do all three tasks, or only a selection of them; this is up to you.\n",
"\n",
"Throughout, as a baseline,\n",
"the suggested sample size is $N=100$ (i.e. use ``data=cancer_data[:100]``)\n",
"and $s=2$ (this is the default if not specified).\n",
"However, you are encouraged to play around with these values if you believe it is useful."
]
},
{
"cell_type": "markdown",
"id": "93cce9e8-275c-4156-a887-6815acd48a4a",
"metadata": {},
"source": [
"## BI-RADS Analysis\n",
"\n",
"The ultimate goal of this data was to see if the doctor's BI-RADS assessment could be improved through image recognition (from which the shape, margin, and density attributes were derived).\n",
"\n",
"1. Run the classifier to check whether or not the additional attributes can replace the BI-RADS assessment.\n",
"\n",
"2. For their BI-RADS assessment, the doctor also has access to the images. Determine whether or not the classifier is good at predicting BI-RADS from the other attributes. Can you explain why the other attributes are good, or not so good, at predicting BI-RADS?\n",
"\n",
"3. When using the credal classifier in the previous part to predict BI-RADS, you will now notice that the set accuracy is no longer 100% and that the indeterminate set size is no longer 2. Why is that? Interpret the new values.\n",
"\n",
"4. Which of the attributes are most useful for classification? Should certain attributes be omitted?"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "875ef8cd-e1ca-49a8-98ed-3696064dc3fa",
"metadata": {},
"outputs": [],
"source": [
"# you can write your code here"
]
},
{
"cell_type": "markdown",
"id": "c86d2cfd-76e1-4ee3-99a1-269cf960bfd5",
"metadata": {},
"source": [
"## Exact Probability Bounds\n",
"\n",
"The ``naive_credal_prob`` function in the code above\n",
"uses the approximate bounds for $\\underline{p}(c,a)$ and $\\overline{p}(c,a)$ that we saw in the lectures:\n",
"$$\\underline{p}(c,a)=\\inf_{t}\\frac{n(c)+s t(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)+s t(a_i,c)}{n(c) + s t(c)}\\ge\\frac{n(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)}{n(c) + s}$$\n",
"$$\\overline{p}(c,a)=\\sup_{t}\\frac{n(c)+s t(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)+s t(a_i,c)}{n(c) + s t(c)}\\le\\frac{n(c)+s}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c) + s}{n(c) + s}$$\n",
"\n",
"1. This conservative approximation for the interval will impact the classifier in what way?\n",
"\n",
"2. (The result of this exercise was proved by Zacch Lines in his Master thesis.)\n",
"Find an expression for the exact values of $\\underline{p}(c,a)$ and $\\overline{p}(c,a)$.\n",
"Hint: First show that\n",
"$$\\underline{p}(c,a)=\\inf_{t(c)\\in[0,1]}\\frac{n(c)+s t(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)}{n(c) + s t(c)}$$\n",
"$$\\overline{p}(c,a)=\\sup_{t(c)\\in[0,1]}\\frac{n(c)+s t(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)+s t(c)}{n(c) + s t(c)}$$\n",
"Hint: One (but only one) of the approximate bounds will turn out to be exact.\n",
"\n",
"3. Implement your improved bounds in the code.\n",
"\n",
"4. Verify the impact your improved bounds have on the classification for a range of sample sizes, different values for $s$, and attributes. When doing so, pay particular attention to the number of attributes. For instance, investigate the impact when predicting severity just from BI-RADS, as opposed to predicting severity from all available attributes."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "37f156ff-1d56-4cdc-8b9a-c9ec11fbe94e",
"metadata": {},
"outputs": [],
"source": [
"# you can write your code here"
]
},
{
"cell_type": "markdown",
"id": "c2fef9fd-2fe1-4160-95f4-d51bff841acb",
"metadata": {},
"source": [
"## Robust Bayes Maximality\n",
"\n",
"We used interval maximality in our credal classifier,\n",
"as it is very easy to implement.\n",
"\n",
"1. How would you go about implementing robust Bayes maximality for the credal classifier?\n",
" Identify the computations required.\n",
"\n",
"2. Recall that\n",
" $$p_t(c,a)=\\frac{n(c)+s t(c)}{N+s}\\prod_{i=1}^k\\frac{n(a_i,c)+s t(a_i,c)}{n(c) + s t(c)}$$\n",
" with $\\sum_{c}t(c)=1$, $\\sum_{a_i}t(a_i,c)=t(c)$, $t(c)>0$, and $t(a_i,c)>0$.\n",
" Show that $p_t(c,a)>p_t(c',a)$\n",
" for all $t$ whenever (see Zaffalon, 2001)\n",
" $$\\left(\\frac{n(c')+s t(c')}{n(c)+s (1-t(c'))}\\right)^{k-1}\\left(\\prod_{i=1}^k \\frac{n(a_i,c)}{n(a_i,c')+s t(c')}\\right)>1$$\n",
" for all $t(c')$ such that $0 Sequence[bool]:\n",
" def is_not_dominated(c1: int) -> bool:\n",
" return all(not dominates(c2, c1) for c2 in cs)\n",
"\n",
" return [is_not_dominated(c1) for c1 in cs]\n",
"\n",
"\n",
"def naive_credal_outcome_2(\n",
" model: Model, test_row: Sequence[int]\n",
") -> Sequence[float | None]:\n",
"\n",
" def dominates(c1: int, c2: int) -> bool:\n",
" return all(\n",
" ... > 1 + TOL # TODO use zaffalon's formula\n",
" for t in [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99]\n",
" )\n",
"\n",
" c_domain = model.domains[model.c_column]\n",
" is_max_cs = is_maximal(dominates, c_domain)\n",
" set_size = sum(is_max_cs)\n",
" c_test = test_row[model.c_column]\n",
" correct = is_max_cs[c_domain.index(c_test)]\n",
" return [\n",
" 1 if correct else 0, # accuracy\n",
" (1 if correct else 0) if set_size == 1 else None, # single accuracy\n",
" (1 if correct else 0) if set_size != 1 else None, # set accuracy\n",
" set_size if set_size != 1 else None, # indeterminate set size\n",
" 1 if set_size == 1 else 0, # determinacy\n",
" ]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}