diff --git a/networking_generic_switch/devices/netmiko_devices/__init__.py b/networking_generic_switch/devices/netmiko_devices/__init__.py index c697ce5a..1980b612 100644 --- a/networking_generic_switch/devices/netmiko_devices/__init__.py +++ b/networking_generic_switch/devices/netmiko_devices/__init__.py @@ -123,6 +123,31 @@ def __init__(self, device_cfg, *args, **kwargs): self.config['session_log_record_writes'] = True self.config['session_log_file_mode'] = 'append' + _NUMERIC_CAST = { + "port": int, + "global_delay_factor": float, + "conn_timeout": float, + "auth_timeout": float, + "banner_timeout": float, + "blocking_timeout": float, + "timeout": float, + "session_timeout": float, + "read_timeout_override": float, + "keepalive": int, + } + + for key, expected_type in _NUMERIC_CAST.items(): + value = self.config.get(key) + if isinstance(value, str): + try: + self.config[key] = expected_type(value) + except ValueError: + LOG.error( + "Invalid value %s for %s; expected %s", + value, key, expected_type.__name__, + ) + raise exc.GenericSwitchNetmikoConfigError() + self.lock_kwargs = { 'locks_pool_size': int(self.ngs_config['ngs_max_connections']), 'locks_prefix': self.config.get( diff --git a/networking_generic_switch/tests/unit/test_devices.py b/networking_generic_switch/tests/unit/test_devices.py index 482dbd1f..57a912dd 100644 --- a/networking_generic_switch/tests/unit/test_devices.py +++ b/networking_generic_switch/tests/unit/test_devices.py @@ -219,6 +219,27 @@ def test__get_ssh_disabled_algorithms(self): } self.assertEqual(expected, algos) + def test_float_params_cast(self): + config = { + "device_type": 'netmiko_ovs_linux', + "ip": "10.1.2.3", + "username": "u", + "password": "p", + "conn_timeout": "20.0", + "global_delay_factor": "2.5", + "port": "2222", + } + device = devices.device_manager(config) + + self.assertIsInstance(device.config["conn_timeout"], float) + self.assertEqual(device.config["conn_timeout"], 20.0) + + self.assertIsInstance(device.config["global_delay_factor"], float) + self.assertEqual(device.config["global_delay_factor"], 2.5) + + self.assertIsInstance(device.config["port"], int) + self.assertEqual(device.config["port"], 2222) + def test_driver_load_config_override(self): device_cfg = {"device_type": 'netmiko_ovs_linux', "vlan_translation_supported": True}