diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 0000000..c4dfb4f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,27 @@ +--- +name: 🐛 Bug Report +about: Create a report to help us improve netdriver +title: '[Bug]: ' +labels: 'bug' +assignees: '' +--- + +**Describe the bug** +A clear and concise description of what the bug is. + +**To Reproduce** +Steps or code snippets to reproduce the behavior: +1. Connection protocol used (e.g., SSH) '...' +2. Device info (e.g., Cisco ASA 9.6.0) +3. Code snippet executed '...' +4. The resulting error '...' + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Logs/Traceback** +If applicable, add full logs or tracebacks to help explain your problem. **(Note: Please mask any passwords, keys, public IP addresses, or sensitive configurations)** + +```text +# Paste code or logs here +``` \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 0000000..572ded6 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,19 @@ +--- +name: 💡 Feature Request +about: Suggest an idea or new device support for netdriver +title: '[Feature]: ' +labels: 'enhancement' +assignees: '' +--- + +**Is your feature request related to a problem? Please describe.** +A clear and concise description of what the problem is. Ex. "I'm always frustrated when I can't directly parse the configuration output of [Specific Vendor] devices..." + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. Providing expected API usage or pseudocode is highly appreciated. + +```python +# Expected API design example +device = netdriver.connect(...) +result = device.do_something_new() +``` diff --git a/.vscode/settings.json b/.vscode/settings.json index 59073f9..76b3e2c 100755 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,31 +10,13 @@ "**/.pytest_cache": true, ".venv": true, }, + "python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python", "python.analysis.autoImportCompletions": true, - "python.analysis.extraPaths": [ - "${workspaceFolder}/bases", - "${workspaceFolder}/components", - "${workspaceFolder}/development" - ], "python.autoComplete.extraPaths": [ - "${workspaceFolder}/bases", - "${workspaceFolder}/components", - "${workspaceFolder}/development" - ], - "pylint.args": [ - "--disable=C0114", - "--disable=C0115", - "--disable=C0116", - "--disable=C0209", - "--disable=C0301", - "--disable=C0415", - "--disable=W0221", - "--disable=W0613", - "--disable=W0718", - "--disable=W1203", - "--disable=R0903", - "--disable=E1101", + "${workspaceFolder}/packages", ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, "python.testing.pytestArgs": [ "-s", "-v", @@ -43,7 +25,4 @@ "packages/agent/tests", "packages/core/tests", ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true, - "python.defaultInterpreterPath": ".venv/bin/python3" } \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000..dea5e89 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,61 @@ +# Security Policy + +## Supported Versions + +| Version | Supported | +| ------- | --------- | +| 0.4.x | Yes | +| < 0.4 | No | + +## Reporting a Vulnerability + +If you discover a security vulnerability in NetDriver, **please do not open a public GitHub issue**. + +Instead, report it via one of the following channels: + +- **Email**: Send details to the maintainers at the addresses listed in `pyproject.toml` +- **GitHub Private Advisory**: Use [GitHub Security Advisories](https://github.com/features/security-advisories) on this repository + +Please include the following in your report: + +- A description of the vulnerability and its potential impact +- Steps to reproduce the issue +- Affected versions +- Any suggested mitigations or patches (if available) + +We aim to acknowledge receipt within **3 business days** and provide an initial assessment within **7 business days**. + +## Security Considerations + +NetDriver interacts with network devices over SSH and exposes a REST API. When deploying this project, consider the following: + +### Credentials and Secrets + +- Device credentials (username/password) are passed via API requests. Use TLS/HTTPS in all deployments to prevent credential exposure in transit. +- Do not log credentials. The agent configuration should be reviewed to ensure no sensitive fields appear in log output. +- Rotate device credentials regularly and restrict API access to trusted clients. + +### API Authentication + +- The agent HTTP API does **not** include built-in authentication. Deploy it behind an API gateway, reverse proxy, or firewall that enforces authentication and authorization appropriate for your environment. +- Restrict network access to the agent port (default: 8000) to trusted hosts only. + +### SSH Host Verification + +- By default, AsyncSSH may be configured to skip host key verification for convenience. In production, enable strict host key checking to prevent man-in-the-middle attacks. + +### Plugin Code Execution + +- Plugins are loaded dynamically from the `components/netdriver/plugins/` directory at startup. Ensure that only trusted code is present in the plugin directories and that the deployment environment has appropriate file system permissions. + +### Simulated Devices (simunet) + +- The `simunet` SSH server is intended for **testing purposes only**. Do not expose it on public networks or use it in production environments. + +## Disclosure Policy + +We follow a coordinated disclosure process. Once a fix is available, we will: + +1. Release a patched version +2. Publish a security advisory describing the vulnerability, its impact, and the fix +3. Credit the reporter (unless they prefer to remain anonymous) diff --git a/packages/agent/src/netdriver_agent/api/rest/__init__.py b/packages/agent/src/netdriver_agent/api/rest/__init__.py index 12b93f2..4cf442e 100755 --- a/packages/agent/src/netdriver_agent/api/rest/__init__.py +++ b/packages/agent/src/netdriver_agent/api/rest/__init__.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi import APIRouter +from netdriver_agent.api.rest import v1 from netdriver_agent.api.rest.v1 import router as _router -router = APIRouter(prefix='/api') +router = APIRouter(prefix="/api") router.include_router(_router) diff --git a/packages/agent/src/netdriver_agent/api/rest/v1/__init__.py b/packages/agent/src/netdriver_agent/api/rest/v1/__init__.py index 8e0107e..b9f97a5 100755 --- a/packages/agent/src/netdriver_agent/api/rest/v1/__init__.py +++ b/packages/agent/src/netdriver_agent/api/rest/v1/__init__.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- from fastapi.routing import APIRouter +from netdriver_agent.api.rest.v1 import api # noqa: F401 from netdriver_agent.api.rest.v1.api import router as cmd_router -router = APIRouter(prefix='/v1', tags=['v1']) +router = APIRouter(prefix="/v1", tags=["v1"]) router.include_router(cmd_router) diff --git a/packages/agent/src/netdriver_agent/client/pool.py b/packages/agent/src/netdriver_agent/client/pool.py index 364e380..8d074a0 100755 --- a/packages/agent/src/netdriver_agent/client/pool.py +++ b/packages/agent/src/netdriver_agent/client/pool.py @@ -15,39 +15,53 @@ class SessionPool: - """ Session Manager Singleton Class """ + """Session Manager Singleton Class + + In a single-threaded asyncio architecture, dict read/write operations are + inherently atomic (no await in between), so no global lock is needed. + The only scenario requiring a lock is "check-then-create": preventing + concurrent requests for the same session key from creating duplicate SSH connections. + + This implementation uses a per-key lock strategy: + - _key_locks: one independent lock per session key + - Connection creation for device A does not block requests for device B + - Serialization only occurs on the same key, avoiding duplicate creation and connection leaks + """ + _instance = None - _pool:Dict[str, Session] = {} - _pool_lock: asyncio.Lock + _pool: Dict[str, Session] = {} + _key_locks: Dict[str, asyncio.Lock] _check_interval: float - _config: Configuration + _config: Optional[Configuration] - def __new__(cls, config: Configuration = None) -> "SessionPool": + def __new__(cls, config: Configuration | None = None) -> "SessionPool": if not cls._instance: log.info("Creating SessionManager instance") cls._instance = super(SessionPool, cls).__new__(cls) cls._instance._pool = {} - cls._instance._pool_lock = asyncio.Lock() + cls._instance._key_locks = {} cls._instance._config = config cls._instance._check_interval = config.session.check_interval() or 30 - asyncio.create_task(cls._instance.monitor_sessions(), name="SessionPoolMonitor") + asyncio.create_task( + cls._instance.monitor_sessions(), name="SessionPoolMonitor" + ) return cls._instance - async def _get_session_by_key(self, session_key: str) -> Optional[Session]: - """ - Get session by key, if session is not exist or not same, return None. - This method is used to check whether session is alive or not. - """ + def _get_key_lock(self, session_key: str) -> asyncio.Lock: + if session_key not in self._key_locks: + self._key_locks[session_key] = asyncio.Lock() + return self._key_locks[session_key] - session = None - log.info(f"Try to acquire _pool_lock for getting session: {session_key}") - async with self._pool_lock: - log.info(f"acquired lock for session {session_key}") - session = self._pool.get(session_key) - log.info(f"_pool_lock released") + def _cleanup_key_lock(self, session_key: str) -> None: + lock = self._key_locks.get(session_key) + if lock and not lock.locked() and session_key not in self._pool: + del self._key_locks[session_key] + + async def _get_session_by_key(self, session_key: str) -> Optional[Session]: + session = self._pool.get(session_key) if session: - close_reason = None + close_reason = None if not await session.is_alive(): close_reason = "not alive" elif session.check_expiration_time(): @@ -58,41 +72,29 @@ async def _get_session_by_key(self, session_key: str) -> Optional[Session]: await self._handle_closed_session(session) log.info(f"Session {session_key} is {close_reason}, removed from pool.") return None - log.info(f"Got alived session by key: {session_key}") - return session + log.debug(f"Got alive session by key: {session_key}") + return session else: - log.info(f"Got no session by key: {session_key}") + log.debug(f"No session found by key: {session_key}") return None - async def _remove_session_by_key(self, session_key: str) -> None: + async def get_session( + self, + ip: Optional[IPvAnyAddress] = None, + username: Optional[str] = "", + password: Optional[str] = "", + vendor: Optional[str] = "", + model: Optional[str] = "", + port: int = 22, + protocol: str = "ssh", + enable_password: Optional[str] = "", + version: str = "base", + encode: str = "utf-8", + **kwargs: dict, + ) -> Optional[Session]: """ - Remove session by key, if session is not exist, do nothing. - This method is used to remove session from pool. - """ - log.info(f"Try to acquire _pool_lock and remove session: {session_key}") - async with self._pool_lock: - if session_key in self._pool: - del self._pool[session_key] - log.info(f"Session {session_key} removed from pool.") - else: - log.warning(f"Session {session_key} not found in pool.") - log.info(f"_pool_lock released") - - async def get_session(self, - ip: Optional[IPvAnyAddress] = None, - username: Optional[str] = "", - password: Optional[str] = "", - vendor: Optional[str] = "", - model: Optional[str] = "", - port: int = 22, - protocol: str = "ssh", - enable_password: Optional[str] = "", - version: str = "base", - encode: str = "utf-8", - **kwargs: dict - ) -> Optional[Session]: - """ - To get a session by key, if session is not exist or not same, create a new session. + Get or create a session. Uses per-key lock to prevent duplicate creation + for the same key, while different keys run fully in parallel without blocking. :param protocol: protocol, default is ssh :param ip: ipv4 or ipv6 address @@ -125,76 +127,112 @@ async def get_session(self, raise ValueError("type is required.") session_key = gen_session_key(protocol, username, ip, port) + _session = await self._get_session_by_key(session_key) - # Check whether session is same if _session and not _session.is_same( - vendor, model, version, password, enable_password, encode): + vendor, model, version, password, enable_password, encode + ): log.warning(f"Session {session_key} is not same, try to remove it.") if _session.is_idle is True: - log.warning(f"Session {session_key} is idle, close and regenerate it.") - await self._handle_closed_session(_session) - _session = None + log.warning(f"Session {session_key} is idle, close and regenerate it.") + await self._handle_closed_session(_session) + _session = None else: - log.warning(f"Session {session_key} is not idle, raise SessionInitFailed.") + log.warning( + f"Session {session_key} is not idle, raise SessionInitFailed." + ) raise SessionInitFailed( - f"A session with same key [{session_key}] is still running, " + \ - "to make sure the execuation safety, please check your request " + \ - "parameters and try again!") + f"A session with same key [{session_key}] is still running, " + + "to make sure the execuation safety, please check your request " + + "parameters and try again!" + ) - # Return session if exist if _session: log.info(f"Got session by key: {session_key}") return _session - session_clz: Session = PluginEngine().get_plugin(vendor, model, version) - if not session_clz: - raise PluginNotFound( - f"Plugin not found for {vendor}/{model}/{protocol}/{version}") - - log.info(f"Try to acquire _pool_lock and add session: {session_key}") - async with self._pool_lock: + key_lock = self._get_key_lock(session_key) + async with key_lock: + # Double-check: re-verify after acquiring the lock, as another coroutine may have already created it + _session = self._pool.get(session_key) + if _session: + log.info( + f"Session {session_key} created by another coroutine, reuse it." + ) + return _session + + session_clz: Session = PluginEngine().get_plugin(vendor, model, version) + if not session_clz: + raise PluginNotFound( + f"Plugin not found for {vendor}/{model}/{protocol}/{version}" + ) + + log.info(f"Creating new session: {session_key}") _session = await session_clz.create( - ip=ip, port=port, protocol=protocol, - username=username, password=password, enable_password=enable_password, - vendor=vendor, model=model, version=version, encode=encode, config=self._config, - **kwargs) + ip=ip, + port=port, + protocol=protocol, + username=username, + password=password, + enable_password=enable_password, + vendor=vendor, + model=model, + version=version, + encode=encode, + config=self._config, + **kwargs, + ) self._pool[_session.session_key] = _session - log.info(f"_pool_lock released, session {_session.session_key} added to pool.") - # Wait for session initialization to complete - log.info(f"Waiting for session {_session.session_key} initialization to complete.") + log.info(f"Session {_session.session_key} added to pool.") + self._cleanup_key_lock(session_key) + + log.info( + f"Waiting for session {_session.session_key} initialization to complete." + ) await _session._init_task_done if _session._init_task.exception(): - log.error(f"Session initialization failed: {_session._init_task.exception()}") - await self._remove_session_by_key(_session.session_key) + log.error( + f"Session initialization failed: {_session._init_task.exception()}" + ) + self._pool.pop(_session.session_key, None) + self._cleanup_key_lock(session_key) raise _session._init_task.exception() log.info(f"Created new session for: {_session.session_key}") return _session async def _handle_closed_session(self, session: Session): - log.debug(f"Try to acquire _pool_lock and remove {session.session_key} from pool and close it.") + log.debug(f"Removing {session.session_key} from pool and closing it.") try: - async with self._pool_lock: - self._pool.pop(session.session_key) + self._pool.pop(session.session_key, None) + self._cleanup_key_lock(session.session_key) await asyncio.wait_for(session.close(), timeout=1) except Exception as e: log.error(f"Error closing session {session.session_key}: {e}") async def close_all(self): await asyncio.gather( - *[self._handle_closed_session(session) for session in self._pool.values()]) + *[ + self._handle_closed_session(session) + for session in list(self._pool.values()) + ] + ) + self._key_locks.clear() async def _display_sessions_info(self): """Display session information in a table.""" table = [] for session in self._pool.values(): table.append(await session.get_display_info()) - log.info("\n########## Session Pool Status ########## \n" + tabulate( - table, - headers=Session.get_info_headers(), - )) + log.info( + "\n########## Session Pool Status ########## \n" + + tabulate( + table, + headers=Session.get_info_headers(), + ) + ) async def _remove_closed_sessions(self): # list() is used to avoid RuntimeError: dictionary changed size during iteration diff --git a/packages/agent/src/netdriver_agent/main.py b/packages/agent/src/netdriver_agent/main.py index 654ac39..de3ab17 100755 --- a/packages/agent/src/netdriver_agent/main.py +++ b/packages/agent/src/netdriver_agent/main.py @@ -4,6 +4,7 @@ This is the main module for the agent. It is responsible for starting the FastAPI server. """ + import os import sys import argparse @@ -21,17 +22,21 @@ from netdriver_core.log import logman -logman.configure_logman(level=container.config.logging.level(), - intercept_loggers=container.config.logging.intercept_loggers(), - log_file=container.config.logging.log_file()) +logman.configure_logman( + level=container.config.logging.level(), + intercept_loggers=container.config.logging.intercept_loggers(), + log_file=container.config.logging.log_file(), +) log = logman.logger -container.wire(modules=[ - rest.v1.api, -]) +container.wire( + modules=[ + rest.v1.api, + ] +) async def on_startup() -> None: - """ put all post up logic here """ + """put all post up logic here""" log.info("Post-startup of NetDriver Agent") # load plugins PluginEngine() @@ -40,7 +45,7 @@ async def on_startup() -> None: async def on_shutdown() -> None: - """ put all clean logic here """ + """put all clean logic here""" log.info("Pre-shutdown of NetDriver Agent") await SessionPool().close_all() @@ -53,18 +58,20 @@ async def lifespan(api: FastAPI): app: FastAPI = FastAPI( - title='NetworkDriver Agent', + title="NetworkDriver Agent", lifespan=lifespan, container=container, - exception_handlers=global_exception_handlers + exception_handlers=global_exception_handlers, +) +app.add_middleware( + CorrelationIdMiddleware, header_name="X-Correlation-Id", validator=None ) -app.add_middleware(CorrelationIdMiddleware, header_name="X-Correlation-Id", validator=None) app.include_router(rest.router) @app.get("/") async def root() -> dict: - """ root endpoint """ + """root endpoint""" return { "message": "Welcome to the NetDriver Agent", } @@ -72,45 +79,33 @@ async def root() -> dict: @app.get("/health") async def health() -> dict: - """ health check endpoint for docker """ - return { - "status": "healthy", - "service": "netdriver-agent" - } + """health check endpoint for docker""" + return {"status": "healthy", "service": "netdriver-agent"} def start(): """Start the agent server with optional configuration file parameter.""" parser = argparse.ArgumentParser(description="NetDriver Agent Server") parser.add_argument( - "-c", "--config", + "-c", + "--config", type=str, default=None, - help="Path to configuration file (default: config/agent/agent.yml or NETDRIVER_AGENT_CONFIG env var)" + help="Path to configuration file (default: config/agent/agent.yml or NETDRIVER_AGENT_CONFIG env var)", ) parser.add_argument( - "--host", - type=str, - default="0.0.0.0", - help="Host to bind (default: 0.0.0.0)" + "--host", type=str, default="0.0.0.0", help="Host to bind (default: 0.0.0.0)" ) parser.add_argument( - "-p", "--port", - type=int, - default=8000, - help="Port to bind (default: 8000)" + "-p", "--port", type=int, default=8000, help="Port to bind (default: 8000)" ) parser.add_argument( "--reload", action="store_true", default=True, - help="Enable auto-reload (default: True)" - ) - parser.add_argument( - "--no-reload", - action="store_true", - help="Disable auto-reload" + help="Enable auto-reload (default: True)", ) + parser.add_argument("--no-reload", action="store_true", help="Disable auto-reload") args = parser.parse_args() @@ -123,15 +118,12 @@ def start(): logman.configure_logman( level=container.config.logging.level(), intercept_loggers=container.config.logging.intercept_loggers(), - log_file=container.config.logging.log_file() + log_file=container.config.logging.log_file(), ) # Handle reload flag reload = args.reload and not args.no_reload uvicorn.run( - "netdriver_agent.main:app", - host=args.host, - port=args.port, - reload=reload + "netdriver_agent.main:app", host=args.host, port=args.port, reload=reload ) diff --git a/pyproject.toml b/pyproject.toml index 5a041aa..09d4d98 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,6 @@ dev = [ "twine>=6.2.0", "httpx>=0.27.2,<0.28", ] - [project.urls] Homepage = "https://github.com/OpenSecFlow/netdriver" Repository = "https://github.com/OpenSecFlow/netdriver" @@ -69,6 +68,32 @@ members = [ "packages/*" ] requires = ["uv_build>=0.9.17,<0.10.0"] build-backend = "uv_build" +[tool.basedpyright] +venvPath = "." +venv = ".venv" +pythonVersion = "3.12" +pythonPlatform = "Linux" +extraPaths = [ + "packages/agent/src", + "packages/core/src", + "packages/simunet/src", + "packages/textfsm/src", +] +reportImplicitRelativeImport = false +reportMissingTypeArgument = false + +[tool.pyright] +venvPath = "." +penv = ".venv" +pythonVersion = "3.12" +pythonPlatform = "Linux" +extraPaths = [ + "packages/agent/src", + "packages/core/src", + "packages/simunet/src", + "packages/textfsm/src", +] + [tool.pytest.ini_options] markers = ['unit', 'integration'] asyncio_mode = "auto"