From 194086fa99ff4f25d4f57d93992758fc51e7be8a Mon Sep 17 00:00:00 2001 From: Chris Banks Date: Sat, 12 Oct 2024 12:19:50 +0100 Subject: [PATCH] Version for submission to PLOS CB --- Data_Curation.ipynb | 1058 ++++++++++ Data_Curation_UKFarmcare.ipynb | 425 ++++ Data_Curation_VetOnly.ipynb | 68 + README | 13 + Vet_Data_Analysis.ipynb | 769 +++++++ bTB-Diagnostic_2020_final_model_NEW.ipynb | 1812 +++++++++++++++++ ...tic_2020_final_model_VetOnly-Control.ipynb | 1637 +++++++++++++++ bTB-Diagnostic_2020_final_model_VetOnly.ipynb | 1741 ++++++++++++++++ ...agnostic_2020_v4_crossVal+tuning_NEW.ipynb | 674 ++++++ ...0_v4_crossVal+tuning_VetOnly-Control.ipynb | 696 +++++++ ...stic_2020_v4_crossVal+tuning_VetOnly.ipynb | 674 ++++++ 11 files changed, 9567 insertions(+) create mode 100644 Data_Curation.ipynb create mode 100644 Data_Curation_UKFarmcare.ipynb create mode 100644 Data_Curation_VetOnly.ipynb create mode 100644 README create mode 100644 Vet_Data_Analysis.ipynb create mode 100644 bTB-Diagnostic_2020_final_model_NEW.ipynb create mode 100644 bTB-Diagnostic_2020_final_model_VetOnly-Control.ipynb create mode 100644 bTB-Diagnostic_2020_final_model_VetOnly.ipynb create mode 100644 bTB-Diagnostic_2020_v4_crossVal+tuning_NEW.ipynb create mode 100644 bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly-Control.ipynb create mode 100644 bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly.ipynb diff --git a/Data_Curation.ipynb b/Data_Curation.ipynb new file mode 100644 index 0000000..f6a8e18 --- /dev/null +++ b/Data_Curation.ipynb @@ -0,0 +1,1058 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e08dd426-bb88-4d71-9a5c-a7019c3e2431", + "metadata": {}, + "outputs": [], + "source": [ + "using DataFrames\n", + "using CSV\n", + "using Dates\n", + "using ProgressMeter\n", + "#using ArchGDAL\n", + "using StatsBase" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8267d9fb-ab01-49d2-8c6b-fc1d17d3300d", + "metadata": {}, + "outputs": [], + "source": [ + "ENV[\"COLUMNS\"]=500;" + ] + }, + { + "cell_type": "markdown", + "id": "80dc3cf3-f5e4-4e2c-af25-5b799e5d4e52", + "metadata": {}, + "source": [ + "# Prepare data from SAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cf8fadf-a9a5-45f7-a2ab-16b83370c998", + "metadata": {}, + "outputs": [], + "source": [ + "## Load SAM Test data:\n", + "test_cols = [:ctTestDate,:ctTestType_FOI,:ctResultOfTest,:ctClearFlag,:ctAssetPK,:ctCPH,:ctCphh_fmt,:ctNumber,:ctSize,:ctReactors,:ctNumberNotTested,\n", + " :ctBreakId,:ctPartCode,:ctConfirmed,:ctTaken,:ctSHTaken,:ctCategory,:ctInterp,:ctSpecies,:ctSlaughteredIRs]\n", + "SAM_Test = CSV.read(\"/Data/SAM/tblccdTest.txt\", DataFrame, dateformat=\"d/m/yyyy HH:MM:SS\", select=test_cols)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "416aee5e-1424-4b7b-aa68-0075660786e4", + "metadata": {}, + "outputs": [], + "source": [ + "## Pull out Gamma tests (for later)\n", + "SAM_Test_Gamma = SAM_Test[(SAM_Test.ctTestDate.>=Date(\"2012\")).&&(isequal.(SAM_Test.ctCategory,\"GAMMA\")),:] #Using isequal avoids missing\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d41d75d-e7ab-47de-9098-dca22a5c41c2", + "metadata": {}, + "outputs": [], + "source": [ + "## Select SICCT skin tests from 2012 onwards:\n", + "SAM_Test = SAM_Test[(SAM_Test.ctTestDate.>=Date(\"2012\")).&&(isequal.(SAM_Test.ctCategory,\"TBSKINTEST\")),:] #Using isequal avoids missing\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a25ace9-6684-448d-aded-344fce551fee", + "metadata": {}, + "outputs": [], + "source": [ + "## Select only Whole Herd tests\n", + "#SAM_test_types = CSV.read(\"/Data/SAM/tlkccdTestType.txt\", DataFrame)\n", + "#WH_test_types = SAM_test_types[isequal.(SAM_test_types.ttTypeCode,\"WH\"),:ttCode]\n", + "\n", + "#SAM_Test = \n", + "#SAM_Test = SAM_Test[in(WH_test_types).(SAM_Test.ctTestType_FOI),:]\n", + "#;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8fc2beb-1bbc-4993-9d91-ce8a71419491", + "metadata": {}, + "outputs": [], + "source": [ + "## Remove records without a test result\n", + "SAM_Test = SAM_Test[(.!ismissing.(SAM_Test.ctResultOfTest)).&&(in([\"C\",\"NC\"]).(SAM_Test.ctResultOfTest)),:]\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75cc0474-1029-4c1a-b0be-b40d22979010", + "metadata": {}, + "outputs": [], + "source": [ + "## Convert DateTimes to Dates\n", + "SAM_Test.ctTestDate = convert.(Date,SAM_Test.ctTestDate)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18daa147-5a4a-4dfa-ade8-2b06175cb281", + "metadata": {}, + "outputs": [], + "source": [ + "## Load SAM Herd (Location) data\n", + "herd_cols = [:chAssetPK,:chMapX,:chMapY,:chType,:chSpecies,:chHerdSize_CTS]\n", + "SAM_Herd = CSV.read(\"/Data/SAM/tblccdHerd_FOI.txt\", DataFrame, select=herd_cols)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a366b8be-cc10-4543-ab0e-1bb1cf044d24", + "metadata": {}, + "outputs": [], + "source": [ + "SAM_joined = innerjoin(SAM_Test,SAM_Herd,on=:ctAssetPK=>:chAssetPK)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d9a59e9-8ef8-41cb-9e63-a74548e2c23d", + "metadata": {}, + "outputs": [], + "source": [ + "## Load SAM Breakdowns\n", + "bd_cols = [:cbAssetPK,:cbBreakDate,:cbConfDate,:cbTB10Date]\n", + "SAM_Breakdown = CSV.read(\"/Data/SAM/tblccdBreakdown.txt\", DataFrame, select=bd_cols, delim='|', dateformat=\"d/m/yyyy HH:MM:SS\")\n", + "SAM_Breakdown.cbBreakDate = convert.(Date,SAM_Breakdown.cbBreakDate)\n", + "SAM_Breakdown.cbConfDate = passmissing(convert).(Date,SAM_Breakdown.cbConfDate)\n", + "SAM_Breakdown.cbTB10Date = passmissing(convert).(Date,SAM_Breakdown.cbTB10Date)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eda3396a-27a3-429a-9fd4-57f91f7a0ba0", + "metadata": {}, + "outputs": [], + "source": [ + "## Group breakdowns by asset, for fast searching\n", + "SAM_Breakdown_byAsset = groupby(SAM_Breakdown,:cbAssetPK)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0337854a-fc12-4fbb-8b0f-9a0fe886112f", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: previousBreakdown(assetPK,testDate,breakdownTable)\n", + "## Returns the date of the last breakdown before testDate for the asset (herd)\n", + "## or \"missing\" if no previous breakdowns.\n", + "## Set pv=2 e.g. to get 2nd last, etc.\n", + "## Uses breakdown table grouped by asset for fast searching\n", + "function previousBreakdown(assetPK,testDate,pv=1)\n", + " if (assetPK,) in keys( SAM_Breakdown_byAsset)\n", + " bds = SAM_Breakdown_byAsset[(assetPK,)]\n", + " else\n", + " return missing\n", + " end\n", + " pvbds = bds[bds.cbBreakDate. unique |> sort\n", + " if length(pvbds)>pv-1\n", + " return pvbds[end-pv+1]\n", + " else\n", + " return missing\n", + " end\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c0af129-8e87-45e0-8217-b5013092c8af", + "metadata": {}, + "outputs": [], + "source": [ + "## Add previous breakdown dates to SAM_joined table\n", + "@time SAM_joined.previousBreakdown = previousBreakdown.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate)\n", + "@time SAM_joined.previousBreakdown2 = previousBreakdown.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate,2)\n", + "; " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7e33835-d2b8-4524-8d0f-6c6307ec9518", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: dayValue(x)\n", + "## Converts Day type to Int value, handles missing\n", + "function dayValue(x::Union{Missing,Day})\n", + " if ismissing(x)\n", + " missing\n", + " else\n", + " x.value\n", + " end\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9701d39c-0641-455b-9209-83600a0cc0c1", + "metadata": {}, + "outputs": [], + "source": [ + "## Get number of days since last breakdown\n", + "SAM_joined.daysSinceBreakdown = map(dayValue, SAM_joined.ctTestDate - SAM_joined.previousBreakdown)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46c74da4-8cc8-4f2a-8a35-03e2bcdbbc00", + "metadata": {}, + "outputs": [], + "source": [ + "## Get SAM table grouped by asset, for fast searching\n", + "SAM_joined_byAsset = groupby(SAM_joined,:ctAssetPK)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ed08234-e41f-4ada-8f42-5381b2f6c99d", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: previousTest(assetPK,testDate)\n", + "## Returns the result of the previous test before testDate for the asset (herd)\n", + "## or \"missing\" if no previous tests.\n", + "## Set pv=2 e.g. to get 2nd last, etc.\n", + "## Uses the SAM table grouped by asset, for fast searching\n", + "function previousTest(assetPK,testDate,pv=1)\n", + " tests = SAM_joined_byAsset[(assetPK,)]\n", + " pvtests = tests[tests.ctTestDate.pv-1\n", + " pvtests.ctResultOfTest[end-pv+1]\n", + " else\n", + " missing\n", + " end\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "361c3266-3b1c-4993-9c42-1a4ae2aeba99", + "metadata": {}, + "outputs": [], + "source": [ + "## Get results of previous tests\n", + "@time SAM_joined.previousResultOfTest = previousTest.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate)\n", + "@time SAM_joined.previousResultOfTest2 = previousTest.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate,2)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "187edc37-1315-4f73-b8fc-e243f8b45ca4", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: daysSincePreviousTest(assetPK,testDate)\n", + "## Returns the days since previous test before testDate for the asset (herd)\n", + "## or \"missing\" if no previous tests.\n", + "## Set pv=2 e.g. to get 2nd last, etc.\n", + "## Uses the SAM table grouped by asset, for fast searching\n", + "function daysSincePreviousTest(assetPK,testDate,pv=1)\n", + " tests = SAM_joined_byAsset[(assetPK,)]\n", + " pvtests = tests[tests.ctTestDate.pv-1\n", + " dayValue(testDate - pvtests.ctTestDate[end-pv+1])\n", + " else\n", + " missing\n", + " end\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "189b525a-01da-41bf-8d55-77a89feb5f0a", + "metadata": {}, + "outputs": [], + "source": [ + "## Get days since previous tests\n", + "@time SAM_joined.daysSincePreviousTest = daysSincePreviousTest.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a1df291-5eed-42fe-ac86-c547703a4bf1", + "metadata": {}, + "outputs": [], + "source": [ + "## Load Defra risk score table from SAM\n", + "SAM_RiskScore = CSV.read(\"/Data/SAM/tblRBT_Scores.txt\", DataFrame)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f875fff6-675e-4b7a-839b-bbb2fd14aa10", + "metadata": {}, + "outputs": [], + "source": [ + "## Group risk scores by cph, for fast searching\n", + "SAM_RiskScore_byCPH = groupby(SAM_RiskScore,:CPH)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e46deb6e-a7e7-4a7c-ae8a-3f30b24444aa", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: defraRiskScore(assetPK,testDate)\n", + "## Get the Defra risk score for the holding on the date of the test\n", + "function defraRiskScore(cph,testDate)\n", + " if (cph,) in keys(SAM_RiskScore_byCPH)\n", + " scores = SAM_RiskScore_byCPH[(cph,)]\n", + " result = scores.RiskScore[(scores.FromDate.<=testDate).&(scores.ToDate.>=testDate)]\n", + " if length(result) > 0\n", + " return result[end]\n", + " else\n", + " return missing\n", + " end\n", + " else\n", + " return missing\n", + " end\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06d22218-199f-4f43-b7d1-5318ce48ee58", + "metadata": {}, + "outputs": [], + "source": [ + "## Get Defra risk score for holding at test date\n", + "@time SAM_joined.defraRiskScore = defraRiskScore.(SAM_joined.ctCPH,SAM_joined.ctTestDate)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f4ea55c-c79b-44c1-9385-fc9ee9898339", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: breakdownConfirmed(assetPK,testDate,x)\n", + "## Returns whether there was a breakdown within x days of testDate for the asset (herd)\n", + "## Uses breakdown table grouped by asset for fast searching\n", + "function confirmedBreakdown(assetPK,testDate,x)\n", + " if (assetPK,) in keys(SAM_Breakdown_byAsset)\n", + " bds = SAM_Breakdown_byAsset[(assetPK,)]\n", + " else\n", + " return false\n", + " end\n", + " bds = dropmissing(bds)\n", + " fcbds = bds[(bds.cbConfDate.>=testDate),:cbConfDate]\n", + " cbds = fcbds[fcbds.<=testDate+Day(x)]\n", + " return length(cbds)>0\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f08ede4a-b311-4ab4-be87-593f31ba652d", + "metadata": {}, + "outputs": [], + "source": [ + "## Get whether the test resulted in a confirmed (lesion or culture)\n", + "@time SAM_joined.confirmedBreakdown = confirmedBreakdown.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate,90)\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "ac11d232-667f-4666-92e4-685b06d8284f", + "metadata": {}, + "source": [ + "# Gamma testing (as proxy for badger culling?)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74b3bd7e-ee23-4c60-b2ac-2fc1a5dbccf5", + "metadata": {}, + "outputs": [], + "source": [ + "## Group gamma tests by asset (for fast searching)\n", + "SAM_Test_Gamma_byAsset = groupby(SAM_Test_Gamma,:ctAssetPK)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cd819c1-b7be-4219-bc2d-6774dbbf7716", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: gammaCount(assetPK,testDate)\n", + "## Gets then number of gamma tests per asset prior to test date\n", + "function gammaCount(assetPK,testDate)\n", + " if (assetPK,) in keys(SAM_Test_Gamma_byAsset)\n", + " h = SAM_Test_Gamma_byAsset[(assetPK,)]\n", + " return nrow(h[h.ctTestDate.<=testDate,:])\n", + " else\n", + " return 0\n", + " end\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80f23c58-5f54-4d7e-a5f5-a144d288e5fd", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "## Get the number of prior gamma tests in the herd\n", + "@time SAM_joined.gammaTestCount = gammaCount.(SAM_joined.ctAssetPK,SAM_joined.ctTestDate)\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "636e0897-4da7-4733-aad9-9a846198fd55", + "metadata": {}, + "source": [ + "# Prepare data from CTS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "357bb64b-6725-4ab4-8528-bb2cffc2a18c", + "metadata": {}, + "outputs": [], + "source": [ + "## Load CTS movements data:\n", + "move_cols = [:MovementDate,:OffLocationKey,:OnLocationKey,:Birth,:Death]\n", + "CTS_moveT = CSV.read(\"/Data/CTS/tblMovementTransition.csv\", DataFrame, select=move_cols, dateformat=\"yyyy-mm-dd HH:MM:SS\")\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "204030d0-571c-4100-858d-4fcd4cb7a35e", + "metadata": {}, + "outputs": [], + "source": [ + "## Remove missing dates\n", + "dropmissing!(CTS_moveT,:MovementDate)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46b061b3-9334-4a35-86e3-9d76e4c9a5fe", + "metadata": {}, + "outputs": [], + "source": [ + "## Since 2012\n", + "CTS_moveT = CTS_moveT[CTS_moveT.MovementDate.>=Date(\"2012\"),:]\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f0f611a9-1be8-4467-94a0-b7f79fb56b5a", + "metadata": {}, + "outputs": [], + "source": [ + "## Convert DateTimes to Dates\n", + "CTS_moveT.MovementDate = convert.(Date,CTS_moveT.MovementDate)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "69da6176-3917-46c9-b299-c661b2eb5bab", + "metadata": {}, + "outputs": [], + "source": [ + "## Load CTS locations:\n", + "loc_cols = [:LocationKey,:CurrentSamCPH]\n", + "CTS_loc = CSV.read(\"/Data/CTS/tblLocation_fixed.csv\",DataFrame,select=loc_cols)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4f41834-f62f-4e86-b5d6-1a38c9410850", + "metadata": {}, + "outputs": [], + "source": [ + "## Map CTS Location IDs to CPH (Int)\n", + "lockey2cph_map = Dict(zip(CTS_loc.LocationKey, CTS_loc.CurrentSamCPH))\n", + "function lockey2cph(x)\n", + " if x in keys(lockey2cph_map)\n", + " lockey2cph_map[x]\n", + " else \n", + " missing\n", + " end\n", + "end\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2905385-aced-47c2-92a4-54d0ec0bea41", + "metadata": {}, + "outputs": [], + "source": [ + "## Add Off/On location CPHs\n", + "CTS_moveT.OffCPH = lockey2cph.(CTS_moveT.OffLocationKey)\n", + "CTS_moveT.OnCPH = lockey2cph.(CTS_moveT.OnLocationKey)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "267eed71-d9cf-4aad-bbba-301abf186426", + "metadata": {}, + "outputs": [], + "source": [ + "## Group by CPH (on and off) for fast searching\n", + "CTS_moveT_byOnCPH = groupby(CTS_moveT,:OnCPH)\n", + "CTS_moveT_byOffCPH = groupby(CTS_moveT,:OffCPH)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf16b072-9e48-4e29-87f6-408c46381d76", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: inflow(CPH,testDate,d)\n", + "## Gets the number of animals moved onto asset within d days prior to test date\n", + "function inflow(cph_i,testDate,d)\n", + " cph = string(cph_i)\n", + " if (cph,) in keys(CTS_moveT_byOnCPH)\n", + " h = CTS_moveT_byOnCPH[(cph,)]\n", + " return nrow(h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)),:])\n", + " else\n", + " return 0\n", + " end\n", + "end;\n", + "\n", + "## Function: outflow(CPH,testDate,d)\n", + "## Gets the number of animals moved off asset within d days prior to test date\n", + "function outflow(cph_i,testDate,d)\n", + " cph = string(cph_i)\n", + " if (cph,) in keys(CTS_moveT_byOffCPH)\n", + " h = CTS_moveT_byOffCPH[(cph,)]\n", + " return nrow(h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)),:])\n", + " else\n", + " return 0\n", + " end\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "473654ed-52a0-4a14-bda0-d6e7798c9918", + "metadata": {}, + "outputs": [], + "source": [ + "## Get the number number of animals moved into the herd within d days prior to test\n", + "@time SAM_joined.inflow1 = inflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365)\n", + "@time SAM_joined.inflow2 = inflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*2)\n", + "@time SAM_joined.inflow4 = inflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 356*4)\n", + "@time SAM_joined.inflow90 = inflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 90)\n", + ";\n", + "\n", + "## Get the number number of animals moved out of the herd within d days prior to test\n", + "@time SAM_joined.outflow1 = outflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365)\n", + "@time SAM_joined.outflow2 = outflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*2)\n", + "@time SAM_joined.outflow4 = outflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*4)\n", + "@time SAM_joined.outflow90 = outflow.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 90)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02628ae6-3dc0-4236-84d4-44e5e1768af2", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: breakdown_within(assetPK,testDate,d)::Bool\n", + "## Does the location have a confirmed breakdown within d days of testDate? (Before or after)\n", + "function breakdown_within(loc,testDate,d)\n", + " if ismissing(loc)\n", + " return false\n", + " end\n", + " if (loc,) in keys(SAM_Breakdown_byAsset)\n", + " # get breakdowns at location\n", + " bd = SAM_Breakdown_byAsset[(loc,)]\n", + " # filter missing confirm date\n", + " bd = dropmissing(bd,:cbConfDate)\n", + " # filter confirmed within d days\n", + " bd = bd[(bd.cbConfDate.>=testDate-Day(d)).&(bd.cbConfDate.0\n", + " else\n", + " return false\n", + " end\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e14c4ceb-1f47-4a35-af18-8f5cdc3bfdff", + "metadata": {}, + "outputs": [], + "source": [ + "## Function inflow_breakdown(CPH,testDate,d, b)\n", + "## Gets the number of animals moved onto the farm, within d days, from farms that had a breakdown within b days.\n", + "function inflow_breakdown(cph_i,testDate,d,b)\n", + " cph = string(cph_i)\n", + " if (cph,) in keys(CTS_moveT_byOnCPH)\n", + " # get moves onto location\n", + " h = CTS_moveT_byOnCPH[(cph,)]\n", + " # filter within d days of test date\n", + " h = h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)),:]\n", + " # filter for moves from CPH with breakdown within b days\n", + " h = h[breakdown_within.(h.OffLocationKey,testDate,b),:]\n", + " # return count\n", + " return nrow(h)\n", + " else\n", + " return 0\n", + " end\n", + "end;\n", + "\n", + "## Function outflow_breakdown(CPH,testDate,d, b)\n", + "## Gets the number of animals moved off the farm, within d days, onto farms that had a breakdown within b days.\n", + "function outflow_breakdown(cph_i,testDate,d,b)\n", + " cph = string(cph_i)\n", + " if (cph,) in keys(CTS_moveT_byOffCPH)\n", + " # get moves off location\n", + " h = CTS_moveT_byOffCPH[(cph,)]\n", + " # filter within d days of test date\n", + " h = h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)),:]\n", + " # filter for moves from CPH with breakdown within b days\n", + " h = h[breakdown_within.(h.OnLocationKey,testDate,b),:]\n", + " # return count\n", + " return nrow(h)\n", + " else\n", + " return 0\n", + " end\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "591db580-93e1-4733-af83-b7b15f8d2538", + "metadata": {}, + "outputs": [], + "source": [ + "## Get the number number of animals moved into the herd within d days prior to test\n", + "## from farms that had a breakdown within 2 years\n", + "@time SAM_joined.inflowBD1 = inflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365, 365*2)\n", + "@time SAM_joined.inflowBD2 = inflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*2, 365*2)\n", + "@time SAM_joined.inflowBD4 = inflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*4, 365*2)\n", + "@time SAM_joined.inflowBD90 = inflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 90, 365*2)\n", + "\n", + "## Get the number number of animals moved off the herd within d days prior to test\n", + "## to farms that had a breakdown within 2 years\n", + "@time SAM_joined.outflowBD1 = outflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365, 365*2)\n", + "@time SAM_joined.outflowBD2 = outflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*2, 365*2)\n", + "@time SAM_joined.outflowBD4 = outflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 365*4, 365*2)\n", + "@time SAM_joined.outflowBD90 = outflow_breakdown.(SAM_joined.ctCPH, SAM_joined.ctTestDate, 90, 365*2)\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "62a0c989-5db0-4798-a1c9-dc5dbda4b3b4", + "metadata": {}, + "source": [ + "## TODO:\n", + "\n", + "* ~~Inflow / outflow from breakdown farms~~\n", + "* ~~farm type /~~\n", + "* ~~test type~~ / risk area\n", + "* Age / breed\n", + "* Herd on/off restriction\n", + "* Defra risk score?\n", + "* Seasonality (month of year?)\n", + "* Birth/deaths need linking from Animals table (no births in movements, despite column, deaths may be just moves to slaughter?)\n", + "* Moves from HRA herds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92869f17-48ce-4389-99ae-bdec098b584d", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: births(CPH,testDate,d)\n", + "## Gets the number of animal births onto asset within d days prior to test date\n", + "function births(cph,testDate,d)\n", + " if (cph,) in keys(CTS_moveT_byOffCPH)\n", + " h = CTS_moveT_byOffCPH[(cph,)]\n", + " return nrow(h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)).&(h.Birth),:])\n", + " else\n", + " return 0\n", + " end\n", + "end;\n", + "\n", + "##\n", + "\n", + "## Function: deaths(CPH,testDate,d)\n", + "## Gets the number of animal deaths at asset within d days prior to test date\n", + "function deaths(cph,testDate,d)\n", + " if (cph,) in keys(CTS_moveT_byOffCPH)\n", + " h = CTS_moveT_byOffCPH[(cph,)]\n", + " return nrow(h[(h.MovementDate.<=testDate).&(h.MovementDate.>testDate-Day(d)).&(h.Death),:])\n", + " else\n", + " return 0\n", + " end\n", + "end;" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9dcc56a0-66a5-4bc4-9aa2-24f3208e6d24", + "metadata": {}, + "outputs": [], + "source": [ + "## Birth/deaths need linking from Animals table (no births in movements, despite column, deaths may be just moves to slaughter?)" + ] + }, + { + "cell_type": "markdown", + "id": "a22ab481-ad9b-44c9-96ce-8db4afe302d5", + "metadata": {}, + "source": [ + "# Add Vet data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8edaa04-d102-44fb-8134-3899ef62505c", + "metadata": {}, + "outputs": [], + "source": [ + "## Load Vet/Tuberculin data from UKFarmcare\n", + "Vet_data = CSV.read(\"/Data/TB_Diagnostics/vetData.csv\", DataFrame)\n", + "Vet_data_noCat = CSV.read(\"/Data/TB_Diagnostics/vetData_nonCat.csv\", DataFrame) #same without categorisation\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3571bef6-f0cb-4f82-9167-a67a4938d416", + "metadata": {}, + "outputs": [], + "source": [ + "## Join the vet data with the SAM data\n", + "SAM_joined_noCat = leftjoin(SAM_joined, Vet_data_noCat, on = [:ctTestDate => :Date, :ctCphh_fmt => :CPH], validate=(false, true)) #first without top 250 ctegrisation of vet data\n", + "SAM_joined = leftjoin(SAM_joined, Vet_data, on = [:ctTestDate => :Date, :ctCphh_fmt => :CPH], validate=(false, true)) #then with categorisation\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "97fd2a05-e4b2-4b3b-9695-6bd504a13ef4", + "metadata": {}, + "source": [ + "# Badger data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04c8601b-14f4-4320-b8a9-e43962a4a698", + "metadata": {}, + "outputs": [], + "source": [ + "# Load badger abundance\n", + "badger_data = CSV.read(\"/Data/TB_Diagnostics/Badgers/badgersAbundancePerHexCell.csv\", DataFrame)\n", + "# Load cell locations\n", + "badger_cells = CSV.read(\"/Data/TB_Diagnostics/Badgers/locationsPerHexCell.csv\", DataFrame)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f27f8913-5db4-4efa-9731-1a15f7500e1a", + "metadata": {}, + "outputs": [], + "source": [ + "# match CPHs to badger abundance\n", + "badger_cph = leftjoin(badger_cells,badger_data, on=:hexCellID)[:,[:CPH,:meanBadgerAbundance]]\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb5823f3-5ca5-4310-ac1d-844707b03eb7", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# Join with SAM data\n", + "SAM_joined = leftjoin(SAM_joined,badger_cph,on=:ctCPH=>:CPH)\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38b1d963-4bed-44bc-98bd-fd2e394da8d0", + "metadata": {}, + "outputs": [], + "source": [ + "(ismissing.(SAM_joined.meanBadgerAbundance)|>sum) / nrow(SAM_joined)" + ] + }, + { + "cell_type": "markdown", + "id": "c7604d8d-c913-4d55-bf92-0017788d571c", + "metadata": {}, + "source": [ + "# Extract features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "abd4dd96-1f67-49be-958b-e6d6e7c5d9c9", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: cat2int(x)\n", + "## Transform a vector of categorical values into integers representing each category\n", + "function cat2int(v)\n", + " s = Set(v)\n", + " d = Dict(collect(zip(s,1:length(s))))\n", + " map(x->d[x], v)\n", + "end" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e86c5d4-e8a9-45cc-9a40-6e16d539c0ea", + "metadata": {}, + "outputs": [], + "source": [ + "## Construct Input Vars table\n", + "inputVars = DataFrame()\n", + "\n", + "## From SAM\n", + "inputVars.dateOfTest = SAM_joined.ctTestDate\n", + "inputVars.resultOfTest = SAM_joined.ctResultOfTest.==\"NC\"\n", + "inputVars.monthOfTest = month.(SAM_joined.ctTestDate)\n", + "inputVars.severe = SAM_joined.ctInterp.==\"SEVERE\"\n", + "inputVars.animalsTested = SAM_joined.ctNumber\n", + "inputVars.locationX = SAM_joined.chMapX\n", + "inputVars.locationY = SAM_joined.chMapY\n", + "inputVars.previousTestResult = SAM_joined.previousResultOfTest.==\"NC\"\n", + "inputVars.previousTestResult2 = SAM_joined.previousResultOfTest2.==\"NC\"\n", + "inputVars.daysSincePreviousTest = SAM_joined.daysSincePreviousTest\n", + "inputVars.daysSinceBreakdown = SAM_joined.daysSinceBreakdown\n", + "inputVars.gammaTestCount = SAM_joined.gammaTestCount\n", + "inputVars.testType = cat2int(SAM_joined.ctTestType_FOI)\n", + "inputVars.herdType = cat2int(SAM_joined.chType)\n", + "inputVars.herdSize = SAM_joined.chHerdSize_CTS\n", + "inputVars.defraRiskScore = SAM_joined.defraRiskScore\n", + "\n", + "## From CTS\n", + "inputVars.inflow1 = SAM_joined.inflow1\n", + "inputVars.inflow2 = SAM_joined.inflow2\n", + "inputVars.inflow4 = SAM_joined.inflow4\n", + "inputVars.inflow90 = SAM_joined.inflow90\n", + "inputVars.outflow1 = SAM_joined.outflow1\n", + "inputVars.outflow2 = SAM_joined.outflow2\n", + "inputVars.outflow4 = SAM_joined.outflow4\n", + "inputVars.outflow90 = SAM_joined.outflow90\n", + "\n", + "inputVars.inflowBD1 = SAM_joined.inflowBD1\n", + "inputVars.inflowBD2 = SAM_joined.inflowBD2\n", + "inputVars.inflowBD4 = SAM_joined.inflowBD4\n", + "inputVars.inflowBD90 = SAM_joined.inflowBD90\n", + "inputVars.outflowBD1 = SAM_joined.outflowBD1\n", + "inputVars.outflowBD2 = SAM_joined.outflowBD2\n", + "inputVars.outflowBD4 = SAM_joined.outflowBD4\n", + "inputVars.outflowBD90 = SAM_joined.outflowBD90\n", + "\n", + "## From vet data\n", + "inputVars.vetPractice = SAM_joined.Practice\n", + "inputVars.batchBovine = SAM_joined.BatchBovine\n", + "inputVars.batchAvian = SAM_joined.BatchAvian\n", + "\n", + "## Badgers\n", + "inputVars.meanBadgerAbundance = SAM_joined.meanBadgerAbundance\n", + "\n", + "## Target var (breakdown within 90 days)\n", + "inputVars.confirmedBreakdown = SAM_joined.confirmedBreakdown\n", + ";" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "172deb53-7da0-4218-aa46-b34c74cba1ce", + "metadata": {}, + "outputs": [], + "source": [ + "## Uncategorised verison\n", + "inputVars_noCat = copy(inputVars)\n", + "inputVars_noCat.vetPractice = SAM_joined_noCat.Practice\n", + "inputVars_noCat.batchBovine = SAM_joined_noCat.BatchBovine\n", + "inputVars_noCat.batchAvian = SAM_joined_noCat.BatchAvian\n", + "inputVars_noCat.herdType = SAM_joined.chType\n", + "inputVars_noCat.testType = SAM_joined.ctTestType_FOI\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "6fdf60a2-41bc-4b1d-bb75-0acae7e27ffa", + "metadata": {}, + "source": [ + "# Write to file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1957496-8c20-413b-8383-daa1507dcf7c", + "metadata": {}, + "outputs": [], + "source": [ + "## Store Input Vars table\n", + "CSV.write(\"/Data/TB_Diagnostics/inputVars.csv\",inputVars)\n", + "CSV.write(\"/Data/TB_Diagnostics/inputVars_noCat.csv\",inputVars_noCat) #uncategorised version" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2f1eaf70-f256-4e25-a0cc-da7c6175bc4d", + "metadata": {}, + "outputs": [], + "source": [ + "## Write file without dates\n", + "##CSV.write(\"/Data/TB_Diagnostics/inputVars_nodate.csv\",inputVars[:,2:end])" + ] + }, + { + "cell_type": "markdown", + "id": "ea694377-11da-43e4-9a35-38e95936c51b", + "metadata": {}, + "source": [ + "---\n", + "---\n", + "# Testing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cddf5cac-b041-4f5d-9d78-42532f08b09d", + "metadata": {}, + "outputs": [], + "source": [ + "#SAM_Test[SAM_Test.ctTestDate.==Date(\"2019-01-01\"),:]" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Julia 1.10.0", + "language": "julia", + "name": "julia-1.10" + }, + "language_info": { + "file_extension": ".jl", + "mimetype": "application/julia", + "name": "julia", + "version": "1.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Data_Curation_UKFarmcare.ipynb b/Data_Curation_UKFarmcare.ipynb new file mode 100644 index 0000000..6eba662 --- /dev/null +++ b/Data_Curation_UKFarmcare.ipynb @@ -0,0 +1,425 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "2547b032-1cc9-4e95-bbee-d0793bf0bc0b", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import geopandas as gp\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sb\n", + "import numpy as np\n", + "from scipy.stats import binomtest" + ] + }, + { + "cell_type": "markdown", + "id": "cf0cc169-1344-4f8d-8280-e6cca7a62133", + "metadata": {}, + "source": [ + "# Prepare veterinary data" + ] + }, + { + "cell_type": "markdown", + "id": "2d2f3add-f582-4ce3-b9bf-d9681bb3b3d5", + "metadata": {}, + "source": [ + "This script prepares the veterinary and tuberculin batch data from UKFarmcare for the TB Diagnostics model." + ] + }, + { + "cell_type": "markdown", + "id": "97f08ba6-9459-4ad0-8fac-6df76a8e4977", + "metadata": {}, + "source": [ + "## Load raw data" + ] + }, + { + "cell_type": "markdown", + "id": "7f9af1ba-ded0-4175-9860-fff15c2ff904", + "metadata": {}, + "source": [ + "The data is supplied in Excel format, encrypted, with a sheet per year. Here we load indinviual year sheets that have been extracted as CSVs.\n", + "There are no batch numbers for 2015." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be8c1d57-6dbd-4e75-a31c-e46671cf39e7", + "metadata": {}, + "outputs": [], + "source": [ + "cols = ['Date','CPH','Practice','Batch Numbers']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "28ef15b0-05da-449b-a031-dc9efc77a514", + "metadata": {}, + "outputs": [], + "source": [ + "data_2015 = pd.read_csv('/Data/TB_Diagnostics/TB_Vet_Data/TestData_2015.csv', usecols=cols[:-1], dtype=str, parse_dates=['Date'], dayfirst=True)\n", + "data_2016 = pd.read_csv('/Data/TB_Diagnostics/TB_Vet_Data/TestData_2016.csv', usecols=cols, dtype=str, parse_dates=['Date'], dayfirst=True)\n", + "data_2017 = pd.read_csv('/Data/TB_Diagnostics/TB_Vet_Data/TestData_2017.csv', usecols=cols, dtype=str, parse_dates=['Date'], dayfirst=True)\n", + "data_2018 = pd.read_csv('/Data/TB_Diagnostics/TB_Vet_Data/TestData_2018.csv', usecols=cols, dtype=str, parse_dates=['Date'], dayfirst=True)\n", + "data_2019 = pd.read_csv('/Data/TB_Diagnostics/TB_Vet_Data/TestData_2019.csv', usecols=cols, dtype=str, parse_dates=['Date'], dayfirst=True)" + ] + }, + { + "cell_type": "markdown", + "id": "50000d46-97a4-4e74-9d28-1572127ed951", + "metadata": {}, + "source": [ + "## Concatenate all years" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08a9fc38-ed77-42bf-9a27-1e99c7fa563a", + "metadata": {}, + "outputs": [], + "source": [ + "data = pd.concat([data_2015,data_2016,data_2017,data_2018,data_2019], ignore_index=True)" + ] + }, + { + "cell_type": "markdown", + "id": "060472ae-20bb-4849-aedc-c1bd538be6ec", + "metadata": {}, + "source": [ + "## Clean data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2324d7d6-3e60-43b0-a1e9-90025f9932c0", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove any rows with NAs in both practice and batch:\n", + "data = data.dropna(how='all', subset=['Practice','Batch Numbers'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4d77bd41-8728-4968-a123-0857ca414e7d", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove * suffix from practice names:\n", + "data.loc[:,'Practice'] = data.Practice.str.replace('*','', regex=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8ef9ec3-8949-447e-94d9-d79d85248066", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop duplicates\n", + "data = data.drop_duplicates(subset=['Date','CPH'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3164db90-5555-491d-ac84-fe57d8f33b6b", + "metadata": {}, + "outputs": [], + "source": [ + "data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d830750d-478f-4e69-a601-62d02bccf789", + "metadata": {}, + "outputs": [], + "source": [ + "data.dropna()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "667b4cf7-5c89-4582-8a73-7eebfffefacf", + "metadata": {}, + "outputs": [], + "source": [ + "data.Practice.value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "eb9915ea-5f63-4922-abdf-561831c1d7b9", + "metadata": {}, + "source": [ + "## Parse batch numbers" + ] + }, + { + "cell_type": "markdown", + "id": "021e8351-73fe-4c98-8158-69f17d2f14f9", + "metadata": {}, + "source": [ + "Unfortunately, batch numbers come in a bewildering array of formats. Gernerally, however the Avian batch is the first six digit number and the Bovine batch the second. Parse according ot this rule, stripping away any other text." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "119d2fe1-d4c6-4ae1-8f51-a77bb065457a", + "metadata": {}, + "outputs": [], + "source": [ + "# split out the batch numbers\n", + "# (here we're assuming the first number is avian, the second is bovine)\n", + "batch_split = data['Batch Numbers'].str.split(\"[^0-9]\").str.join(' ').str.split(expand=True,n=1).dropna()\n", + "batch_split = batch_split.rename(columns={0:'BatchAvian',1:'BatchBovine'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "747afd01-a2b6-4bf2-85b8-4153cfd6afe4", + "metadata": {}, + "outputs": [], + "source": [ + "# Limit to 6 digit numbers\n", + "batch_split = batch_split[batch_split.BatchAvian.str.contains('^[0-9]{6}$')]\n", + "batch_split = batch_split[batch_split.BatchBovine.str.contains('^[0-9]{6}$')]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66d99dd7-a32b-4a83-aa7d-c35e43d9436a", + "metadata": {}, + "outputs": [], + "source": [ + "# rejoin with data\n", + "data = data.join(batch_split)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d5f15ec-006f-411e-9ce7-785200a18c1d", + "metadata": {}, + "outputs": [], + "source": [ + "data.BatchBovine.value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d79eb30-1c14-400f-b779-11d506b6aedf", + "metadata": {}, + "outputs": [], + "source": [ + "data.BatchAvian.value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "d0c5fdbf-3263-41af-8eed-b6d1c47238c6", + "metadata": {}, + "source": [ + "## Analysis" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56b9532b-0582-4496-9aff-c47c869aedfd", + "metadata": {}, + "outputs": [], + "source": [ + "data.Practice.value_counts().plot.bar(figsize=(60,10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2b8afbc-414b-4489-8080-6fdcb991b3c6", + "metadata": {}, + "outputs": [], + "source": [ + "data.BatchBovine.value_counts().plot.bar(figsize=(60,10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f249c969-fab6-40dc-88ed-6ca1251decfb", + "metadata": {}, + "outputs": [], + "source": [ + "data.BatchAvian.value_counts().plot.bar(figsize=(60,10))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e4b84293-49e3-404f-bd6f-c2efd0a55dbf", + "metadata": {}, + "outputs": [], + "source": [ + "fig = plt.figure(figsize=(20,5))\n", + "ax = plt.subplot()\n", + "data.dropna(subset=['Practice']).Date.value_counts().sort_index().resample('M').sum().plot.bar(ax=ax, color='red',alpha=0.5,label='Practice')\n", + "data.dropna(subset=['Batch Numbers']).Date.value_counts().sort_index().resample('M').sum().plot.bar(ax=ax,color='blue',alpha=0.5,label='Batch')\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2353d150-38d9-4e3f-b295-abb54435fc3e", + "metadata": {}, + "outputs": [], + "source": [ + "data" + ] + }, + { + "cell_type": "markdown", + "id": "ea0525e9-36f1-4624-af9f-380958656232", + "metadata": {}, + "source": [ + "## Raw output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4fd74b8-65d8-41b9-b290-f3aa30480ea1", + "metadata": {}, + "outputs": [], + "source": [ + "data.drop(columns=['Batch Numbers']).to_csv('/Data/TB_Diagnostics/vetData_nonCat.csv',index=False)" + ] + }, + { + "cell_type": "markdown", + "id": "c6a6f64d-e6da-4fe6-9e0e-d1d80f71d88a", + "metadata": {}, + "source": [ + "## Categorical encoding" + ] + }, + { + "cell_type": "markdown", + "id": "7d7803ee-9d81-4fb6-a544-bf400cfe0fff", + "metadata": {}, + "source": [ + "Practice and Batch data are nominal, high-cardinality features, so we need to encode them down to at most 255 categories for Histogram-based GBT, preferably lower for better computational performance (this is traded off with predictive performance...).\n", + "\n", + "One method to do this is Bayesian LeaveOneOut encoding [REF?], but this requires comparison to the target variable.\n", + "Another is Hashing, but this splits the feature into multiple features, losing explainability...\n", + "\n", + "We choose here to take the 250 most frequent categories and an \"other\" cetegory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d23d6eae-a23c-4451-9ab6-5526ca037023", + "metadata": {}, + "outputs": [], + "source": [ + "# Function to map categorical feature to an ID, grouping any beyond the top 250 into one ID and NaNs into one ID\n", + "#def map_feature_to_category_id(feature):\n", + "# #function for limititing to 250\n", + "# def top250(x): return x if x<250 else 250\n", + "# size_order = list(feature.value_counts(dropna=False).index)\n", + "# ids = list(map(top250,list(range(len(size_order)))))\n", + "# index = {size_order[i]:ids[i] for i in range(len(size_order))}\n", + "# return feature.apply(lambda x:index[x])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7730755f-7426-40e8-89ec-74cd25a749f9", + "metadata": {}, + "outputs": [], + "source": [ + "def map_feature_to_category_id(feature):\n", + " size_order = list(feature.value_counts().index)\n", + " mapping = dict(zip(size_order, list(map(lambda x: min(x,250), range(1,len(size_order)+1)))))\n", + " def catmap(x):\n", + " if pd.isna(x):\n", + " return x\n", + " else:\n", + " return mapping[x]\n", + " return feature.apply(catmap)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64bac2fe-f32f-4d36-8b90-9f0fad7ec59e", + "metadata": {}, + "outputs": [], + "source": [ + "# Apply the categorical mapping to the data\n", + "data.Practice = map_feature_to_category_id(data.Practice)\n", + "data.BatchAvian = map_feature_to_category_id(data.BatchAvian)\n", + "data.BatchBovine = map_feature_to_category_id(data.BatchBovine)" + ] + }, + { + "cell_type": "markdown", + "id": "95f272fd-f89e-4f0b-b072-a7b13f7782f9", + "metadata": {}, + "source": [ + "## Categorical output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fa0ed840-f843-4bd3-804f-8fa466f0e111", + "metadata": {}, + "outputs": [], + "source": [ + "data.drop(columns=['Batch Numbers']).to_csv('/Data/TB_Diagnostics/vetData.csv',index=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Data_Curation_VetOnly.ipynb b/Data_Curation_VetOnly.ipynb new file mode 100644 index 0000000..52e1749 --- /dev/null +++ b/Data_Curation_VetOnly.ipynb @@ -0,0 +1,68 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "4fafef82-3964-415b-9979-7c96f76c73ad", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53aa7adc-fc8f-4a34-9b00-80b78c5cda6b", + "metadata": {}, + "outputs": [], + "source": [ + "# Load dataset\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars.csv', low_memory=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "81e05216-f5d8-4c13-8997-53d5f7997ee0", + "metadata": {}, + "outputs": [], + "source": [ + "# Remove rows with no vet practice information\n", + "data_vet_only = data.dropna(subset=['vetPractice'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a2387f4-2943-4864-a5a5-ca4e6e87f9e4", + "metadata": {}, + "outputs": [], + "source": [ + "# Output\n", + "data_vet_only.to_csv('/Data/TB_Diagnostics/inputVars_VetOnly.csv', index=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/README b/README new file mode 100644 index 0000000..94f43a9 --- /dev/null +++ b/README @@ -0,0 +1,13 @@ +# bTB-diagnostics + +This project trains a Histogram Boosted Regression Tree model on data from Bovine Tuberculosis (bTB) testing and cattle herd metadata to predict the risk of bTB outbreak. + +This can be used to improve the herd-level sensitivity or specificity of the diagnostic test and also to analyse the risk factors involved in predicting bTB outbreaks. + +The project consists of a number of Jupyter Notebooks: +(i) Data_Curation* -- processes the various inpiut data into a matrix for model training. +(ii) bTB-Diagnostic_2020_v4_crossVal+tuning* -- code that trains the various models. +(iii) bTB-Diagnostic_2020_final_model* -- code that performs various analysis on the models. +(iv) Vet_Data_Analysis -- code that performs some extra analysis on the veterinary data. + +Further details can be found in the preprint (paper in sumbission for peer review) at: https://arxiv.org/abs/2404.03678 \ No newline at end of file diff --git a/Vet_Data_Analysis.ipynb b/Vet_Data_Analysis.ipynb new file mode 100644 index 0000000..5f8708b --- /dev/null +++ b/Vet_Data_Analysis.ipynb @@ -0,0 +1,769 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "e09f519e-9ef9-4cfd-8c0d-967826b12bb6", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sb\n", + "import numpy as np\n", + "from scipy.stats import binomtest\n", + "from scipy.stats import pearsonr\n", + "import geopandas as gp\n", + "from joblib import load" + ] + }, + { + "cell_type": "markdown", + "id": "993ad173-3bea-40b4-9de1-74739c1b5a18", + "metadata": {}, + "source": [ + "# Post-mapping analysis" + ] + }, + { + "cell_type": "markdown", + "id": "0181c2c3-5c83-42e3-92b9-e60676a44737", + "metadata": {}, + "source": [ + "~This uses the data mapped to top 250 categories, plus an 'other' category. Practice names are not retained in this data.~\n", + "\n", + "Replaced by feature data with vet names for deeper analysis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab70a210-ea0e-4a4d-aebd-14b0daaa587f", + "metadata": {}, + "outputs": [], + "source": [ + "# Load practice/batch with test result feature table\n", + "#inputVars = pd.read_csv('/Data/TB_Diagnostics/inputVars_NEW.csv', dtype=float, parse_dates=['dateOfTest'])\n", + "inputVars = pd.read_csv('/Data/TB_Diagnostics/inputVars_noCat.csv', parse_dates=['dateOfTest'],low_memory=False)\n", + "inputVars_model = pd.read_csv('/Data/TB_Diagnostics/inputVars.csv', parse_dates=['dateOfTest'],low_memory=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72c3836f-3e8b-40e6-bfeb-c677a77885ba", + "metadata": {}, + "outputs": [], + "source": [ + "# Split into features(X)/target(y) for model validation later\n", + "data_y = inputVars.confirmedBreakdown.to_numpy().astype(bool)\n", + "data_X = inputVars.drop(columns=['confirmedBreakdown'])\n", + "data_y_model = inputVars_model.confirmedBreakdown.to_numpy().astype(bool)\n", + "data_X_model = inputVars_model.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46b04d26-fdca-4e4e-a6cf-902e2d8fe97f", + "metadata": {}, + "outputs": [], + "source": [ + "inputVars" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7208cfdd-a84c-45e7-9951-a00c2d2d78d9", + "metadata": {}, + "outputs": [], + "source": [ + "# Proportion of positive tests by vet\n", + "practice_test_sum = inputVars.groupby('vetPractice')['resultOfTest'].sum()\n", + "practice_test_count = inputVars.groupby('vetPractice')['resultOfTest'].count()\n", + "sb.histplot(practice_test_sum/practice_test_count)\n", + "plt.title('Proportion of positve tests by vet practice')\n", + "plt.xlabel('Positive Tests')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9a02bac-839c-40df-8d65-59b827970968", + "metadata": {}, + "outputs": [], + "source": [ + "# Proportion of positve tests by tuberculin batch (Bovine)\n", + "batch_test_sum = inputVars.groupby('batchBovine')['resultOfTest'].sum()\n", + "batch_test_count = inputVars.groupby('batchBovine')['resultOfTest'].count()\n", + "sb.histplot(batch_test_sum/batch_test_count)\n", + "plt.title('Proportion of positve tests by tuberculin batch (Bovine)')\n", + "plt.xlabel('Positive Tests')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87e99148-33d7-4d44-914a-38875dca69ac", + "metadata": {}, + "outputs": [], + "source": [ + "# Proportion of residuals by vet (where a negative (/posive) test was (/not) followed by a breakdown)\n", + "inputVars['residual'] = (inputVars.resultOfTest != inputVars.confirmedBreakdown)\n", + "practice_residual_sum = inputVars.groupby('vetPractice')['residual'].sum()\n", + "practice_residual_count = inputVars.groupby('vetPractice')['residual'].count()\n", + "sb.histplot(practice_residual_sum/practice_residual_count)\n", + "plt.title('Proportion of residuals by vet practice')\n", + "plt.xlabel('Proportion of residuals')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba5d00b5-4dba-4a7d-951d-16014afe4cf5", + "metadata": {}, + "outputs": [], + "source": [ + "# Vet practice accuracy (test --> breakdown)\n", + "plt.rcParams.update({'font.size': 18})\n", + "inputVars['vet_acc'] = (inputVars.resultOfTest == inputVars.confirmedBreakdown)\n", + "practice_acc_sum = inputVars.groupby('vetPractice')['vet_acc'].sum()\n", + "practice_acc_count = inputVars.groupby('vetPractice')['vet_acc'].count()\n", + "sb.histplot(practice_acc_sum/practice_acc_count)\n", + "plt.title('Accuracy by vet practice')\n", + "plt.xlabel('Test accuracy')\n", + "plt.savefig('../Paper/figs/vet_acc.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8364f467-e915-4a7e-973a-85891fc6487f", + "metadata": {}, + "outputs": [], + "source": [ + "# Proportion of residuals by tuberculin batch (where a negative (/posive) test was (/not) followed by a breakdown)\n", + "batch_residual_sum = inputVars.groupby('batchBovine')['residual'].sum()\n", + "batch_residual_count = inputVars.groupby('batchBovine')['residual'].count()\n", + "sb.histplot(batch_residual_sum/batch_residual_count)\n", + "plt.title('Proportion of residuals by tuberculin batch (Bovine)')\n", + "plt.xlabel('Proportion of residuals')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ff85bc3-81c7-457f-9572-c1f32ba8b787", + "metadata": {}, + "outputs": [], + "source": [ + "# Binomial test for vet practices\n", + "expected_success = sum(inputVars.resultOfTest == inputVars.confirmedBreakdown) / len(inputVars)\n", + "pvals_vet = []\n", + "for i in inputVars.vetPractice.dropna().unique():\n", + " results_for_practice = inputVars[inputVars.vetPractice==i]\n", + " successes = sum(results_for_practice.resultOfTest == results_for_practice.confirmedBreakdown)\n", + " trials = len(results_for_practice)\n", + " pvals_vet.append(binomtest(successes,trials,expected_success).pvalue)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "595c709a-8ab3-4197-a776-b038cf40cc55", + "metadata": {}, + "outputs": [], + "source": [ + "# Distribution of p values\n", + "sb.histplot(pvals_vet,bins=20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef43c44b-94f0-4177-ad2b-1d2b6e2ed1b7", + "metadata": {}, + "outputs": [], + "source": [ + "# proportion of outliers\n", + "sum(np.array(pvals_vet)<0.05) / len(pvals_vet)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "613a006c-e33b-4390-99b9-fb0f5934e0b5", + "metadata": {}, + "outputs": [], + "source": [ + "# Binomial test for tuberculin batches\n", + "expected_success = sum(inputVars.resultOfTest == inputVars.confirmedBreakdown) / len(inputVars)\n", + "pvals_batch = []\n", + "for i in inputVars.batchBovine.dropna().unique():\n", + " results_for_batch = inputVars[inputVars.batchBovine==i]\n", + " successes = sum(results_for_batch.resultOfTest == results_for_batch.confirmedBreakdown)\n", + " trials = len(results_for_batch)\n", + " pvals_batch.append(binomtest(successes,trials,expected_success).pvalue)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d17006e-5a10-4297-9231-416fc49997ac", + "metadata": {}, + "outputs": [], + "source": [ + "# Distribution of p values\n", + "sb.histplot(pvals_batch,bins=20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9033f490-0dda-411b-945f-44c39ec3d3e2", + "metadata": {}, + "outputs": [], + "source": [ + "# proportion of outliers\n", + "sum(np.array(pvals_batch)<0.05) / len(pvals_batch)" + ] + }, + { + "cell_type": "markdown", + "id": "3641faa5-5569-4fd9-b3d5-6f5de602430b", + "metadata": {}, + "source": [ + "# Comparison of high/low performing practices" + ] + }, + { + "cell_type": "markdown", + "id": "3aea01e0-0900-4669-9b48-e0ecd3434709", + "metadata": {}, + "source": [ + "Compare practices to see if high/low perfomring ones have distinct features?\n", + "* Size of practice (number of tests/yr)\n", + "* Size of herds managed\n", + "* Location?\n", + "* ??" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "360126b5-06b7-46f2-b32d-96fd3b3cf058", + "metadata": {}, + "outputs": [], + "source": [ + "# Add year of test\n", + "inputVars['yearOfTest'] = inputVars.dateOfTest.apply(lambda x:x.year)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fd48ba6-e24b-4915-af86-24e72efe04a7", + "metadata": {}, + "outputs": [], + "source": [ + "# Practice accuracy\n", + "# Proportion of tests conducted that result in confirmed breakdown within 90 days\n", + "practice_accuracy = 1-(inputVars.groupby(['vetPractice'])['residual'].sum() / inputVars.groupby(['vetPractice'])['residual'].count())\n", + "practice_accuracy.name = \"accuracy\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70528396-957f-4ec9-b66f-ad39c0fcc792", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a dataframe to add other stats to\n", + "practice_stats = pd.DataFrame(practice_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "803386cd-fa97-4662-885f-272952a730a0", + "metadata": {}, + "outputs": [], + "source": [ + "# Size of practice (by number of tests conducted)\n", + "practice_stats['numberOfTests'] = inputVars.groupby(['vetPractice'])['residual'].count()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07340349-8e89-4a75-b83d-8eb9c85aae36", + "metadata": {}, + "outputs": [], + "source": [ + "# Mean size of herds for practice\n", + "practice_stats['meanFarmSize'] = inputVars.groupby(['vetPractice'])['animalsTested'].mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "742b097a-8fdd-4e81-bc22-ef35f1c8d601", + "metadata": {}, + "outputs": [], + "source": [ + "practice_stats" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aa98135-72a9-48f3-8084-d9670218eb1d", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop the 'other' category, leaving only the top 250 practices by size\n", + "#practice_stats_top250 = practice_stats[practice_stats.index<250]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "83907681-737f-4fe2-875b-a898b1fc044a", + "metadata": {}, + "outputs": [], + "source": [ + "# Get only vets that have done at least 100 tests\n", + "practice_stats_100tests = practice_stats[practice_stats.numberOfTests>=100]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d49cef8-f27b-4416-8c3e-52e2e7400c82", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot accuracy by number of tests conducted\n", + "sb.jointplot(x='numberOfTests', y='accuracy', data=practice_stats_100tests, kind='reg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6ac97c5-1b0c-4eb0-b224-ad65b1ab401e", + "metadata": {}, + "outputs": [], + "source": [ + "sb.jointplot(x='numberOfTests', y='meanFarmSize', data=practice_stats, kind='reg')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2359c23d-37dc-4d43-93d2-996618d761c5", + "metadata": {}, + "outputs": [], + "source": [ + "# Pearson R of accuracy by number of tests conducted\n", + "pearsonr(practice_stats.numberOfTests,practice_stats.accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec4aa206-06b9-4dc8-8eb9-c30b1c0d0599", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot accuracy by mean farm size\n", + "plt.rcParams.update({'font.size': 18})\n", + "sb.jointplot(x='meanFarmSize', y='accuracy', data=practice_stats_100tests, kind='reg')\n", + "plt.ylabel('Vet diagnostic accuracy')\n", + "plt.xlabel('Mean herd size tested')\n", + "plt.savefig('../Paper/figs/vet_acc_herd_size.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "993d01f3-5da7-4280-82c7-2af44a131163", + "metadata": {}, + "outputs": [], + "source": [ + "# Pearson R of accuracy by mean farm size\n", + "pearsonr(practice_stats.meanFarmSize,practice_stats.accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b765fab-dee7-4978-8836-0bbd6d5181d3", + "metadata": {}, + "outputs": [], + "source": [ + "# Ouptut practice stats\n", + "practice_stats.to_csv(\"/Data/TB_Diagnostics/vet_practice_stats.csv\")" + ] + }, + { + "cell_type": "markdown", + "id": "b0cf4f99-d963-4ff0-8b57-0f34e95d51b2", + "metadata": {}, + "source": [ + "# Post model analysis" + ] + }, + { + "cell_type": "markdown", + "id": "9a245613-3805-4c61-bde5-7e7906d7317a", + "metadata": {}, + "source": [ + "How does the model improve practice performance?\n", + "* Improvement in general of practice accuracy?\n", + "* Compare best/worst practices.\n", + "* Confusion matrices for best/worst." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "67a4b81f-f8a5-40e2-8474-2028c0467e63", + "metadata": {}, + "outputs": [], + "source": [ + "# Load model\n", + "model = load('/Data/TB_Diagnostics/final_model.model')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc6f2748-9a1d-4c10-87c4-f8728be9bdfe", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X_model.dateOfTest = data_X_model.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X_model['rand'] = np.random.random_sample(len(data_X_model))\n", + "# Convery all to float matrix\n", + "#data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "572d9b40-afa4-48cb-8312-91616694d9e2", + "metadata": {}, + "outputs": [], + "source": [ + "# run model on all data\n", + "predict_y = model.predict(data_X_model.to_numpy())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "050dad3c-a121-452b-9991-997b1a4dee9f", + "metadata": {}, + "outputs": [], + "source": [ + "# Get model residuals\n", + "predict_residual = predict_y != data_y\n", + "inputVars['model_residual'] = predict_residual" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e00eb2c2-bfcc-4fc4-b7b1-7da92d268cf2", + "metadata": {}, + "outputs": [], + "source": [ + "# get model accuracy by practice\n", + "practice_stats['model_accuracy'] = 1-(inputVars.groupby(['vetPractice'])['model_residual'].sum() / inputVars.groupby(['vetPractice'])['model_residual'].count())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1cccae1e-d5d1-4896-97d5-ae72b57aa304", + "metadata": {}, + "outputs": [], + "source": [ + "# get increase in accuracy by practice\n", + "practice_stats['accuracy_increase'] = practice_stats.model_accuracy - practice_stats.accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "262ab942-f1d2-4da1-a077-751f57fca852", + "metadata": {}, + "outputs": [], + "source": [ + "# Drop the 'other' category, leaving only the top 250 practices by size\n", + "#practice_stats_top250 = practice_stats[practice_stats.index<250]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7168981c-3a70-427e-a916-356d78aadc3f", + "metadata": {}, + "outputs": [], + "source": [ + "# increase in accuracy by practice\n", + "plt.rcParams.update({'font.size': 18})\n", + "sb.histplot(practice_stats.accuracy_increase, kde=True)\n", + "plt.ylabel('Number of practices')\n", + "plt.xlabel('Accuracy increase with model')\n", + "plt.savefig('../Paper/figs/vet_improvement.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "677f7e05-3b4d-47c0-a203-9a3b1a5802fa", + "metadata": {}, + "outputs": [], + "source": [ + "# mean increase in accuracy\n", + "practice_stats.accuracy_increase.mean()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7d7f8e8-2a3c-49ab-bdfc-53da903c1866", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot model accuracy by size\n", + "plt.rcParams.update({'font.size': 18})\n", + "sb.jointplot(x='meanFarmSize', y='model_accuracy', data=practice_stats, kind='reg')\n", + "plt.ylabel('Model accuracy')\n", + "plt.xlabel('Mean herd size tested')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1e9b363e-7fba-41b8-a91e-e677acc27401", + "metadata": {}, + "outputs": [], + "source": [ + "# Pearson R of model accuracy by size\n", + "pearsonr(practice_stats.meanFarmSize,practice_stats.model_accuracy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae8e7f53-1dcb-47ae-96dd-33074f44e88e", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot model accuracy increase by herd size\n", + "sb.jointplot(x='meanFarmSize', y='accuracy_increase', data=practice_stats, kind='reg')\n", + "plt.xlabel('Mean herd size tested')\n", + "plt.ylabel('Practice accuracy increase')\n", + "plt.savefig('../Paper/figs/vet_acc_inc.pdf', bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "4b9c18c6-bcd0-4b0b-9dd0-4b651595ba7b", + "metadata": {}, + "source": [ + "# Geospatial analysis\n", + "\n", + "Is the geo distribution of tests with vet data similar to all tests?\n", + "\n", + "Where do best/worst performing vets operate?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f171537-8c72-4c37-8ee2-e801fc260cc6", + "metadata": {}, + "outputs": [], + "source": [ + "#Projections:\n", + "bng = 'epsg:27700' # British National Grid\n", + "wgs84 = 'epsg:4326' # Lat.Long." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d922d030-c07d-4324-a7c5-c5196d0d3f6c", + "metadata": {}, + "outputs": [], + "source": [ + "#UK base map\n", + "uk_shp = gp.read_file('/Data/Shapefiles/bdline_essh_gb/Data/Supplementary_Country/country_region.shp').to_crs(bng)\n", + "#uk_shp.plot(color='white', edgecolor='black')\n", + "eng_shp = uk_shp[uk_shp.NAME=='England']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c4e707-a9b4-4800-91e2-98acbd57cbc0", + "metadata": {}, + "outputs": [], + "source": [ + "# create geodataframe from data\n", + "geo_data = gp.GeoDataFrame(inputVars, geometry=gp.points_from_xy(inputVars.locationX,inputVars.locationY), crs=bng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f0a36bb-35cd-44c0-9f2e-2b6d2d92fb23", + "metadata": {}, + "outputs": [], + "source": [ + "# join practice stats\n", + "geo_data = geo_data.join(practice_stats, on='vetPractice')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab1989b6-6202-4980-8175-17afecbbc812", + "metadata": {}, + "outputs": [], + "source": [ + "geo_data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6129c47c-0578-4420-9513-9f9faaa2c7cb", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot KDE of all tests\n", + "\n", + "ax = uk_shp.to_crs(bng).plot(alpha=0.2, figsize=(10,20))\n", + "sb.kdeplot(ax=ax, x=geo_data.locationX, y=geo_data.locationY, fill=True)#, color='gold')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d022173-909e-4392-a9cc-32e29dbf18bf", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot KDE of only tests with vet data\n", + "\n", + "ax = uk_shp.to_crs(bng).plot(alpha=0.2, figsize=(10,20))\n", + "sb.kdeplot(ax=ax, x=geo_data.dropna(subset=['vetPractice']).locationX, y=geo_data.dropna(subset=['vetPractice']).locationY, fill=True)" + ] + }, + { + "cell_type": "markdown", + "id": "43ac8da0-b0e2-4c89-9cfb-f878adf88bfd", + "metadata": {}, + "source": [ + "----\n", + "## Testing...\n", + "----" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8fecec04-3c73-4d1c-a83f-34e69704a613", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of Vets who do the most tests\n", + "ax = eng_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_data.plot(ax=ax, markersize=1.0, column='vetPractice', legend=True, legend_kwds={\"label\": \"Top 250 Vets by tests conducted (0=largest)\", \"orientation\": \"horizontal\"})\n", + "ax.set_axis_off()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62fbf2c9-75fc-456d-ac8d-cf0a4e96ecf3", + "metadata": {}, + "outputs": [], + "source": [ + "# find convex hull of points for each practice\n", + "ax = eng_shp.plot(color='white', edgecolor='black', figsize=(10,20))\n", + "vet_coverage = geo_data.dissolve('vetPractice').convex_hull.reset_index().iloc[:-1]\n", + "vet_coverage.plot(ax=ax,column='vetPractice',alpha=0.2)\n", + "ax.set_axis_off()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "878d1451-70b7-4373-89a5-86fd2eb6e04c", + "metadata": {}, + "outputs": [], + "source": [ + "# plot locations for vet accuracy\n", + "ax = eng_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_data.plot(ax=ax, markersize=1.0, column='accuracy', legend=True, legend_kwds={\"label\": \"Vet performance\", \"orientation\": \"horizontal\"})\n", + "ax.set_axis_off()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4900d616-b899-4fba-bed1-2204d61e3514", + "metadata": {}, + "outputs": [], + "source": [ + "# convex hull of vet locations by accuracy\n", + "ax = eng_shp.plot(color='white', edgecolor='black', figsize=(10,20))\n", + "vet_coverage.join(practice_stats).plot(ax=ax, column='accuracy', alpha=0.2, legend=True, legend_kwds={\"label\": \"Vet accuracy\", \"orientation\": \"horizontal\"})\n", + "ax.set_axis_off()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a2a1057-dee4-4acc-8451-61762e681764", + "metadata": {}, + "outputs": [], + "source": [ + "# plot locations for vet mean herd size\n", + "ax = eng_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_data.plot(ax=ax, markersize=1.0, column='meanFarmSize', cmap='viridis_r',legend=True, legend_kwds={\"label\": \"Mean herd size for vet\", \"orientation\": \"horizontal\"})\n", + "ax.set_axis_off()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_final_model_NEW.ipynb b/bTB-Diagnostic_2020_final_model_NEW.ipynb new file mode 100644 index 0000000..dfae46c --- /dev/null +++ b/bTB-Diagnostic_2020_final_model_NEW.ipynb @@ -0,0 +1,1812 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "edd1ac98-7c57-46b3-9ca0-9b4dfe74e8ef", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "435fc234-bef8-4f8a-9d70-9615b2357a0e", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbc72ded-6afc-4705-ab8d-42c3256bf3bc", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "#from sklearn.utils.fixes import loguniform\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_validate, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform\n", + "from sklearn.inspection import PartialDependenceDisplay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62536a34-b964-4bf5-bb81-175568b6b215", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load\n", + "from copy import deepcopy" + ] + }, + { + "cell_type": "markdown", + "id": "69503fc4-a026-416e-bec2-1942c887738e", + "metadata": {}, + "source": [ + "---\n", + "# Load original data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adbef996-7012-43fa-9fcd-b00008af862f", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bce7e6b4-7092-4a55-a1f1-9a9d0d7161e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4ea4432-4955-4cb6-ac7e-9350d8072da4", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e13433-6b57-4810-9100-7f717b5f74f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b326537d-15d7-40ac-9bc2-1026a6518ec9", + "metadata": {}, + "outputs": [], + "source": [ + "# Detect categorical features (<= 3 categories and explicit named features)\n", + "named_cat_features = ['vetPractice','batchBovine','batchAvian']\n", + "cat_features = []\n", + "for c in data_X.columns:\n", + " catf = len(data_X[c].unique())<=3\n", + " if c in named_cat_features:\n", + " catf = True\n", + " cat_features.append(catf)\n", + "\n", + "# NB: this is fine for boolean features (inc. missing values)\n", + "# but needs a proper encoding for true categorical features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f0099ab-5200-4dcc-b8cb-e298ee128e5e", + "metadata": {}, + "outputs": [], + "source": [ + "# Convery all to float matrix\n", + "#data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "bcf17004-07e3-4dd7-844c-9019dd289561", + "metadata": {}, + "source": [ + "# Load training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1977ad48-146c-4df8-b1d5-8d04fa26fff9", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the train/test split used in training\n", + "X_train, X_test, y_train, y_test = load('/Data/TB_Diagnostics/final_data_split.data')" + ] + }, + { + "cell_type": "markdown", + "id": "bfd95433-6c64-4cde-8d38-ab4f349e78a7", + "metadata": {}, + "source": [ + "### Create new train/test splits with pre-/post-cull" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5cffd778-d3f0-43ad-bb78-7a38fb084afa", + "metadata": {}, + "outputs": [], + "source": [ + "#pre_cull = data.dateOfTest<\"2016\"\n", + "#data_X_pre = data_X[pre_cull].to_numpy()\n", + "#data_y_pre = data_y[pre_cull]\n", + "#X_train_pre, X_test_pre, y_train_pre, y_test_pre = train_test_split(data_X_pre, data_y_pre, test_size=0.20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3628601-ee7c-4a3c-abea-de4a3bdd0ca4", + "metadata": {}, + "outputs": [], + "source": [ + "#post_cull = data.dateOfTest>=\"2016\"\n", + "#data_X_post = data_X[post_cull].to_numpy()\n", + "#data_y_post = data_y[post_cull]\n", + "#X_train_post, X_test_post, y_train_post, y_test_post = train_test_split(data_X_post, data_y_post, test_size=0.20)" + ] + }, + { + "cell_type": "markdown", + "id": "55c3a959-c82e-438f-9e7c-8f1097ce8858", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44ec9fbc-7654-4fd3-b68b-fbee40cec379", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "43e67834-1086-4273-82f4-aaa79426d31e", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39fd9f81-f9f2-4bbd-8302-214c125f48af", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = X_test[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "533481d1-fa41-4cdc-8036-75d9381386d8", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se_sicct = sensitivity(sicct,y_test)\n", + "Se_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a54891f-63ac-4403-8eaf-dc2d77370d2d", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp_sicct = specificity(sicct,y_test)\n", + "Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2e44fc6-f917-4adf-953f-0418da6a6e91", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(sicct==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61386875-f746-4271-94c8-b37dcd50e896", + "metadata": {}, + "outputs": [], + "source": [ + "# Set specificity threshold to level for SICCT-only prediction\n", + "specificity_threshold = Sp_sicct" + ] + }, + { + "cell_type": "markdown", + "id": "7974a2c0-1ea5-4076-b550-7a99fd98a616", + "metadata": {}, + "source": [ + "### Pre/post-cull SICCT performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b5eb78f-091b-4f8f-bd39-81f9f8f380dd", + "metadata": {}, + "outputs": [], + "source": [ + "#sicct_pre = X_test_pre[:,1].astype(bool)\n", + "#sicct_post = X_test_post[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bccc23e0-eb6e-4516-8beb-4ef558de6617", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "#Se_sicct_pre = sensitivity(sicct_pre,y_test_pre)\n", + "#Se_sicct_post = sensitivity(sicct_post,y_test_post)\n", + "#(Se_sicct_pre,Se_sicct_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "656f0151-689e-4a5c-b5e4-38659c9d79f2", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "#Sp_sicct_pre = specificity(sicct_pre,y_test_pre)\n", + "#Sp_sicct_post = specificity(sicct_post,y_test_post)\n", + "#(Sp_sicct_pre,Sp_sicct_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a69c72b7-82d2-4874-935f-1400fce64686", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "#((sicct_pre==y_test_pre).sum() / len(y_test_pre) , (sicct_post==y_test_post).sum() / len(y_test_post))" + ] + }, + { + "cell_type": "markdown", + "id": "6fd52b52-caf6-42bf-9952-855c46d6347b", + "metadata": { + "tags": [] + }, + "source": [ + "# Load model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f02aac16-bec1-45f6-8f7c-bc221562f5f9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Load cross-validated and fit model\n", + "model = load('/Data/TB_Diagnostics/final_model.model')" + ] + }, + { + "cell_type": "markdown", + "id": "4af4ec63-571c-4355-99e5-d45468aeedd8", + "metadata": {}, + "source": [ + "## Pre-/post-cull models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "384eec7c-60bd-42d7-89ff-5e34690d6a0f", + "metadata": {}, + "outputs": [], + "source": [ + "# get best parameter set from full model\n", + "#gbt_pre = deepcopy(model.best_estimator_)\n", + "#gbt_post = deepcopy(model.best_estimator_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1e2cb37-e4cb-4ba4-bf8b-aa42fcad00af", + "metadata": {}, + "outputs": [], + "source": [ + "# train subset models\n", + "#model_pre = gbt_pre.fit(X_train_pre,y_train_pre)\n", + "#model_post = gbt_post.fit(X_train_post,y_train_post)" + ] + }, + { + "cell_type": "markdown", + "id": "18f01840-d148-412d-b355-0a5295c84212", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9dcc7f3-e102-4dfb-b7af-343568f8121b", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90128953-7914-41f5-96be-78dc241971fa", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = model.predict(X_test)\n", + "y_score = model.decision_function(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63a9b46-26b7-4510-9b82-096e51c2216c", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = sensitivity(y_test_result,y_test)\n", + "Se" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edd92f58-ebc8-4364-9436-d56d8f1e0365", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = specificity(y_test_result,y_test)\n", + "Sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ac84b5d-0ac3-4eea-9856-1e91beef126c", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(y_test_result==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "7de21db1-c670-47b0-9115-fbc2f7a99a7f", + "metadata": {}, + "source": [ + "### Pre-/post-cull performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5338e23e-55c7-44db-ba6b-440573297474", + "metadata": {}, + "outputs": [], + "source": [ + "#model_pre.score(X_test_pre,y_test_pre)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b6a4e87-fa27-4121-8441-99a577186000", + "metadata": {}, + "outputs": [], + "source": [ + "#model_post.score(X_test_post,y_test_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ac67ac6-5234-4c07-b9f1-b645cc9fb146", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "#y_test_result_pre = model_pre.predict(X_test_pre)\n", + "#y_score_pre = model_pre.decision_function(X_test_pre)\n", + "#y_test_result_post = model_post.predict(X_test_post)\n", + "#y_score_post = model_post.decision_function(X_test_post)" + ] + }, + { + "cell_type": "markdown", + "id": "a788b3e4-cbf7-4fac-99ce-68273d9aabed", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41ec298c-991d-459c-a6c7-668fe091f635", + "metadata": {}, + "outputs": [], + "source": [ + "fpr, tpr, _ = roc_curve(y_test,y_score)\n", + "roc_auc = auc(fpr,tpr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95328ec9-6450-4246-b7f1-178efcffb92a", + "metadata": {}, + "outputs": [], + "source": [ + "roc_auc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a3ee1fc-8abb-4b7d-8df9-d0fa73109374", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#function ot plot roc curve\n", + "def plot_roc(fpr,tpr,roc_auc):\n", + " plt.figure()\n", + " lw = 2\n", + " plt.plot(\n", + " fpr,\n", + " tpr,\n", + " lw=lw,\n", + " label=\"Model (AUC = %0.2f)\" % roc_auc,\n", + " )\n", + " plt.plot(1-Sp_sicct,Se_sicct,'+', label=\"SICCT only\", ms='15')\n", + " plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\", label='Random')\n", + " plt.xlim([0.0, 1.0])\n", + " plt.ylim([0.0, 1.0])\n", + " plt.xlabel(\"(1 - Specificity)\")\n", + " plt.ylabel(\"Sensitivity\")\n", + " plt.title(\"Receiver operating characteristic\")\n", + " plt.legend(loc=\"lower right\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f5f3bb-e7a3-4987-a381-6824560d836a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "plot_roc(fpr,tpr,roc_auc)\n", + "plt.savefig('../Paper/figs/roc.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "4b3059d0-a776-46dc-9c22-3b2caf9d2d49", + "metadata": {}, + "source": [ + "### Pre-post-cull ROC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb5349e3-504e-47c8-848c-c9c4ea23edb5", + "metadata": {}, + "outputs": [], + "source": [ + "#fpr_pre, tpr_pre, _ = roc_curve(y_test_pre,y_score_pre)\n", + "#roc_auc_pre = auc(fpr_pre,tpr_pre)\n", + "#fpr_post, tpr_post, _ = roc_curve(y_test_post,y_score_post)\n", + "#roc_auc_post = auc(fpr_post,tpr_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d775404b-e526-4f81-bc75-8a3577870040", + "metadata": {}, + "outputs": [], + "source": [ + "#plot_roc(fpr_pre,tpr_pre,roc_auc_pre)\n", + "#plot_roc(fpr_post,tpr_post,roc_auc_post)" + ] + }, + { + "cell_type": "markdown", + "id": "3a0f6f31-b460-4b73-89b0-90edb0302818", + "metadata": {}, + "source": [ + "---\n", + "# Feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06b2bcc-5da2-46b3-8a91-9e130fd17a49", + "metadata": {}, + "outputs": [], + "source": [ + "## Calcuate permutation importance\n", + "importance = permutation_importance(model,X_test,y_test, n_repeats=20, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80121727-2789-476b-9456-0584686afbfe", + "metadata": {}, + "outputs": [], + "source": [ + "for i in importance.importances_mean.argsort()[::-1]:\n", + " if abs(importance.importances_mean[i]) - 2 * importance.importances_std[i] > 0:\n", + " print('*',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])\n", + " else:\n", + " print(' ',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dacd01f1-a312-4f1c-a33a-3433e54e4e4b", + "metadata": {}, + "outputs": [], + "source": [ + "#transform into table\n", + "importance_table = pd.DataFrame(importance.importances.T, columns=data_X.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d46e9901-1d69-4041-bfa2-0e0710e11ceb", + "metadata": {}, + "outputs": [], + "source": [ + "# FIX: drop 'species' (it is nonsense...)\n", + "#importance_table = importance_table.drop(columns=['species'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00152d7d-c5c6-4213-8563-e9519f66a059", + "metadata": {}, + "outputs": [], + "source": [ + "mean_order = list(importance_table.mean().sort_values(ascending=False).index)\n", + "mean_order_nozero = list(importance_table.mean()[importance_table.mean()>0].sort_values(ascending=False).index)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1606583-aef9-455a-9b5e-2391ed666894", + "metadata": {}, + "outputs": [], + "source": [ + "#Define feature lables:\n", + "feature_labels = {'resultOfTest':'Herd-level SICCT result',\n", + " 'locationY':'Holding location Easting',\n", + " 'locationX':'Holding location Northing',\n", + " 'daysSinceBreakdown':'Days since herd breakdown *',\n", + " 'dateOfTest':'Date of herd SICCT testing',\n", + " 'animalsTested':'Number of animals tested',\n", + " 'daysSincePreviousTest':'Time since previous SICCT test in herd *',\n", + " 'severe':'Was the severe interpretation applied?',\n", + " 'previousTestResult':'Result of last previous SICCT test in herd *',\n", + " 'previousTestResult2':'Result of 2nd last previous SICCT test in herd *',\n", + " 'gammaTestCount':'Number of historical GammaIFN test events in herd',\n", + " 'rand':'Set of uniformly distributed random numbers (CONTROL)',\n", + " 'inflow1':'Animals moved into herd, 1 year',\n", + " 'inflow2':'Animals moved into herd, 2 years',\n", + " 'inflow4':'Animals moved into herd, 4 years',\n", + " 'inflow90':'Animals moved into herd, 90 days',\n", + " 'outflow1':'Animals moved out herd, 1 year',\n", + " 'outflow2':'Animals moved out herd, 2 years',\n", + " 'outflow4':'Animals moved out herd, 4 years',\n", + " 'outflow90':'Animals moved out herd, 90 days',\n", + " 'inflowBD1':'Animals moved into herd, 1 year, from recent breakdown herds',\n", + " 'inflowBD2':'Animals moved into herd, 2 years, from recent breakdown herds',\n", + " 'inflowBD4':'Animals moved into herd, 4 years, from recent breakdown herds',\n", + " 'inflowBD90':'Animals moved into herd, 90 days, from recent breakdown herds',\n", + " 'outflowBD1':'Animals moved out herd, 1 year, from recent breakdown herds',\n", + " 'outflowBD2':'Animals moved out herd, 2 years, from recent breakdown herds',\n", + " 'outflowBD4':'Animals moved out herd, 4 years, from recent breakdown herds',\n", + " 'outflowBD90':'Animals moved out herd, 90 days, from recent breakdown herds',\n", + " 'vetPractice':'Veterinary practice conducting the test **',\n", + " 'batchBovine':'Tuberculin batch (bovine) **',\n", + " 'batchAvian':'Tuberculin batch (avian) **',\n", + " 'testType':'Type of testing event',\n", + " 'herdSize':'Size of herd at time of test',\n", + " 'herdType':'Herd type (dairy, beef, etc.)',\n", + " 'monthOfTest':'Month in which test taken',\n", + " 'defraRiskScore':'APHA risk score for herd',\n", + " 'meanBadgerAbundance':'Mean badger abundance'}\n", + "def feature_label(x):\n", + " try:\n", + " return feature_labels[x]\n", + " except KeyError:\n", + " return x " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386786c4-ccbc-4695-988c-5088d667af5d", + "metadata": {}, + "outputs": [], + "source": [ + "# apply labels to mean ordered set\n", + "mean_order_labels = list(map(lambda x:feature_label(x), mean_order))\n", + "mean_order_nozero_labels = list(map(lambda x:feature_label(x), mean_order_nozero))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c6439f-d5d9-46c5-ba1e-cfe09a0d6585", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 28})\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(importance_table[mean_order], orient='h', errorbar='ci', ax=ax)\n", + "ax.set_yticklabels(mean_order_labels)\n", + "plt.title('Relative importance of model features')\n", + "#plt.xscale('log')\n", + "plt.savefig('../Paper/figs/importance.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "58ae32dc-8e56-42c2-bf0b-4f89da35becf", + "metadata": {}, + "outputs": [], + "source": [ + "importance_table.mean()['vetPractice']*100" + ] + }, + { + "cell_type": "markdown", + "id": "292a0077-5639-4185-8132-a1623a0c3f5e", + "metadata": {}, + "source": [ + "### How much missing data?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9e029b7-8279-4fda-a4ac-a9bc1b28c08f", + "metadata": {}, + "outputs": [], + "source": [ + "# proportion of non-missing data in each feature:\n", + "non_miss = data_X.isna().sum() / len(data_X)\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(non_miss[mean_order], orient='h', ax=ax)\n", + "plt.xlim(0.0,1.0)\n", + "plt.title(\"Proportion of missing data\")\n", + "ax.bar_label(ax.containers[0])\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "a3126928-d594-4bdc-9126-d5f98eb71ff1", + "metadata": {}, + "source": [ + "### Pre-/post-cull feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9378a06c-8588-493a-9831-c057eaee97f0", + "metadata": {}, + "outputs": [], + "source": [ + "## Calcuate permutation importance\n", + "#importance_pre = permutation_importance(model_pre,X_test_pre,y_test_pre, n_repeats=10, n_jobs=-1)\n", + "#importance_post = permutation_importance(model_post,X_test_post,y_test_post, n_repeats=10, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37cd018f-2627-4eb3-8448-ead0c8836a2e", + "metadata": {}, + "outputs": [], + "source": [ + "# Pre-cull plot\n", + "#importance_table_pre = pd.DataFrame(importance_pre.importances.T, columns=data_X.columns)\n", + "#mean_order_pre = list(importance_table_pre.mean().sort_values(ascending=False).index)\n", + "#seaborn.barplot(importance_table_pre[mean_order_pre], orient='h')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3436526b-f1d0-40a9-bfa7-934e5e3c2d1f", + "metadata": {}, + "outputs": [], + "source": [ + "# Post-cull plot\n", + "#importance_table_post = pd.DataFrame(importance_post.importances.T, columns=data_X.columns)\n", + "#mean_order_post = list(importance_table_post.mean().sort_values(ascending=False).index)\n", + "#seaborn.barplot(importance_table_post[mean_order_post], orient='h')" + ] + }, + { + "cell_type": "markdown", + "id": "21d4c471-2f52-40e3-8922-7c123713c442", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f14688-94e3-4290-b8d4-01bc49b52619", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55c1c89c-47ec-4ca1-a55c-10757085f87c", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0391e6cb-5d30-42da-864a-64a39ab56fe7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "# function to plot thresholds\n", + "plt.plot(thresholds,sens,label='Model HSe')\n", + "plt.plot(thresholds,spec,label='Model HSp')\n", + "plt.xlim(0.2,0.8)\n", + "plt.ylim(0.5,1.0)\n", + "best_sens_label = 'Chosen HSe = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT HSe = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT HSp = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Chosen threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1, 0.5))\n", + "plt.savefig('../Paper/figs/decision_threshold.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6f5e010-d528-4d1e-89b5-4f0db5ddaca3", + "metadata": {}, + "outputs": [], + "source": [ + "# Percentage increase in sensitivity over SICCT alone\n", + "(best_sens - Se_sicct)/Se_sicct*100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb9e9c5b-efc1-4007-9670-d3bee2769b98", + "metadata": {}, + "outputs": [], + "source": [ + "# Percentage point increase in sensitivity over SICCT alone\n", + "(best_sens - Se_sicct) *100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8207d271-4edc-400d-875e-9c69d604b6b6", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold\n", + "y_test_predicted = predict_with_threshold(X_test, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "034b0bb2-d40e-49c1-a05f-9b0ee7f03533", + "metadata": {}, + "source": [ + "### Test on 2020 only data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29e8fcbc-416b-4034-b09d-7ab1b4604fcb", + "metadata": {}, + "outputs": [], + "source": [ + "# Get data for 2020 only\n", + "mask_2020 = data.dateOfTest.apply(lambda x:x.year)==2020\n", + "X_2020 = data_X[mask_2020].to_numpy()\n", + "y_2020 = data_y[mask_2020]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63f7f7b-9bfb-4059-a937-3a687301389c", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold for 2020 only\n", + "y_2020_predicted = predict_with_threshold(X_2020, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "976cd604-608e-4b4b-97b4-b7c75d5f2002", + "metadata": {}, + "source": [ + "### Check predictions (for 2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "792290ee-e2e9-439a-af4c-d3691ddaa1a9", + "metadata": {}, + "outputs": [], + "source": [ + "# negative sicct tests\n", + "neg_sicct = X_2020[:,1]==0\n", + "neg_sicct.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "03084d9e-7b6f-4c84-b7e1-3dbe2eebaa7f", + "metadata": {}, + "outputs": [], + "source": [ + "# confimed breakdowns\n", + "confirmed_bd = y_2020\n", + "confirmed_bd.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "09e608f0-7945-4fb2-a225-4fc6f97f1a2e", + "metadata": {}, + "outputs": [], + "source": [ + "# positive predictions\n", + "pos_predict = y_2020_predicted\n", + "pos_predict.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f1224de-4412-4fbc-9f71-7e9b68e4dc4c", + "metadata": {}, + "outputs": [], + "source": [ + "# early detections are negative siccts that led to confiremd BD that the model predicted positive\n", + "early_detects = neg_sicct & confirmed_bd & pos_predict\n", + "early_detects.sum()" + ] + }, + { + "cell_type": "markdown", + "id": "7e1a8a42-6c35-45a3-9e4c-de7aec974af0", + "metadata": {}, + "source": [ + "### Test with HSp maximised" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "502c5dc0-8ccd-422b-a662-55e50d280d3d", + "metadata": {}, + "outputs": [], + "source": [ + "# What if we maximise specificty instead?\n", + "\n", + "sens_thresh = max(thresholds[sens>= Se_sicct]) # threshold with max specificity s.t. sensitivity >= SICCT\n", + "best_spec = max(spec[sens >= Se_sicct]) # specificty s.t. sensitivity >= SICCT\n", + "\n", + "sens_thresh , best_spec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76a014bd-6b3e-4c97-bab5-be3071c37204", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at HSp threshold\n", + "y_test_predicted_hsp = predict_with_threshold(X_test, model, sens_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "38850be3-e638-4ef0-abf8-6d75e70751ac", + "metadata": {}, + "source": [ + "---\n", + "# Plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "684f1b2d-197b-4920-8624-881e42490323", + "metadata": {}, + "outputs": [], + "source": [ + "#Projections:\n", + "bng = 'epsg:27700' # British National Grid\n", + "wgs84 = 'epsg:4326' # Lat.Long." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20681a9a-e764-4018-9bd8-0ea3c503143f", + "metadata": {}, + "outputs": [], + "source": [ + "#UK base map\n", + "uk_shp = gp.read_file('/Data/Shapefiles/bdline_essh_gb/Data/Supplementary_Country/country_region.shp').to_crs(wgs84)\n", + "#uk_shp.plot(color='white', edgecolor='black')" + ] + }, + { + "cell_type": "markdown", + "id": "36f72d9c-b845-4295-bed9-24da037d64d5", + "metadata": {}, + "source": [ + "## Plot residuals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06f28e8a-027b-4887-a663-d6af3a56b450", + "metadata": {}, + "outputs": [], + "source": [ + "residual = (y_test != y_test_predicted)\n", + "\n", + "test_locs = pd.DataFrame({'locationX':X_test[:,5],'locationY':X_test[:,6],'date':X_test[:,0]})\n", + "test_locs.date = test_locs.date.astype('datetime64[ns]')\n", + "\n", + "test_locs['residual'] = residual\n", + "\n", + "test_geo = gp.GeoDataFrame(test_locs,geometry=gp.points_from_xy(test_locs.locationX,test_locs.locationY,crs=bng))\n", + "test_geo = test_geo.to_crs(wgs84)\n", + "\n", + "errors = test_geo[test_geo.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a4105b-c64c-4f15-a480-8e7b9e5b8f41", + "metadata": {}, + "outputs": [], + "source": [ + "# residuals for 2020 only\n", + "residual_2020 = (y_2020 != y_2020_predicted)\n", + "\n", + "locs_2020 = pd.DataFrame({'locationX':X_2020[:,5],'locationY':X_2020[:,6],'date':X_2020[:,0]})\n", + "locs_2020.date = locs_2020.date.astype('datetime64[ns]')\n", + "\n", + "locs_2020['residual'] = residual_2020\n", + "\n", + "geo_2020 = gp.GeoDataFrame(locs_2020,geometry=gp.points_from_xy(locs_2020.locationX,locs_2020.locationY,crs=bng))\n", + "geo_2020 = geo_2020.to_crs(wgs84)\n", + "\n", + "errors_2020 = geo_2020[geo_2020.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a69a1d1-93c9-46a3-a1cc-d61d2dcb638c", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "#test_geo.plot('residual', markersize=1.0, ax=ax)\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "errors.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c6fed6e-e265-444c-a8ff-c2221b0957dd", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in 2020 set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_2020.plot(markersize=1.0, ax=ax)\n", + "errors_2020.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62bf025c-f5b9-4bc5-8d06-9df48399078a", + "metadata": {}, + "outputs": [], + "source": [ + "# Residuals in time\n", + "test_times = pd.DataFrame({'date':X_test[:,0].copy().astype('datetime64[ns]')})\n", + "test_times['residual'] = residual\n", + "error_times = test_times[test_times.residual]\n", + "e = error_times.groupby(error_times[\"date\"].dt.year).count().date\n", + "t = test_times.groupby(test_times[\"date\"].dt.year).count().date\n", + "(e/t).plot.bar()\n", + "plt.title(\"Proportion of model misclassifications by year\")\n", + "plt.xlabel('Year')\n", + "plt.savefig('../Paper/figs/temporal.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "440f57da-101a-4d2d-8120-3d7f31841401", + "metadata": {}, + "source": [ + "## Plot newly discovered positives" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daf3f8f6-dd51-4767-99d9-1b17e0ec5375", + "metadata": {}, + "outputs": [], + "source": [ + "new_detected = (X_test[:,1]==0) & y_test & y_test_predicted\n", + "test_geo['new_detected'] = new_detected\n", + "new_detected.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0044411-c380-4245-9fca-5710095bb9a5", + "metadata": {}, + "outputs": [], + "source": [ + "# new detected in 2020\n", + "new_detected_2020 = (X_2020[:,1]==0) & y_2020 & y_2020_predicted\n", + "geo_2020['new_detected'] = new_detected_2020\n", + "new_detected_2020.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a811c36-47d2-4332-8892-acda30a5e0f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. newly detected herds in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "test_geo[test_geo.new_detected].plot(markersize=1.0, color='gold', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b83e5b01-9076-4bca-b7b6-d148da851f76", + "metadata": {}, + "outputs": [], + "source": [ + "ax = uk_shp.to_crs(bng).plot(alpha=0.2, figsize=(10,20))\n", + "seaborn.kdeplot(ax=ax, x=test_geo[test_geo.new_detected].locationX, y=test_geo[test_geo.new_detected].locationY, fill=True, color='gold')" + ] + }, + { + "cell_type": "markdown", + "id": "3cb5f9c0-6520-4322-9ffb-c574896fcd03", + "metadata": {}, + "source": [ + "## Plot Residuals and Newly Detected density by area" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d20700e-71e5-4093-8739-b9c41e119dd7", + "metadata": {}, + "outputs": [], + "source": [ + "## Create grid for normalising misclassifcation\n", + "# total area for the grid\n", + "xmin, ymin, xmax, ymax = uk_shp.to_crs(bng).total_bounds\n", + "#size of cell\n", + "cell_size = 10000 #10km x 10km squares\n", + "# create the cells in a loop\n", + "grid_cells = []\n", + "for x0 in np.arange(xmin, xmax+cell_size, cell_size ):\n", + " for y0 in np.arange(ymin, ymax+cell_size, cell_size):\n", + " # bounds\n", + " x1 = x0-cell_size\n", + " y1 = y0+cell_size\n", + " grid_cells.append(shapely.geometry.box(x0, y0, x1, y1))\n", + "grid = gp.GeoDataFrame(grid_cells, columns=['geometry'], crs=bng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd76d332-b6f3-4546-bb96-ecad028ed72b", + "metadata": {}, + "outputs": [], + "source": [ + "# For normalisation of map cells by number of tests in cell:\n", + "# number of tests in test set\n", + "grid_n_tests = gp.sjoin(test_geo.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests['n_tests'] = 1\n", + "grid_n_tests_d = grid_n_tests.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_d.index, 'n_tests'] = grid_n_tests_d.n_tests.values\n", + "\n", + "#number of tests in 2002\n", + "grid_n_tests_2020 = gp.sjoin(geo_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests_2020['n_tests20'] = 1\n", + "grid_n_tests_2020_d = grid_n_tests_2020.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_2020_d.index, 'n_tests20'] = grid_n_tests_2020_d.n_tests.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d57d06ba-c729-4894-b625-00068334487e", + "metadata": {}, + "outputs": [], + "source": [ + "grid.plot(column='n_tests')\n", + "grid.plot(column='n_tests20')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "265fd8d7-e75f-4bc6-8b9f-386010dbba60", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join residuals with grid\n", + "## and plot to produce heatmap\n", + "errors_grid = gp.sjoin(errors.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors_grid['n_resid'] = 1\n", + "#errors_grid_d = errors_grid.dissolve(by=\"index_right\", aggfunc=\"count\")\n", + "errors_grid_n_resid = errors_grid[['index_right','n_resid']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors_grid_n_resid.index, 'n_resid'] = errors_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors_grid_n_resid.index, 'norm_resid'] = grid.loc[errors_grid_n_resid.index, 'n_resid'] / grid.loc[errors_grid_n_resid.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9cc036b-1972-4b35-8a42-c7312a0b65ef", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid\n", + "ax = grid.plot(column='n_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfa1fa72-f62b-446d-bd70-daf7ba15f364", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid (normalised)\n", + "ax = grid.plot(column='norm_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in test set)\")\n", + "plt.axis('off')\n", + "plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06873036-d6d9-4b70-89cb-f404a4374f62", + "metadata": {}, + "outputs": [], + "source": [ + "## compute residuals grid for 2020 only\n", + "errors20_grid = gp.sjoin(errors_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors20_grid['n_resid20'] = 1\n", + "errors20_grid_n_resid = errors20_grid[['index_right','n_resid20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'n_resid20'] = errors20_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'norm_resid20'] = grid.loc[errors20_grid_n_resid.index, 'n_resid20'] / grid.loc[errors20_grid_n_resid.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "542b0514-66d3-4af5-a679-046954217c3b", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid for 2020 only\n", + "ax = grid.plot(column='n_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56dc477-b066-4b04-b9e7-4be58b6f871d", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot normalised errors grid for 2020 only\n", + "ax = grid.plot(column='norm_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in 2020)\")\n", + "plt.axis('off')\n", + "plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')\n", + "plt.savefig('../Paper/figs/map_misclassified_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cc2daac-8175-408c-a9c3-1b814c0a539e", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid\n", + "## and plot to produce heatmap\n", + "new_detect_grid = gp.sjoin(test_geo[test_geo.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect_grid['n_new'] = 1\n", + "new_detect_grid_n_new = new_detect_grid[['index_right','n_new']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'n_new'] = new_detect_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'norm_new'] = grid.loc[new_detect_grid_n_new.index, 'n_new'] / grid.loc[new_detect_grid_n_new.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6221042-5b7d-436d-9e39-ad50a5d6e031", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid\n", + "ax = grid.plot(column='n_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in test set)\")\n", + "plt.axis('off')\n", + "plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')\n", + "plt.savefig('../Paper/figs/map_newly_detected.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcd7a9c2-c819-419e-93e9-404ab49472b3", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4c3a5e4-937b-4abd-bf4e-a27c90ba3a46", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid, normalised\n", + "ax = grid.plot(column='norm_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "567de8db-925f-41c3-83f2-7c29208b90e2", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid in 2020 only\n", + "new_detect20_grid = gp.sjoin(geo_2020[geo_2020.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect20_grid['n_new20'] = 1\n", + "new_detect20_grid_n_new = new_detect20_grid[['index_right','n_new20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'n_new20'] = new_detect20_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'norm_new20'] = grid.loc[new_detect20_grid_n_new.index, 'n_new20'] / grid.loc[new_detect20_grid_n_new.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f584801-1387-4272-9971-bfc8f6ebc325", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only\n", + "ax = grid.plot(column='n_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "plt.savefig('../Paper/figs/map_newly_detected_number_2020.pdf',bbox_inches='tight')\n", + "plt.savefig('../Paper/figs/map_newly_detected_number_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8874e3ba-9215-4b57-bedf-21f98653bc7b", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new20.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c120f9-fa27-462b-91f7-ae873e01ae4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only, normalised\n", + "ax = grid.plot(column='norm_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in 2020)\")\n", + "plt.axis('off')\n", + "plt.savefig('../Paper/figs/map_newly_detected_2020.pdf',bbox_inches='tight')\n", + "plt.savefig('../Paper/figs/map_newly_detected_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "961a329b-0c9e-45c8-a1e1-6ea868e6c438", + "metadata": {}, + "source": [ + "## Partial dependence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00c6b6a9-e50f-431c-b19e-7783e8150a3e", + "metadata": {}, + "outputs": [], + "source": [ + "#### Takes ages and eats memory... ?!?!\n", + "#PartialDependenceDisplay.from_estimator(model, X_train, [4,5,(4,5)])" + ] + }, + { + "cell_type": "markdown", + "id": "19843de4-23f1-4b4b-988e-4e3cd5380308", + "metadata": {}, + "source": [ + "## Plot feature correlations" + ] + }, + { + "cell_type": "markdown", + "id": "1e3fc849-0d1c-4b34-8476-d8f9dad0603b", + "metadata": {}, + "source": [ + "---\n", + "# Analysis" + ] + }, + { + "cell_type": "markdown", + "id": "c2502ddd-f83b-4fd2-aa05-244d7f987292", + "metadata": {}, + "source": [ + "## What sort of herd is newly detected?\n", + "\n", + "* See map above for spatial distribution.\n", + "* What else?" + ] + }, + { + "cell_type": "markdown", + "id": "a4c5c4bb-b5b3-4482-b558-04c905c12b98", + "metadata": {}, + "source": [ + "## Confusion matrix" + ] + }, + { + "cell_type": "markdown", + "id": "cd3d9e83-4dbc-4dc7-9db9-8cc3e08416e4", + "metadata": {}, + "source": [ + "* all stats w.r.t. sicct alone and Stanski" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cabfe2e7-d1e4-4471-b84d-afde63ea268d", + "metadata": {}, + "outputs": [], + "source": [ + "np.array([['tp','tn'],['fp','fn']])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53fbb0e8-9d7d-42aa-84d9-51334f68a1da", + "metadata": {}, + "outputs": [], + "source": [ + "# Function ot calculate confusion matrix\n", + "# p = predicted class\n", + "# t = actual class\n", + "def confusion_matrix(p,t):\n", + " # True Positives\n", + " TP = (p&t).sum()\n", + " # True Negatives\n", + " TN = (~p&~t).sum()\n", + " # False Positives\n", + " FP = (p&~t).sum()\n", + " # False Negatives\n", + " FN = (~p&t).sum()\n", + " # return matrix (values and proportions)\n", + " total = len(p)\n", + " val_array = np.array([[TP,TN],[FP,FN]])\n", + " prop_array = np.around(np.array([[TP/total, TN/total], [FP/total, FN/total]]) * 100, 1)\n", + " return val_array , prop_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2387564b-eb96-4115-a389-5803d15f3ec0", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for test set\n", + "cm_model = confusion_matrix(y_2020_predicted,y_2020)\n", + "cm_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef6f3705-701a-4c7d-b3cf-c23271337e23", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(y_2020_predicted,y_2020), specificity(y_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fc27bf2-231b-4aca-814c-53b939e764cb", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for SICCT\n", + "sicct_2020_predicted = X_2020[:,1]==1\n", + "cm_sicct = confusion_matrix(sicct_2020_predicted,y_2020)\n", + "cm_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b235d247-0498-49d8-a1d3-4d191ffb2c46", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(sicct_2020_predicted,y_2020), specificity(sicct_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0454bc8a-bf69-4f60-b6e6-dfd4773f1100", + "metadata": {}, + "outputs": [], + "source": [ + "# reduction in FNs / FPs from SICCT to model\n", + "m_fn = cm_model[0][1][1]\n", + "m_fp = cm_model[0][1][0]\n", + "s_fn = cm_sicct[0][1][1]\n", + "s_fp = cm_sicct[0][1][0]\n", + "\n", + "print('FN reduction: ', (s_fn - m_fn)/s_fn, '\\nFP reduction:', (s_fp - m_fp)/s_fp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f490636-f5aa-4262-8e28-c05a98348f10", + "metadata": {}, + "outputs": [], + "source": [ + "print('2020 HSe increase: ',(sensitivity(y_2020_predicted,y_2020) - sensitivity(sicct_2020_predicted,y_2020))/sensitivity(sicct_2020_predicted,y_2020))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6bd39d1-5099-4ad4-9fde-a54f3aec4ccd", + "metadata": {}, + "outputs": [], + "source": [ + "# How many FPs caught by HSp maximisation threshold?\n", + "y_2020_predicted_hse = predict_with_threshold(X_2020, model, sens_thresh)\n", + "# TNs where test was P, vs. all TNs (in 2020)\n", + "sum((X_2020[:,1] == 1) & (y_2020_predicted_hse == 0) & (y_2020 == 0)) , sum((y_2020_predicted_hse == 0) & (y_2020 == 0))" + ] + }, + { + "cell_type": "markdown", + "id": "a332ef62-d490-45d8-8540-79b142158479", + "metadata": {}, + "source": [ + "## Number of days to breakdown (distribution)\n", + "\n", + "* for newly detected / vs sicct detected\n", + "* other? " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2649fb1d-68a8-49f4-b15f-8ac4567b3cf8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "5dfc972c-248b-41a0-a0c3-764c3869c9b5", + "metadata": {}, + "source": [ + "## Normalised spatial analysis\n", + "\n", + "* residuals normalised by no. of tests\n", + "* new detections ---\"---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68413072-ebe9-47c7-8af8-a4075b41577d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ad6e367d-9bb1-409a-a175-1efdded44533", + "metadata": {}, + "source": [ + "---\n", + "# Other TODOs:" + ] + }, + { + "cell_type": "markdown", + "id": "fb21b12c-8be2-4488-af6b-78e51e0a0bcd", + "metadata": {}, + "source": [ + "\n", + "* Permutation importance with multicolinear features\n", + "* Gold standard period? 90-day or other? Test?\n", + "* Time / area split models\n", + " - Fix fitting to existing model, or re-tune for each model?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_final_model_VetOnly-Control.ipynb b/bTB-Diagnostic_2020_final_model_VetOnly-Control.ipynb new file mode 100644 index 0000000..fb01add --- /dev/null +++ b/bTB-Diagnostic_2020_final_model_VetOnly-Control.ipynb @@ -0,0 +1,1637 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "edd1ac98-7c57-46b3-9ca0-9b4dfe74e8ef", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "435fc234-bef8-4f8a-9d70-9615b2357a0e", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbc72ded-6afc-4705-ab8d-42c3256bf3bc", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "#from sklearn.utils.fixes import loguniform\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_validate, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform\n", + "from sklearn.inspection import PartialDependenceDisplay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62536a34-b964-4bf5-bb81-175568b6b215", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load\n", + "from copy import deepcopy" + ] + }, + { + "cell_type": "markdown", + "id": "ac1e7837-b69f-4a30-acc4-36ed795efe29", + "metadata": {}, + "source": [ + "---\n", + "# Config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5260a1ce-8c2a-4324-8d78-0891028f8bfa", + "metadata": {}, + "outputs": [], + "source": [ + "## Control runs (labels)\n", + "controls = [1,2,3,4,5]\n", + "runs = ['vet'] + ['full'] + controls" + ] + }, + { + "cell_type": "markdown", + "id": "bcf17004-07e3-4dd7-844c-9019dd289561", + "metadata": {}, + "source": [ + "# Load training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1977ad48-146c-4df8-b1d5-8d04fa26fff9", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the train/test splits used in training controls\n", + "X_train = {}\n", + "X_test = {}\n", + "y_train = {}\n", + "y_test = {}\n", + "for r in controls:\n", + " X_train[r], X_test[r], y_train[r], y_test[r] = load('/Data/TB_Diagnostics/final_data_split_VetOnly_Control_'+str(r)+'.data')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f70ebe58-5fa2-43b0-a3d2-4814d6cc42e7", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the train/test splits used in training vet only run\n", + "r = 'vet'\n", + "X_train[r], X_test[r], y_train[r], y_test[r] = load('/Data/TB_Diagnostics/final_data_split_VetOnly.data')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae555eb4-7f28-43ef-8fef-be2d24888f6c", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the train/test splits used in training full model\n", + "r = 'full'\n", + "X_train[r], X_test[r], y_train[r], y_test[r] = load('/Data/TB_Diagnostics/final_data_split.data')" + ] + }, + { + "cell_type": "markdown", + "id": "0ee29cd1-a9b8-44d4-aebc-e7ee91c995ae", + "metadata": {}, + "source": [ + "# Load original data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b3a0ef6-a83a-489d-80ec-e8388ab0a5ca", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars_VetOnly.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6c458044-540f-4a9e-b7d8-e1b6110f5e12", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "048a5e8f-2963-466b-87eb-ee95ebe77238", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35b0a512-e740-4352-b1ed-e0ac8698cd8d", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "markdown", + "id": "55c3a959-c82e-438f-9e7c-8f1097ce8858", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44ec9fbc-7654-4fd3-b68b-fbee40cec379", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "43e67834-1086-4273-82f4-aaa79426d31e", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e767706d-81c9-447b-a7a4-7780a7d68b53", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = {}\n", + "Se_sicct = {}\n", + "Sp_sicct = {}\n", + "Acc_sicct = {}\n", + "specificity_threshold = {}\n", + "\n", + "for r in runs:\n", + " sicct[r] = X_test[r][:,1].astype(bool)\n", + " ## Sensitivity\n", + " Se_sicct[r] = sensitivity(sicct[r],y_test[r])\n", + " ## Specificity\n", + " Sp_sicct[r] = specificity(sicct[r],y_test[r])\n", + " ## Accuracy\n", + " Acc_sicct[r] = (sicct[r]==y_test[r]).sum() / len(y_test[r])\n", + " # Set specificity threshold to level for SICCT-only prediction\n", + " specificity_threshold[r] = Sp_sicct[r]" + ] + }, + { + "cell_type": "markdown", + "id": "6fd52b52-caf6-42bf-9952-855c46d6347b", + "metadata": { + "tags": [] + }, + "source": [ + "# Load models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f02aac16-bec1-45f6-8f7c-bc221562f5f9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Load cross-validated and fit model\n", + "model = {}\n", + "for r in controls:\n", + " model[r] = load('/Data/TB_Diagnostics/final_model_VetOnly_Control_'+str(r)+'.model')\n", + " \n", + "model['vet'] = load('/Data/TB_Diagnostics/final_model_VetOnly.model')\n", + "model['full'] = load('/Data/TB_Diagnostics/final_model.model')" + ] + }, + { + "cell_type": "markdown", + "id": "18f01840-d148-412d-b355-0a5295c84212", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9dcc7f3-e102-4dfb-b7af-343568f8121b", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "for r in controls:\n", + " print(model[r].score(X_test[r],y_test[r]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90128953-7914-41f5-96be-78dc241971fa", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = {}\n", + "y_score = {}\n", + "for r in runs:\n", + " y_test_result[r] = model[r].predict(X_test[r])\n", + " y_score[r] = model[r].decision_function(X_test[r])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63a9b46-26b7-4510-9b82-096e51c2216c", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = {}\n", + "for r in controls:\n", + " Se[r] = sensitivity(y_test_result[r],y_test[r])\n", + " print(Se[r])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edd92f58-ebc8-4364-9436-d56d8f1e0365", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = {}\n", + "for r in controls:\n", + " Sp[r] = specificity(y_test_result[r],y_test[r])\n", + " print(Sp[r])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ac84b5d-0ac3-4eea-9856-1e91beef126c", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "Acc = {}\n", + "for r in controls:\n", + " Acc[r] = (y_test_result[r]==y_test[r]).sum() / len(y_test[r])\n", + " print(Acc[r])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3588c1ad-fc26-4a0f-84b6-4257f6e406c7", + "metadata": {}, + "outputs": [], + "source": [ + "np.mean(list(Acc.values()))" + ] + }, + { + "cell_type": "markdown", + "id": "a788b3e4-cbf7-4fac-99ce-68273d9aabed", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41ec298c-991d-459c-a6c7-668fe091f635", + "metadata": {}, + "outputs": [], + "source": [ + "fpr = {}\n", + "tpr = {}\n", + "roc_auc = {}\n", + "for r in controls:\n", + " fpr[r], tpr[r], _ = roc_curve(y_test[r],y_score[r])\n", + " roc_auc[r] = auc(fpr[r],tpr[r])\n", + " print(roc_auc[r])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6677e340-fb0a-438c-b05b-fa039cc8bf03", + "metadata": {}, + "outputs": [], + "source": [ + "np.mean(list(roc_auc.values()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a3ee1fc-8abb-4b7d-8df9-d0fa73109374", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#function ot plot roc curve\n", + "def plot_roc(fpr,tpr,roc_auc,run):\n", + " r=run\n", + " plt.figure()\n", + " lw = 2\n", + " plt.plot(\n", + " fpr[r],\n", + " tpr[r],\n", + " lw=lw,\n", + " label=\"Model (AUC = %0.2f)\" % roc_auc[r],\n", + " )\n", + " plt.plot(1-Sp_sicct[r],Se_sicct[r],'+', label=\"SICCT only\", ms='15')\n", + " plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\", label='Random')\n", + " plt.xlim([0.0, 1.0])\n", + " plt.ylim([0.0, 1.0])\n", + " plt.xlabel(\"(1 - Specificity)\")\n", + " plt.ylabel(\"Sensitivity\")\n", + " plt.title(\"Receiver operating characteristic\")\n", + " plt.legend(loc=\"lower right\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f5f3bb-e7a3-4987-a381-6824560d836a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "for r in controls:\n", + " plot_roc(fpr,tpr,roc_auc,r)\n", + " #plt.savefig('../Paper/figs/roc.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "3a0f6f31-b460-4b73-89b0-90edb0302818", + "metadata": {}, + "source": [ + "---\n", + "# Feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06b2bcc-5da2-46b3-8a91-9e130fd17a49", + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "## Calcuate permutation importance\n", + "importance = {}\n", + "for r in runs:\n", + " importance[r] = permutation_importance(model[r],X_test[r],y_test[r], n_repeats=20, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80121727-2789-476b-9456-0584686afbfe", + "metadata": {}, + "outputs": [], + "source": [ + "#for i in importance.importances_mean.argsort()[::-1]:\n", + "# if abs(importance.importances_mean[i]) - 2 * importance.importances_std[i] > 0:\n", + "# print('*',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])\n", + "# else:\n", + "# print(' ',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dacd01f1-a312-4f1c-a33a-3433e54e4e4b", + "metadata": {}, + "outputs": [], + "source": [ + "#transform into table\n", + "importance_table = {}\n", + "for r in runs:\n", + " importance_table[r] = pd.DataFrame(importance[r].importances.T, columns=data_X.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00152d7d-c5c6-4213-8563-e9519f66a059", + "metadata": {}, + "outputs": [], + "source": [ + "mean_order = {}\n", + "mean_order_nozero = {}\n", + "for r in runs:\n", + " mean_order[r] = list(importance_table[r].mean().sort_values(ascending=False).index)\n", + " mean_order_nozero[r] = list(importance_table[r].mean()[importance_table[r].mean()>0].sort_values(ascending=False).index)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1606583-aef9-455a-9b5e-2391ed666894", + "metadata": {}, + "outputs": [], + "source": [ + "#Define feature lables:\n", + "feature_labels = {'resultOfTest':'Herd-level SICCT result',\n", + " 'locationY':'Holding location Easting',\n", + " 'locationX':'Holding location Northing',\n", + " 'daysSinceBreakdown':'Days since herd breakdown *',\n", + " 'dateOfTest':'Date of herd SICCT testing',\n", + " 'animalsTested':'Number of animals tested',\n", + " 'daysSincePreviousTest':'Time since previous SICCT test in herd *',\n", + " 'severe':'Was the severe interpretation applied?',\n", + " 'previousTestResult':'Result of last previous SICCT test in herd *',\n", + " 'previousTestResult2':'Result of 2nd last previous SICCT test in herd *',\n", + " 'gammaTestCount':'Number of historical GammaIFN test events in herd',\n", + " 'rand':'Set of uniformly distributed random numbers (CONTROL)',\n", + " 'inflow1':'Animals moved into herd, 1 year',\n", + " 'inflow2':'Animals moved into herd, 2 years',\n", + " 'inflow4':'Animals moved into herd, 4 years',\n", + " 'inflow90':'Animals moved into herd, 90 days',\n", + " 'outflow1':'Animals moved out herd, 1 year',\n", + " 'outflow2':'Animals moved out herd, 2 years',\n", + " 'outflow4':'Animals moved out herd, 4 years',\n", + " 'outflow90':'Animals moved out herd, 90 days',\n", + " 'inflowBD1':'Animals moved into herd, 1 year, from recent breakdown herds',\n", + " 'inflowBD2':'Animals moved into herd, 2 years, from recent breakdown herds',\n", + " 'inflowBD4':'Animals moved into herd, 4 years, from recent breakdown herds',\n", + " 'inflowBD90':'Animals moved into herd, 90 days, from recent breakdown herds',\n", + " 'outflowBD1':'Animals moved out herd, 1 year, from recent breakdown herds',\n", + " 'outflowBD2':'Animals moved out herd, 2 years, from recent breakdown herds',\n", + " 'outflowBD4':'Animals moved out herd, 4 years, from recent breakdown herds',\n", + " 'outflowBD90':'Animals moved out herd, 90 days, from recent breakdown herds',\n", + " 'vetPractice':'Veterinary practice conducting the test **',\n", + " 'batchBovine':'Tuberculin batch (bovine) **',\n", + " 'batchAvian':'Tuberculin batch (avian) **',\n", + " 'testType':'Type of testing event',\n", + " 'herdSize':'Size of herd at time of test',\n", + " 'herdType':'Herd type (dairy, beef, etc.)',\n", + " 'monthOfTest':'Month in which test taken',\n", + " 'defraRiskScore':'APHA risk score for herd',\n", + " 'meanBadgerAbundance':'Mean badger abundance'}\n", + "def feature_label(x):\n", + " try:\n", + " return feature_labels[x]\n", + " except KeyError:\n", + " return x " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386786c4-ccbc-4695-988c-5088d667af5d", + "metadata": {}, + "outputs": [], + "source": [ + "# apply labels to mean ordered set\n", + "mean_order_labels = {}\n", + "mean_order_nozero_labels = {}\n", + "for r in runs:\n", + " mean_order_labels[r] = list(map(lambda x:feature_label(x), mean_order[r]))\n", + " mean_order_nozero_labels[r] = list(map(lambda x:feature_label(x), mean_order_nozero[r]))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00dcff23-3e9f-4e4e-8342-f8ebd8679b54", + "metadata": {}, + "outputs": [], + "source": [ + "# aggregate all control runs:\n", + "importance_table_agg = pd.concat([importance_table[r] for r in controls],ignore_index=True)\n", + "mean_order_agg = list(importance_table_agg.mean().sort_values(ascending=False).index)\n", + "mean_order_labels_agg = list(map(lambda x:feature_label(x), mean_order_agg))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b990941b-35ed-4461-bcc6-0d20fedc1db9", + "metadata": {}, + "outputs": [], + "source": [ + "## Plot aggregate rankings over all runs:\n", + "\n", + "plt.rcParams.update({'font.size': 28})\n", + "#fig, ax = plt.subplots(figsize=(16,20))\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(importance_table_agg[mean_order_agg], orient='h', errorbar='ci', ax=ax)\n", + "ax.set_yticklabels(mean_order_labels_agg)\n", + "plt.title('Relative importance of model features (5 Control runs)')\n", + "#plt.xscale('log')\n", + "plt.savefig('../Paper/figs/importance-reduced-sample.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5292ca64-112b-4226-821c-1b4ad0c81a2b", + "metadata": {}, + "outputs": [], + "source": [ + "importance_table_agg.mean()['vetPractice']*100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c6439f-d5d9-46c5-ba1e-cfe09a0d6585", + "metadata": {}, + "outputs": [], + "source": [ + "## Plot for each run:\n", + "\n", + "for r in runs:\n", + " plt.rcParams.update({'font.size': 28})\n", + " fig, ax = plt.subplots(figsize=(16,20))\n", + " seaborn.barplot(importance_table[r][mean_order[r]], orient='h', errorbar='ci', ax=ax)\n", + " ax.set_yticklabels(mean_order_labels[r])\n", + " plt.title('Relative importance of model features')\n", + " #plt.xscale('log')\n", + " #plt.savefig('../Paper/figs/importance.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19f128aa-8d97-4d5a-87d3-33f584c4a833", + "metadata": {}, + "outputs": [], + "source": [ + "for r in controls:\n", + " print(importance_table[r].mean()['vetPractice']*100)" + ] + }, + { + "cell_type": "markdown", + "id": "5b81c03e-cf60-40c7-9114-99a607c0df20", + "metadata": {}, + "source": [ + "## Plot controls against vet only and full model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6bb90a7b-300b-4387-9ba8-3b1bf0249a9a", + "metadata": {}, + "outputs": [], + "source": [ + "#label tables:\n", + "#importance_table['full']['model'] = 'full'\n", + "#importance_table['vet']['model'] = 'vet'\n", + "#importance_table_agg['model'] = 'control'\n", + "\n", + "#aggregate\n", + "#importance_table_forPlot = pd.concat([importance_table['full'],importance_table['vet'],importance_table_agg])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b1aa99a-967c-4c59-a3dd-411d5cedad8c", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 28})\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "ord = mean_order['full']\n", + "seaborn.barplot(importance_table['full'][ord], orient='h', errorbar='ci', ax=ax, label='full', color='red', alpha=0.5)\n", + "#seaborn.barplot(importance_table['vet'][ord], orient='h', errorbar='ci', ax=ax, label='vet', color='blue', alpha=0.5)\n", + "ax.set_yticklabels(mean_order_labels['full'])\n", + "plt.title('Relative importance of model features (all runs)')\n", + "plt.xscale('log')\n", + "plt.legend()\n", + "#plt.savefig('../Paper/figs/importance.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5aff16a7-5983-40f5-81de-31b714be0918", + "metadata": {}, + "outputs": [], + "source": [ + "importance_table['full'].melt()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6e160cc-1b15-4b33-8bb2-1f29b25485ba", + "metadata": {}, + "outputs": [], + "source": [ + "# Rearange tables for plotting grouped bar plot for all three models (full, vet, controls):\n", + "\n", + "order_by = 'vet' #order all by full model order\n", + "ord = mean_order[order_by] \n", + "ord_labels = mean_order_labels[order_by]\n", + "\n", + "t1 = importance_table['full'][ord].melt()\n", + "t1['model'] = 'full'\n", + "t2 = importance_table['vet'][ord].melt()\n", + "t2['model'] = 'vet'\n", + "t3 = importance_table_agg[ord].melt()\n", + "t3['model'] = 'control'\n", + "\n", + "t_all = pd.concat([t1,t2,t3])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4bc69f67-ad40-47e0-9632-0cc0e56b3604", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot all models in grouped bar plot:\n", + "\n", + "plt.rcParams.update({'font.size': 28})\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(ax=ax, data=t_all, x='value', y='variable', hue='model', errorbar='ci')\n", + "ax.set_yticklabels(ord_labels)\n", + "plt.title('Relative importance of model features (all models)')\n", + "#plt.xscale('log')\n", + "#plt.legend()\n", + "plt.ylabel(None)\n", + "plt.xlabel('Relative Importance')\n", + "#plt.savefig('test.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "21d4c471-2f52-40e3-8922-7c123713c442", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f14688-94e3-4290-b8d4-01bc49b52619", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55c1c89c-47ec-4ca1-a55c-10757085f87c", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0391e6cb-5d30-42da-864a-64a39ab56fe7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "# function to plot thresholds\n", + "plt.plot(thresholds,sens,label='Model HSe')\n", + "plt.plot(thresholds,spec,label='Model HSp')\n", + "plt.xlim(0.2,0.8)\n", + "plt.ylim(0.5,1.0)\n", + "best_sens_label = 'Chosen HSe = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT HSe = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT HSp = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Chosen threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1, 0.5))\n", + "#plt.savefig('../Paper/figs/decision_threshold.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6f5e010-d528-4d1e-89b5-4f0db5ddaca3", + "metadata": {}, + "outputs": [], + "source": [ + "# Percentage increase in sensitivity over SICCT alone\n", + "(best_sens - Se_sicct)/Se_sicct*100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8207d271-4edc-400d-875e-9c69d604b6b6", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold\n", + "y_test_predicted = predict_with_threshold(X_test, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "034b0bb2-d40e-49c1-a05f-9b0ee7f03533", + "metadata": {}, + "source": [ + "### Test on 2020 only data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29e8fcbc-416b-4034-b09d-7ab1b4604fcb", + "metadata": {}, + "outputs": [], + "source": [ + "# Get data for 2020 only\n", + "mask_2020 = data.dateOfTest.apply(lambda x:x.year)==2020\n", + "X_2020 = data_X[mask_2020].to_numpy()\n", + "y_2020 = data_y[mask_2020]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63f7f7b-9bfb-4059-a937-3a687301389c", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold for 2020 only\n", + "y_2020_predicted = predict_with_threshold(X_2020, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "7e1a8a42-6c35-45a3-9e4c-de7aec974af0", + "metadata": {}, + "source": [ + "### Test with HSp maximised" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "502c5dc0-8ccd-422b-a662-55e50d280d3d", + "metadata": {}, + "outputs": [], + "source": [ + "# What if we maximise specificty instead?\n", + "\n", + "sens_thresh = max(thresholds[sens>= Se_sicct]) # threshold with max specificity s.t. sensitivity >= SICCT\n", + "best_spec = max(spec[sens >= Se_sicct]) # specificty s.t. sensitivity >= SICCT\n", + "\n", + "sens_thresh , best_spec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76a014bd-6b3e-4c97-bab5-be3071c37204", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at HSp threshold\n", + "y_test_predicted_hsp = predict_with_threshold(X_test, model, sens_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "38850be3-e638-4ef0-abf8-6d75e70751ac", + "metadata": {}, + "source": [ + "---\n", + "# Plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "684f1b2d-197b-4920-8624-881e42490323", + "metadata": {}, + "outputs": [], + "source": [ + "#Projections:\n", + "bng = 'epsg:27700' # British National Grid\n", + "wgs84 = 'epsg:4326' # Lat.Long." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20681a9a-e764-4018-9bd8-0ea3c503143f", + "metadata": {}, + "outputs": [], + "source": [ + "#UK base map\n", + "uk_shp = gp.read_file('/Data/Shapefiles/bdline_essh_gb/Data/Supplementary_Country/country_region.shp').to_crs(wgs84)\n", + "#uk_shp.plot(color='white', edgecolor='black')" + ] + }, + { + "cell_type": "markdown", + "id": "36f72d9c-b845-4295-bed9-24da037d64d5", + "metadata": {}, + "source": [ + "## Plot residuals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06f28e8a-027b-4887-a663-d6af3a56b450", + "metadata": {}, + "outputs": [], + "source": [ + "residual = (y_test != y_test_predicted)\n", + "\n", + "test_locs = pd.DataFrame({'locationX':X_test[:,5],'locationY':X_test[:,6],'date':X_test[:,0]})\n", + "test_locs.date = test_locs.date.astype('datetime64[ns]')\n", + "\n", + "test_locs['residual'] = residual\n", + "\n", + "test_geo = gp.GeoDataFrame(test_locs,geometry=gp.points_from_xy(test_locs.locationX,test_locs.locationY,crs=bng))\n", + "test_geo = test_geo.to_crs(wgs84)\n", + "\n", + "errors = test_geo[test_geo.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a4105b-c64c-4f15-a480-8e7b9e5b8f41", + "metadata": {}, + "outputs": [], + "source": [ + "# residuals for 2020 only\n", + "residual_2020 = (y_2020 != y_2020_predicted)\n", + "\n", + "locs_2020 = pd.DataFrame({'locationX':X_2020[:,5],'locationY':X_2020[:,6],'date':X_2020[:,0]})\n", + "locs_2020.date = locs_2020.date.astype('datetime64[ns]')\n", + "\n", + "locs_2020['residual'] = residual_2020\n", + "\n", + "geo_2020 = gp.GeoDataFrame(locs_2020,geometry=gp.points_from_xy(locs_2020.locationX,locs_2020.locationY,crs=bng))\n", + "geo_2020 = geo_2020.to_crs(wgs84)\n", + "\n", + "errors_2020 = geo_2020[geo_2020.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a69a1d1-93c9-46a3-a1cc-d61d2dcb638c", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "#test_geo.plot('residual', markersize=1.0, ax=ax)\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "errors.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c6fed6e-e265-444c-a8ff-c2221b0957dd", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in 2020 set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_2020.plot(markersize=1.0, ax=ax)\n", + "errors_2020.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62bf025c-f5b9-4bc5-8d06-9df48399078a", + "metadata": {}, + "outputs": [], + "source": [ + "# Residuals in time\n", + "test_times = pd.DataFrame({'date':X_test[:,0].copy().astype('datetime64[ns]')})\n", + "test_times['residual'] = residual\n", + "error_times = test_times[test_times.residual]\n", + "e = error_times.groupby(error_times[\"date\"].dt.year).count().date\n", + "t = test_times.groupby(test_times[\"date\"].dt.year).count().date\n", + "(e/t).plot.bar()\n", + "plt.title(\"Proportion of model misclassifications by year\")\n", + "plt.xlabel('Year')\n", + "#plt.savefig('../Paper/figs/temporal.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "440f57da-101a-4d2d-8120-3d7f31841401", + "metadata": {}, + "source": [ + "## Plot newly discovered positives" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daf3f8f6-dd51-4767-99d9-1b17e0ec5375", + "metadata": {}, + "outputs": [], + "source": [ + "new_detected = (~X_test[:,1].astype(bool) & y_test_predicted)\n", + "test_geo['new_detected'] = new_detected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0044411-c380-4245-9fca-5710095bb9a5", + "metadata": {}, + "outputs": [], + "source": [ + "# new detected in 2020\n", + "new_detected_2020 = (~X_2020[:,1].astype(bool) & y_2020_predicted)\n", + "geo_2020['new_detected'] = new_detected_2020" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a811c36-47d2-4332-8892-acda30a5e0f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. newly detected herds in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "test_geo[test_geo.new_detected].plot(markersize=1.0, color='gold', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b83e5b01-9076-4bca-b7b6-d148da851f76", + "metadata": {}, + "outputs": [], + "source": [ + "ax = uk_shp.to_crs(bng).plot(alpha=0.2, figsize=(10,20))\n", + "seaborn.kdeplot(ax=ax, x=test_geo[test_geo.new_detected].locationX, y=test_geo[test_geo.new_detected].locationY, fill=True, color='gold')" + ] + }, + { + "cell_type": "markdown", + "id": "3cb5f9c0-6520-4322-9ffb-c574896fcd03", + "metadata": {}, + "source": [ + "## Plot Residuals and Newly Detected density by area" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d20700e-71e5-4093-8739-b9c41e119dd7", + "metadata": {}, + "outputs": [], + "source": [ + "## Create grid for normalising misclassifcation\n", + "# total area for the grid\n", + "xmin, ymin, xmax, ymax = uk_shp.to_crs(bng).total_bounds\n", + "#size of cell\n", + "cell_size = 10000 #10km x 10km squares\n", + "# create the cells in a loop\n", + "grid_cells = []\n", + "for x0 in np.arange(xmin, xmax+cell_size, cell_size ):\n", + " for y0 in np.arange(ymin, ymax+cell_size, cell_size):\n", + " # bounds\n", + " x1 = x0-cell_size\n", + " y1 = y0+cell_size\n", + " grid_cells.append(shapely.geometry.box(x0, y0, x1, y1))\n", + "grid = gp.GeoDataFrame(grid_cells, columns=['geometry'], crs=bng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd76d332-b6f3-4546-bb96-ecad028ed72b", + "metadata": {}, + "outputs": [], + "source": [ + "# For normalisation of map cells by number of tests in cell:\n", + "# number of tests in test set\n", + "grid_n_tests = gp.sjoin(test_geo.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests['n_tests'] = 1\n", + "grid_n_tests_d = grid_n_tests.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_d.index, 'n_tests'] = grid_n_tests_d.n_tests.values\n", + "\n", + "#number of tests in 2002\n", + "grid_n_tests_2020 = gp.sjoin(geo_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests_2020['n_tests20'] = 1\n", + "grid_n_tests_2020_d = grid_n_tests_2020.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_2020_d.index, 'n_tests20'] = grid_n_tests_2020_d.n_tests.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d57d06ba-c729-4894-b625-00068334487e", + "metadata": {}, + "outputs": [], + "source": [ + "grid.plot(column='n_tests')\n", + "grid.plot(column='n_tests20')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "265fd8d7-e75f-4bc6-8b9f-386010dbba60", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join residuals with grid\n", + "## and plot to produce heatmap\n", + "errors_grid = gp.sjoin(errors.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors_grid['n_resid'] = 1\n", + "#errors_grid_d = errors_grid.dissolve(by=\"index_right\", aggfunc=\"count\")\n", + "errors_grid_n_resid = errors_grid[['index_right','n_resid']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors_grid_n_resid.index, 'n_resid'] = errors_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors_grid_n_resid.index, 'norm_resid'] = grid.loc[errors_grid_n_resid.index, 'n_resid'] / grid.loc[errors_grid_n_resid.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9cc036b-1972-4b35-8a42-c7312a0b65ef", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid\n", + "ax = grid.plot(column='n_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfa1fa72-f62b-446d-bd70-daf7ba15f364", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid (normalised)\n", + "ax = grid.plot(column='norm_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06873036-d6d9-4b70-89cb-f404a4374f62", + "metadata": {}, + "outputs": [], + "source": [ + "## compute residuals grid for 2020 only\n", + "errors20_grid = gp.sjoin(errors_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors20_grid['n_resid20'] = 1\n", + "errors20_grid_n_resid = errors20_grid[['index_right','n_resid20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'n_resid20'] = errors20_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'norm_resid20'] = grid.loc[errors20_grid_n_resid.index, 'n_resid20'] / grid.loc[errors20_grid_n_resid.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "542b0514-66d3-4af5-a679-046954217c3b", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid for 2020 only\n", + "ax = grid.plot(column='n_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56dc477-b066-4b04-b9e7-4be58b6f871d", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot normalised errors grid for 2020 only\n", + "ax = grid.plot(column='norm_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cc2daac-8175-408c-a9c3-1b814c0a539e", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid\n", + "## and plot to produce heatmap\n", + "new_detect_grid = gp.sjoin(test_geo[test_geo.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect_grid['n_new'] = 1\n", + "new_detect_grid_n_new = new_detect_grid[['index_right','n_new']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'n_new'] = new_detect_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'norm_new'] = grid.loc[new_detect_grid_n_new.index, 'n_new'] / grid.loc[new_detect_grid_n_new.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6221042-5b7d-436d-9e39-ad50a5d6e031", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid\n", + "ax = grid.plot(column='n_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcd7a9c2-c819-419e-93e9-404ab49472b3", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4c3a5e4-937b-4abd-bf4e-a27c90ba3a46", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid, normalised\n", + "ax = grid.plot(column='norm_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "567de8db-925f-41c3-83f2-7c29208b90e2", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid in 2020 only\n", + "new_detect20_grid = gp.sjoin(geo_2020[geo_2020.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect20_grid['n_new20'] = 1\n", + "new_detect20_grid_n_new = new_detect20_grid[['index_right','n_new20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'n_new20'] = new_detect20_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'norm_new20'] = grid.loc[new_detect20_grid_n_new.index, 'n_new20'] / grid.loc[new_detect20_grid_n_new.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f584801-1387-4272-9971-bfc8f6ebc325", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only\n", + "ax = grid.plot(column='n_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8874e3ba-9215-4b57-bedf-21f98653bc7b", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new20.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c120f9-fa27-462b-91f7-ae873e01ae4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only, normalised\n", + "ax = grid.plot(column='norm_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.pdf',bbox_inches='tight')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "961a329b-0c9e-45c8-a1e1-6ea868e6c438", + "metadata": {}, + "source": [ + "## Partial dependence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00c6b6a9-e50f-431c-b19e-7783e8150a3e", + "metadata": {}, + "outputs": [], + "source": [ + "#### Takes ages and eats memory... ?!?!\n", + "#PartialDependenceDisplay.from_estimator(model, X_train, [4,5,(4,5)])" + ] + }, + { + "cell_type": "markdown", + "id": "19843de4-23f1-4b4b-988e-4e3cd5380308", + "metadata": {}, + "source": [ + "## Plot feature correlations" + ] + }, + { + "cell_type": "markdown", + "id": "1e3fc849-0d1c-4b34-8476-d8f9dad0603b", + "metadata": {}, + "source": [ + "---\n", + "# Analysis" + ] + }, + { + "cell_type": "markdown", + "id": "c2502ddd-f83b-4fd2-aa05-244d7f987292", + "metadata": {}, + "source": [ + "## What sort of herd is newly detected?\n", + "\n", + "* See map above for spatial distribution.\n", + "* What else?" + ] + }, + { + "cell_type": "markdown", + "id": "a4c5c4bb-b5b3-4482-b558-04c905c12b98", + "metadata": {}, + "source": [ + "## Confusion matrix" + ] + }, + { + "cell_type": "markdown", + "id": "cd3d9e83-4dbc-4dc7-9db9-8cc3e08416e4", + "metadata": {}, + "source": [ + "* all stats w.r.t. sicct alone and Stanski" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cabfe2e7-d1e4-4471-b84d-afde63ea268d", + "metadata": {}, + "outputs": [], + "source": [ + "np.array([['tp','tn'],['fp','fn']])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53fbb0e8-9d7d-42aa-84d9-51334f68a1da", + "metadata": {}, + "outputs": [], + "source": [ + "# Function ot calculate confusion matrix\n", + "# p = predicted class\n", + "# t = actual class\n", + "def confusion_matrix(p,t):\n", + " # True Positives\n", + " TP = (p&t).sum()\n", + " # True Negatives\n", + " TN = (~p&~t).sum()\n", + " # False Positives\n", + " FP = (p&~t).sum()\n", + " # False Negatives\n", + " FN = (~p&t).sum()\n", + " # return matrix (values and proportions)\n", + " total = len(p)\n", + " val_array = np.array([[TP,TN],[FP,FN]])\n", + " prop_array = np.around(np.array([[TP/total, TN/total], [FP/total, FN/total]]) * 100, 1)\n", + " return val_array , prop_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2387564b-eb96-4115-a389-5803d15f3ec0", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for test set\n", + "cm_model = confusion_matrix(y_2020_predicted,y_2020)\n", + "cm_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef6f3705-701a-4c7d-b3cf-c23271337e23", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(y_2020_predicted,y_2020), specificity(y_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fc27bf2-231b-4aca-814c-53b939e764cb", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for SICCT\n", + "sicct_2020_predicted = X_2020[:,1]==1\n", + "cm_sicct = confusion_matrix(sicct_2020_predicted,y_2020)\n", + "cm_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b235d247-0498-49d8-a1d3-4d191ffb2c46", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(sicct_2020_predicted,y_2020), specificity(sicct_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0454bc8a-bf69-4f60-b6e6-dfd4773f1100", + "metadata": {}, + "outputs": [], + "source": [ + "# reduction in FNs / FPs from SICCT to model\n", + "m_fn = cm_model[0][1][1]\n", + "m_fp = cm_model[0][1][0]\n", + "s_fn = cm_sicct[0][1][1]\n", + "s_fp = cm_sicct[0][1][0]\n", + "\n", + "print('FN reduction: ', (s_fn - m_fn)/s_fn, '\\nFP reduction:', (s_fp - m_fp)/s_fp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f490636-f5aa-4262-8e28-c05a98348f10", + "metadata": {}, + "outputs": [], + "source": [ + "print('2020 HSe increase: ',(sensitivity(y_2020_predicted,y_2020) - sensitivity(sicct_2020_predicted,y_2020))/sensitivity(sicct_2020_predicted,y_2020))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6bd39d1-5099-4ad4-9fde-a54f3aec4ccd", + "metadata": {}, + "outputs": [], + "source": [ + "# How many FPs caught by HSp maximisation threshold?\n", + "y_2020_predicted_hse = predict_with_threshold(X_2020, model, sens_thresh)\n", + "# TNs where test was P, vs. all TNs (in 2020)\n", + "sum((X_2020[:,1] == 1) & (y_2020_predicted_hse == 0) & (y_2020 == 0)) , sum((y_2020_predicted_hse == 0) & (y_2020 == 0))" + ] + }, + { + "cell_type": "markdown", + "id": "a332ef62-d490-45d8-8540-79b142158479", + "metadata": {}, + "source": [ + "## Number of days to breakdown (distribution)\n", + "\n", + "* for newly detected / vs sicct detected\n", + "* other? " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2649fb1d-68a8-49f4-b15f-8ac4567b3cf8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "5dfc972c-248b-41a0-a0c3-764c3869c9b5", + "metadata": {}, + "source": [ + "## Normalised spatial analysis\n", + "\n", + "* residuals normalised by no. of tests\n", + "* new detections ---\"---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68413072-ebe9-47c7-8af8-a4075b41577d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ad6e367d-9bb1-409a-a175-1efdded44533", + "metadata": {}, + "source": [ + "---\n", + "# Other TODOs:" + ] + }, + { + "cell_type": "markdown", + "id": "fb21b12c-8be2-4488-af6b-78e51e0a0bcd", + "metadata": {}, + "source": [ + "\n", + "* Permutation importance with multicolinear features\n", + "* Gold standard period? 90-day or other? Test?\n", + "* Time / area split models\n", + " - Fix fitting to existing model, or re-tune for each model?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_final_model_VetOnly.ipynb b/bTB-Diagnostic_2020_final_model_VetOnly.ipynb new file mode 100644 index 0000000..c1865e5 --- /dev/null +++ b/bTB-Diagnostic_2020_final_model_VetOnly.ipynb @@ -0,0 +1,1741 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "edd1ac98-7c57-46b3-9ca0-9b4dfe74e8ef", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "435fc234-bef8-4f8a-9d70-9615b2357a0e", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbc72ded-6afc-4705-ab8d-42c3256bf3bc", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "#from sklearn.utils.fixes import loguniform\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_validate, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform\n", + "from sklearn.inspection import PartialDependenceDisplay" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62536a34-b964-4bf5-bb81-175568b6b215", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load\n", + "from copy import deepcopy" + ] + }, + { + "cell_type": "markdown", + "id": "69503fc4-a026-416e-bec2-1942c887738e", + "metadata": {}, + "source": [ + "---\n", + "# Load original data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "adbef996-7012-43fa-9fcd-b00008af862f", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars_VetOnly.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bce7e6b4-7092-4a55-a1f1-9a9d0d7161e6", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c4ea4432-4955-4cb6-ac7e-9350d8072da4", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1e13433-6b57-4810-9100-7f717b5f74f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b326537d-15d7-40ac-9bc2-1026a6518ec9", + "metadata": {}, + "outputs": [], + "source": [ + "# Detect categorical features (<= 3 categories and explicit named features)\n", + "named_cat_features = ['vetPractice','batchBovine','batchAvian']\n", + "cat_features = []\n", + "for c in data_X.columns:\n", + " catf = len(data_X[c].unique())<=3\n", + " if c in named_cat_features:\n", + " catf = True\n", + " cat_features.append(catf)\n", + "\n", + "# NB: this is fine for boolean features (inc. missing values)\n", + "# but needs a proper encoding for true categorical features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7f0099ab-5200-4dcc-b8cb-e298ee128e5e", + "metadata": {}, + "outputs": [], + "source": [ + "# Convery all to float matrix\n", + "#data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "bcf17004-07e3-4dd7-844c-9019dd289561", + "metadata": {}, + "source": [ + "# Load training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1977ad48-146c-4df8-b1d5-8d04fa26fff9", + "metadata": {}, + "outputs": [], + "source": [ + "# Load the train/test split used in training\n", + "X_train, X_test, y_train, y_test = load('/Data/TB_Diagnostics/final_data_split_VetOnly.data')" + ] + }, + { + "cell_type": "markdown", + "id": "bfd95433-6c64-4cde-8d38-ab4f349e78a7", + "metadata": {}, + "source": [ + "### Create new train/test splits with pre-/post-cull" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5cffd778-d3f0-43ad-bb78-7a38fb084afa", + "metadata": {}, + "outputs": [], + "source": [ + "#pre_cull = data.dateOfTest<\"2016\"\n", + "#data_X_pre = data_X[pre_cull].to_numpy()\n", + "#data_y_pre = data_y[pre_cull]\n", + "#X_train_pre, X_test_pre, y_train_pre, y_test_pre = train_test_split(data_X_pre, data_y_pre, test_size=0.20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3628601-ee7c-4a3c-abea-de4a3bdd0ca4", + "metadata": {}, + "outputs": [], + "source": [ + "#post_cull = data.dateOfTest>=\"2016\"\n", + "#data_X_post = data_X[post_cull].to_numpy()\n", + "#data_y_post = data_y[post_cull]\n", + "#X_train_post, X_test_post, y_train_post, y_test_post = train_test_split(data_X_post, data_y_post, test_size=0.20)" + ] + }, + { + "cell_type": "markdown", + "id": "55c3a959-c82e-438f-9e7c-8f1097ce8858", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "44ec9fbc-7654-4fd3-b68b-fbee40cec379", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "43e67834-1086-4273-82f4-aaa79426d31e", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39fd9f81-f9f2-4bbd-8302-214c125f48af", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = X_test[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "533481d1-fa41-4cdc-8036-75d9381386d8", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se_sicct = sensitivity(sicct,y_test)\n", + "Se_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a54891f-63ac-4403-8eaf-dc2d77370d2d", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp_sicct = specificity(sicct,y_test)\n", + "Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f2e44fc6-f917-4adf-953f-0418da6a6e91", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(sicct==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61386875-f746-4271-94c8-b37dcd50e896", + "metadata": {}, + "outputs": [], + "source": [ + "# Set specificity threshold to level for SICCT-only prediction\n", + "specificity_threshold = Sp_sicct" + ] + }, + { + "cell_type": "markdown", + "id": "7974a2c0-1ea5-4076-b550-7a99fd98a616", + "metadata": {}, + "source": [ + "### Pre/post-cull SICCT performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b5eb78f-091b-4f8f-bd39-81f9f8f380dd", + "metadata": {}, + "outputs": [], + "source": [ + "#sicct_pre = X_test_pre[:,1].astype(bool)\n", + "#sicct_post = X_test_post[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bccc23e0-eb6e-4516-8beb-4ef558de6617", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "#Se_sicct_pre = sensitivity(sicct_pre,y_test_pre)\n", + "#Se_sicct_post = sensitivity(sicct_post,y_test_post)\n", + "#(Se_sicct_pre,Se_sicct_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "656f0151-689e-4a5c-b5e4-38659c9d79f2", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "#Sp_sicct_pre = specificity(sicct_pre,y_test_pre)\n", + "#Sp_sicct_post = specificity(sicct_post,y_test_post)\n", + "#(Sp_sicct_pre,Sp_sicct_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a69c72b7-82d2-4874-935f-1400fce64686", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "#((sicct_pre==y_test_pre).sum() / len(y_test_pre) , (sicct_post==y_test_post).sum() / len(y_test_post))" + ] + }, + { + "cell_type": "markdown", + "id": "6fd52b52-caf6-42bf-9952-855c46d6347b", + "metadata": { + "tags": [] + }, + "source": [ + "# Load model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f02aac16-bec1-45f6-8f7c-bc221562f5f9", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Load cross-validated and fit model\n", + "model = load('/Data/TB_Diagnostics/final_model_VetOnly.model')" + ] + }, + { + "cell_type": "markdown", + "id": "4af4ec63-571c-4355-99e5-d45468aeedd8", + "metadata": {}, + "source": [ + "## Pre-/post-cull models" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "384eec7c-60bd-42d7-89ff-5e34690d6a0f", + "metadata": {}, + "outputs": [], + "source": [ + "# get best parameter set from full model\n", + "#gbt_pre = deepcopy(model.best_estimator_)\n", + "#gbt_post = deepcopy(model.best_estimator_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1e2cb37-e4cb-4ba4-bf8b-aa42fcad00af", + "metadata": {}, + "outputs": [], + "source": [ + "# train subset models\n", + "#model_pre = gbt_pre.fit(X_train_pre,y_train_pre)\n", + "#model_post = gbt_post.fit(X_train_post,y_train_post)" + ] + }, + { + "cell_type": "markdown", + "id": "18f01840-d148-412d-b355-0a5295c84212", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9dcc7f3-e102-4dfb-b7af-343568f8121b", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "90128953-7914-41f5-96be-78dc241971fa", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = model.predict(X_test)\n", + "y_score = model.decision_function(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63a9b46-26b7-4510-9b82-096e51c2216c", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = sensitivity(y_test_result,y_test)\n", + "Se" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edd92f58-ebc8-4364-9436-d56d8f1e0365", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = specificity(y_test_result,y_test)\n", + "Sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2ac84b5d-0ac3-4eea-9856-1e91beef126c", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(y_test_result==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "7de21db1-c670-47b0-9115-fbc2f7a99a7f", + "metadata": {}, + "source": [ + "### Pre-/post-cull performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5338e23e-55c7-44db-ba6b-440573297474", + "metadata": {}, + "outputs": [], + "source": [ + "#model_pre.score(X_test_pre,y_test_pre)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b6a4e87-fa27-4121-8441-99a577186000", + "metadata": {}, + "outputs": [], + "source": [ + "#model_post.score(X_test_post,y_test_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ac67ac6-5234-4c07-b9f1-b645cc9fb146", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "#y_test_result_pre = model_pre.predict(X_test_pre)\n", + "#y_score_pre = model_pre.decision_function(X_test_pre)\n", + "#y_test_result_post = model_post.predict(X_test_post)\n", + "#y_score_post = model_post.decision_function(X_test_post)" + ] + }, + { + "cell_type": "markdown", + "id": "a788b3e4-cbf7-4fac-99ce-68273d9aabed", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41ec298c-991d-459c-a6c7-668fe091f635", + "metadata": {}, + "outputs": [], + "source": [ + "fpr, tpr, _ = roc_curve(y_test,y_score)\n", + "roc_auc = auc(fpr,tpr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95328ec9-6450-4246-b7f1-178efcffb92a", + "metadata": {}, + "outputs": [], + "source": [ + "roc_auc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a3ee1fc-8abb-4b7d-8df9-d0fa73109374", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "#function ot plot roc curve\n", + "def plot_roc(fpr,tpr,roc_auc):\n", + " plt.figure()\n", + " lw = 2\n", + " plt.plot(\n", + " fpr,\n", + " tpr,\n", + " lw=lw,\n", + " label=\"Model (AUC = %0.2f)\" % roc_auc,\n", + " )\n", + " plt.plot(1-Sp_sicct,Se_sicct,'+', label=\"SICCT only\", ms='15')\n", + " plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\", label='Random')\n", + " plt.xlim([0.0, 1.0])\n", + " plt.ylim([0.0, 1.0])\n", + " plt.xlabel(\"(1 - Specificity)\")\n", + " plt.ylabel(\"Sensitivity\")\n", + " plt.title(\"Receiver operating characteristic\")\n", + " plt.legend(loc=\"lower right\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f5f3bb-e7a3-4987-a381-6824560d836a", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "plot_roc(fpr,tpr,roc_auc)\n", + "#plt.savefig('../Paper/figs/roc.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "4b3059d0-a776-46dc-9c22-3b2caf9d2d49", + "metadata": {}, + "source": [ + "### Pre-post-cull ROC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bb5349e3-504e-47c8-848c-c9c4ea23edb5", + "metadata": {}, + "outputs": [], + "source": [ + "#fpr_pre, tpr_pre, _ = roc_curve(y_test_pre,y_score_pre)\n", + "#roc_auc_pre = auc(fpr_pre,tpr_pre)\n", + "#fpr_post, tpr_post, _ = roc_curve(y_test_post,y_score_post)\n", + "#roc_auc_post = auc(fpr_post,tpr_post)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d775404b-e526-4f81-bc75-8a3577870040", + "metadata": {}, + "outputs": [], + "source": [ + "#plot_roc(fpr_pre,tpr_pre,roc_auc_pre)\n", + "#plot_roc(fpr_post,tpr_post,roc_auc_post)" + ] + }, + { + "cell_type": "markdown", + "id": "3a0f6f31-b460-4b73-89b0-90edb0302818", + "metadata": {}, + "source": [ + "---\n", + "# Feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c06b2bcc-5da2-46b3-8a91-9e130fd17a49", + "metadata": {}, + "outputs": [], + "source": [ + "## Calcuate permutation importance\n", + "importance = permutation_importance(model,X_test,y_test, n_repeats=20, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "80121727-2789-476b-9456-0584686afbfe", + "metadata": {}, + "outputs": [], + "source": [ + "for i in importance.importances_mean.argsort()[::-1]:\n", + " if abs(importance.importances_mean[i]) - 2 * importance.importances_std[i] > 0:\n", + " print('*',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])\n", + " else:\n", + " print(' ',f\"{importance.importances_mean[i]:.5f}\", f\" +/- {importance.importances_std[i]:.5f}\", data_X.columns[i])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dacd01f1-a312-4f1c-a33a-3433e54e4e4b", + "metadata": {}, + "outputs": [], + "source": [ + "#transform into table\n", + "importance_table = pd.DataFrame(importance.importances.T, columns=data_X.columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d46e9901-1d69-4041-bfa2-0e0710e11ceb", + "metadata": {}, + "outputs": [], + "source": [ + "# FIX: drop 'species' (it is nonsense...)\n", + "#importance_table = importance_table.drop(columns=['species'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00152d7d-c5c6-4213-8563-e9519f66a059", + "metadata": {}, + "outputs": [], + "source": [ + "mean_order = list(importance_table.mean().sort_values(ascending=False).index)\n", + "mean_order_nozero = list(importance_table.mean()[importance_table.mean()>0].sort_values(ascending=False).index)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1606583-aef9-455a-9b5e-2391ed666894", + "metadata": {}, + "outputs": [], + "source": [ + "#Define feature lables:\n", + "feature_labels = {'resultOfTest':'Herd-level SICCT result',\n", + " 'locationY':'Holding location Easting',\n", + " 'locationX':'Holding location Northing',\n", + " 'daysSinceBreakdown':'Days since herd breakdown *',\n", + " 'dateOfTest':'Date of herd SICCT testing',\n", + " 'animalsTested':'Number of animals tested',\n", + " 'daysSincePreviousTest':'Time since previous SICCT test in herd *',\n", + " 'severe':'Was the severe interpretation applied?',\n", + " 'previousTestResult':'Result of last previous SICCT test in herd *',\n", + " 'previousTestResult2':'Result of 2nd last previous SICCT test in herd *',\n", + " 'gammaTestCount':'Number of historical GammaIFN test events in herd',\n", + " 'rand':'Set of uniformly distributed random numbers (CONTROL)',\n", + " 'inflow1':'Animals moved into herd, 1 year',\n", + " 'inflow2':'Animals moved into herd, 2 years',\n", + " 'inflow4':'Animals moved into herd, 4 years',\n", + " 'inflow90':'Animals moved into herd, 90 days',\n", + " 'outflow1':'Animals moved out herd, 1 year',\n", + " 'outflow2':'Animals moved out herd, 2 years',\n", + " 'outflow4':'Animals moved out herd, 4 years',\n", + " 'outflow90':'Animals moved out herd, 90 days',\n", + " 'inflowBD1':'Animals moved into herd, 1 year, from recent breakdown herds',\n", + " 'inflowBD2':'Animals moved into herd, 2 years, from recent breakdown herds',\n", + " 'inflowBD4':'Animals moved into herd, 4 years, from recent breakdown herds',\n", + " 'inflowBD90':'Animals moved into herd, 90 days, from recent breakdown herds',\n", + " 'outflowBD1':'Animals moved out herd, 1 year, from recent breakdown herds',\n", + " 'outflowBD2':'Animals moved out herd, 2 years, from recent breakdown herds',\n", + " 'outflowBD4':'Animals moved out herd, 4 years, from recent breakdown herds',\n", + " 'outflowBD90':'Animals moved out herd, 90 days, from recent breakdown herds',\n", + " 'vetPractice':'Veterinary practice conducting the test **',\n", + " 'batchBovine':'Tuberculin batch (bovine) **',\n", + " 'batchAvian':'Tuberculin batch (avian) **',\n", + " 'testType':'Type of testing event',\n", + " 'herdSize':'Size of herd at time of test',\n", + " 'herdType':'Herd type (dairy, beef, etc.)',\n", + " 'monthOfTest':'Month in which test taken',\n", + " 'defraRiskScore':'APHA risk score for herd',\n", + " 'meanBadgerAbundance':'Mean badger abundance'}\n", + "def feature_label(x):\n", + " try:\n", + " return feature_labels[x]\n", + " except KeyError:\n", + " return x " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "386786c4-ccbc-4695-988c-5088d667af5d", + "metadata": {}, + "outputs": [], + "source": [ + "# apply labels to mean ordered set\n", + "mean_order_labels = list(map(lambda x:feature_label(x), mean_order))\n", + "mean_order_nozero_labels = list(map(lambda x:feature_label(x), mean_order_nozero))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1c6439f-d5d9-46c5-ba1e-cfe09a0d6585", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 28})\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(importance_table[mean_order], orient='h', errorbar='ci', ax=ax)\n", + "ax.set_yticklabels(mean_order_labels)\n", + "plt.title('Relative importance of model features (with Vet data only)')\n", + "#plt.xscale('log')\n", + "plt.savefig('../Paper/figs/importance-vet-only.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fab18a82-8c38-44c3-a96a-03edcc21741e", + "metadata": {}, + "outputs": [], + "source": [ + "importance_table.mean()['vetPractice']*100" + ] + }, + { + "cell_type": "markdown", + "id": "292a0077-5639-4185-8132-a1623a0c3f5e", + "metadata": {}, + "source": [ + "### How much missing data?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9e029b7-8279-4fda-a4ac-a9bc1b28c08f", + "metadata": {}, + "outputs": [], + "source": [ + "# proportion of non-missing data in each feature:\n", + "non_miss = data_X.isna().sum() / len(data_X)\n", + "fig, ax = plt.subplots(figsize=(16,20))\n", + "seaborn.barplot(non_miss[mean_order], orient='h', ax=ax)\n", + "plt.xlim(0.0,1.0)\n", + "plt.title(\"Proportion of missing data\")\n", + "ax.bar_label(ax.containers[0])\n", + ";" + ] + }, + { + "cell_type": "markdown", + "id": "a3126928-d594-4bdc-9126-d5f98eb71ff1", + "metadata": {}, + "source": [ + "### Pre-/post-cull feature importance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9378a06c-8588-493a-9831-c057eaee97f0", + "metadata": {}, + "outputs": [], + "source": [ + "## Calcuate permutation importance\n", + "#importance_pre = permutation_importance(model_pre,X_test_pre,y_test_pre, n_repeats=10, n_jobs=-1)\n", + "#importance_post = permutation_importance(model_post,X_test_post,y_test_post, n_repeats=10, n_jobs=-1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37cd018f-2627-4eb3-8448-ead0c8836a2e", + "metadata": {}, + "outputs": [], + "source": [ + "# Pre-cull plot\n", + "#importance_table_pre = pd.DataFrame(importance_pre.importances.T, columns=data_X.columns)\n", + "#mean_order_pre = list(importance_table_pre.mean().sort_values(ascending=False).index)\n", + "#seaborn.barplot(importance_table_pre[mean_order_pre], orient='h')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3436526b-f1d0-40a9-bfa7-934e5e3c2d1f", + "metadata": {}, + "outputs": [], + "source": [ + "# Post-cull plot\n", + "#importance_table_post = pd.DataFrame(importance_post.importances.T, columns=data_X.columns)\n", + "#mean_order_post = list(importance_table_post.mean().sort_values(ascending=False).index)\n", + "#seaborn.barplot(importance_table_post[mean_order_post], orient='h')" + ] + }, + { + "cell_type": "markdown", + "id": "21d4c471-2f52-40e3-8922-7c123713c442", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e1f14688-94e3-4290-b8d4-01bc49b52619", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "55c1c89c-47ec-4ca1-a55c-10757085f87c", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0391e6cb-5d30-42da-864a-64a39ab56fe7", + "metadata": {}, + "outputs": [], + "source": [ + "plt.rcParams.update({'font.size': 16})\n", + "# function to plot thresholds\n", + "plt.plot(thresholds,sens,label='Model HSe')\n", + "plt.plot(thresholds,spec,label='Model HSp')\n", + "plt.xlim(0.2,0.8)\n", + "plt.ylim(0.5,1.0)\n", + "best_sens_label = 'Chosen HSe = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT HSe = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT HSp = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Chosen threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1, 0.5))\n", + "#plt.savefig('../Paper/figs/decision_threshold.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6f5e010-d528-4d1e-89b5-4f0db5ddaca3", + "metadata": {}, + "outputs": [], + "source": [ + "# Percentage increase in sensitivity over SICCT alone\n", + "(best_sens - Se_sicct)/Se_sicct*100" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8207d271-4edc-400d-875e-9c69d604b6b6", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold\n", + "y_test_predicted = predict_with_threshold(X_test, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "034b0bb2-d40e-49c1-a05f-9b0ee7f03533", + "metadata": {}, + "source": [ + "### Test on 2020 only data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29e8fcbc-416b-4034-b09d-7ab1b4604fcb", + "metadata": {}, + "outputs": [], + "source": [ + "# Get data for 2020 only\n", + "mask_2020 = data.dateOfTest.apply(lambda x:x.year)==2020\n", + "X_2020 = data_X[mask_2020].to_numpy()\n", + "y_2020 = data_y[mask_2020]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a63f7f7b-9bfb-4059-a937-3a687301389c", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at threshold for 2020 only\n", + "y_2020_predicted = predict_with_threshold(X_2020, model, best_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "7e1a8a42-6c35-45a3-9e4c-de7aec974af0", + "metadata": {}, + "source": [ + "### Test with HSp maximised" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "502c5dc0-8ccd-422b-a662-55e50d280d3d", + "metadata": {}, + "outputs": [], + "source": [ + "# What if we maximise specificty instead?\n", + "\n", + "sens_thresh = max(thresholds[sens>= Se_sicct]) # threshold with max specificity s.t. sensitivity >= SICCT\n", + "best_spec = max(spec[sens >= Se_sicct]) # specificty s.t. sensitivity >= SICCT\n", + "\n", + "sens_thresh , best_spec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "76a014bd-6b3e-4c97-bab5-be3071c37204", + "metadata": {}, + "outputs": [], + "source": [ + "#Predictions from model at HSp threshold\n", + "y_test_predicted_hsp = predict_with_threshold(X_test, model, sens_thresh)" + ] + }, + { + "cell_type": "markdown", + "id": "38850be3-e638-4ef0-abf8-6d75e70751ac", + "metadata": {}, + "source": [ + "---\n", + "# Plots" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "684f1b2d-197b-4920-8624-881e42490323", + "metadata": {}, + "outputs": [], + "source": [ + "#Projections:\n", + "bng = 'epsg:27700' # British National Grid\n", + "wgs84 = 'epsg:4326' # Lat.Long." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20681a9a-e764-4018-9bd8-0ea3c503143f", + "metadata": {}, + "outputs": [], + "source": [ + "#UK base map\n", + "uk_shp = gp.read_file('/Data/Shapefiles/bdline_essh_gb/Data/Supplementary_Country/country_region.shp').to_crs(wgs84)\n", + "#uk_shp.plot(color='white', edgecolor='black')" + ] + }, + { + "cell_type": "markdown", + "id": "36f72d9c-b845-4295-bed9-24da037d64d5", + "metadata": {}, + "source": [ + "## Plot residuals" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06f28e8a-027b-4887-a663-d6af3a56b450", + "metadata": {}, + "outputs": [], + "source": [ + "residual = (y_test != y_test_predicted)\n", + "\n", + "test_locs = pd.DataFrame({'locationX':X_test[:,5],'locationY':X_test[:,6],'date':X_test[:,0]})\n", + "test_locs.date = test_locs.date.astype('datetime64[ns]')\n", + "\n", + "test_locs['residual'] = residual\n", + "\n", + "test_geo = gp.GeoDataFrame(test_locs,geometry=gp.points_from_xy(test_locs.locationX,test_locs.locationY,crs=bng))\n", + "test_geo = test_geo.to_crs(wgs84)\n", + "\n", + "errors = test_geo[test_geo.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8a4105b-c64c-4f15-a480-8e7b9e5b8f41", + "metadata": {}, + "outputs": [], + "source": [ + "# residuals for 2020 only\n", + "residual_2020 = (y_2020 != y_2020_predicted)\n", + "\n", + "locs_2020 = pd.DataFrame({'locationX':X_2020[:,5],'locationY':X_2020[:,6],'date':X_2020[:,0]})\n", + "locs_2020.date = locs_2020.date.astype('datetime64[ns]')\n", + "\n", + "locs_2020['residual'] = residual_2020\n", + "\n", + "geo_2020 = gp.GeoDataFrame(locs_2020,geometry=gp.points_from_xy(locs_2020.locationX,locs_2020.locationY,crs=bng))\n", + "geo_2020 = geo_2020.to_crs(wgs84)\n", + "\n", + "errors_2020 = geo_2020[geo_2020.residual]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a69a1d1-93c9-46a3-a1cc-d61d2dcb638c", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "#test_geo.plot('residual', markersize=1.0, ax=ax)\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "errors.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c6fed6e-e265-444c-a8ff-c2221b0957dd", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. residuals in 2020 set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "geo_2020.plot(markersize=1.0, ax=ax)\n", + "errors_2020.plot(markersize=1.0, color='red', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62bf025c-f5b9-4bc5-8d06-9df48399078a", + "metadata": {}, + "outputs": [], + "source": [ + "# Residuals in time\n", + "test_times = pd.DataFrame({'date':X_test[:,0].copy().astype('datetime64[ns]')})\n", + "test_times['residual'] = residual\n", + "error_times = test_times[test_times.residual]\n", + "e = error_times.groupby(error_times[\"date\"].dt.year).count().date\n", + "t = test_times.groupby(test_times[\"date\"].dt.year).count().date\n", + "(e/t).plot.bar()\n", + "plt.title(\"Proportion of model misclassifications by year\")\n", + "plt.xlabel('Year')\n", + "#plt.savefig('../Paper/figs/temporal.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "440f57da-101a-4d2d-8120-3d7f31841401", + "metadata": {}, + "source": [ + "## Plot newly discovered positives" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "daf3f8f6-dd51-4767-99d9-1b17e0ec5375", + "metadata": {}, + "outputs": [], + "source": [ + "new_detected = (~X_test[:,1].astype(bool) & y_test_predicted)\n", + "test_geo['new_detected'] = new_detected" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0044411-c380-4245-9fca-5710095bb9a5", + "metadata": {}, + "outputs": [], + "source": [ + "# new detected in 2020\n", + "new_detected_2020 = (~X_2020[:,1].astype(bool) & y_2020_predicted)\n", + "geo_2020['new_detected'] = new_detected_2020" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a811c36-47d2-4332-8892-acda30a5e0f5", + "metadata": {}, + "outputs": [], + "source": [ + "# Location of observations vs. newly detected herds in Test set\n", + "ax = uk_shp.plot(alpha=0.2, figsize=(10,20))\n", + "test_geo.plot(markersize=1.0, ax=ax)\n", + "test_geo[test_geo.new_detected].plot(markersize=1.0, color='gold', ax=ax)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b83e5b01-9076-4bca-b7b6-d148da851f76", + "metadata": {}, + "outputs": [], + "source": [ + "ax = uk_shp.to_crs(bng).plot(alpha=0.2, figsize=(10,20))\n", + "seaborn.kdeplot(ax=ax, x=test_geo[test_geo.new_detected].locationX, y=test_geo[test_geo.new_detected].locationY, fill=True, color='gold')" + ] + }, + { + "cell_type": "markdown", + "id": "3cb5f9c0-6520-4322-9ffb-c574896fcd03", + "metadata": {}, + "source": [ + "## Plot Residuals and Newly Detected density by area" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d20700e-71e5-4093-8739-b9c41e119dd7", + "metadata": {}, + "outputs": [], + "source": [ + "## Create grid for normalising misclassifcation\n", + "# total area for the grid\n", + "xmin, ymin, xmax, ymax = uk_shp.to_crs(bng).total_bounds\n", + "#size of cell\n", + "cell_size = 10000 #10km x 10km squares\n", + "# create the cells in a loop\n", + "grid_cells = []\n", + "for x0 in np.arange(xmin, xmax+cell_size, cell_size ):\n", + " for y0 in np.arange(ymin, ymax+cell_size, cell_size):\n", + " # bounds\n", + " x1 = x0-cell_size\n", + " y1 = y0+cell_size\n", + " grid_cells.append(shapely.geometry.box(x0, y0, x1, y1))\n", + "grid = gp.GeoDataFrame(grid_cells, columns=['geometry'], crs=bng)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd76d332-b6f3-4546-bb96-ecad028ed72b", + "metadata": {}, + "outputs": [], + "source": [ + "# For normalisation of map cells by number of tests in cell:\n", + "# number of tests in test set\n", + "grid_n_tests = gp.sjoin(test_geo.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests['n_tests'] = 1\n", + "grid_n_tests_d = grid_n_tests.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_d.index, 'n_tests'] = grid_n_tests_d.n_tests.values\n", + "\n", + "#number of tests in 2002\n", + "grid_n_tests_2020 = gp.sjoin(geo_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "grid_n_tests_2020['n_tests20'] = 1\n", + "grid_n_tests_2020_d = grid_n_tests_2020.dissolve(by='index_right', aggfunc='count')\n", + "grid.loc[grid_n_tests_2020_d.index, 'n_tests20'] = grid_n_tests_2020_d.n_tests.values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d57d06ba-c729-4894-b625-00068334487e", + "metadata": {}, + "outputs": [], + "source": [ + "grid.plot(column='n_tests')\n", + "grid.plot(column='n_tests20')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "265fd8d7-e75f-4bc6-8b9f-386010dbba60", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join residuals with grid\n", + "## and plot to produce heatmap\n", + "errors_grid = gp.sjoin(errors.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors_grid['n_resid'] = 1\n", + "#errors_grid_d = errors_grid.dissolve(by=\"index_right\", aggfunc=\"count\")\n", + "errors_grid_n_resid = errors_grid[['index_right','n_resid']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors_grid_n_resid.index, 'n_resid'] = errors_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors_grid_n_resid.index, 'norm_resid'] = grid.loc[errors_grid_n_resid.index, 'n_resid'] / grid.loc[errors_grid_n_resid.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d9cc036b-1972-4b35-8a42-c7312a0b65ef", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid\n", + "ax = grid.plot(column='n_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cfa1fa72-f62b-446d-bd70-daf7ba15f364", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid (normalised)\n", + "ax = grid.plot(column='norm_resid', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06873036-d6d9-4b70-89cb-f404a4374f62", + "metadata": {}, + "outputs": [], + "source": [ + "## compute residuals grid for 2020 only\n", + "errors20_grid = gp.sjoin(errors_2020.to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute residuals per grid cell\n", + "errors20_grid['n_resid20'] = 1\n", + "errors20_grid_n_resid = errors20_grid[['index_right','n_resid20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'n_resid20'] = errors20_grid_n_resid\n", + "\n", + "# add normalised to grid\n", + "grid.loc[errors20_grid_n_resid.index, 'norm_resid20'] = grid.loc[errors20_grid_n_resid.index, 'n_resid20'] / grid.loc[errors20_grid_n_resid.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "542b0514-66d3-4af5-a679-046954217c3b", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot errors grid for 2020 only\n", + "ax = grid.plot(column='n_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Misclassified tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56dc477-b066-4b04-b9e7-4be58b6f871d", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot normalised errors grid for 2020 only\n", + "ax = grid.plot(column='norm_resid20', cmap='YlOrRd', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests misclassified by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.pdf',bbox_inches='tight')\n", + "#plt.savefig('../Paper/figs/map_misclassified_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0cc2daac-8175-408c-a9c3-1b814c0a539e", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid\n", + "## and plot to produce heatmap\n", + "new_detect_grid = gp.sjoin(test_geo[test_geo.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect_grid['n_new'] = 1\n", + "new_detect_grid_n_new = new_detect_grid[['index_right','n_new']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'n_new'] = new_detect_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect_grid_n_new.index, 'norm_new'] = grid.loc[new_detect_grid_n_new.index, 'n_new'] / grid.loc[new_detect_grid_n_new.index, 'n_tests']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e6221042-5b7d-436d-9e39-ad50a5d6e031", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid\n", + "ax = grid.plot(column='n_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcd7a9c2-c819-419e-93e9-404ab49472b3", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f4c3a5e4-937b-4abd-bf4e-a27c90ba3a46", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid, normalised\n", + "ax = grid.plot(column='norm_new', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in test set)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "567de8db-925f-41c3-83f2-7c29208b90e2", + "metadata": {}, + "outputs": [], + "source": [ + "## Spatial join newly detected with grid in 2020 only\n", + "new_detect20_grid = gp.sjoin(geo_2020[geo_2020.new_detected].to_crs(bng), grid, how='left', predicate='within')\n", + "\n", + "# Compute new detects per grid cell -- aggregate with dissolve\n", + "new_detect20_grid['n_new20'] = 1\n", + "new_detect20_grid_n_new = new_detect20_grid[['index_right','n_new20']].groupby(by=\"index_right\").count()\n", + "\n", + "# Add to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'n_new20'] = new_detect20_grid_n_new\n", + "\n", + "# add normalised to grid\n", + "grid.loc[new_detect20_grid_n_new.index, 'norm_new20'] = grid.loc[new_detect20_grid_n_new.index, 'n_new20'] / grid.loc[new_detect20_grid_n_new.index, 'n_tests20']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f584801-1387-4272-9971-bfc8f6ebc325", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only\n", + "ax = grid.plot(column='n_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Early detected tests by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.pdf',bbox_inches='tight')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8874e3ba-9215-4b57-bedf-21f98653bc7b", + "metadata": {}, + "outputs": [], + "source": [ + "grid.n_new20.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62c120f9-fa27-462b-91f7-ae873e01ae4f", + "metadata": {}, + "outputs": [], + "source": [ + "# Plot new grid for 2020 only, normalised\n", + "ax = grid.plot(column='norm_new20', cmap='cividis', legend=True, legend_kwds={'shrink': 0.3}, figsize=(10,20))\n", + "uk_shp.to_crs(bng).plot(ax=ax,alpha=0.1)\n", + "plt.title(\"Proportion of tests early detected by area (in 2020)\")\n", + "plt.axis('off')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.pdf',bbox_inches='tight')\n", + "#plt.savefig('../Paper/figs/map_newly_detected_2020.png',bbox_inches='tight')" + ] + }, + { + "cell_type": "markdown", + "id": "961a329b-0c9e-45c8-a1e1-6ea868e6c438", + "metadata": {}, + "source": [ + "## Partial dependence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "00c6b6a9-e50f-431c-b19e-7783e8150a3e", + "metadata": {}, + "outputs": [], + "source": [ + "#### Takes ages and eats memory... ?!?!\n", + "#PartialDependenceDisplay.from_estimator(model, X_train, [4,5,(4,5)])" + ] + }, + { + "cell_type": "markdown", + "id": "19843de4-23f1-4b4b-988e-4e3cd5380308", + "metadata": {}, + "source": [ + "## Plot feature correlations" + ] + }, + { + "cell_type": "markdown", + "id": "1e3fc849-0d1c-4b34-8476-d8f9dad0603b", + "metadata": {}, + "source": [ + "---\n", + "# Analysis" + ] + }, + { + "cell_type": "markdown", + "id": "c2502ddd-f83b-4fd2-aa05-244d7f987292", + "metadata": {}, + "source": [ + "## What sort of herd is newly detected?\n", + "\n", + "* See map above for spatial distribution.\n", + "* What else?" + ] + }, + { + "cell_type": "markdown", + "id": "a4c5c4bb-b5b3-4482-b558-04c905c12b98", + "metadata": {}, + "source": [ + "## Confusion matrix" + ] + }, + { + "cell_type": "markdown", + "id": "cd3d9e83-4dbc-4dc7-9db9-8cc3e08416e4", + "metadata": {}, + "source": [ + "* all stats w.r.t. sicct alone and Stanski" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cabfe2e7-d1e4-4471-b84d-afde63ea268d", + "metadata": {}, + "outputs": [], + "source": [ + "np.array([['tp','tn'],['fp','fn']])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53fbb0e8-9d7d-42aa-84d9-51334f68a1da", + "metadata": {}, + "outputs": [], + "source": [ + "# Function ot calculate confusion matrix\n", + "# p = predicted class\n", + "# t = actual class\n", + "def confusion_matrix(p,t):\n", + " # True Positives\n", + " TP = (p&t).sum()\n", + " # True Negatives\n", + " TN = (~p&~t).sum()\n", + " # False Positives\n", + " FP = (p&~t).sum()\n", + " # False Negatives\n", + " FN = (~p&t).sum()\n", + " # return matrix (values and proportions)\n", + " total = len(p)\n", + " val_array = np.array([[TP,TN],[FP,FN]])\n", + " prop_array = np.around(np.array([[TP/total, TN/total], [FP/total, FN/total]]) * 100, 1)\n", + " return val_array , prop_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2387564b-eb96-4115-a389-5803d15f3ec0", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for test set\n", + "cm_model = confusion_matrix(y_2020_predicted,y_2020)\n", + "cm_model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef6f3705-701a-4c7d-b3cf-c23271337e23", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(y_2020_predicted,y_2020), specificity(y_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1fc27bf2-231b-4aca-814c-53b939e764cb", + "metadata": {}, + "outputs": [], + "source": [ + "# confusion matrix for SICCT\n", + "sicct_2020_predicted = X_2020[:,1]==1\n", + "cm_sicct = confusion_matrix(sicct_2020_predicted,y_2020)\n", + "cm_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b235d247-0498-49d8-a1d3-4d191ffb2c46", + "metadata": {}, + "outputs": [], + "source": [ + "sensitivity(sicct_2020_predicted,y_2020), specificity(sicct_2020_predicted,y_2020)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0454bc8a-bf69-4f60-b6e6-dfd4773f1100", + "metadata": {}, + "outputs": [], + "source": [ + "# reduction in FNs / FPs from SICCT to model\n", + "m_fn = cm_model[0][1][1]\n", + "m_fp = cm_model[0][1][0]\n", + "s_fn = cm_sicct[0][1][1]\n", + "s_fp = cm_sicct[0][1][0]\n", + "\n", + "print('FN reduction: ', (s_fn - m_fn)/s_fn, '\\nFP reduction:', (s_fp - m_fp)/s_fp)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f490636-f5aa-4262-8e28-c05a98348f10", + "metadata": {}, + "outputs": [], + "source": [ + "print('2020 HSe increase: ',(sensitivity(y_2020_predicted,y_2020) - sensitivity(sicct_2020_predicted,y_2020))/sensitivity(sicct_2020_predicted,y_2020))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6bd39d1-5099-4ad4-9fde-a54f3aec4ccd", + "metadata": {}, + "outputs": [], + "source": [ + "# How many FPs caught by HSp maximisation threshold?\n", + "y_2020_predicted_hse = predict_with_threshold(X_2020, model, sens_thresh)\n", + "# TNs where test was P, vs. all TNs (in 2020)\n", + "sum((X_2020[:,1] == 1) & (y_2020_predicted_hse == 0) & (y_2020 == 0)) , sum((y_2020_predicted_hse == 0) & (y_2020 == 0))" + ] + }, + { + "cell_type": "markdown", + "id": "a332ef62-d490-45d8-8540-79b142158479", + "metadata": {}, + "source": [ + "## Number of days to breakdown (distribution)\n", + "\n", + "* for newly detected / vs sicct detected\n", + "* other? " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2649fb1d-68a8-49f4-b15f-8ac4567b3cf8", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "5dfc972c-248b-41a0-a0c3-764c3869c9b5", + "metadata": {}, + "source": [ + "## Normalised spatial analysis\n", + "\n", + "* residuals normalised by no. of tests\n", + "* new detections ---\"---" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68413072-ebe9-47c7-8af8-a4075b41577d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "ad6e367d-9bb1-409a-a175-1efdded44533", + "metadata": {}, + "source": [ + "---\n", + "# Other TODOs:" + ] + }, + { + "cell_type": "markdown", + "id": "fb21b12c-8be2-4488-af6b-78e51e0a0bcd", + "metadata": {}, + "source": [ + "\n", + "* Permutation importance with multicolinear features\n", + "* Gold standard period? 90-day or other? Test?\n", + "* Time / area split models\n", + " - Fix fitting to existing model, or re-tune for each model?" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_v4_crossVal+tuning_NEW.ipynb b/bTB-Diagnostic_2020_v4_crossVal+tuning_NEW.ipynb new file mode 100644 index 0000000..582907e --- /dev/null +++ b/bTB-Diagnostic_2020_v4_crossVal+tuning_NEW.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c8454912-7ef8-48d5-bc40-ab4a2d5ed557", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12f0587-9d1e-4a2a-8aa7-57062f389cc4", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b85c35f-6acf-4ed9-9ad5-8101ca434a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_val_score, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d842e20-252c-4ddf-8602-dab12137bfe5", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load" + ] + }, + { + "cell_type": "markdown", + "id": "216ed604-493f-47c0-933e-69c7be17d13d", + "metadata": {}, + "source": [ + "# Load data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66c256bb-b5f4-4bb9-8a77-b70e76263187", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "956b91a3-a22d-46b8-aa0c-868cc802dadc", + "metadata": {}, + "outputs": [], + "source": [ + "min(data.dateOfTest) , max(data.dateOfTest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "777af021-1f6a-4270-846c-9d7a90603a4e", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19ea9381-9e2e-42c5-8a89-c6480c5b0fb5", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19559dca-d5e5-4e73-907f-c0496474067d", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b09bee9-c8d9-41f2-9b00-dd2a06b28933", + "metadata": {}, + "outputs": [], + "source": [ + "# Detect categorical features (<= 3 categories and explicit named features)\n", + "named_cat_features = ['vetPractice','batchBovine','batchAvian']\n", + "cat_features = []\n", + "for c in data_X.columns:\n", + " catf = len(data_X[c].unique())<=3\n", + " if c in named_cat_features:\n", + " catf = True\n", + " cat_features.append(catf)\n", + "\n", + "# NB: this is fine for boolean features (inc. missing values)\n", + "# but needs a proper encoding for true categorical features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "247e72c9-11bc-47c4-af7e-4acd727e40f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Convery all to float matrix\n", + "data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "4b230243-181a-4b99-8718-904fd94f1409", + "metadata": {}, + "source": [ + "# Training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0332c21f-1377-45f6-8deb-6bb33117bbde", + "metadata": {}, + "outputs": [], + "source": [ + "# Hold a final test set (random)\n", + "X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, test_size=0.20)" + ] + }, + { + "cell_type": "markdown", + "id": "34be244f-ecdc-405f-ae24-722d23be7fdd", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291c4fc2-90ac-4e40-bc81-dbe5dde7f058", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "06670e4a-081b-4b54-b4d7-b950a3899292", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a4e9c32-3bdc-45a2-9ef7-3e35f89686e6", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = X_test[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b957aa88-d5c6-49a9-954e-f3830c88956d", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se_sicct = sensitivity(sicct,y_test)\n", + "Se_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5afe6254-69ff-447e-916d-acfe449583c3", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp_sicct = specificity(sicct,y_test)\n", + "Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "050ee423-09dc-49ab-b423-12ab4023ffec", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(sicct==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "d8e8dc4c-0587-4922-b53e-068d866aead4", + "metadata": {}, + "source": [ + "### Custom model scoring function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b680b7d-78f9-4d5c-900b-b9c4426c0332", + "metadata": {}, + "outputs": [], + "source": [ + "# Set specificity threshold to level for SICCT-only prediction\n", + "specificity_threshold = Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ffd35bd-fbee-47e5-812d-db40934a6032", + "metadata": {}, + "outputs": [], + "source": [ + "# define a custom score function:\n", + "# score keeps specificity above SICCT, maximises sensitivity\n", + "def sensspec_score(t,p): #input: true (t) and predicted (p) classes\n", + " if specificity(p,t) < specificity_threshold:\n", + " return 0\n", + " else:\n", + " return sensitivity(p,t)\n", + "\n", + "custom_score = make_scorer(sensspec_score)" + ] + }, + { + "cell_type": "markdown", + "id": "ed335a56-8bdc-4f75-be84-464f18452c2c", + "metadata": { + "tags": [] + }, + "source": [ + "# Hyperparameter tuning / cross-validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d33da3a-cdf3-4552-8518-c9228bcaf606", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Create model\n", + "gbt = GBT(categorical_features=cat_features, class_weight='balanced')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5aaff0-11de-4a39-a9dc-3f2fe925c646", + "metadata": {}, + "outputs": [], + "source": [ + "# define parameter spaces\n", + "param_grid = {'learning_rate':[1.0,0.1,0.01,0.001,0.0005,0.0001,0.00005,0.00001],\n", + " 'max_leaf_nodes':[2,5,10,20,30,50,100,500,1000]}\n", + "\n", + "#param_dists = {'learning_rate':loguniform(0.00001,1.0),\n", + "# 'max_leaf_nodes':randint(2,10000)}\n", + "\n", + "param_dists = {'learning_rate':loguniform(0.01,1.0),\n", + " 'max_leaf_nodes':randint(2,2000)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acbf1509-6d32-446b-849a-aef78fbe5b72", + "metadata": {}, + "outputs": [], + "source": [ + "# perform grid search\n", + "#model = GridSearchCV(gbt, param_grid, n_jobs=-1, cv=5)#, scoring='recall') #5-fold cross-validation #n_jobs -1 for all procs\n", + "#start_time = time.time()\n", + "#model.fit(X_train, y_train)\n", + "#print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "002e9e98-19ad-4b73-911e-aee0aa8b002f", + "metadata": {}, + "outputs": [], + "source": [ + "# perform random search\n", + "model = RandomizedSearchCV(gbt, param_dists, n_jobs=-1, cv=10, verbose=1, n_iter=100, scoring='roc_auc')\n", + "start_time = time.time()\n", + "model.fit(X_train, y_train)\n", + "print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c82261ac-4794-4946-9a3c-9250ca2abd3c", + "metadata": {}, + "outputs": [], + "source": [ + "model.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43f701c4-5c64-404d-becb-589bc570b120", + "metadata": {}, + "outputs": [], + "source": [ + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "020db8a3-b880-44b0-b216-9d729c24e81d", + "metadata": {}, + "outputs": [], + "source": [ + "tuning_results = pd.DataFrame(model.cv_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10ae9194-678c-4e45-8cc3-e5bb0890b47b", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='mean_test_score',hue='param_max_leaf_nodes')\n", + "plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe93d0e5-084e-47bb-9d34-413613301ad4", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_max_leaf_nodes',y='mean_test_score',hue='param_learning_rate')\n", + "#plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2268492-5484-4e06-a3a6-38caa03eb236", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='param_max_leaf_nodes', hue='mean_fit_time',size='mean_test_score',sizes=(1,200))\n", + "#plt.axhline(model.best_params_['max_leaf_nodes'],ls='--', label='Best fit')\n", + "#plt.axvline(model.best_params_['learning_rate'],ls='--')\n", + "plt.scatter(model.best_params_['learning_rate'],model.best_params_['max_leaf_nodes'], marker='+', c='b', s=300)\n", + "plt.xscale('log')" + ] + }, + { + "cell_type": "markdown", + "id": "07c6a8d8-02ff-442c-9392-b1bc522b923f", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53f99b2b-20ea-42b7-b052-8879a4e193b0", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "534591d2-c5b7-4a15-b2a1-1b66b3b8c088", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = model.predict(X_test)\n", + "y_score = model.decision_function(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cb3cb9d-4a82-4092-97b1-111bcc15a410", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = sensitivity(y_test_result,y_test)\n", + "Se" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2c803c1-0b0a-4d14-8661-ac1e0346d836", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = specificity(y_test_result,y_test)\n", + "Sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4787cebd-043c-4a6e-8c43-4a30d8e9dfd0", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(y_test_result==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "a178d34b-4fd9-46ef-bba1-3ea392f64a6f", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "253e1f2f-dd84-4635-8dc2-1c47c24ff60f", + "metadata": {}, + "outputs": [], + "source": [ + "fpr, tpr, _ = roc_curve(y_test,y_score)\n", + "roc_auc = auc(fpr,tpr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41a1dafb-75b4-42f6-bfac-1ce9fb521cfe", + "metadata": {}, + "outputs": [], + "source": [ + "roc_auc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ac1678-e96c-4203-b014-fe9532219038", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plt.figure()\n", + "lw = 2\n", + "plt.plot(\n", + " fpr,\n", + " tpr,\n", + " lw=lw,\n", + " label=\"ROC curve, model (area = %0.2f)\" % roc_auc,\n", + ")\n", + "plt.plot(1-Sp_sicct,Se_sicct,'+', label=\"SICCT only\", ms='15')\n", + "plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\")\n", + "plt.xlim([0.0, 1.0])\n", + "plt.ylim([0.0, 1.0])\n", + "plt.xlabel(\"(1 - Specificity)\")\n", + "plt.ylabel(\"Sensitivity\")\n", + "plt.title(\"Receiver operating characteristic\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6ce0029f-bb51-4ee0-a579-4e57192c6266", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b4ad4c6-eeb2-4b9b-b0a7-f4be6d2b0ed1", + "metadata": {}, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ded871a5-0f67-44a1-af96-269dbca425a9", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6932b32c-2505-4edf-887d-529006b3454f", + "metadata": {}, + "outputs": [], + "source": [ + "# plot thresholds\n", + "plt.plot(thresholds,sens,label='Model sensitivity')\n", + "plt.plot(thresholds,spec,label='Model specificity')\n", + "best_sens_label = 'Best sensitivity = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT sensitivity = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT specificity = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Best threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1.0, 0.7))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8c12558-1b39-4a81-a6bb-896ed2c142d4", + "metadata": {}, + "outputs": [], + "source": [ + "# Increase in sensitivity\n", + "str(round((best_sens-Se_sicct)/Se_sicct * 100,1))+'% increase in sensitivity over SICCT alone.'" + ] + }, + { + "cell_type": "markdown", + "id": "7db487d9-8f91-484c-9e62-c2f4eba7bee6", + "metadata": {}, + "source": [ + "---\n", + "# Save model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea33123-351d-40d9-a3eb-c84b69ffc628", + "metadata": {}, + "outputs": [], + "source": [ + "# Save training / testing datasets to disk\n", + "dump((X_train, X_test, y_train, y_test), '/Data/TB_Diagnostics/final_data_split.data')\n", + "# Save model to disk\n", + "dump(model, '/Data/TB_Diagnostics/final_model.model')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly-Control.ipynb b/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly-Control.ipynb new file mode 100644 index 0000000..6da1683 --- /dev/null +++ b/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly-Control.ipynb @@ -0,0 +1,696 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c8454912-7ef8-48d5-bc40-ab4a2d5ed557", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12f0587-9d1e-4a2a-8aa7-57062f389cc4", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b85c35f-6acf-4ed9-9ad5-8101ca434a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_val_score, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d842e20-252c-4ddf-8602-dab12137bfe5", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load" + ] + }, + { + "cell_type": "markdown", + "id": "216ed604-493f-47c0-933e-69c7be17d13d", + "metadata": {}, + "source": [ + "# Load data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66c256bb-b5f4-4bb9-8a77-b70e76263187", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data_raw = pd.read_csv('/Data/TB_Diagnostics/inputVars.csv',parse_dates=['dateOfTest'],dtype=float)\n", + "data_vet = pd.read_csv('/Data/TB_Diagnostics/inputVars_VetOnly.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e88b5303-cd12-415c-b442-d5867c63d70e", + "metadata": {}, + "outputs": [], + "source": [ + "## Choose a random subset, same size as vet data\n", + "data = data_raw.sample(len(data_vet))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2f67ecc-b358-4677-8c3d-559871e9590d", + "metadata": {}, + "outputs": [], + "source": [ + "len(data)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "956b91a3-a22d-46b8-aa0c-868cc802dadc", + "metadata": {}, + "outputs": [], + "source": [ + "min(data.dateOfTest) , max(data.dateOfTest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "777af021-1f6a-4270-846c-9d7a90603a4e", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19ea9381-9e2e-42c5-8a89-c6480c5b0fb5", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19559dca-d5e5-4e73-907f-c0496474067d", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b09bee9-c8d9-41f2-9b00-dd2a06b28933", + "metadata": {}, + "outputs": [], + "source": [ + "# Detect categorical features (<= 3 categories and explicit named features)\n", + "named_cat_features = ['vetPractice','batchBovine','batchAvian']\n", + "cat_features = []\n", + "for c in data_X.columns:\n", + " catf = len(data_X[c].unique())<=3\n", + " if c in named_cat_features:\n", + " catf = True\n", + " cat_features.append(catf)\n", + "\n", + "# NB: this is fine for boolean features (inc. missing values)\n", + "# but needs a proper encoding for true categorical features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "247e72c9-11bc-47c4-af7e-4acd727e40f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Convery all to float matrix\n", + "data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "4b230243-181a-4b99-8718-904fd94f1409", + "metadata": {}, + "source": [ + "# Training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0332c21f-1377-45f6-8deb-6bb33117bbde", + "metadata": {}, + "outputs": [], + "source": [ + "# Hold a final test set (random)\n", + "X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, test_size=0.20)" + ] + }, + { + "cell_type": "markdown", + "id": "34be244f-ecdc-405f-ae24-722d23be7fdd", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291c4fc2-90ac-4e40-bc81-dbe5dde7f058", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "06670e4a-081b-4b54-b4d7-b950a3899292", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a4e9c32-3bdc-45a2-9ef7-3e35f89686e6", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = X_test[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b957aa88-d5c6-49a9-954e-f3830c88956d", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se_sicct = sensitivity(sicct,y_test)\n", + "Se_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5afe6254-69ff-447e-916d-acfe449583c3", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp_sicct = specificity(sicct,y_test)\n", + "Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "050ee423-09dc-49ab-b423-12ab4023ffec", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(sicct==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "d8e8dc4c-0587-4922-b53e-068d866aead4", + "metadata": {}, + "source": [ + "### Custom model scoring function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b680b7d-78f9-4d5c-900b-b9c4426c0332", + "metadata": {}, + "outputs": [], + "source": [ + "# Set specificity threshold to level for SICCT-only prediction\n", + "specificity_threshold = Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ffd35bd-fbee-47e5-812d-db40934a6032", + "metadata": {}, + "outputs": [], + "source": [ + "# define a custom score function:\n", + "# score keeps specificity above SICCT, maximises sensitivity\n", + "def sensspec_score(t,p): #input: true (t) and predicted (p) classes\n", + " if specificity(p,t) < specificity_threshold:\n", + " return 0\n", + " else:\n", + " return sensitivity(p,t)\n", + "\n", + "custom_score = make_scorer(sensspec_score)" + ] + }, + { + "cell_type": "markdown", + "id": "ed335a56-8bdc-4f75-be84-464f18452c2c", + "metadata": { + "tags": [] + }, + "source": [ + "# Hyperparameter tuning / cross-validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d33da3a-cdf3-4552-8518-c9228bcaf606", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Create model\n", + "gbt = GBT(categorical_features=cat_features, class_weight='balanced')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5aaff0-11de-4a39-a9dc-3f2fe925c646", + "metadata": {}, + "outputs": [], + "source": [ + "# define parameter spaces\n", + "param_grid = {'learning_rate':[1.0,0.1,0.01,0.001,0.0005,0.0001,0.00005,0.00001],\n", + " 'max_leaf_nodes':[2,5,10,20,30,50,100,500,1000]}\n", + "\n", + "#param_dists = {'learning_rate':loguniform(0.00001,1.0),\n", + "# 'max_leaf_nodes':randint(2,10000)}\n", + "\n", + "param_dists = {'learning_rate':loguniform(0.01,1.0),\n", + " 'max_leaf_nodes':randint(2,2000)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acbf1509-6d32-446b-849a-aef78fbe5b72", + "metadata": {}, + "outputs": [], + "source": [ + "# perform grid search\n", + "#model = GridSearchCV(gbt, param_grid, n_jobs=-1, cv=5)#, scoring='recall') #5-fold cross-validation #n_jobs -1 for all procs\n", + "#start_time = time.time()\n", + "#model.fit(X_train, y_train)\n", + "#print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "002e9e98-19ad-4b73-911e-aee0aa8b002f", + "metadata": {}, + "outputs": [], + "source": [ + "# perform random search\n", + "model = RandomizedSearchCV(gbt, param_dists, n_jobs=-1, cv=10, verbose=1, n_iter=100, scoring='roc_auc')\n", + "start_time = time.time()\n", + "model.fit(X_train, y_train)\n", + "print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c82261ac-4794-4946-9a3c-9250ca2abd3c", + "metadata": {}, + "outputs": [], + "source": [ + "model.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43f701c4-5c64-404d-becb-589bc570b120", + "metadata": {}, + "outputs": [], + "source": [ + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "020db8a3-b880-44b0-b216-9d729c24e81d", + "metadata": {}, + "outputs": [], + "source": [ + "tuning_results = pd.DataFrame(model.cv_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10ae9194-678c-4e45-8cc3-e5bb0890b47b", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='mean_test_score',hue='param_max_leaf_nodes')\n", + "plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe93d0e5-084e-47bb-9d34-413613301ad4", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_max_leaf_nodes',y='mean_test_score',hue='param_learning_rate')\n", + "#plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2268492-5484-4e06-a3a6-38caa03eb236", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='param_max_leaf_nodes', hue='mean_fit_time',size='mean_test_score',sizes=(1,200))\n", + "#plt.axhline(model.best_params_['max_leaf_nodes'],ls='--', label='Best fit')\n", + "#plt.axvline(model.best_params_['learning_rate'],ls='--')\n", + "plt.scatter(model.best_params_['learning_rate'],model.best_params_['max_leaf_nodes'], marker='+', c='b', s=300)\n", + "plt.xscale('log')" + ] + }, + { + "cell_type": "markdown", + "id": "07c6a8d8-02ff-442c-9392-b1bc522b923f", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53f99b2b-20ea-42b7-b052-8879a4e193b0", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "534591d2-c5b7-4a15-b2a1-1b66b3b8c088", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = model.predict(X_test)\n", + "y_score = model.decision_function(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cb3cb9d-4a82-4092-97b1-111bcc15a410", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = sensitivity(y_test_result,y_test)\n", + "Se" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2c803c1-0b0a-4d14-8661-ac1e0346d836", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = specificity(y_test_result,y_test)\n", + "Sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4787cebd-043c-4a6e-8c43-4a30d8e9dfd0", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(y_test_result==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "a178d34b-4fd9-46ef-bba1-3ea392f64a6f", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "253e1f2f-dd84-4635-8dc2-1c47c24ff60f", + "metadata": {}, + "outputs": [], + "source": [ + "fpr, tpr, _ = roc_curve(y_test,y_score)\n", + "roc_auc = auc(fpr,tpr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41a1dafb-75b4-42f6-bfac-1ce9fb521cfe", + "metadata": {}, + "outputs": [], + "source": [ + "roc_auc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ac1678-e96c-4203-b014-fe9532219038", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plt.figure()\n", + "lw = 2\n", + "plt.plot(\n", + " fpr,\n", + " tpr,\n", + " lw=lw,\n", + " label=\"ROC curve, model (area = %0.2f)\" % roc_auc,\n", + ")\n", + "plt.plot(1-Sp_sicct,Se_sicct,'+', label=\"SICCT only\", ms='15')\n", + "plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\")\n", + "plt.xlim([0.0, 1.0])\n", + "plt.ylim([0.0, 1.0])\n", + "plt.xlabel(\"(1 - Specificity)\")\n", + "plt.ylabel(\"Sensitivity\")\n", + "plt.title(\"Receiver operating characteristic\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6ce0029f-bb51-4ee0-a579-4e57192c6266", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b4ad4c6-eeb2-4b9b-b0a7-f4be6d2b0ed1", + "metadata": {}, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ded871a5-0f67-44a1-af96-269dbca425a9", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6932b32c-2505-4edf-887d-529006b3454f", + "metadata": {}, + "outputs": [], + "source": [ + "# plot thresholds\n", + "plt.plot(thresholds,sens,label='Model sensitivity')\n", + "plt.plot(thresholds,spec,label='Model specificity')\n", + "best_sens_label = 'Best sensitivity = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT sensitivity = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT specificity = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Best threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1.0, 0.7))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8c12558-1b39-4a81-a6bb-896ed2c142d4", + "metadata": {}, + "outputs": [], + "source": [ + "# Increase in sensitivity\n", + "str(round((best_sens-Se_sicct)/Se_sicct * 100,1))+'% increase in sensitivity over SICCT alone.'" + ] + }, + { + "cell_type": "markdown", + "id": "7db487d9-8f91-484c-9e62-c2f4eba7bee6", + "metadata": {}, + "source": [ + "---\n", + "# Save model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea33123-351d-40d9-a3eb-c84b69ffc628", + "metadata": {}, + "outputs": [], + "source": [ + "# Save training / testing datasets to disk\n", + "dump((X_train, X_test, y_train, y_test), '/Data/TB_Diagnostics/final_data_split_VetOnly_Control_5.data')\n", + "# Save model to disk\n", + "dump(model, '/Data/TB_Diagnostics/final_model_VetOnly_Control_5.model')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly.ipynb b/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly.ipynb new file mode 100644 index 0000000..0631935 --- /dev/null +++ b/bTB-Diagnostic_2020_v4_crossVal+tuning_VetOnly.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "c8454912-7ef8-48d5-bc40-ab4a2d5ed557", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import geopandas as gp\n", + "import geoplot\n", + "import seaborn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f12f0587-9d1e-4a2a-8aa7-57062f389cc4", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import time\n", + "import shapely\n", + "import rtree" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2b85c35f-6acf-4ed9-9ad5-8101ca434a8a", + "metadata": {}, + "outputs": [], + "source": [ + "import sklearn\n", + "from sklearn.ensemble import HistGradientBoostingClassifier as GBT\n", + "from sklearn.metrics import roc_curve, auc, roc_auc_score, make_scorer\n", + "from sklearn.inspection import permutation_importance\n", + "from sklearn.model_selection import RandomizedSearchCV, train_test_split, cross_val_score, GridSearchCV\n", + "from scipy.stats import randint,uniform,loguniform" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d842e20-252c-4ddf-8602-dab12137bfe5", + "metadata": {}, + "outputs": [], + "source": [ + "from joblib import dump, load" + ] + }, + { + "cell_type": "markdown", + "id": "216ed604-493f-47c0-933e-69c7be17d13d", + "metadata": {}, + "source": [ + "# Load data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "66c256bb-b5f4-4bb9-8a77-b70e76263187", + "metadata": {}, + "outputs": [], + "source": [ + "## Load data\n", + "data = pd.read_csv('/Data/TB_Diagnostics/inputVars_VetOnly.csv',parse_dates=['dateOfTest'],dtype=float)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "956b91a3-a22d-46b8-aa0c-868cc802dadc", + "metadata": {}, + "outputs": [], + "source": [ + "min(data.dateOfTest) , max(data.dateOfTest)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "777af021-1f6a-4270-846c-9d7a90603a4e", + "metadata": {}, + "outputs": [], + "source": [ + "# Get target feature (confirmed breakdowns) as binary class\n", + "data_y = data.confirmedBreakdown.to_numpy().astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19ea9381-9e2e-42c5-8a89-c6480c5b0fb5", + "metadata": {}, + "outputs": [], + "source": [ + "# Get observed features\n", + "data_X = data.drop(columns=['confirmedBreakdown'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19559dca-d5e5-4e73-907f-c0496474067d", + "metadata": {}, + "outputs": [], + "source": [ + "# Convert dates to float\n", + "data_X.dateOfTest = data_X.dateOfTest.astype(int).astype(float)\n", + "# Add Random features\n", + "data_X['rand'] = np.random.random_sample(len(data_X))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5b09bee9-c8d9-41f2-9b00-dd2a06b28933", + "metadata": {}, + "outputs": [], + "source": [ + "# Detect categorical features (<= 3 categories and explicit named features)\n", + "named_cat_features = ['vetPractice','batchBovine','batchAvian']\n", + "cat_features = []\n", + "for c in data_X.columns:\n", + " catf = len(data_X[c].unique())<=3\n", + " if c in named_cat_features:\n", + " catf = True\n", + " cat_features.append(catf)\n", + "\n", + "# NB: this is fine for boolean features (inc. missing values)\n", + "# but needs a proper encoding for true categorical features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "247e72c9-11bc-47c4-af7e-4acd727e40f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Convery all to float matrix\n", + "data_X = data_X.to_numpy()" + ] + }, + { + "cell_type": "markdown", + "id": "4b230243-181a-4b99-8718-904fd94f1409", + "metadata": {}, + "source": [ + "# Training and testing sets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0332c21f-1377-45f6-8deb-6bb33117bbde", + "metadata": {}, + "outputs": [], + "source": [ + "# Hold a final test set (random)\n", + "X_train, X_test, y_train, y_test = train_test_split(data_X, data_y, test_size=0.20)" + ] + }, + { + "cell_type": "markdown", + "id": "34be244f-ecdc-405f-ae24-722d23be7fdd", + "metadata": {}, + "source": [ + "# Model scoring functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "291c4fc2-90ac-4e40-bc81-dbe5dde7f058", + "metadata": {}, + "outputs": [], + "source": [ + "## Function: sensitivity(prediction,target)\n", + "# returns sensitivity of prediction vs. target\n", + "# Se = TP / (TP + FN)\n", + "def sensitivity(p,t):\n", + " TP = (p&t).sum()\n", + " FN = (~p&t).sum()\n", + " return TP / (TP + FN)\n", + "\n", + "## Function: specificity(prediction,target)\n", + "# returns specificity of prediction vs. target\n", + "# Sp = TN / (TN + FP)\n", + "def specificity(p,t):\n", + " TN = (~p&~t).sum()\n", + " FP = (p&~t).sum()\n", + " return TN / (TN + FP)" + ] + }, + { + "cell_type": "markdown", + "id": "06670e4a-081b-4b54-b4d7-b950a3899292", + "metadata": {}, + "source": [ + "### SICCT Test performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a4e9c32-3bdc-45a2-9ef7-3e35f89686e6", + "metadata": {}, + "outputs": [], + "source": [ + "sicct = X_test[:,1].astype(bool)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b957aa88-d5c6-49a9-954e-f3830c88956d", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se_sicct = sensitivity(sicct,y_test)\n", + "Se_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5afe6254-69ff-447e-916d-acfe449583c3", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp_sicct = specificity(sicct,y_test)\n", + "Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "050ee423-09dc-49ab-b423-12ab4023ffec", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(sicct==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "d8e8dc4c-0587-4922-b53e-068d866aead4", + "metadata": {}, + "source": [ + "### Custom model scoring function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4b680b7d-78f9-4d5c-900b-b9c4426c0332", + "metadata": {}, + "outputs": [], + "source": [ + "# Set specificity threshold to level for SICCT-only prediction\n", + "specificity_threshold = Sp_sicct" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ffd35bd-fbee-47e5-812d-db40934a6032", + "metadata": {}, + "outputs": [], + "source": [ + "# define a custom score function:\n", + "# score keeps specificity above SICCT, maximises sensitivity\n", + "def sensspec_score(t,p): #input: true (t) and predicted (p) classes\n", + " if specificity(p,t) < specificity_threshold:\n", + " return 0\n", + " else:\n", + " return sensitivity(p,t)\n", + "\n", + "custom_score = make_scorer(sensspec_score)" + ] + }, + { + "cell_type": "markdown", + "id": "ed335a56-8bdc-4f75-be84-464f18452c2c", + "metadata": { + "tags": [] + }, + "source": [ + "# Hyperparameter tuning / cross-validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d33da3a-cdf3-4552-8518-c9228bcaf606", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Create model\n", + "gbt = GBT(categorical_features=cat_features, class_weight='balanced')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d5aaff0-11de-4a39-a9dc-3f2fe925c646", + "metadata": {}, + "outputs": [], + "source": [ + "# define parameter spaces\n", + "param_grid = {'learning_rate':[1.0,0.1,0.01,0.001,0.0005,0.0001,0.00005,0.00001],\n", + " 'max_leaf_nodes':[2,5,10,20,30,50,100,500,1000]}\n", + "\n", + "#param_dists = {'learning_rate':loguniform(0.00001,1.0),\n", + "# 'max_leaf_nodes':randint(2,10000)}\n", + "\n", + "param_dists = {'learning_rate':loguniform(0.01,1.0),\n", + " 'max_leaf_nodes':randint(2,2000)}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "acbf1509-6d32-446b-849a-aef78fbe5b72", + "metadata": {}, + "outputs": [], + "source": [ + "# perform grid search\n", + "#model = GridSearchCV(gbt, param_grid, n_jobs=-1, cv=5)#, scoring='recall') #5-fold cross-validation #n_jobs -1 for all procs\n", + "#start_time = time.time()\n", + "#model.fit(X_train, y_train)\n", + "#print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "002e9e98-19ad-4b73-911e-aee0aa8b002f", + "metadata": {}, + "outputs": [], + "source": [ + "# perform random search\n", + "model = RandomizedSearchCV(gbt, param_dists, n_jobs=-1, cv=10, verbose=1, n_iter=100, scoring='roc_auc')\n", + "start_time = time.time()\n", + "model.fit(X_train, y_train)\n", + "print(\"%0.2f seconds\" % (time.time() - start_time))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c82261ac-4794-4946-9a3c-9250ca2abd3c", + "metadata": {}, + "outputs": [], + "source": [ + "model.best_params_" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43f701c4-5c64-404d-becb-589bc570b120", + "metadata": {}, + "outputs": [], + "source": [ + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "020db8a3-b880-44b0-b216-9d729c24e81d", + "metadata": {}, + "outputs": [], + "source": [ + "tuning_results = pd.DataFrame(model.cv_results_)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "10ae9194-678c-4e45-8cc3-e5bb0890b47b", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='mean_test_score',hue='param_max_leaf_nodes')\n", + "plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe93d0e5-084e-47bb-9d34-413613301ad4", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_max_leaf_nodes',y='mean_test_score',hue='param_learning_rate')\n", + "#plt.xscale('log')\n", + "plt.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b2268492-5484-4e06-a3a6-38caa03eb236", + "metadata": {}, + "outputs": [], + "source": [ + "seaborn.relplot(tuning_results,x='param_learning_rate',y='param_max_leaf_nodes', hue='mean_fit_time',size='mean_test_score',sizes=(1,200))\n", + "#plt.axhline(model.best_params_['max_leaf_nodes'],ls='--', label='Best fit')\n", + "#plt.axvline(model.best_params_['learning_rate'],ls='--')\n", + "plt.scatter(model.best_params_['learning_rate'],model.best_params_['max_leaf_nodes'], marker='+', c='b', s=300)\n", + "plt.xscale('log')" + ] + }, + { + "cell_type": "markdown", + "id": "07c6a8d8-02ff-442c-9392-b1bc522b923f", + "metadata": { + "tags": [] + }, + "source": [ + "# Evaluate performance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53f99b2b-20ea-42b7-b052-8879a4e193b0", + "metadata": {}, + "outputs": [], + "source": [ + "## Model score on testing set: (score is metric set at training time)\n", + "model.score(X_test,y_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "534591d2-c5b7-4a15-b2a1-1b66b3b8c088", + "metadata": {}, + "outputs": [], + "source": [ + "## Get test predictions for more detailed evaluation:\n", + "y_test_result = model.predict(X_test)\n", + "y_score = model.decision_function(X_test)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7cb3cb9d-4a82-4092-97b1-111bcc15a410", + "metadata": {}, + "outputs": [], + "source": [ + "## Sensitivity\n", + "Se = sensitivity(y_test_result,y_test)\n", + "Se" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2c803c1-0b0a-4d14-8661-ac1e0346d836", + "metadata": {}, + "outputs": [], + "source": [ + "## Specificity\n", + "Sp = specificity(y_test_result,y_test)\n", + "Sp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4787cebd-043c-4a6e-8c43-4a30d8e9dfd0", + "metadata": {}, + "outputs": [], + "source": [ + "## Accuracy\n", + "(y_test_result==y_test).sum() / len(y_test)" + ] + }, + { + "cell_type": "markdown", + "id": "a178d34b-4fd9-46ef-bba1-3ea392f64a6f", + "metadata": {}, + "source": [ + "---\n", + "# ROC Curves" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "253e1f2f-dd84-4635-8dc2-1c47c24ff60f", + "metadata": {}, + "outputs": [], + "source": [ + "fpr, tpr, _ = roc_curve(y_test,y_score)\n", + "roc_auc = auc(fpr,tpr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "41a1dafb-75b4-42f6-bfac-1ce9fb521cfe", + "metadata": {}, + "outputs": [], + "source": [ + "roc_auc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ac1678-e96c-4203-b014-fe9532219038", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "plt.figure()\n", + "lw = 2\n", + "plt.plot(\n", + " fpr,\n", + " tpr,\n", + " lw=lw,\n", + " label=\"ROC curve, model (area = %0.2f)\" % roc_auc,\n", + ")\n", + "plt.plot(1-Sp_sicct,Se_sicct,'+', label=\"SICCT only\", ms='15')\n", + "plt.plot([0, 1], [0, 1], lw=lw, linestyle=\"--\")\n", + "plt.xlim([0.0, 1.0])\n", + "plt.ylim([0.0, 1.0])\n", + "plt.xlabel(\"(1 - Specificity)\")\n", + "plt.ylabel(\"Sensitivity\")\n", + "plt.title(\"Receiver operating characteristic\")\n", + "plt.legend(loc=\"lower right\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "6ce0029f-bb51-4ee0-a579-4e57192c6266", + "metadata": {}, + "source": [ + "---\n", + "# Decision threshold choice" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b4ad4c6-eeb2-4b9b-b0a7-f4be6d2b0ed1", + "metadata": {}, + "outputs": [], + "source": [ + "# function to apply decision threshold\n", + "def predict_with_threshold(X, model, decision_threshold):\n", + " return model.predict_proba(X)[:,1]>=decision_threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ded871a5-0f67-44a1-af96-269dbca425a9", + "metadata": {}, + "outputs": [], + "source": [ + "# try different thresholds\n", + "thresholds = np.linspace(0.0,1.0,101)\n", + "sens = np.zeros(len(thresholds)) #sensitivity at threshold\n", + "spec = np.zeros(len(thresholds)) #specificity at threshold\n", + "for x in range(len(thresholds)):\n", + " y_th = predict_with_threshold(X_test,model,thresholds[x])\n", + " sens[x] = sensitivity(y_th,y_test)\n", + " spec[x] = specificity(y_th,y_test)\n", + "\n", + "best_sens = max(sens[spec >= Sp_sicct]) #sensitivity s.t. specificity >= SICCT\n", + "best_thresh = min(thresholds[spec >= Sp_sicct]) #threshold with max sensitivity s.t. specificity >= SICCT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6932b32c-2505-4edf-887d-529006b3454f", + "metadata": {}, + "outputs": [], + "source": [ + "# plot thresholds\n", + "plt.plot(thresholds,sens,label='Model sensitivity')\n", + "plt.plot(thresholds,spec,label='Model specificity')\n", + "best_sens_label = 'Best sensitivity = '+str(round(best_sens*100,1))+'%'\n", + "sicct_sens_label = 'SICCT sensitivity = '+str(round(Se_sicct*100,1))+'%'\n", + "sicct_spec_label = 'SICCT specificity = '+str(round(Sp_sicct*100,1))+'%'\n", + "best_thresh_label = 'Best threshold = '+str(round(best_thresh,3))\n", + "plt.axvline(best_thresh,c='k',ls='-.',label=best_thresh_label)\n", + "plt.axhline(best_sens,c='k',ls='--',label=best_sens_label)\n", + "plt.axhline(Se_sicct,c='tab:blue',ls=':',label=sicct_sens_label)\n", + "plt.axhline(Sp_sicct,c='tab:orange',ls=':',label=sicct_spec_label)\n", + "plt.xlabel('Decision Threshold')\n", + "plt.legend(bbox_to_anchor=(1.0, 0.7))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e8c12558-1b39-4a81-a6bb-896ed2c142d4", + "metadata": {}, + "outputs": [], + "source": [ + "# Increase in sensitivity\n", + "str(round((best_sens-Se_sicct)/Se_sicct * 100,1))+'% increase in sensitivity over SICCT alone.'" + ] + }, + { + "cell_type": "markdown", + "id": "7db487d9-8f91-484c-9e62-c2f4eba7bee6", + "metadata": {}, + "source": [ + "---\n", + "# Save model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9ea33123-351d-40d9-a3eb-c84b69ffc628", + "metadata": {}, + "outputs": [], + "source": [ + "# Save training / testing datasets to disk\n", + "dump((X_train, X_test, y_train, y_test), '/Data/TB_Diagnostics/final_data_split_VetOnly.data')\n", + "# Save model to disk\n", + "dump(model, '/Data/TB_Diagnostics/final_model_VetOnly.model')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}