Compare commits

..

1 Commits

Author SHA1 Message Date
f65547e6e8 feat: add addition function and its test-function 2026-02-20 16:36:03 +01:00
734 changed files with 158974 additions and 96112 deletions

BIN
.DS_Store vendored Normal file

Binary file not shown.

18
.pre-commit-config.yaml Normal file
View File

@ -0,0 +1,18 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.0
hooks:
- id: ruff # linting
- repo: https://github.com/psf/black
rev: stable
hooks:
- id: black # formatting

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from black import patched_main from black import patched_main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(patched_main()) sys.exit(patched_main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from blackd import patched_main from blackd import patched_main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(patched_main()) sys.exit(patched_main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from identify.cli import main from identify.cli import main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(main()) sys.exit(main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from nodeenv import main from nodeenv import main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(main()) sys.exit(main())

View File

@ -3,6 +3,7 @@
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) if __name__ == "__main__":
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(main()) sys.exit(main())

View File

@ -3,6 +3,7 @@
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) if __name__ == "__main__":
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(main()) sys.exit(main())

View File

@ -3,6 +3,7 @@
import re import re
import sys import sys
from pip._internal.cli.main import main from pip._internal.cli.main import main
if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) if __name__ == "__main__":
sys.argv[0] = re.sub(r"(-script\.pyw|\.exe)?$", "", sys.argv[0])
sys.exit(main()) sys.exit(main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from pre_commit.main import main from pre_commit.main import main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(main()) sys.exit(main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from pytest import console_main from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(console_main()) sys.exit(console_main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from pygments.cmdline import main from pygments.cmdline import main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(main()) sys.exit(main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from pytest import console_main from pytest import console_main
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(console_main()) sys.exit(console_main())

View File

@ -1,6 +1,7 @@
#!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python #!/Users/emoniefaychetwin/PPE2/PPE2/.venv/bin/python
import sys import sys
from virtualenv.__main__ import run_with_catch from virtualenv.__main__ import run_with_catch
if __name__ == '__main__':
sys.argv[0] = sys.argv[0].removesuffix('.exe') if __name__ == "__main__":
sys.argv[0] = sys.argv[0].removesuffix(".exe")
sys.exit(run_with_catch()) sys.exit(run_with_catch())

View File

@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
__all__ = ["__version__", "version_tuple"] __all__ = ["__version__", "version_tuple"]
try: try:

View File

@ -12,7 +12,6 @@ from .code import TracebackEntry
from .source import getrawcode from .source import getrawcode
from .source import Source from .source import Source
__all__ = [ __all__ = [
"Code", "Code",
"ExceptionInfo", "ExceptionInfo",

View File

@ -48,7 +48,6 @@ from _pytest.deprecated import check_ispytest
from _pytest.pathlib import absolutepath from _pytest.pathlib import absolutepath
from _pytest.pathlib import bestrelpath from _pytest.pathlib import bestrelpath
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -230,6 +229,7 @@ class TracebackEntry:
@property @property
def end_colno(self) -> int | None: def end_colno(self) -> int | None:
return None return None
else: else:
@property @property
@ -595,33 +595,33 @@ class ExceptionInfo(Generic[E]):
@property @property
def type(self) -> type[E]: def type(self) -> type[E]:
"""The exception class.""" """The exception class."""
assert self._excinfo is not None, ( assert (
".type can only be used after the context manager exits" self._excinfo is not None
) ), ".type can only be used after the context manager exits"
return self._excinfo[0] return self._excinfo[0]
@property @property
def value(self) -> E: def value(self) -> E:
"""The exception value.""" """The exception value."""
assert self._excinfo is not None, ( assert (
".value can only be used after the context manager exits" self._excinfo is not None
) ), ".value can only be used after the context manager exits"
return self._excinfo[1] return self._excinfo[1]
@property @property
def tb(self) -> TracebackType: def tb(self) -> TracebackType:
"""The exception raw traceback.""" """The exception raw traceback."""
assert self._excinfo is not None, ( assert (
".tb can only be used after the context manager exits" self._excinfo is not None
) ), ".tb can only be used after the context manager exits"
return self._excinfo[2] return self._excinfo[2]
@property @property
def typename(self) -> str: def typename(self) -> str:
"""The type name of the exception.""" """The type name of the exception."""
assert self._excinfo is not None, ( assert (
".typename can only be used after the context manager exits" self._excinfo is not None
) ), ".typename can only be used after the context manager exits"
return self.type.__name__ return self.type.__name__
@property @property

View File

@ -3,7 +3,6 @@ from __future__ import annotations
from .terminalwriter import get_terminal_width from .terminalwriter import get_terminal_width
from .terminalwriter import TerminalWriter from .terminalwriter import TerminalWriter
__all__ = [ __all__ = [
"TerminalWriter", "TerminalWriter",
"get_terminal_width", "get_terminal_width",

View File

@ -113,7 +113,7 @@ class PrettyPrinter:
elif ( elif (
_dataclasses.is_dataclass(object) _dataclasses.is_dataclass(object)
and not isinstance(object, type) and not isinstance(object, type)
and object.__dataclass_params__.repr # type:ignore[attr-defined] and object.__dataclass_params__.repr # type: ignore[attr-defined]
and and
# Check dataclass has generated repr method. # Check dataclass has generated repr method.
hasattr(object.__repr__, "__wrapped__") hasattr(object.__repr__, "__wrapped__")

View File

@ -19,7 +19,6 @@ from pygments.lexers.python import PythonLexer
from ..compat import assert_never from ..compat import assert_never
from .wcwidth import wcswidth from .wcwidth import wcswidth
# This code was initially copied from py 1.8.1, file _io/terminalwriter.py. # This code was initially copied from py 1.8.1, file _io/terminalwriter.py.

View File

@ -9,7 +9,6 @@ import sys
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
@ -94,7 +93,7 @@ class ErrorMaker:
try: try:
# error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index] # error: Invalid index type "Optional[int]" for "dict[int, int]"; expected type "int" [index]
# OK to ignore because we catch the KeyError below. # OK to ignore because we catch the KeyError below.
cls = self._geterrnoclass(_winerrnomap[value.errno]) # type:ignore[index] cls = self._geterrnoclass(_winerrnomap[value.errno]) # type: ignore[index]
except KeyError: except KeyError:
raise value raise value
else: else:

View File

@ -33,7 +33,6 @@ import warnings
from . import error from . import error
# Moved from local.py. # Moved from local.py.
iswin32 = sys.platform == "win32" or (getattr(os, "_name", False) == "nt") iswin32 = sys.platform == "win32" or (getattr(os, "_name", False) == "nt")
@ -222,7 +221,7 @@ class Stat:
raise NotImplementedError("XXX win32") raise NotImplementedError("XXX win32")
import pwd import pwd
entry = error.checked_call(pwd.getpwuid, self.uid) # type:ignore[attr-defined,unused-ignore] entry = error.checked_call(pwd.getpwuid, self.uid) # type: ignore[attr-defined,unused-ignore]
return entry[0] return entry[0]
@property @property
@ -232,7 +231,7 @@ class Stat:
raise NotImplementedError("XXX win32") raise NotImplementedError("XXX win32")
import grp import grp
entry = error.checked_call(grp.getgrgid, self.gid) # type:ignore[attr-defined,unused-ignore] entry = error.checked_call(grp.getgrgid, self.gid) # type: ignore[attr-defined,unused-ignore]
return entry[0] return entry[0]
def isdir(self): def isdir(self):
@ -250,7 +249,7 @@ def getuserid(user):
import pwd import pwd
if not isinstance(user, int): if not isinstance(user, int):
user = pwd.getpwnam(user)[2] # type:ignore[attr-defined,unused-ignore] user = pwd.getpwnam(user)[2] # type: ignore[attr-defined,unused-ignore]
return user return user
@ -258,7 +257,7 @@ def getgroupid(group):
import grp import grp
if not isinstance(group, int): if not isinstance(group, int):
group = grp.getgrnam(group)[2] # type:ignore[attr-defined,unused-ignore] group = grp.getgrnam(group)[2] # type: ignore[attr-defined,unused-ignore]
return group return group

View File

@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID commit_id: COMMIT_ID
__commit_id__: COMMIT_ID __commit_id__: COMMIT_ID
__version__ = version = '9.0.2' __version__ = version = "9.0.2"
__version_tuple__ = version_tuple = (9, 0, 2) __version_tuple__ = version_tuple = (9, 0, 2)
__commit_id__ = commit_id = None __commit_id__ = commit_id = None

View File

@ -18,7 +18,6 @@ from _pytest.config import hookimpl
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.nodes import Item from _pytest.nodes import Item
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.main import Session from _pytest.main import Session

View File

@ -26,7 +26,6 @@ import types
from typing import IO from typing import IO
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if sys.version_info >= (3, 12): if sys.version_info >= (3, 12):
from importlib.resources.abc import TraversableResources from importlib.resources.abc import TraversableResources
else: else:
@ -704,9 +703,9 @@ class AssertionRewriter(ast.NodeVisitor):
pos = 0 pos = 0
for item in mod.body: for item in mod.body:
match item: match item:
case ast.Expr(value=ast.Constant(value=str() as doc)) if ( case ast.Expr(
expect_docstring value=ast.Constant(value=str() as doc)
): ) if expect_docstring:
if self.is_rewrite_disabled(doc): if self.is_rewrite_disabled(doc):
return return
expect_docstring = False expect_docstring = False
@ -1018,9 +1017,9 @@ class AssertionRewriter(ast.NodeVisitor):
e.id for e in boolop.values[:i] if hasattr(e, "id") e.id for e in boolop.values[:i] if hasattr(e, "id")
]: ]:
pytest_temp = self.variable() pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] self.variables_overwrite[self.scope][target_id] = v.left # type: ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists. # mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined] v.left.target.id = pytest_temp # type: ignore[attr-defined]
self.push_format_context() self.push_format_context()
res, expl = self.visit(v) res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
@ -1065,7 +1064,7 @@ class AssertionRewriter(ast.NodeVisitor):
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {} self.scope, {}
): ):
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment] arg = self.variables_overwrite[self.scope][arg.id] # type: ignore[assignment]
res, expl = self.visit(arg) res, expl = self.visit(arg)
arg_expls.append(expl) arg_expls.append(expl)
new_args.append(res) new_args.append(res)
@ -1074,7 +1073,7 @@ class AssertionRewriter(ast.NodeVisitor):
case ast.Name(id=id) if id in self.variables_overwrite.get( case ast.Name(id=id) if id in self.variables_overwrite.get(
self.scope, {} self.scope, {}
): ):
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] keyword.value = self.variables_overwrite[self.scope][id] # type: ignore[assignment]
res, expl = self.visit(keyword.value) res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res)) new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg: if keyword.arg:
@ -1132,7 +1131,9 @@ class AssertionRewriter(ast.NodeVisitor):
case ( case (
ast.NamedExpr(target=ast.Name(id=target_id)), ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id), ast.Name(id=name_id),
) if target_id == name_id: ) if (
target_id == name_id
):
next_operand.target.id = self.variable() next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]

View File

@ -10,7 +10,6 @@ from _pytest.compat import running_on_ci
from _pytest.config import Config from _pytest.config import Config
from _pytest.nodes import Item from _pytest.nodes import Item
DEFAULT_MAX_LINES = 8 DEFAULT_MAX_LINES = 8
DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80 DEFAULT_MAX_CHARS = DEFAULT_MAX_LINES * 80
USAGE_MSG = "use '-vv' to show" USAGE_MSG = "use '-vv' to show"

View File

@ -23,7 +23,6 @@ from _pytest._io.saferepr import saferepr_unlimited
from _pytest.compat import running_on_ci from _pytest.compat import running_on_ci
from _pytest.config import Config from _pytest.config import Config
# The _reprcompare attribute on the util module is used by the new assertion # The _reprcompare attribute on the util module is used by the new assertion
# interpretation code and assertion rewriter to detect this plugin was # interpretation code and assertion rewriter to detect this plugin was
# loaded and in turn call the hooks defined here as part of the # loaded and in turn call the hooks defined here as part of the

View File

@ -32,7 +32,6 @@ from _pytest.nodes import Directory
from _pytest.nodes import File from _pytest.nodes import File
from _pytest.reports import TestReport from _pytest.reports import TestReport
README_CONTENT = """\ README_CONTENT = """\
# pytest cache directory # # pytest cache directory #

View File

@ -27,7 +27,6 @@ from typing import NamedTuple
from typing import TextIO from typing import TextIO
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -42,7 +41,6 @@ from _pytest.nodes import File
from _pytest.nodes import Item from _pytest.nodes import Item
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
_CaptureMethod = Literal["fd", "sys", "no", "tee-sys"] _CaptureMethod = Literal["fd", "sys", "no", "tee-sys"]
@ -393,10 +391,10 @@ class SysCaptureBase(CaptureBase[AnyStr]):
) )
def _assert_state(self, op: str, states: tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert self._state in states, ( assert (
"cannot {} in state {!r}: expected one of {}".format( self._state in states
op, self._state, ", ".join(states) ), "cannot {} in state {!r}: expected one of {}".format(
) op, self._state, ", ".join(states)
) )
def start(self) -> None: def start(self) -> None:
@ -510,10 +508,10 @@ class FDCaptureBase(CaptureBase[AnyStr]):
) )
def _assert_state(self, op: str, states: tuple[str, ...]) -> None: def _assert_state(self, op: str, states: tuple[str, ...]) -> None:
assert self._state in states, ( assert (
"cannot {} in state {!r}: expected one of {}".format( self._state in states
op, self._state, ", ".join(states) ), "cannot {} in state {!r}: expected one of {}".format(
) op, self._state, ", ".join(states)
) )
def start(self) -> None: def start(self) -> None:

View File

@ -18,7 +18,6 @@ from typing import NoReturn
import py import py
if sys.version_info >= (3, 14): if sys.version_info >= (3, 14):
from annotationlib import Format from annotationlib import Format

View File

@ -74,7 +74,6 @@ from _pytest.stash import Stash
from _pytest.warning_types import PytestConfigWarning from _pytest.warning_types import PytestConfigWarning
from _pytest.warning_types import warn_explicit_for from _pytest.warning_types import warn_explicit_for
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.assertion.rewrite import AssertionRewritingHook from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.cacheprovider import Cache from _pytest.cacheprovider import Cache
@ -344,7 +343,7 @@ def _prepareconfig(
if isinstance(args, os.PathLike): if isinstance(args, os.PathLike):
args = [os.fspath(args)] args = [os.fspath(args)]
elif not isinstance(args, list): elif not isinstance(args, list):
msg = ( # type:ignore[unreachable] msg = ( # type: ignore[unreachable]
"`args` parameter expected to be a list of strings, got: {!r} (type: {})" "`args` parameter expected to be a list of strings, got: {!r} (type: {})"
) )
raise TypeError(msg.format(args, type(args))) raise TypeError(msg.format(args, type(args)))
@ -861,9 +860,9 @@ class PytestPluginManager(PluginManager):
# "terminal" or "capture". Those plugins are registered under their # "terminal" or "capture". Those plugins are registered under their
# basename for historic purposes but must be imported with the # basename for historic purposes but must be imported with the
# _pytest prefix. # _pytest prefix.
assert isinstance(modname, str), ( assert isinstance(
f"module name as text required, got {modname!r}" modname, str
) ), f"module name as text required, got {modname!r}"
if self.is_blocked(modname) or self.get_plugin(modname) is not None: if self.is_blocked(modname) or self.get_plugin(modname) is not None:
return return
@ -1475,9 +1474,9 @@ class Config:
def parse(self, args: list[str], addopts: bool = True) -> None: def parse(self, args: list[str], addopts: bool = True) -> None:
# Parse given cmdline arguments into this config object. # Parse given cmdline arguments into this config object.
assert self.args == [], ( assert (
"can only parse cmdline args at most once per Config object" self.args == []
) ), "can only parse cmdline args at most once per Config object"
self.hook.pytest_addhooks.call_historic( self.hook.pytest_addhooks.call_historic(
kwargs=dict(pluginmanager=self.pluginmanager) kwargs=dict(pluginmanager=self.pluginmanager)
@ -2082,8 +2081,7 @@ def parse_warning_filter(
* Raises UsageError so we get nice error messages on failure. * Raises UsageError so we get nice error messages on failure.
""" """
__tracebackhide__ = True __tracebackhide__ = True
error_template = dedent( error_template = dedent(f"""\
f"""\
while parsing the following warning configuration: while parsing the following warning configuration:
{arg} {arg}
@ -2091,23 +2089,20 @@ def parse_warning_filter(
This error occurred: This error occurred:
{{error}} {{error}}
""" """)
)
parts = arg.split(":") parts = arg.split(":")
if len(parts) > 5: if len(parts) > 5:
doc_url = ( doc_url = (
"https://docs.python.org/3/library/warnings.html#describing-warning-filters" "https://docs.python.org/3/library/warnings.html#describing-warning-filters"
) )
error = dedent( error = dedent(f"""\
f"""\
Too many fields ({len(parts)}), expected at most 5 separated by colons: Too many fields ({len(parts)}), expected at most 5 separated by colons:
action:message:category:module:line action:message:category:module:line
For more information please consult: {doc_url} For more information please consult: {doc_url}
""" """)
)
raise UsageError(error_template.format(error=error)) raise UsageError(error_template.format(error=error))
while len(parts) < 5: while len(parts) < 5:

View File

@ -16,7 +16,6 @@ from .exceptions import UsageError
import _pytest._io import _pytest._io
from _pytest.deprecated import check_ispytest from _pytest.deprecated import check_ispytest
FILE_OR_DIR = "file_or_dir" FILE_OR_DIR = "file_or_dir"
@ -186,10 +185,19 @@ class Parser:
self, self,
name: str, name: str,
help: str, help: str,
type: Literal[ type: (
"string", "paths", "pathlist", "args", "linelist", "bool", "int", "float" Literal[
] "string",
| None = None, "paths",
"pathlist",
"args",
"linelist",
"bool",
"int",
"float",
]
| None
) = None,
default: Any = NOT_SET, default: Any = NOT_SET,
*, *,
aliases: Sequence[str] = (), aliases: Sequence[str] = (),

View File

@ -12,7 +12,6 @@ from ..compat import LEGACY_PATH
from ..compat import legacy_path from ..compat import legacy_path
from ..deprecated import HOOK_LEGACY_PATH_ARG from ..deprecated import HOOK_LEGACY_PATH_ARG
# hookname: (Path, LEGACY_PATH) # hookname: (Path, LEGACY_PATH)
imply_paths_hooks: Mapping[str, tuple[str, str]] = { imply_paths_hooks: Mapping[str, tuple[str, str]] = {
"pytest_ignore_collect": ("collection_path", "path"), "pytest_ignore_collect": ("collection_path", "path"),

View File

@ -18,7 +18,6 @@ from _pytest.warning_types import PytestRemovedIn9Warning
from _pytest.warning_types import PytestRemovedIn10Warning from _pytest.warning_types import PytestRemovedIn10Warning
from _pytest.warning_types import UnformattedWarning from _pytest.warning_types import UnformattedWarning
# set of plugins which have been integrated into the core; we use this list to ignore # set of plugins which have been integrated into the core; we use this list to ignore
# them during registration to avoid conflicts # them during registration to avoid conflicts
DEPRECATED_EXTERNAL_PLUGINS = { DEPRECATED_EXTERNAL_PLUGINS = {

View File

@ -41,7 +41,6 @@ from _pytest.python import Module
from _pytest.python_api import approx from _pytest.python_api import approx
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
import doctest import doctest
@ -525,7 +524,7 @@ class DoctestModule(Module):
obj = inspect.unwrap(obj) obj = inspect.unwrap(obj)
# Type ignored because this is a private function. # Type ignored because this is a private function.
return super()._find_lineno( # type:ignore[misc] return super()._find_lineno( # type: ignore[misc]
obj, obj,
source_lines, source_lines,
) )

View File

@ -10,7 +10,6 @@ from _pytest.nodes import Item
from _pytest.stash import StashKey from _pytest.stash import StashKey
import pytest import pytest
fault_handler_original_stderr_fd_key = StashKey[int]() fault_handler_original_stderr_fd_key = StashKey[int]()
fault_handler_stderr_fd_key = StashKey[int]() fault_handler_stderr_fd_key = StashKey[int]()

View File

@ -70,7 +70,6 @@ from _pytest.scope import Scope
from _pytest.warning_types import PytestRemovedIn9Warning from _pytest.warning_types import PytestRemovedIn9Warning
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -759,9 +758,9 @@ class SubRequest(FixtureRequest):
if node is None and scope is Scope.Class: if node is None and scope is Scope.Class:
# Fallback to function item itself. # Fallback to function item itself.
node = self._pyfuncitem node = self._pyfuncitem
assert node, ( assert (
f'Could not obtain a node for scope "{scope}" for function {self._pyfuncitem!r}' node
) ), f'Could not obtain a node for scope "{scope}" for function {self._pyfuncitem!r}'
return node return node
def _check_scope( def _check_scope(
@ -821,9 +820,9 @@ class FixtureLookupError(LookupError):
# new cases it might break. # new cases it might break.
# Add the assert to make it clearer to developer that this will fail, otherwise # Add the assert to make it clearer to developer that this will fail, otherwise
# it crashes because `fspath` does not get set due to `stack` being empty. # it crashes because `fspath` does not get set due to `stack` being empty.
assert self.msg is None or self.fixturestack, ( assert (
"formatrepr assumptions broken, rewrite it to handle it" self.msg is None or self.fixturestack
) ), "formatrepr assumptions broken, rewrite it to handle it"
if msg is not None: if msg is not None:
# The last fixture raise an error, let's present # The last fixture raise an error, let's present
# it at the requesting side. # it at the requesting side.
@ -1950,9 +1949,11 @@ def _show_fixtures_per_test(config: Config, session: Session) -> None:
if fixture_doc: if fixture_doc:
write_docstring( write_docstring(
tw, tw,
fixture_doc.split("\n\n", maxsplit=1)[0] (
if verbose <= 0 fixture_doc.split("\n\n", maxsplit=1)[0]
else fixture_doc, if verbose <= 0
else fixture_doc
),
) )
else: else:
tw.line(" no docstring available", red=True) tw.line(" no docstring available", red=True)

View File

@ -15,7 +15,6 @@ from pluggy import HookspecMarker
from .deprecated import HOOK_LEGACY_PATH_ARG from .deprecated import HOOK_LEGACY_PATH_ARG
if TYPE_CHECKING: if TYPE_CHECKING:
import pdb import pdb
from typing import Literal from typing import Literal
@ -1043,7 +1042,7 @@ def pytest_assertion_pass(item: Item, lineno: int, orig: str, expl: str) -> None
), ),
}, },
) )
def pytest_report_header( # type:ignore[empty-body] def pytest_report_header( # type: ignore[empty-body]
config: Config, start_path: Path, startdir: LEGACY_PATH config: Config, start_path: Path, startdir: LEGACY_PATH
) -> str | list[str]: ) -> str | list[str]:
"""Return a string or list of strings to be displayed as header info for terminal reporting. """Return a string or list of strings to be displayed as header info for terminal reporting.
@ -1079,7 +1078,7 @@ def pytest_report_header( # type:ignore[empty-body]
), ),
}, },
) )
def pytest_report_collectionfinish( # type:ignore[empty-body] def pytest_report_collectionfinish( # type: ignore[empty-body]
config: Config, config: Config,
start_path: Path, start_path: Path,
startdir: LEGACY_PATH, startdir: LEGACY_PATH,
@ -1118,7 +1117,7 @@ def pytest_report_collectionfinish( # type:ignore[empty-body]
@hookspec(firstresult=True) @hookspec(firstresult=True)
def pytest_report_teststatus( # type:ignore[empty-body] def pytest_report_teststatus( # type: ignore[empty-body]
report: CollectReport | TestReport, config: Config report: CollectReport | TestReport, config: Config
) -> TestShortLogReport | tuple[str, str, str | tuple[str, Mapping[str, bool]]]: ) -> TestShortLogReport | tuple[str, str, str | tuple[str, Mapping[str, bool]]]:
"""Return result-category, shortletter and verbose word for status """Return result-category, shortletter and verbose word for status
@ -1216,7 +1215,7 @@ def pytest_warning_recorded(
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def pytest_markeval_namespace( # type:ignore[empty-body] def pytest_markeval_namespace( # type: ignore[empty-body]
config: Config, config: Config,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Called when constructing the globals dictionary used for """Called when constructing the globals dictionary used for

View File

@ -30,7 +30,6 @@ from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest import pytest
xml_key = StashKey["LogXML"]() xml_key = StashKey["LogXML"]()

View File

@ -33,7 +33,6 @@ from _pytest.pytester import RunResult
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
from _pytest.tmpdir import TempPathFactory from _pytest.tmpdir import TempPathFactory
if TYPE_CHECKING: if TYPE_CHECKING:
import pexpect import pexpect

View File

@ -41,7 +41,6 @@ from _pytest.main import Session
from _pytest.stash import StashKey from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
if TYPE_CHECKING: if TYPE_CHECKING:
logging_StreamHandler = logging.StreamHandler[StringIO] logging_StreamHandler = logging.StreamHandler[StringIO]
else: else:

View File

@ -48,7 +48,6 @@ from _pytest.runner import collect_one_node
from _pytest.runner import SetupState from _pytest.runner import SetupState
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self

View File

@ -27,7 +27,6 @@ from _pytest.config.argparsing import NOT_SET
from _pytest.config.argparsing import Parser from _pytest.config.argparsing import Parser
from _pytest.stash import StashKey from _pytest.stash import StashKey
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.nodes import Item from _pytest.nodes import Item

View File

@ -38,7 +38,6 @@ from typing import NoReturn
from typing import overload from typing import overload
from typing import Protocol from typing import Protocol
__all__ = [ __all__ = [
"Expression", "Expression",
"ExpressionMatcher", "ExpressionMatcher",

View File

@ -31,7 +31,6 @@ from _pytest.raises import AbstractRaises
from _pytest.scope import _ScopeName from _pytest.scope import _ScopeName
from _pytest.warning_types import PytestUnknownMarkWarning from _pytest.warning_types import PytestUnknownMarkWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from ..nodes import Node from ..nodes import Node
@ -500,10 +499,12 @@ if TYPE_CHECKING:
*conditions: str | bool, *conditions: str | bool,
reason: str = ..., reason: str = ...,
run: bool = ..., run: bool = ...,
raises: None raises: (
| type[BaseException] None
| tuple[type[BaseException], ...] | type[BaseException]
| AbstractRaises[BaseException] = ..., | tuple[type[BaseException], ...]
| AbstractRaises[BaseException]
) = ...,
strict: bool = ..., strict: bool = ...,
) -> MarkDecorator: ... ) -> MarkDecorator: ...
@ -514,9 +515,11 @@ if TYPE_CHECKING:
argvalues: Iterable[ParameterSet | Sequence[object] | object], argvalues: Iterable[ParameterSet | Sequence[object] | object],
*, *,
indirect: bool | Sequence[str] = ..., indirect: bool | Sequence[str] = ...,
ids: Iterable[None | str | float | int | bool] ids: (
| Callable[[Any], object | None] Iterable[None | str | float | int | bool]
| None = ..., | Callable[[Any], object | None]
| None
) = ...,
scope: _ScopeName | None = ..., scope: _ScopeName | None = ...,
) -> MarkDecorator: ... ) -> MarkDecorator: ...

View File

@ -21,7 +21,6 @@ from _pytest.deprecated import MONKEYPATCH_LEGACY_NAMESPACE_PACKAGES
from _pytest.fixtures import fixture from _pytest.fixtures import fixture
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$") RE_IMPORT_ERROR_NAME = re.compile(r"^No module named (.*)$")

View File

@ -41,7 +41,6 @@ from _pytest.pathlib import absolutepath
from _pytest.stash import Stash from _pytest.stash import Stash
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self

View File

@ -14,7 +14,6 @@ from _pytest.stash import StashKey
from _pytest.terminal import TerminalReporter from _pytest.terminal import TerminalReporter
import pytest import pytest
pastebinfile_key = StashKey[IO[bytes]]() pastebinfile_key = StashKey[IO[bytes]]()

View File

@ -37,7 +37,6 @@ from _pytest.compat import assert_never
from _pytest.outcomes import skip from _pytest.outcomes import skip
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from importlib._bootstrap_external import _NamespaceLoader as NamespaceLoader from importlib._bootstrap_external import _NamespaceLoader as NamespaceLoader
else: else:
@ -73,8 +72,10 @@ def get_lock_path(path: _AnyPurePath) -> _AnyPurePath:
def on_rm_rf_error( def on_rm_rf_error(
func: Callable[..., Any] | None, func: Callable[..., Any] | None,
path: str, path: str,
excinfo: BaseException excinfo: (
| tuple[type[BaseException], BaseException, types.TracebackType | None], BaseException
| tuple[type[BaseException], BaseException, types.TracebackType | None]
),
*, *,
start_path: Path, start_path: Path,
) -> bool: ) -> bool:

View File

@ -67,7 +67,6 @@ from _pytest.reports import TestReport
from _pytest.tmpdir import TempPathFactory from _pytest.tmpdir import TempPathFactory
from _pytest.warning_types import PytestFDWarning from _pytest.warning_types import PytestFDWarning
if TYPE_CHECKING: if TYPE_CHECKING:
import pexpect import pexpect

View File

@ -77,7 +77,6 @@ from _pytest.stash import StashKey
from _pytest.warning_types import PytestCollectionWarning from _pytest.warning_types import PytestCollectionWarning
from _pytest.warning_types import PytestReturnNotNoneWarning from _pytest.warning_types import PytestReturnNotNoneWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -959,9 +958,9 @@ class IdMaker:
new_id = f"{id}{suffix}{id_suffixes[id]}" new_id = f"{id}{suffix}{id_suffixes[id]}"
resolved_ids[index] = new_id resolved_ids[index] = new_id
id_suffixes[id] += 1 id_suffixes[id] += 1
assert len(resolved_ids) == len(set(resolved_ids)), ( assert len(resolved_ids) == len(
f"Internal error: {resolved_ids=}" set(resolved_ids)
) ), f"Internal error: {resolved_ids=}"
return resolved_ids return resolved_ids
def _strict_parametrization_ids_enabled(self) -> bool: def _strict_parametrization_ids_enabled(self) -> bool:

View File

@ -13,7 +13,6 @@ import sys
from typing import Any from typing import Any
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy import ndarray from numpy import ndarray

View File

@ -22,7 +22,6 @@ from _pytest._code.code import stringify_exception
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.warning_types import PytestWarning from _pytest.warning_types import PytestWarning
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Sequence from collections.abc import Sequence
@ -709,9 +708,9 @@ class RaisesExc(AbstractRaises[BaseExcT_co_default]):
fail(f"DID NOT RAISE {self.expected_exceptions[0]!r}") fail(f"DID NOT RAISE {self.expected_exceptions[0]!r}")
assert self.excinfo is not None, ( assert (
"Internal error - should have been constructed in __enter__" self.excinfo is not None
) ), "Internal error - should have been constructed in __enter__"
if not self.matches(exc_val): if not self.matches(exc_val):
if self._just_propagate: if self._just_propagate:
@ -928,9 +927,9 @@ class RaisesGroup(AbstractRaises[BaseExceptionGroup[BaseExcT_co]]):
@overload @overload
def __init__( def __init__(
self: RaisesGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], self: RaisesGroup[BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]],
expected_exception: type[BaseExcT_1] expected_exception: (
| RaisesExc[BaseExcT_1] type[BaseExcT_1] | RaisesExc[BaseExcT_1] | RaisesGroup[BaseExcT_2]
| RaisesGroup[BaseExcT_2], ),
/, /,
*other_exceptions: type[BaseExcT_1] *other_exceptions: type[BaseExcT_1]
| RaisesExc[BaseExcT_1] | RaisesExc[BaseExcT_1]
@ -947,9 +946,9 @@ class RaisesGroup(AbstractRaises[BaseExceptionGroup[BaseExcT_co]]):
def __init__( def __init__(
self: RaisesGroup[ExcT_1 | BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]], self: RaisesGroup[ExcT_1 | BaseExcT_1 | BaseExceptionGroup[BaseExcT_2]],
expected_exception: type[BaseExcT_1] expected_exception: (
| RaisesExc[BaseExcT_1] type[BaseExcT_1] | RaisesExc[BaseExcT_1] | RaisesGroup[BaseExcT_2]
| RaisesGroup[BaseExcT_2], ),
/, /,
*other_exceptions: type[BaseExcT_1] *other_exceptions: type[BaseExcT_1]
| RaisesExc[BaseExcT_1] | RaisesExc[BaseExcT_1]
@ -1416,9 +1415,9 @@ class RaisesGroup(AbstractRaises[BaseExceptionGroup[BaseExcT_co]]):
if exc_type is None: if exc_type is None:
fail(f"DID NOT RAISE any exception, expected `{self.expected_type()}`") fail(f"DID NOT RAISE any exception, expected `{self.expected_type()}`")
assert self.excinfo is not None, ( assert (
"Internal error - should have been constructed in __enter__" self.excinfo is not None
) ), "Internal error - should have been constructed in __enter__"
# group_str is the only thing that differs between RaisesExc and RaisesGroup... # group_str is the only thing that differs between RaisesExc and RaisesGroup...
# I might just scrap it? Or make it part of fail_reason # I might just scrap it? Or make it part of fail_reason

View File

@ -15,7 +15,6 @@ from typing import overload
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from typing import TypeVar from typing import TypeVar
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self
@ -26,7 +25,6 @@ from _pytest.fixtures import fixture
from _pytest.outcomes import Exit from _pytest.outcomes import Exit
from _pytest.outcomes import fail from _pytest.outcomes import fail
T = TypeVar("T") T = TypeVar("T")

View File

@ -35,7 +35,6 @@ from _pytest.nodes import Item
from _pytest.outcomes import fail from _pytest.outcomes import fail
from _pytest.outcomes import skip from _pytest.outcomes import skip
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -274,9 +273,9 @@ def _format_exception_group_all_skipped_longrepr(
excinfo: ExceptionInfo[BaseExceptionGroup[BaseException | BaseExceptionGroup]], excinfo: ExceptionInfo[BaseExceptionGroup[BaseException | BaseExceptionGroup]],
) -> tuple[str, int, str]: ) -> tuple[str, int, str]:
r = excinfo._getreprcrash() r = excinfo._getreprcrash()
assert r is not None, ( assert (
"There should always be a traceback entry for skipping a test." r is not None
) ), "There should always be a traceback entry for skipping a test."
if all( if all(
getattr(skip, "_use_item_location", False) for skip in excinfo.value.exceptions getattr(skip, "_use_item_location", False) for skip in excinfo.value.exceptions
): ):
@ -321,11 +320,13 @@ class TestReport(BaseReport):
location: tuple[str, int | None, str], location: tuple[str, int | None, str],
keywords: Mapping[str, Any], keywords: Mapping[str, Any],
outcome: Literal["passed", "failed", "skipped"], outcome: Literal["passed", "failed", "skipped"],
longrepr: None longrepr: (
| ExceptionInfo[BaseException] None
| tuple[str, int, str] | ExceptionInfo[BaseException]
| str | tuple[str, int, str]
| TerminalRepr, | str
| TerminalRepr
),
when: Literal["setup", "call", "teardown"], when: Literal["setup", "call", "teardown"],
sections: Iterable[tuple[str, str]] = (), sections: Iterable[tuple[str, str]] = (),
duration: float = 0, duration: float = 0,
@ -412,9 +413,9 @@ class TestReport(BaseReport):
elif isinstance(excinfo.value, skip.Exception): elif isinstance(excinfo.value, skip.Exception):
outcome = "skipped" outcome = "skipped"
r = excinfo._getreprcrash() r = excinfo._getreprcrash()
assert r is not None, ( assert (
"There should always be a traceback entry for skipping a test." r is not None
) ), "There should always be a traceback entry for skipping a test."
if excinfo.value._use_item_location: if excinfo.value._use_item_location:
path, line = item.reportinfo()[:2] path, line = item.reportinfo()[:2]
assert line is not None assert line is not None
@ -466,11 +467,13 @@ class CollectReport(BaseReport):
self, self,
nodeid: str, nodeid: str,
outcome: Literal["passed", "failed", "skipped"], outcome: Literal["passed", "failed", "skipped"],
longrepr: None longrepr: (
| ExceptionInfo[BaseException] None
| tuple[str, int, str] | ExceptionInfo[BaseException]
| str | tuple[str, int, str]
| TerminalRepr, | str
| TerminalRepr
),
result: list[Item | Collector] | None, result: list[Item | Collector] | None,
sections: Iterable[tuple[str, str]] = (), sections: Iterable[tuple[str, str]] = (),
**extra, **extra,
@ -496,7 +499,7 @@ class CollectReport(BaseReport):
self.__dict__.update(extra) self.__dict__.update(extra)
@property @property
def location( # type:ignore[override] def location( # type: ignore[override]
self, self,
) -> tuple[str, int | None, str] | None: ) -> tuple[str, int | None, str] | None:
return (self.fspath, None, self.fspath) return (self.fspath, None, self.fspath)

View File

@ -36,7 +36,6 @@ from _pytest.outcomes import OutcomeException
from _pytest.outcomes import Skipped from _pytest.outcomes import Skipped
from _pytest.outcomes import TEST_OUTCOME from _pytest.outcomes import TEST_OUTCOME
if sys.version_info < (3, 11): if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup from exceptiongroup import BaseExceptionGroup
@ -172,7 +171,7 @@ def pytest_runtest_call(item: Item) -> None:
del sys.last_value del sys.last_value
del sys.last_traceback del sys.last_traceback
if sys.version_info >= (3, 12, 0): if sys.version_info >= (3, 12, 0):
del sys.last_exc # type:ignore[attr-defined] del sys.last_exc # type: ignore[attr-defined]
except AttributeError: except AttributeError:
pass pass
try: try:
@ -182,7 +181,7 @@ def pytest_runtest_call(item: Item) -> None:
sys.last_type = type(e) sys.last_type = type(e)
sys.last_value = e sys.last_value = e
if sys.version_info >= (3, 12, 0): if sys.version_info >= (3, 12, 0):
sys.last_exc = e # type:ignore[attr-defined] sys.last_exc = e # type: ignore[attr-defined]
assert e.__traceback__ is not None assert e.__traceback__ is not None
# Skip *this* frame # Skip *this* frame
sys.last_traceback = e.__traceback__.tb_next sys.last_traceback = e.__traceback__.tb_next

View File

@ -14,7 +14,6 @@ from enum import Enum
from functools import total_ordering from functools import total_ordering
from typing import Literal from typing import Literal
_ScopeName = Literal["session", "package", "module", "class", "function"] _ScopeName = Literal["session", "package", "module", "class", "function"]

View File

@ -5,7 +5,6 @@ from typing import cast
from typing import Generic from typing import Generic
from typing import TypeVar from typing import TypeVar
__all__ = ["Stash", "StashKey"] __all__ = ["Stash", "StashKey"]

View File

@ -13,7 +13,6 @@ from _pytest.config.argparsing import Parser
from _pytest.main import Session from _pytest.main import Session
from _pytest.reports import TestReport from _pytest.reports import TestReport
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self

View File

@ -38,7 +38,6 @@ from _pytest.runner import check_interactive_exception
from _pytest.runner import get_reraise_exceptions from _pytest.runner import get_reraise_exceptions
from _pytest.stash import StashKey from _pytest.stash import StashKey
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Self from typing_extensions import Self

View File

@ -53,7 +53,6 @@ from _pytest.reports import BaseReport
from _pytest.reports import CollectReport from _pytest.reports import CollectReport
from _pytest.reports import TestReport from _pytest.reports import TestReport
if TYPE_CHECKING: if TYPE_CHECKING:
from _pytest.main import Session from _pytest.main import Session
@ -1231,7 +1230,7 @@ class TerminalReporter:
return return
session_duration = self._session_start.elapsed() session_duration = self._session_start.elapsed()
(parts, main_color) = self.build_summary_stats_line() parts, main_color = self.build_summary_stats_line()
line_parts = [] line_parts = []
display_sep = self.verbosity >= 0 display_sep = self.verbosity >= 0
@ -1460,7 +1459,9 @@ class TerminalReporter:
elif deselected == 0: elif deselected == 0:
main_color = "green" main_color = "green"
collected_output = "%d %s collected" % pluralize(self._numcollected, "test") # noqa: UP031 collected_output = "%d %s collected" % pluralize(
self._numcollected, "test"
) # noqa: UP031
parts = [(collected_output, {main_color: True})] parts = [(collected_output, {main_color: True})]
else: else:
all_tests_were_deselected = self._numcollected == deselected all_tests_were_deselected = self._numcollected == deselected
@ -1476,7 +1477,9 @@ class TerminalReporter:
if errors: if errors:
main_color = _color_for_type["error"] main_color = _color_for_type["error"]
parts += [("%d %s" % pluralize(errors, "error"), {main_color: True})] # noqa: UP031 parts += [
("%d %s" % pluralize(errors, "error"), {main_color: True})
] # noqa: UP031
return parts, main_color return parts, main_color

View File

@ -16,7 +16,6 @@ from _pytest.stash import StashKey
from _pytest.tracemalloc import tracemalloc_message from _pytest.tracemalloc import tracemalloc_message
import pytest import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass

View File

@ -16,7 +16,6 @@ from time import sleep
from time import time from time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from pytest import MonkeyPatch from pytest import MonkeyPatch
@ -73,7 +72,8 @@ class MockTiming:
uses `_pytest.timing` functions. uses `_pytest.timing` functions.
Time is static, and only advances through `sleep` calls, thus tests might sleep over large Time is static, and only advances through `sleep` calls, thus tests might sleep over large
numbers and obtain accurate time() calls at the end, making tests reliable and instant.""" numbers and obtain accurate time() calls at the end, making tests reliable and instant.
"""
_current_time: float = datetime(2020, 5, 22, 14, 20, 50).timestamp() _current_time: float = datetime(2020, 5, 22, 14, 20, 50).timestamp()

View File

@ -32,7 +32,6 @@ from _pytest.nodes import Item
from _pytest.reports import TestReport from _pytest.reports import TestReport
from _pytest.stash import StashKey from _pytest.stash import StashKey
tmppath_result_key = StashKey[dict[str, bool]]() tmppath_result_key = StashKey[dict[str, bool]]()
RetentionType = Literal["all", "failed", "none"] RetentionType = Literal["all", "failed", "none"]

View File

@ -38,7 +38,6 @@ from _pytest.runner import check_interactive_exception
from _pytest.subtests import SubtestContext from _pytest.subtests import SubtestContext
from _pytest.subtests import SubtestReport from _pytest.subtests import SubtestReport
if sys.version_info[:2] < (3, 11): if sys.version_info[:2] < (3, 11):
from exceptiongroup import ExceptionGroup from exceptiongroup import ExceptionGroup
@ -405,9 +404,11 @@ class TestCaseFunction(Function):
self, self,
test_case: Any, test_case: Any,
test: TestCase, test: TestCase,
exc_info: ExceptionInfo[BaseException] exc_info: (
| tuple[type[BaseException], BaseException, TracebackType] ExceptionInfo[BaseException]
| None, | tuple[type[BaseException], BaseException, TracebackType]
| None
),
) -> None: ) -> None:
exception_info: ExceptionInfo[BaseException] | None exception_info: ExceptionInfo[BaseException] | None
match exc_info: match exc_info:
@ -459,9 +460,10 @@ class TestCaseFunction(Function):
"""Compute or obtain the cached values for subtest errors and non-subtest skips.""" """Compute or obtain the cached values for subtest errors and non-subtest skips."""
from unittest.case import _SubTest # type: ignore[attr-defined] from unittest.case import _SubTest # type: ignore[attr-defined]
assert sys.version_info < (3, 11), ( assert sys.version_info < (
"This workaround only should be used in Python 3.10" 3,
) 11,
), "This workaround only should be used in Python 3.10"
if self._cached_errors_and_skips is not None: if self._cached_errors_and_skips is not None:
return self._cached_errors_and_skips return self._cached_errors_and_skips
@ -600,7 +602,7 @@ def _handle_twisted_exc_info(
# Unfortunately, because we cannot import `twisted.python.failure` at the top of the file # Unfortunately, because we cannot import `twisted.python.failure` at the top of the file
# and use it in the signature, we need to use `type:ignore` here because we cannot narrow # and use it in the signature, we need to use `type:ignore` here because we cannot narrow
# the type properly in the `if` statement above. # the type properly in the `if` statement above.
return rawexcinfo # type:ignore[return-value] return rawexcinfo # type: ignore[return-value]
elif twisted_version is TwistedVersion.Version24: elif twisted_version is TwistedVersion.Version24:
# Twisted calls addError() passing its own classes (like `twisted.python.Failure`), which violates # Twisted calls addError() passing its own classes (like `twisted.python.Failure`), which violates
# the `addError()` signature, so we extract the original `sys.exc_info()` tuple which is stored # the `addError()` signature, so we extract the original `sys.exc_info()` tuple which is stored
@ -609,8 +611,8 @@ def _handle_twisted_exc_info(
saved_exc_info = getattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR) saved_exc_info = getattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
# Delete the attribute from the original object to avoid leaks. # Delete the attribute from the original object to avoid leaks.
delattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR) delattr(rawexcinfo, TWISTED_RAW_EXCINFO_ATTR)
return saved_exc_info # type:ignore[no-any-return] return saved_exc_info # type: ignore[no-any-return]
return rawexcinfo # type:ignore[return-value] return rawexcinfo # type: ignore[return-value]
elif twisted_version is TwistedVersion.Version25: elif twisted_version is TwistedVersion.Version25:
if isinstance(rawexcinfo, BaseException): if isinstance(rawexcinfo, BaseException):
import twisted.python.failure import twisted.python.failure
@ -621,7 +623,7 @@ def _handle_twisted_exc_info(
tb = sys.exc_info()[2] tb = sys.exc_info()[2]
return type(rawexcinfo.value), rawexcinfo.value, tb return type(rawexcinfo.value), rawexcinfo.value, tb
return rawexcinfo # type:ignore[return-value] return rawexcinfo # type: ignore[return-value]
else: else:
# Ideally we would use assert_never() here, but it is not available in all Python versions # Ideally we would use assert_never() here, but it is not available in all Python versions
# we support, plus we do not require `type_extensions` currently. # we support, plus we do not require `type_extensions` currently.

View File

@ -16,7 +16,6 @@ from _pytest.stash import StashKey
from _pytest.tracemalloc import tracemalloc_message from _pytest.tracemalloc import tracemalloc_message
import pytest import pytest
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass

View File

@ -7,7 +7,7 @@ import yaml
# in some circumstances, the yaml module we imoprted may be from a different version, so we need # in some circumstances, the yaml module we imoprted may be from a different version, so we need
# to tread carefully when poking at it here (it may not have the attributes we expect) # to tread carefully when poking at it here (it may not have the attributes we expect)
if not getattr(yaml, '__with_libyaml__', False): if not getattr(yaml, "__with_libyaml__", False):
from sys import version_info from sys import version_info
exc = ModuleNotFoundError if version_info >= (3, 6) else ImportError exc = ModuleNotFoundError if version_info >= (3, 6) else ImportError
@ -15,19 +15,20 @@ if not getattr(yaml, '__with_libyaml__', False):
else: else:
from yaml._yaml import * from yaml._yaml import *
import warnings import warnings
warnings.warn( warnings.warn(
'The _yaml extension module is now located at yaml._yaml' "The _yaml extension module is now located at yaml._yaml"
' and its location is subject to change. To use the' " and its location is subject to change. To use the"
' LibYAML-based parser and emitter, import from `yaml`:' " LibYAML-based parser and emitter, import from `yaml`:"
' `from yaml import CLoader as Loader, CDumper as Dumper`.', " `from yaml import CLoader as Loader, CDumper as Dumper`.",
DeprecationWarning DeprecationWarning,
) )
del warnings del warnings
# Don't `del yaml` here because yaml is actually an existing # Don't `del yaml` here because yaml is actually an existing
# namespace member of _yaml. # namespace member of _yaml.
__name__ = '_yaml' __name__ = "_yaml"
# If the module is top-level (i.e. not a part of any specific package) # If the module is top-level (i.e. not a part of any specific package)
# then the attribute should be set to ''. # then the attribute should be set to ''.
# https://docs.python.org/3.8/library/types.html # https://docs.python.org/3.8/library/types.html
__package__ = '' __package__ = ""

View File

@ -3,4 +3,3 @@ Generator: hatchling 1.28.0
Root-Is-Purelib: false Root-Is-Purelib: false
Tag: cp312-cp312-macosx_11_0_arm64 Tag: cp312-cp312-macosx_11_0_arm64
Generator: delocate 0.13.0 Generator: delocate 0.13.0

View File

@ -13,30 +13,36 @@ from black.mode import Mode
from black.output import out from black.output import out
from black.report import NothingChanged from black.report import NothingChanged
TRANSFORMED_MAGICS = frozenset(( TRANSFORMED_MAGICS = frozenset(
"get_ipython().run_cell_magic", (
"get_ipython().system", "get_ipython().run_cell_magic",
"get_ipython().getoutput", "get_ipython().system",
"get_ipython().run_line_magic", "get_ipython().getoutput",
)) "get_ipython().run_line_magic",
TOKENS_TO_IGNORE = frozenset(( )
"ENDMARKER", )
"NL", TOKENS_TO_IGNORE = frozenset(
"NEWLINE", (
"COMMENT", "ENDMARKER",
"DEDENT", "NL",
"UNIMPORTANT_WS", "NEWLINE",
"ESCAPED_NL", "COMMENT",
)) "DEDENT",
PYTHON_CELL_MAGICS = frozenset(( "UNIMPORTANT_WS",
"capture", "ESCAPED_NL",
"prun", )
"pypy", )
"python", PYTHON_CELL_MAGICS = frozenset(
"python3", (
"time", "capture",
"timeit", "prun",
)) "pypy",
"python",
"python3",
"time",
"timeit",
)
)
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)

View File

@ -309,16 +309,18 @@ class Mode:
return ".".join(parts) return ".".join(parts)
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(( return hash(
frozenset(self.target_versions), (
self.line_length, frozenset(self.target_versions),
self.string_normalization, self.line_length,
self.is_pyi, self.string_normalization,
self.is_ipynb, self.is_pyi,
self.skip_source_first_line, self.is_ipynb,
self.magic_trailing_comma, self.skip_source_first_line,
frozenset(self.python_cell_magics), self.magic_trailing_comma,
self.preview, frozenset(self.python_cell_magics),
self.unstable, self.preview,
frozenset(self.enabled_features), self.unstable,
)) frozenset(self.enabled_features),
)
)

View File

@ -3,4 +3,3 @@ Generator: setuptools (75.5.0)
Root-Is-Purelib: true Root-Is-Purelib: true
Tag: py2-none-any Tag: py2-none-any
Tag: py3-none-any Tag: py3-none-any

View File

@ -14,17 +14,17 @@ class ValidationError(ValueError):
self.ctx = ctx self.ctx = ctx
def __str__(self): def __str__(self):
out = '\n' out = "\n"
err = self err = self
while err.ctx is not None: while err.ctx is not None:
out += f'==> {err.ctx}\n' out += f"==> {err.ctx}\n"
err = err.error_msg err = err.error_msg
out += f'=====> {err.error_msg}' out += f"=====> {err.error_msg}"
return out return out
MISSING = collections.namedtuple('Missing', ())() MISSING = collections.namedtuple("Missing", ())()
type(MISSING).__repr__ = lambda self: 'MISSING' type(MISSING).__repr__ = lambda self: "MISSING"
@contextlib.contextmanager @contextlib.contextmanager
@ -52,7 +52,7 @@ def _dct_noop(self, dct):
def _check_optional(self, dct): def _check_optional(self, dct):
if self.key not in dct: if self.key not in dct:
return return
with validate_context(f'At key: {self.key}'): with validate_context(f"At key: {self.key}"):
self.check_fn(dct[self.key]) self.check_fn(dct[self.key])
@ -67,7 +67,7 @@ def _remove_default_optional(self, dct):
def _require_key(self, dct): def _require_key(self, dct):
if self.key not in dct: if self.key not in dct:
raise ValidationError(f'Missing required key: {self.key}') raise ValidationError(f"Missing required key: {self.key}")
def _check_required(self, dct): def _check_required(self, dct):
@ -79,6 +79,7 @@ def _check_required(self, dct):
def _check_fn_recurse(self): def _check_fn_recurse(self):
def check_fn(val): def check_fn(val):
validate(val, self.schema) validate(val, self.schema)
return check_fn return check_fn
@ -106,18 +107,16 @@ def _get_check_conditional(inner):
def _check_conditional(self, dct): def _check_conditional(self, dct):
if dct.get(self.condition_key, MISSING) == self.condition_value: if dct.get(self.condition_key, MISSING) == self.condition_value:
inner(self, dct) inner(self, dct)
elif ( elif self.condition_key in dct and self.ensure_absent and self.key in dct:
self.condition_key in dct and if hasattr(self.condition_value, "describe_opposite"):
self.ensure_absent and self.key in dct
):
if hasattr(self.condition_value, 'describe_opposite'):
explanation = self.condition_value.describe_opposite() explanation = self.condition_value.describe_opposite()
else: else:
explanation = f'is not {self.condition_value!r}' explanation = f"is not {self.condition_value!r}"
raise ValidationError( raise ValidationError(
f'Expected {self.key} to be absent when {self.condition_key} ' f"Expected {self.key} to be absent when {self.condition_key} "
f'{explanation}, found {self.key}: {dct[self.key]!r}', f"{explanation}, found {self.key}: {dct[self.key]!r}",
) )
return _check_conditional return _check_conditional
@ -144,11 +143,11 @@ def _remove_default_conditional_recurse(self, dct):
def _no_additional_keys_check(self, dct): def _no_additional_keys_check(self, dct):
extra = sorted(set(dct) - set(self.keys)) extra = sorted(set(dct) - set(self.keys))
if extra: if extra:
extra_s = ', '.join(str(x) for x in extra) extra_s = ", ".join(str(x) for x in extra)
keys_s = ', '.join(str(x) for x in self.keys) keys_s = ", ".join(str(x) for x in self.keys)
raise ValidationError( raise ValidationError(
f'Additional keys found: {extra_s}. ' f"Additional keys found: {extra_s}. "
f'Only these keys are allowed: {keys_s}', f"Only these keys are allowed: {keys_s}",
) )
@ -158,45 +157,51 @@ def _warn_additional_keys_check(self, dct):
self.callback(extra, self.keys, dct) self.callback(extra, self.keys, dct)
Required = collections.namedtuple('Required', ('key', 'check_fn')) Required = collections.namedtuple("Required", ("key", "check_fn"))
Required.check = _check_required Required.check = _check_required
Required.apply_default = _dct_noop Required.apply_default = _dct_noop
Required.remove_default = _dct_noop Required.remove_default = _dct_noop
RequiredRecurse = collections.namedtuple('RequiredRecurse', ('key', 'schema')) RequiredRecurse = collections.namedtuple("RequiredRecurse", ("key", "schema"))
RequiredRecurse.check = _check_required RequiredRecurse.check = _check_required
RequiredRecurse.check_fn = _check_fn_recurse RequiredRecurse.check_fn = _check_fn_recurse
RequiredRecurse.apply_default = _apply_default_required_recurse RequiredRecurse.apply_default = _apply_default_required_recurse
RequiredRecurse.remove_default = _remove_default_required_recurse RequiredRecurse.remove_default = _remove_default_required_recurse
Optional = collections.namedtuple('Optional', ('key', 'check_fn', 'default')) Optional = collections.namedtuple("Optional", ("key", "check_fn", "default"))
Optional.check = _check_optional Optional.check = _check_optional
Optional.apply_default = _apply_default_optional Optional.apply_default = _apply_default_optional
Optional.remove_default = _remove_default_optional Optional.remove_default = _remove_default_optional
OptionalRecurse = collections.namedtuple( OptionalRecurse = collections.namedtuple(
'OptionalRecurse', ('key', 'schema', 'default'), "OptionalRecurse",
("key", "schema", "default"),
) )
OptionalRecurse.check = _check_optional OptionalRecurse.check = _check_optional
OptionalRecurse.check_fn = _check_fn_recurse OptionalRecurse.check_fn = _check_fn_recurse
OptionalRecurse.apply_default = _apply_default_optional_recurse OptionalRecurse.apply_default = _apply_default_optional_recurse
OptionalRecurse.remove_default = _remove_default_optional_recurse OptionalRecurse.remove_default = _remove_default_optional_recurse
OptionalNoDefault = collections.namedtuple( OptionalNoDefault = collections.namedtuple(
'OptionalNoDefault', ('key', 'check_fn'), "OptionalNoDefault",
("key", "check_fn"),
) )
OptionalNoDefault.check = _check_optional OptionalNoDefault.check = _check_optional
OptionalNoDefault.apply_default = _dct_noop OptionalNoDefault.apply_default = _dct_noop
OptionalNoDefault.remove_default = _dct_noop OptionalNoDefault.remove_default = _dct_noop
Conditional = collections.namedtuple( Conditional = collections.namedtuple(
'Conditional', "Conditional",
('key', 'check_fn', 'condition_key', 'condition_value', 'ensure_absent'), ("key", "check_fn", "condition_key", "condition_value", "ensure_absent"),
) )
Conditional.__new__.__defaults__ = (False,) Conditional.__new__.__defaults__ = (False,)
Conditional.check = _get_check_conditional(_check_required) Conditional.check = _get_check_conditional(_check_required)
Conditional.apply_default = _dct_noop Conditional.apply_default = _dct_noop
Conditional.remove_default = _dct_noop Conditional.remove_default = _dct_noop
ConditionalOptional = collections.namedtuple( ConditionalOptional = collections.namedtuple(
'ConditionalOptional', "ConditionalOptional",
( (
'key', 'check_fn', 'default', 'condition_key', 'condition_value', "key",
'ensure_absent', "check_fn",
"default",
"condition_key",
"condition_value",
"ensure_absent",
), ),
) )
ConditionalOptional.__new__.__defaults__ = (False,) ConditionalOptional.__new__.__defaults__ = (False,)
@ -204,27 +209,28 @@ ConditionalOptional.check = _get_check_conditional(_check_optional)
ConditionalOptional.apply_default = _apply_default_conditional_optional ConditionalOptional.apply_default = _apply_default_conditional_optional
ConditionalOptional.remove_default = _remove_default_conditional_optional ConditionalOptional.remove_default = _remove_default_conditional_optional
ConditionalRecurse = collections.namedtuple( ConditionalRecurse = collections.namedtuple(
'ConditionalRecurse', "ConditionalRecurse",
('key', 'schema', 'condition_key', 'condition_value', 'ensure_absent'), ("key", "schema", "condition_key", "condition_value", "ensure_absent"),
) )
ConditionalRecurse.__new__.__defaults__ = (False,) ConditionalRecurse.__new__.__defaults__ = (False,)
ConditionalRecurse.check = _get_check_conditional(_check_required) ConditionalRecurse.check = _get_check_conditional(_check_required)
ConditionalRecurse.check_fn = _check_fn_recurse ConditionalRecurse.check_fn = _check_fn_recurse
ConditionalRecurse.apply_default = _apply_default_conditional_recurse ConditionalRecurse.apply_default = _apply_default_conditional_recurse
ConditionalRecurse.remove_default = _remove_default_conditional_recurse ConditionalRecurse.remove_default = _remove_default_conditional_recurse
NoAdditionalKeys = collections.namedtuple('NoAdditionalKeys', ('keys',)) NoAdditionalKeys = collections.namedtuple("NoAdditionalKeys", ("keys",))
NoAdditionalKeys.check = _no_additional_keys_check NoAdditionalKeys.check = _no_additional_keys_check
NoAdditionalKeys.apply_default = _dct_noop NoAdditionalKeys.apply_default = _dct_noop
NoAdditionalKeys.remove_default = _dct_noop NoAdditionalKeys.remove_default = _dct_noop
WarnAdditionalKeys = collections.namedtuple( WarnAdditionalKeys = collections.namedtuple(
'WarnAdditionalKeys', ('keys', 'callback'), "WarnAdditionalKeys",
("keys", "callback"),
) )
WarnAdditionalKeys.check = _warn_additional_keys_check WarnAdditionalKeys.check = _warn_additional_keys_check
WarnAdditionalKeys.apply_default = _dct_noop WarnAdditionalKeys.apply_default = _dct_noop
WarnAdditionalKeys.remove_default = _dct_noop WarnAdditionalKeys.remove_default = _dct_noop
class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))): class Map(collections.namedtuple("Map", ("object_name", "id_key", "items"))):
__slots__ = () __slots__ = ()
def __new__(cls, object_name, id_key, *items): def __new__(cls, object_name, id_key, *items):
@ -233,14 +239,13 @@ class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))):
def check(self, v): def check(self, v):
if not isinstance(v, dict): if not isinstance(v, dict):
raise ValidationError( raise ValidationError(
f'Expected a {self.object_name} map but got a ' f"Expected a {self.object_name} map but got a " f"{type(v).__name__}",
f'{type(v).__name__}',
) )
if self.id_key is None: if self.id_key is None:
context = f'At {self.object_name}()' context = f"At {self.object_name}()"
else: else:
key_v_s = v.get(self.id_key, MISSING) key_v_s = v.get(self.id_key, MISSING)
context = f'At {self.object_name}({self.id_key}={key_v_s!r})' context = f"At {self.object_name}({self.id_key}={key_v_s!r})"
with validate_context(context): with validate_context(context):
for item in self.items: for item in self.items:
item.check(v) item.check(v)
@ -259,40 +264,33 @@ class Map(collections.namedtuple('Map', ('object_name', 'id_key', 'items'))):
class KeyValueMap( class KeyValueMap(
collections.namedtuple( collections.namedtuple(
'KeyValueMap', "KeyValueMap",
('object_name', 'check_key_fn', 'value_schema'), ("object_name", "check_key_fn", "value_schema"),
), ),
): ):
__slots__ = () __slots__ = ()
def check(self, v): def check(self, v):
if not isinstance(v, dict): if not isinstance(v, dict):
raise ValidationError( raise ValidationError(
f'Expected a {self.object_name} map but got a ' f"Expected a {self.object_name} map but got a " f"{type(v).__name__}",
f'{type(v).__name__}',
) )
with validate_context(f'At {self.object_name}()'): with validate_context(f"At {self.object_name}()"):
for k, val in v.items(): for k, val in v.items():
with validate_context(f'For key: {k}'): with validate_context(f"For key: {k}"):
self.check_key_fn(k) self.check_key_fn(k)
with validate_context(f'At key: {k}'): with validate_context(f"At key: {k}"):
validate(val, self.value_schema) validate(val, self.value_schema)
def apply_defaults(self, v): def apply_defaults(self, v):
return { return {k: apply_defaults(val, self.value_schema) for k, val in v.items()}
k: apply_defaults(val, self.value_schema)
for k, val in v.items()
}
def remove_defaults(self, v): def remove_defaults(self, v):
return { return {k: remove_defaults(val, self.value_schema) for k, val in v.items()}
k: remove_defaults(val, self.value_schema)
for k, val in v.items()
}
class Array(collections.namedtuple('Array', ('of', 'allow_empty'))): class Array(collections.namedtuple("Array", ("of", "allow_empty"))):
__slots__ = () __slots__ = ()
def __new__(cls, of, allow_empty=True): def __new__(cls, of, allow_empty=True):
@ -314,37 +312,37 @@ class Array(collections.namedtuple('Array', ('of', 'allow_empty'))):
return [remove_defaults(val, self.of) for val in v] return [remove_defaults(val, self.of) for val in v]
class Not(collections.namedtuple('Not', ('val',))): class Not(collections.namedtuple("Not", ("val",))):
__slots__ = () __slots__ = ()
def describe_opposite(self): def describe_opposite(self):
return f'is {self.val!r}' return f"is {self.val!r}"
def __eq__(self, other): def __eq__(self, other):
return other is not MISSING and other != self.val return other is not MISSING and other != self.val
class NotIn(collections.namedtuple('NotIn', ('values',))): class NotIn(collections.namedtuple("NotIn", ("values",))):
__slots__ = () __slots__ = ()
def __new__(cls, *values): def __new__(cls, *values):
return super().__new__(cls, values=values) return super().__new__(cls, values=values)
def describe_opposite(self): def describe_opposite(self):
return f'is any of {self.values!r}' return f"is any of {self.values!r}"
def __eq__(self, other): def __eq__(self, other):
return other is not MISSING and other not in self.values return other is not MISSING and other not in self.values
class In(collections.namedtuple('In', ('values',))): class In(collections.namedtuple("In", ("values",))):
__slots__ = () __slots__ = ()
def __new__(cls, *values): def __new__(cls, *values):
return super().__new__(cls, values=values) return super().__new__(cls, values=values)
def describe_opposite(self): def describe_opposite(self):
return f'is not any of {self.values!r}' return f"is not any of {self.values!r}"
def __eq__(self, other): def __eq__(self, other):
return other is not MISSING and other in self.values return other is not MISSING and other in self.values
@ -359,25 +357,27 @@ def check_type(tp, typename=None):
if not isinstance(v, tp): if not isinstance(v, tp):
typename_s = typename or tp.__name__ typename_s = typename or tp.__name__
raise ValidationError( raise ValidationError(
f'Expected {typename_s} got {type(v).__name__}', f"Expected {typename_s} got {type(v).__name__}",
) )
return check_type_fn return check_type_fn
check_bool = check_type(bool) check_bool = check_type(bool)
check_bytes = check_type(bytes) check_bytes = check_type(bytes)
check_int = check_type(int) check_int = check_type(int)
check_string = check_type(str, typename='string') check_string = check_type(str, typename="string")
check_text = check_type(str, typename='text') check_text = check_type(str, typename="text")
def check_one_of(possible): def check_one_of(possible):
def check_one_of_fn(v): def check_one_of_fn(v):
if v not in possible: if v not in possible:
possible_s = ', '.join(str(x) for x in sorted(possible)) possible_s = ", ".join(str(x) for x in sorted(possible))
raise ValidationError( raise ValidationError(
f'Expected one of {possible_s} but got: {v!r}', f"Expected one of {possible_s} but got: {v!r}",
) )
return check_one_of_fn return check_one_of_fn
@ -385,19 +385,20 @@ def check_regex(v):
try: try:
re.compile(v) re.compile(v)
except re.error: except re.error:
raise ValidationError(f'{v!r} is not a valid python regex') raise ValidationError(f"{v!r} is not a valid python regex")
def check_array(inner_check): def check_array(inner_check):
def check_array_fn(v): def check_array_fn(v):
if not isinstance(v, (list, tuple)): if not isinstance(v, (list, tuple)):
raise ValidationError( raise ValidationError(
f'Expected array but got {type(v).__name__!r}', f"Expected array but got {type(v).__name__!r}",
) )
for i, val in enumerate(v): for i, val in enumerate(v):
with validate_context(f'At index {i}'): with validate_context(f"At index {i}"):
inner_check(val) inner_check(val)
return check_array_fn return check_array_fn
@ -405,6 +406,7 @@ def check_and(*fns):
def check(v): def check(v):
for fn in fns: for fn in fns:
fn(v) fn(v)
return check return check
@ -422,21 +424,21 @@ def remove_defaults(v, schema):
def load_from_filename( def load_from_filename(
filename, filename,
schema, schema,
load_strategy, load_strategy,
exc_tp=ValidationError, exc_tp=ValidationError,
*, *,
display_filename=None, display_filename=None,
): ):
display_filename = display_filename or filename display_filename = display_filename or filename
with reraise_as(exc_tp): with reraise_as(exc_tp):
if not os.path.isfile(filename): if not os.path.isfile(filename):
raise ValidationError(f'{display_filename} is not a file') raise ValidationError(f"{display_filename} is not a file")
with validate_context(f'File {display_filename}'): with validate_context(f"File {display_filename}"):
try: try:
with open(filename, encoding='utf-8') as f: with open(filename, encoding="utf-8") as f:
contents = f.read() contents = f.read()
except UnicodeDecodeError as e: except UnicodeDecodeError as e:
raise ValidationError(str(e)) raise ValidationError(str(e))

View File

@ -81,4 +81,3 @@ contribute, including reporting issues, requesting features, asking or answering
questions, and making PRs. questions, and making PRs.
[contrib]: https://palletsprojects.com/contributing/ [contrib]: https://palletsprojects.com/contributing/

View File

@ -1556,9 +1556,9 @@ class Group(Command):
def __init__( def __init__(
self, self,
name: str | None = None, name: str | None = None,
commands: cabc.MutableMapping[str, Command] commands: (
| cabc.Sequence[Command] cabc.MutableMapping[str, Command] | cabc.Sequence[Command] | None
| None = None, ) = None,
invoke_without_command: bool = False, invoke_without_command: bool = False,
no_args_is_help: bool | None = None, no_args_is_help: bool | None = None,
subcommand_metavar: str | None = None, subcommand_metavar: str | None = None,
@ -1659,9 +1659,9 @@ class Group(Command):
func: t.Callable[..., t.Any] | None = None func: t.Callable[..., t.Any] | None = None
if args and callable(args[0]): if args and callable(args[0]):
assert len(args) == 1 and not kwargs, ( assert (
"Use 'command(**kwargs)(callable)' to provide arguments." len(args) == 1 and not kwargs
) ), "Use 'command(**kwargs)(callable)' to provide arguments."
(func,) = args (func,) = args
args = () args = ()
@ -1708,9 +1708,9 @@ class Group(Command):
func: t.Callable[..., t.Any] | None = None func: t.Callable[..., t.Any] | None = None
if args and callable(args[0]): if args and callable(args[0]):
assert len(args) == 1 and not kwargs, ( assert (
"Use 'group(**kwargs)(callable)' to provide arguments." len(args) == 1 and not kwargs
) ), "Use 'group(**kwargs)(callable)' to provide arguments."
(func,) = args (func,) = args
args = () args = ()
@ -2140,10 +2140,10 @@ class Parameter:
expose_value: bool = True, expose_value: bool = True,
is_eager: bool = False, is_eager: bool = False,
envvar: str | cabc.Sequence[str] | None = None, envvar: str | cabc.Sequence[str] | None = None,
shell_complete: t.Callable[ shell_complete: (
[Context, Parameter, str], list[CompletionItem] | list[str] t.Callable[[Context, Parameter, str], list[CompletionItem] | list[str]]
] | None
| None = None, ) = None,
deprecated: bool | str = False, deprecated: bool | str = False,
) -> None: ) -> None:
self.name: str | None self.name: str | None
@ -2594,9 +2594,9 @@ class Parameter:
): ):
# Click is logically enforcing that the name is None if the parameter is # Click is logically enforcing that the name is None if the parameter is
# not to be exposed. We still assert it here to please the type checker. # not to be exposed. We still assert it here to please the type checker.
assert self.name is not None, ( assert (
f"{self!r} parameter's name should not be None when exposing value." self.name is not None
) ), f"{self!r} parameter's name should not be None when exposing value."
ctx.params[self.name] = value ctx.params[self.name] = value
return value, args return value, args

View File

@ -179,8 +179,9 @@ class Result:
return_value: t.Any, return_value: t.Any,
exit_code: int, exit_code: int,
exception: BaseException | None, exception: BaseException | None,
exc_info: tuple[type[BaseException], BaseException, TracebackType] exc_info: (
| None = None, tuple[type[BaseException], BaseException, TracebackType] | None
) = None,
): ):
self.runner = runner self.runner = runner
self.stdout_bytes = stdout_bytes self.stdout_bytes = stdout_bytes

View File

@ -114,5 +114,3 @@ Everyone interacting in the distlib project's codebases, issue trackers, chat
rooms, and mailing lists is expected to follow the `PyPA Code of Conduct`_. rooms, and mailing lists is expected to follow the `PyPA Code of Conduct`_.
.. _PyPA Code of Conduct: https://www.pypa.io/en/latest/code-of-conduct/ .. _PyPA Code of Conduct: https://www.pypa.io/en/latest/code-of-conduct/

View File

@ -3,4 +3,3 @@ Generator: bdist_wheel (0.37.1)
Root-Is-Purelib: true Root-Is-Purelib: true
Tag: py2-none-any Tag: py2-none-any
Tag: py3-none-any Tag: py3-none-any

View File

@ -6,7 +6,7 @@
# #
import logging import logging
__version__ = '0.4.0' __version__ = "0.4.0"
class DistlibException(Exception): class DistlibException(Exception):

View File

@ -18,24 +18,41 @@ except ImportError: # pragma: no cover
if sys.version_info[0] < 3: # pragma: no cover if sys.version_info[0] < 3: # pragma: no cover
from StringIO import StringIO from StringIO import StringIO
string_types = basestring,
string_types = (basestring,)
text_type = unicode text_type = unicode
from types import FileType as file_type from types import FileType as file_type
import __builtin__ as builtins import __builtin__ as builtins
import ConfigParser as configparser import ConfigParser as configparser
from urlparse import urlparse, urlunparse, urljoin, urlsplit, urlunsplit from urlparse import urlparse, urlunparse, urljoin, urlsplit, urlunsplit
from urllib import (urlretrieve, quote as _quote, unquote, url2pathname, from urllib import (
pathname2url, ContentTooShortError, splittype) urlretrieve,
quote as _quote,
unquote,
url2pathname,
pathname2url,
ContentTooShortError,
splittype,
)
def quote(s): def quote(s):
if isinstance(s, unicode): if isinstance(s, unicode):
s = s.encode('utf-8') s = s.encode("utf-8")
return _quote(s) return _quote(s)
import urllib2 import urllib2
from urllib2 import (Request, urlopen, URLError, HTTPError, from urllib2 import (
HTTPBasicAuthHandler, HTTPPasswordMgr, HTTPHandler, Request,
HTTPRedirectHandler, build_opener) urlopen,
URLError,
HTTPError,
HTTPBasicAuthHandler,
HTTPPasswordMgr,
HTTPHandler,
HTTPRedirectHandler,
build_opener,
)
if ssl: if ssl:
from urllib2 import HTTPSHandler from urllib2 import HTTPSHandler
import httplib import httplib
@ -43,6 +60,7 @@ if sys.version_info[0] < 3: # pragma: no cover
import Queue as queue import Queue as queue
from HTMLParser import HTMLParser from HTMLParser import HTMLParser
import htmlentitydefs import htmlentitydefs
raw_input = raw_input raw_input = raw_input
from itertools import ifilter as filter from itertools import ifilter as filter
from itertools import ifilterfalse as filterfalse from itertools import ifilterfalse as filterfalse
@ -62,17 +80,35 @@ if sys.version_info[0] < 3: # pragma: no cover
else: # pragma: no cover else: # pragma: no cover
from io import StringIO from io import StringIO
string_types = str,
string_types = (str,)
text_type = str text_type = str
from io import TextIOWrapper as file_type from io import TextIOWrapper as file_type
import builtins import builtins
import configparser import configparser
from urllib.parse import (urlparse, urlunparse, urljoin, quote, unquote, from urllib.parse import (
urlsplit, urlunsplit, splittype) urlparse,
from urllib.request import (urlopen, urlretrieve, Request, url2pathname, urlunparse,
pathname2url, HTTPBasicAuthHandler, urljoin,
HTTPPasswordMgr, HTTPHandler, quote,
HTTPRedirectHandler, build_opener) unquote,
urlsplit,
urlunsplit,
splittype,
)
from urllib.request import (
urlopen,
urlretrieve,
Request,
url2pathname,
pathname2url,
HTTPBasicAuthHandler,
HTTPPasswordMgr,
HTTPHandler,
HTTPRedirectHandler,
build_opener,
)
if ssl: if ssl:
from urllib.request import HTTPSHandler from urllib.request import HTTPSHandler
from urllib.error import HTTPError, URLError, ContentTooShortError from urllib.error import HTTPError, URLError, ContentTooShortError
@ -82,8 +118,10 @@ else: # pragma: no cover
import queue import queue
from html.parser import HTMLParser from html.parser import HTMLParser
import html.entities as htmlentitydefs import html.entities as htmlentitydefs
raw_input = input raw_input = input
from itertools import filterfalse from itertools import filterfalse
filter = filter filter = filter
try: try:
@ -102,17 +140,18 @@ except ImportError: # pragma: no cover
if not dn: if not dn:
return False return False
parts = dn.split('.') parts = dn.split(".")
leftmost, remainder = parts[0], parts[1:] leftmost, remainder = parts[0], parts[1:]
wildcards = leftmost.count('*') wildcards = leftmost.count("*")
if wildcards > max_wildcards: if wildcards > max_wildcards:
# Issue #17980: avoid denials of service by refusing more # Issue #17980: avoid denials of service by refusing more
# than one wildcard per fragment. A survey of established # than one wildcard per fragment. A survey of established
# policy among SSL implementations showed it to be a # policy among SSL implementations showed it to be a
# reasonable choice. # reasonable choice.
raise CertificateError( raise CertificateError(
"too many wildcards in certificate DNS name: " + repr(dn)) "too many wildcards in certificate DNS name: " + repr(dn)
)
# speed up common case w/o wildcards # speed up common case w/o wildcards
if not wildcards: if not wildcards:
@ -121,11 +160,11 @@ except ImportError: # pragma: no cover
# RFC 6125, section 6.4.3, subitem 1. # RFC 6125, section 6.4.3, subitem 1.
# The client SHOULD NOT attempt to match a presented identifier in which # The client SHOULD NOT attempt to match a presented identifier in which
# the wildcard character comprises a label other than the left-most label. # the wildcard character comprises a label other than the left-most label.
if leftmost == '*': if leftmost == "*":
# When '*' is a fragment by itself, it matches a non-empty dotless # When '*' is a fragment by itself, it matches a non-empty dotless
# fragment. # fragment.
pats.append('[^.]+') pats.append("[^.]+")
elif leftmost.startswith('xn--') or hostname.startswith('xn--'): elif leftmost.startswith("xn--") or hostname.startswith("xn--"):
# RFC 6125, section 6.4.3, subitem 3. # RFC 6125, section 6.4.3, subitem 3.
# The client SHOULD NOT attempt to match a presented identifier # The client SHOULD NOT attempt to match a presented identifier
# where the wildcard character is embedded within an A-label or # where the wildcard character is embedded within an A-label or
@ -133,13 +172,13 @@ except ImportError: # pragma: no cover
pats.append(re.escape(leftmost)) pats.append(re.escape(leftmost))
else: else:
# Otherwise, '*' matches any dotless string, e.g. www* # Otherwise, '*' matches any dotless string, e.g. www*
pats.append(re.escape(leftmost).replace(r'\*', '[^.]*')) pats.append(re.escape(leftmost).replace(r"\*", "[^.]*"))
# add the remaining fragments, ignore any wildcards # add the remaining fragments, ignore any wildcards
for frag in remainder: for frag in remainder:
pats.append(re.escape(frag)) pats.append(re.escape(frag))
pat = re.compile(r'\A' + r'\.'.join(pats) + r'\Z', re.IGNORECASE) pat = re.compile(r"\A" + r"\.".join(pats) + r"\Z", re.IGNORECASE)
return pat.match(hostname) return pat.match(hostname)
def match_hostname(cert, hostname): def match_hostname(cert, hostname):
@ -151,38 +190,43 @@ except ImportError: # pragma: no cover
returns nothing. returns nothing.
""" """
if not cert: if not cert:
raise ValueError("empty or no certificate, match_hostname needs a " raise ValueError(
"SSL socket or SSL context with either " "empty or no certificate, match_hostname needs a "
"CERT_OPTIONAL or CERT_REQUIRED") "SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED"
)
dnsnames = [] dnsnames = []
san = cert.get('subjectAltName', ()) san = cert.get("subjectAltName", ())
for key, value in san: for key, value in san:
if key == 'DNS': if key == "DNS":
if _dnsname_match(value, hostname): if _dnsname_match(value, hostname):
return return
dnsnames.append(value) dnsnames.append(value)
if not dnsnames: if not dnsnames:
# The subject is only checked when there is no dNSName entry # The subject is only checked when there is no dNSName entry
# in subjectAltName # in subjectAltName
for sub in cert.get('subject', ()): for sub in cert.get("subject", ()):
for key, value in sub: for key, value in sub:
# XXX according to RFC 2818, the most specific Common Name # XXX according to RFC 2818, the most specific Common Name
# must be used. # must be used.
if key == 'commonName': if key == "commonName":
if _dnsname_match(value, hostname): if _dnsname_match(value, hostname):
return return
dnsnames.append(value) dnsnames.append(value)
if len(dnsnames) > 1: if len(dnsnames) > 1:
raise CertificateError("hostname %r " raise CertificateError(
"doesn't match either of %s" % "hostname %r "
(hostname, ', '.join(map(repr, dnsnames)))) "doesn't match either of %s"
% (hostname, ", ".join(map(repr, dnsnames)))
)
elif len(dnsnames) == 1: elif len(dnsnames) == 1:
raise CertificateError("hostname %r " raise CertificateError(
"doesn't match %r" % "hostname %r " "doesn't match %r" % (hostname, dnsnames[0])
(hostname, dnsnames[0])) )
else: else:
raise CertificateError("no appropriate commonName or " raise CertificateError(
"subjectAltName fields were found") "no appropriate commonName or " "subjectAltName fields were found"
)
try: try:
@ -217,7 +261,7 @@ except ImportError: # pragma: no cover
# Additionally check that `file` is not a directory, as on Windows # Additionally check that `file` is not a directory, as on Windows
# directories pass the os.access check. # directories pass the os.access check.
def _access_check(fn, mode): def _access_check(fn, mode):
return (os.path.exists(fn) and os.access(fn, mode) and not os.path.isdir(fn)) return os.path.exists(fn) and os.access(fn, mode) and not os.path.isdir(fn)
# If we're given a path with a directory part, look it up directly rather # If we're given a path with a directory part, look it up directly rather
# than referring to PATH directories. This includes checking relative to the # than referring to PATH directories. This includes checking relative to the
@ -269,7 +313,7 @@ except ImportError: # pragma: no cover
from zipfile import ZipFile as BaseZipFile from zipfile import ZipFile as BaseZipFile
if hasattr(BaseZipFile, '__enter__'): # pragma: no cover if hasattr(BaseZipFile, "__enter__"): # pragma: no cover
ZipFile = BaseZipFile ZipFile = BaseZipFile
else: # pragma: no cover else: # pragma: no cover
from zipfile import ZipExtFile as BaseZipExtFile from zipfile import ZipExtFile as BaseZipExtFile
@ -306,13 +350,13 @@ except ImportError: # pragma: no cover
def python_implementation(): def python_implementation():
"""Return a string identifying the Python implementation.""" """Return a string identifying the Python implementation."""
if 'PyPy' in sys.version: if "PyPy" in sys.version:
return 'PyPy' return "PyPy"
if os.name == 'java': if os.name == "java":
return 'Jython' return "Jython"
if sys.version.startswith('IronPython'): if sys.version.startswith("IronPython"):
return 'IronPython' return "IronPython"
return 'CPython' return "CPython"
import sysconfig import sysconfig
@ -336,11 +380,11 @@ except AttributeError: # pragma: no cover
# sys.getfilesystemencoding(): the return value is "the users preference # sys.getfilesystemencoding(): the return value is "the users preference
# according to the result of nl_langinfo(CODESET), or None if the # according to the result of nl_langinfo(CODESET), or None if the
# nl_langinfo(CODESET) failed." # nl_langinfo(CODESET) failed."
_fsencoding = sys.getfilesystemencoding() or 'utf-8' _fsencoding = sys.getfilesystemencoding() or "utf-8"
if _fsencoding == 'mbcs': if _fsencoding == "mbcs":
_fserrors = 'strict' _fserrors = "strict"
else: else:
_fserrors = 'surrogateescape' _fserrors = "surrogateescape"
def fsencode(filename): def fsencode(filename):
if isinstance(filename, bytes): if isinstance(filename, bytes):
@ -348,8 +392,7 @@ except AttributeError: # pragma: no cover
elif isinstance(filename, text_type): elif isinstance(filename, text_type):
return filename.encode(_fsencoding, _fserrors) return filename.encode(_fsencoding, _fserrors)
else: else:
raise TypeError("expect bytes or str, not %s" % raise TypeError("expect bytes or str, not %s" % type(filename).__name__)
type(filename).__name__)
def fsdecode(filename): def fsdecode(filename):
if isinstance(filename, text_type): if isinstance(filename, text_type):
@ -357,8 +400,7 @@ except AttributeError: # pragma: no cover
elif isinstance(filename, bytes): elif isinstance(filename, bytes):
return filename.decode(_fsencoding, _fserrors) return filename.decode(_fsencoding, _fserrors)
else: else:
raise TypeError("expect bytes or str, not %s" % raise TypeError("expect bytes or str, not %s" % type(filename).__name__)
type(filename).__name__)
try: try:
@ -374,8 +416,9 @@ except ImportError: # pragma: no cover
enc = orig_enc[:12].lower().replace("_", "-") enc = orig_enc[:12].lower().replace("_", "-")
if enc == "utf-8" or enc.startswith("utf-8-"): if enc == "utf-8" or enc.startswith("utf-8-"):
return "utf-8" return "utf-8"
if enc in ("latin-1", "iso-8859-1", "iso-latin-1") or \ if enc in ("latin-1", "iso-8859-1", "iso-latin-1") or enc.startswith(
enc.startswith(("latin-1-", "iso-8859-1-", "iso-latin-1-")): ("latin-1-", "iso-8859-1-", "iso-latin-1-")
):
return "iso-8859-1" return "iso-8859-1"
return orig_enc return orig_enc
@ -402,24 +445,24 @@ except ImportError: # pragma: no cover
filename = None filename = None
bom_found = False bom_found = False
encoding = None encoding = None
default = 'utf-8' default = "utf-8"
def read_or_stop(): def read_or_stop():
try: try:
return readline() return readline()
except StopIteration: except StopIteration:
return b'' return b""
def find_cookie(line): def find_cookie(line):
try: try:
# Decode as UTF-8. Either the line is an encoding declaration, # Decode as UTF-8. Either the line is an encoding declaration,
# in which case it should be pure ASCII, or it must be UTF-8 # in which case it should be pure ASCII, or it must be UTF-8
# per default encoding. # per default encoding.
line_string = line.decode('utf-8') line_string = line.decode("utf-8")
except UnicodeDecodeError: except UnicodeDecodeError:
msg = "invalid or missing encoding declaration" msg = "invalid or missing encoding declaration"
if filename is not None: if filename is not None:
msg = '{} for {!r}'.format(msg, filename) msg = "{} for {!r}".format(msg, filename)
raise SyntaxError(msg) raise SyntaxError(msg)
matches = cookie_re.findall(line_string) matches = cookie_re.findall(line_string)
@ -433,27 +476,25 @@ except ImportError: # pragma: no cover
if filename is None: if filename is None:
msg = "unknown encoding: " + encoding msg = "unknown encoding: " + encoding
else: else:
msg = "unknown encoding for {!r}: {}".format( msg = "unknown encoding for {!r}: {}".format(filename, encoding)
filename, encoding)
raise SyntaxError(msg) raise SyntaxError(msg)
if bom_found: if bom_found:
if codec.name != 'utf-8': if codec.name != "utf-8":
# This behaviour mimics the Python interpreter # This behaviour mimics the Python interpreter
if filename is None: if filename is None:
msg = 'encoding problem: utf-8' msg = "encoding problem: utf-8"
else: else:
msg = 'encoding problem for {!r}: utf-8'.format( msg = "encoding problem for {!r}: utf-8".format(filename)
filename)
raise SyntaxError(msg) raise SyntaxError(msg)
encoding += '-sig' encoding += "-sig"
return encoding return encoding
first = read_or_stop() first = read_or_stop()
if first.startswith(BOM_UTF8): if first.startswith(BOM_UTF8):
bom_found = True bom_found = True
first = first[3:] first = first[3:]
default = 'utf-8-sig' default = "utf-8-sig"
if not first: if not first:
return default, [] return default, []
@ -491,11 +532,11 @@ except ImportError: # pragma: no cover
from reprlib import recursive_repr as _recursive_repr from reprlib import recursive_repr as _recursive_repr
except ImportError: except ImportError:
def _recursive_repr(fillvalue='...'): def _recursive_repr(fillvalue="..."):
''' """
Decorator to make a repr function return fillvalue for a recursive Decorator to make a repr function return fillvalue for a recursive
call call
''' """
def decorating_function(user_function): def decorating_function(user_function):
repr_running = set() repr_running = set()
@ -512,17 +553,16 @@ except ImportError: # pragma: no cover
return result return result
# Can't use functools.wraps() here because of bootstrap issues # Can't use functools.wraps() here because of bootstrap issues
wrapper.__module__ = getattr(user_function, '__module__') wrapper.__module__ = getattr(user_function, "__module__")
wrapper.__doc__ = getattr(user_function, '__doc__') wrapper.__doc__ = getattr(user_function, "__doc__")
wrapper.__name__ = getattr(user_function, '__name__') wrapper.__name__ = getattr(user_function, "__name__")
wrapper.__annotations__ = getattr(user_function, wrapper.__annotations__ = getattr(user_function, "__annotations__", {})
'__annotations__', {})
return wrapper return wrapper
return decorating_function return decorating_function
class ChainMap(MutableMapping): class ChainMap(MutableMapping):
''' """
A ChainMap groups multiple dicts (or other mappings) together A ChainMap groups multiple dicts (or other mappings) together
to create a single, updateable view. to create a single, updateable view.
@ -532,13 +572,13 @@ except ImportError: # pragma: no cover
Lookups search the underlying mappings successively until a key is found. Lookups search the underlying mappings successively until a key is found.
In contrast, writes, updates, and deletions only operate on the first In contrast, writes, updates, and deletions only operate on the first
mapping. mapping.
''' """
def __init__(self, *maps): def __init__(self, *maps):
'''Initialize a ChainMap by setting *maps* to the given mappings. """Initialize a ChainMap by setting *maps* to the given mappings.
If no mappings are provided, a single empty dictionary is used. If no mappings are provided, a single empty dictionary is used.
''' """
self.maps = list(maps) or [{}] # always at least one map self.maps = list(maps) or [{}] # always at least one map
def __missing__(self, key): def __missing__(self, key):
@ -547,19 +587,16 @@ except ImportError: # pragma: no cover
def __getitem__(self, key): def __getitem__(self, key):
for mapping in self.maps: for mapping in self.maps:
try: try:
return mapping[ return mapping[key] # can't use 'key in mapping' with defaultdict
key] # can't use 'key in mapping' with defaultdict
except KeyError: except KeyError:
pass pass
return self.__missing__( return self.__missing__(key) # support subclasses that define __missing__
key) # support subclasses that define __missing__
def get(self, key, default=None): def get(self, key, default=None):
return self[key] if key in self else default return self[key] if key in self else default
def __len__(self): def __len__(self):
return len(set().union( return len(set().union(*self.maps)) # reuses stored hash values if possible
*self.maps)) # reuses stored hash values if possible
def __iter__(self): def __iter__(self):
return iter(set().union(*self.maps)) return iter(set().union(*self.maps))
@ -572,27 +609,28 @@ except ImportError: # pragma: no cover
@_recursive_repr() @_recursive_repr()
def __repr__(self): def __repr__(self):
return '{0.__class__.__name__}({1})'.format( return "{0.__class__.__name__}({1})".format(
self, ', '.join(map(repr, self.maps))) self, ", ".join(map(repr, self.maps))
)
@classmethod @classmethod
def fromkeys(cls, iterable, *args): def fromkeys(cls, iterable, *args):
'Create a ChainMap with a single dict created from the iterable.' "Create a ChainMap with a single dict created from the iterable."
return cls(dict.fromkeys(iterable, *args)) return cls(dict.fromkeys(iterable, *args))
def copy(self): def copy(self):
'New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]' "New ChainMap or subclass with a new copy of maps[0] and refs to maps[1:]"
return self.__class__(self.maps[0].copy(), *self.maps[1:]) return self.__class__(self.maps[0].copy(), *self.maps[1:])
__copy__ = copy __copy__ = copy
def new_child(self): # like Django's Context.push() def new_child(self): # like Django's Context.push()
'New ChainMap with a new dict followed by all previous maps.' "New ChainMap with a new dict followed by all previous maps."
return self.__class__({}, *self.maps) return self.__class__({}, *self.maps)
@property @property
def parents(self): # like Django's Context.pop() def parents(self): # like Django's Context.pop()
'New ChainMap from maps[1:].' "New ChainMap from maps[1:]."
return self.__class__(*self.maps[1:]) return self.__class__(*self.maps[1:])
def __setitem__(self, key, value): def __setitem__(self, key, value):
@ -602,26 +640,24 @@ except ImportError: # pragma: no cover
try: try:
del self.maps[0][key] del self.maps[0][key]
except KeyError: except KeyError:
raise KeyError( raise KeyError("Key not found in the first mapping: {!r}".format(key))
'Key not found in the first mapping: {!r}'.format(key))
def popitem(self): def popitem(self):
'Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty.' "Remove and return an item pair from maps[0]. Raise KeyError is maps[0] is empty."
try: try:
return self.maps[0].popitem() return self.maps[0].popitem()
except KeyError: except KeyError:
raise KeyError('No keys found in the first mapping.') raise KeyError("No keys found in the first mapping.")
def pop(self, key, *args): def pop(self, key, *args):
'Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0].' "Remove *key* from maps[0] and return its value. Raise KeyError if *key* not in maps[0]."
try: try:
return self.maps[0].pop(key, *args) return self.maps[0].pop(key, *args)
except KeyError: except KeyError:
raise KeyError( raise KeyError("Key not found in the first mapping: {!r}".format(key))
'Key not found in the first mapping: {!r}'.format(key))
def clear(self): def clear(self):
'Clear maps[0], leaving maps[1:] intact.' "Clear maps[0], leaving maps[1:] intact."
self.maps[0].clear() self.maps[0].clear()
@ -630,13 +666,13 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
def cache_from_source(path, debug_override=None): def cache_from_source(path, debug_override=None):
assert path.endswith('.py') assert path.endswith(".py")
if debug_override is None: if debug_override is None:
debug_override = __debug__ debug_override = __debug__
if debug_override: if debug_override:
suffix = 'c' suffix = "c"
else: else:
suffix = 'o' suffix = "o"
return path + suffix return path + suffix
@ -657,7 +693,7 @@ except ImportError: # pragma: no cover
pass pass
class OrderedDict(dict): class OrderedDict(dict):
'Dictionary that remembers insertion order' "Dictionary that remembers insertion order"
# An inherited dict maps keys to values. # An inherited dict maps keys to values.
# The inherited dict provides __getitem__, __len__, __contains__, and get. # The inherited dict provides __getitem__, __len__, __contains__, and get.
@ -670,14 +706,13 @@ except ImportError: # pragma: no cover
# Each link is stored as a list of length three: [PREV, NEXT, KEY]. # Each link is stored as a list of length three: [PREV, NEXT, KEY].
def __init__(self, *args, **kwds): def __init__(self, *args, **kwds):
'''Initialize an ordered dictionary. Signature is the same as for """Initialize an ordered dictionary. Signature is the same as for
regular dictionaries, but keyword arguments are not recommended regular dictionaries, but keyword arguments are not recommended
because their insertion order is arbitrary. because their insertion order is arbitrary.
''' """
if len(args) > 1: if len(args) > 1:
raise TypeError('expected at most 1 arguments, got %d' % raise TypeError("expected at most 1 arguments, got %d" % len(args))
len(args))
try: try:
self.__root self.__root
except AttributeError: except AttributeError:
@ -687,7 +722,7 @@ except ImportError: # pragma: no cover
self.__update(*args, **kwds) self.__update(*args, **kwds)
def __setitem__(self, key, value, dict_setitem=dict.__setitem__): def __setitem__(self, key, value, dict_setitem=dict.__setitem__):
'od.__setitem__(i, y) <==> od[i]=y' "od.__setitem__(i, y) <==> od[i]=y"
# Setting a new item creates a new link which goes at the end of the linked # Setting a new item creates a new link which goes at the end of the linked
# list, and the inherited dictionary is updated with the new key/value pair. # list, and the inherited dictionary is updated with the new key/value pair.
if key not in self: if key not in self:
@ -697,7 +732,7 @@ except ImportError: # pragma: no cover
dict_setitem(self, key, value) dict_setitem(self, key, value)
def __delitem__(self, key, dict_delitem=dict.__delitem__): def __delitem__(self, key, dict_delitem=dict.__delitem__):
'od.__delitem__(y) <==> del od[y]' "od.__delitem__(y) <==> del od[y]"
# Deleting an existing item uses self.__map to find the link which is # Deleting an existing item uses self.__map to find the link which is
# then removed by updating the links in the predecessor and successor nodes. # then removed by updating the links in the predecessor and successor nodes.
dict_delitem(self, key) dict_delitem(self, key)
@ -706,7 +741,7 @@ except ImportError: # pragma: no cover
link_next[0] = link_prev link_next[0] = link_prev
def __iter__(self): def __iter__(self):
'od.__iter__() <==> iter(od)' "od.__iter__() <==> iter(od)"
root = self.__root root = self.__root
curr = root[1] curr = root[1]
while curr is not root: while curr is not root:
@ -714,7 +749,7 @@ except ImportError: # pragma: no cover
curr = curr[1] curr = curr[1]
def __reversed__(self): def __reversed__(self):
'od.__reversed__() <==> reversed(od)' "od.__reversed__() <==> reversed(od)"
root = self.__root root = self.__root
curr = root[0] curr = root[0]
while curr is not root: while curr is not root:
@ -722,7 +757,7 @@ except ImportError: # pragma: no cover
curr = curr[0] curr = curr[0]
def clear(self): def clear(self):
'od.clear() -> None. Remove all items from od.' "od.clear() -> None. Remove all items from od."
try: try:
for node in self.__map.itervalues(): for node in self.__map.itervalues():
del node[:] del node[:]
@ -734,12 +769,12 @@ except ImportError: # pragma: no cover
dict.clear(self) dict.clear(self)
def popitem(self, last=True): def popitem(self, last=True):
'''od.popitem() -> (k, v), return and remove a (key, value) pair. """od.popitem() -> (k, v), return and remove a (key, value) pair.
Pairs are returned in LIFO order if last is true or FIFO order if false. Pairs are returned in LIFO order if last is true or FIFO order if false.
''' """
if not self: if not self:
raise KeyError('dictionary is empty') raise KeyError("dictionary is empty")
root = self.__root root = self.__root
if last: if last:
link = root[0] link = root[0]
@ -759,45 +794,47 @@ except ImportError: # pragma: no cover
# -- the following methods do not depend on the internal structure -- # -- the following methods do not depend on the internal structure --
def keys(self): def keys(self):
'od.keys() -> list of keys in od' "od.keys() -> list of keys in od"
return list(self) return list(self)
def values(self): def values(self):
'od.values() -> list of values in od' "od.values() -> list of values in od"
return [self[key] for key in self] return [self[key] for key in self]
def items(self): def items(self):
'od.items() -> list of (key, value) pairs in od' "od.items() -> list of (key, value) pairs in od"
return [(key, self[key]) for key in self] return [(key, self[key]) for key in self]
def iterkeys(self): def iterkeys(self):
'od.iterkeys() -> an iterator over the keys in od' "od.iterkeys() -> an iterator over the keys in od"
return iter(self) return iter(self)
def itervalues(self): def itervalues(self):
'od.itervalues -> an iterator over the values in od' "od.itervalues -> an iterator over the values in od"
for k in self: for k in self:
yield self[k] yield self[k]
def iteritems(self): def iteritems(self):
'od.iteritems -> an iterator over the (key, value) items in od' "od.iteritems -> an iterator over the (key, value) items in od"
for k in self: for k in self:
yield (k, self[k]) yield (k, self[k])
def update(*args, **kwds): def update(*args, **kwds):
'''od.update(E, **F) -> None. Update od from dict/iterable E and F. """od.update(E, **F) -> None. Update od from dict/iterable E and F.
If E is a dict instance, does: for k in E: od[k] = E[k] If E is a dict instance, does: for k in E: od[k] = E[k]
If E has a .keys() method, does: for k in E.keys(): od[k] = E[k] If E has a .keys() method, does: for k in E.keys(): od[k] = E[k]
Or if E is an iterable of items, does: for k, v in E: od[k] = v Or if E is an iterable of items, does: for k, v in E: od[k] = v
In either case, this is followed by: for k, v in F.items(): od[k] = v In either case, this is followed by: for k, v in F.items(): od[k] = v
''' """
if len(args) > 2: if len(args) > 2:
raise TypeError('update() takes at most 2 positional ' raise TypeError(
'arguments (%d given)' % (len(args), )) "update() takes at most 2 positional "
"arguments (%d given)" % (len(args),)
)
elif not args: elif not args:
raise TypeError('update() takes at least 1 argument (0 given)') raise TypeError("update() takes at least 1 argument (0 given)")
self = args[0] self = args[0]
# Make progressively weaker assumptions about "other" # Make progressively weaker assumptions about "other"
other = () other = ()
@ -806,7 +843,7 @@ except ImportError: # pragma: no cover
if isinstance(other, dict): if isinstance(other, dict):
for key in other: for key in other:
self[key] = other[key] self[key] = other[key]
elif hasattr(other, 'keys'): elif hasattr(other, "keys"):
for key in other.keys(): for key in other.keys():
self[key] = other[key] self[key] = other[key]
else: else:
@ -820,10 +857,10 @@ except ImportError: # pragma: no cover
__marker = object() __marker = object()
def pop(self, key, default=__marker): def pop(self, key, default=__marker):
'''od.pop(k[,d]) -> v, remove specified key and return the corresponding value. """od.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised. If key is not found, d is returned if given, otherwise KeyError is raised.
''' """
if key in self: if key in self:
result = self[key] result = self[key]
del self[key] del self[key]
@ -833,60 +870,59 @@ except ImportError: # pragma: no cover
return default return default
def setdefault(self, key, default=None): def setdefault(self, key, default=None):
'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od' "od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od"
if key in self: if key in self:
return self[key] return self[key]
self[key] = default self[key] = default
return default return default
def __repr__(self, _repr_running=None): def __repr__(self, _repr_running=None):
'od.__repr__() <==> repr(od)' "od.__repr__() <==> repr(od)"
if not _repr_running: if not _repr_running:
_repr_running = {} _repr_running = {}
call_key = id(self), _get_ident() call_key = id(self), _get_ident()
if call_key in _repr_running: if call_key in _repr_running:
return '...' return "..."
_repr_running[call_key] = 1 _repr_running[call_key] = 1
try: try:
if not self: if not self:
return '%s()' % (self.__class__.__name__, ) return "%s()" % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, self.items()) return "%s(%r)" % (self.__class__.__name__, self.items())
finally: finally:
del _repr_running[call_key] del _repr_running[call_key]
def __reduce__(self): def __reduce__(self):
'Return state information for pickling' "Return state information for pickling"
items = [[k, self[k]] for k in self] items = [[k, self[k]] for k in self]
inst_dict = vars(self).copy() inst_dict = vars(self).copy()
for k in vars(OrderedDict()): for k in vars(OrderedDict()):
inst_dict.pop(k, None) inst_dict.pop(k, None)
if inst_dict: if inst_dict:
return (self.__class__, (items, ), inst_dict) return (self.__class__, (items,), inst_dict)
return self.__class__, (items, ) return self.__class__, (items,)
def copy(self): def copy(self):
'od.copy() -> a shallow copy of od' "od.copy() -> a shallow copy of od"
return self.__class__(self) return self.__class__(self)
@classmethod @classmethod
def fromkeys(cls, iterable, value=None): def fromkeys(cls, iterable, value=None):
'''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S """OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S
and values equal to v (which defaults to None). and values equal to v (which defaults to None).
''' """
d = cls() d = cls()
for key in iterable: for key in iterable:
d[key] = value d[key] = value
return d return d
def __eq__(self, other): def __eq__(self, other):
'''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive """od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
while comparison to a regular mapping is order-insensitive. while comparison to a regular mapping is order-insensitive.
''' """
if isinstance(other, OrderedDict): if isinstance(other, OrderedDict):
return len(self) == len( return len(self) == len(other) and self.items() == other.items()
other) and self.items() == other.items()
return dict.__eq__(self, other) return dict.__eq__(self, other)
def __ne__(self, other): def __ne__(self, other):
@ -910,12 +946,12 @@ except ImportError: # pragma: no cover
try: try:
from logging.config import BaseConfigurator, valid_ident from logging.config import BaseConfigurator, valid_ident
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
IDENTIFIER = re.compile('^[a-z_][a-z0-9_]*$', re.I) IDENTIFIER = re.compile("^[a-z_][a-z0-9_]*$", re.I)
def valid_ident(s): def valid_ident(s):
m = IDENTIFIER.match(s) m = IDENTIFIER.match(s)
if not m: if not m:
raise ValueError('Not a valid Python identifier: %r' % s) raise ValueError("Not a valid Python identifier: %r" % s)
return True return True
# The ConvertingXXX classes are wrappers around standard Python containers, # The ConvertingXXX classes are wrappers around standard Python containers,
@ -936,8 +972,7 @@ except ImportError: # pragma: no cover
# If the converted value is different, save for next time # If the converted value is different, save for next time
if value is not result: if value is not result:
self[key] = result self[key] = result
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
result.key = key result.key = key
return result return result
@ -948,8 +983,7 @@ except ImportError: # pragma: no cover
# If the converted value is different, save for next time # If the converted value is different, save for next time
if value is not result: if value is not result:
self[key] = result self[key] = result
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
result.key = key result.key = key
return result return result
@ -958,8 +992,7 @@ except ImportError: # pragma: no cover
value = dict.pop(self, key, default) value = dict.pop(self, key, default)
result = self.configurator.convert(value) result = self.configurator.convert(value)
if value is not result: if value is not result:
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
result.key = key result.key = key
return result return result
@ -973,8 +1006,7 @@ except ImportError: # pragma: no cover
# If the converted value is different, save for next time # If the converted value is different, save for next time
if value is not result: if value is not result:
self[key] = result self[key] = result
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
result.key = key result.key = key
return result return result
@ -983,8 +1015,7 @@ except ImportError: # pragma: no cover
value = list.pop(self, idx) value = list.pop(self, idx)
result = self.configurator.convert(value) result = self.configurator.convert(value)
if value is not result: if value is not result:
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
return result return result
@ -995,8 +1026,7 @@ except ImportError: # pragma: no cover
value = tuple.__getitem__(self, key) value = tuple.__getitem__(self, key)
result = self.configurator.convert(value) result = self.configurator.convert(value)
if value is not result: if value is not result:
if type(result) in (ConvertingDict, ConvertingList, if type(result) in (ConvertingDict, ConvertingList, ConvertingTuple):
ConvertingTuple):
result.parent = self result.parent = self
result.key = key result.key = key
return result return result
@ -1006,16 +1036,16 @@ except ImportError: # pragma: no cover
The configurator base class which defines some useful defaults. The configurator base class which defines some useful defaults.
""" """
CONVERT_PATTERN = re.compile(r'^(?P<prefix>[a-z]+)://(?P<suffix>.*)$') CONVERT_PATTERN = re.compile(r"^(?P<prefix>[a-z]+)://(?P<suffix>.*)$")
WORD_PATTERN = re.compile(r'^\s*(\w+)\s*') WORD_PATTERN = re.compile(r"^\s*(\w+)\s*")
DOT_PATTERN = re.compile(r'^\.\s*(\w+)\s*') DOT_PATTERN = re.compile(r"^\.\s*(\w+)\s*")
INDEX_PATTERN = re.compile(r'^\[\s*(\w+)\s*\]\s*') INDEX_PATTERN = re.compile(r"^\[\s*(\w+)\s*\]\s*")
DIGIT_PATTERN = re.compile(r'^\d+$') DIGIT_PATTERN = re.compile(r"^\d+$")
value_converters = { value_converters = {
'ext': 'ext_convert', "ext": "ext_convert",
'cfg': 'cfg_convert', "cfg": "cfg_convert",
} }
# We might want to use a different one, e.g. importlib # We might want to use a different one, e.g. importlib
@ -1030,12 +1060,12 @@ except ImportError: # pragma: no cover
Resolve strings to objects using standard import and attribute Resolve strings to objects using standard import and attribute
syntax. syntax.
""" """
name = s.split('.') name = s.split(".")
used = name.pop(0) used = name.pop(0)
try: try:
found = self.importer(used) found = self.importer(used)
for frag in name: for frag in name:
used += '.' + frag used += "." + frag
try: try:
found = getattr(found, frag) found = getattr(found, frag)
except AttributeError: except AttributeError:
@ -1044,7 +1074,7 @@ except ImportError: # pragma: no cover
return found return found
except ImportError: except ImportError:
e, tb = sys.exc_info()[1:] e, tb = sys.exc_info()[1:]
v = ValueError('Cannot resolve %r: %s' % (s, e)) v = ValueError("Cannot resolve %r: %s" % (s, e))
v.__cause__, v.__traceback__ = e, tb v.__cause__, v.__traceback__ = e, tb
raise v raise v
@ -1059,7 +1089,7 @@ except ImportError: # pragma: no cover
if m is None: if m is None:
raise ValueError("Unable to convert %r" % value) raise ValueError("Unable to convert %r" % value)
else: else:
rest = rest[m.end():] rest = rest[m.end() :]
d = self.config[m.groups()[0]] d = self.config[m.groups()[0]]
while rest: while rest:
m = self.DOT_PATTERN.match(rest) m = self.DOT_PATTERN.match(rest)
@ -1073,17 +1103,16 @@ except ImportError: # pragma: no cover
d = d[idx] d = d[idx]
else: else:
try: try:
n = int( n = int(idx) # try as number first (most likely)
idx
) # try as number first (most likely)
d = d[n] d = d[n]
except TypeError: except TypeError:
d = d[idx] d = d[idx]
if m: if m:
rest = rest[m.end():] rest = rest[m.end() :]
else: else:
raise ValueError('Unable to convert ' raise ValueError(
'%r at %r' % (value, rest)) "Unable to convert " "%r at %r" % (value, rest)
)
# rest should be empty # rest should be empty
return d return d
@ -1093,12 +1122,10 @@ except ImportError: # pragma: no cover
replaced by their converting alternatives. Strings are checked to replaced by their converting alternatives. Strings are checked to
see if they have a conversion format and are converted if they do. see if they have a conversion format and are converted if they do.
""" """
if not isinstance(value, ConvertingDict) and isinstance( if not isinstance(value, ConvertingDict) and isinstance(value, dict):
value, dict):
value = ConvertingDict(value) value = ConvertingDict(value)
value.configurator = self value.configurator = self
elif not isinstance(value, ConvertingList) and isinstance( elif not isinstance(value, ConvertingList) and isinstance(value, list):
value, list):
value = ConvertingList(value) value = ConvertingList(value)
value.configurator = self value.configurator = self
elif not isinstance(value, ConvertingTuple) and isinstance(value, tuple): elif not isinstance(value, ConvertingTuple) and isinstance(value, tuple):
@ -1108,20 +1135,20 @@ except ImportError: # pragma: no cover
m = self.CONVERT_PATTERN.match(value) m = self.CONVERT_PATTERN.match(value)
if m: if m:
d = m.groupdict() d = m.groupdict()
prefix = d['prefix'] prefix = d["prefix"]
converter = self.value_converters.get(prefix, None) converter = self.value_converters.get(prefix, None)
if converter: if converter:
suffix = d['suffix'] suffix = d["suffix"]
converter = getattr(self, converter) converter = getattr(self, converter)
value = converter(suffix) value = converter(suffix)
return value return value
def configure_custom(self, config): def configure_custom(self, config):
"""Configure an object with a user-supplied factory.""" """Configure an object with a user-supplied factory."""
c = config.pop('()') c = config.pop("()")
if not callable(c): if not callable(c):
c = self.resolve(c) c = self.resolve(c)
props = config.pop('.', None) props = config.pop(".", None)
# Check for valid identifiers # Check for valid identifiers
kwargs = dict([(k, config[k]) for k in config if valid_ident(k)]) kwargs = dict([(k, config[k]) for k in config if valid_ident(k)])
result = c(**kwargs) result = c(**kwargs)

View File

@ -20,22 +20,46 @@ import zipimport
from . import DistlibException, resources from . import DistlibException, resources
from .compat import StringIO from .compat import StringIO
from .version import get_scheme, UnsupportedVersionError from .version import get_scheme, UnsupportedVersionError
from .metadata import (Metadata, METADATA_FILENAME, WHEEL_METADATA_FILENAME, LEGACY_METADATA_FILENAME) from .metadata import (
from .util import (parse_requirement, cached_property, parse_name_and_version, read_exports, write_exports, CSVReader, Metadata,
CSVWriter) METADATA_FILENAME,
WHEEL_METADATA_FILENAME,
LEGACY_METADATA_FILENAME,
)
from .util import (
parse_requirement,
cached_property,
parse_name_and_version,
read_exports,
write_exports,
CSVReader,
CSVWriter,
)
__all__ = [ __all__ = [
'Distribution', 'BaseInstalledDistribution', 'InstalledDistribution', 'EggInfoDistribution', 'DistributionPath' "Distribution",
"BaseInstalledDistribution",
"InstalledDistribution",
"EggInfoDistribution",
"DistributionPath",
] ]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
EXPORTS_FILENAME = 'pydist-exports.json' EXPORTS_FILENAME = "pydist-exports.json"
COMMANDS_FILENAME = 'pydist-commands.json' COMMANDS_FILENAME = "pydist-commands.json"
DIST_FILES = ('INSTALLER', METADATA_FILENAME, 'RECORD', 'REQUESTED', 'RESOURCES', EXPORTS_FILENAME, 'SHARED') DIST_FILES = (
"INSTALLER",
METADATA_FILENAME,
"RECORD",
"REQUESTED",
"RESOURCES",
EXPORTS_FILENAME,
"SHARED",
)
DISTINFO_EXT = '.dist-info' DISTINFO_EXT = ".dist-info"
class _Cache(object): class _Cache(object):
@ -92,7 +116,7 @@ class DistributionPath(object):
self._cache = _Cache() self._cache = _Cache()
self._cache_egg = _Cache() self._cache_egg = _Cache()
self._cache_enabled = True self._cache_enabled = True
self._scheme = get_scheme('default') self._scheme = get_scheme("default")
def _get_cache_enabled(self): def _get_cache_enabled(self):
return self._cache_enabled return self._cache_enabled
@ -121,7 +145,7 @@ class DistributionPath(object):
finder = resources.finder_for_path(path) finder = resources.finder_for_path(path)
if finder is None: if finder is None:
continue continue
r = finder.find('') r = finder.find("")
if not r or not r.is_container: if not r or not r.is_container:
continue continue
rset = sorted(r.resources) rset = sorted(r.resources)
@ -131,7 +155,11 @@ class DistributionPath(object):
continue continue
try: try:
if self._include_dist and entry.endswith(DISTINFO_EXT): if self._include_dist and entry.endswith(DISTINFO_EXT):
possible_filenames = [METADATA_FILENAME, WHEEL_METADATA_FILENAME, LEGACY_METADATA_FILENAME] possible_filenames = [
METADATA_FILENAME,
WHEEL_METADATA_FILENAME,
LEGACY_METADATA_FILENAME,
]
for metadata_filename in possible_filenames: for metadata_filename in possible_filenames:
metadata_path = posixpath.join(entry, metadata_filename) metadata_path = posixpath.join(entry, metadata_filename)
pydist = finder.find(metadata_path) pydist = finder.find(metadata_path)
@ -141,18 +169,19 @@ class DistributionPath(object):
continue continue
with contextlib.closing(pydist.as_stream()) as stream: with contextlib.closing(pydist.as_stream()) as stream:
metadata = Metadata(fileobj=stream, scheme='legacy') metadata = Metadata(fileobj=stream, scheme="legacy")
logger.debug('Found %s', r.path) logger.debug("Found %s", r.path)
seen.add(r.path) seen.add(r.path)
yield new_dist_class(r.path, metadata=metadata, env=self) yield new_dist_class(r.path, metadata=metadata, env=self)
elif self._include_egg and entry.endswith(('.egg-info', '.egg')): elif self._include_egg and entry.endswith((".egg-info", ".egg")):
logger.debug('Found %s', r.path) logger.debug("Found %s", r.path)
seen.add(r.path) seen.add(r.path)
yield old_dist_class(r.path, self) yield old_dist_class(r.path, self)
except Exception as e: except Exception as e:
msg = 'Unable to read distribution at %s, perhaps due to bad metadata: %s' msg = "Unable to read distribution at %s, perhaps due to bad metadata: %s"
logger.warning(msg, r.path, e) logger.warning(msg, r.path, e)
import warnings import warnings
warnings.warn(msg % (r.path, e), stacklevel=2) warnings.warn(msg % (r.path, e), stacklevel=2)
def _generate_cache(self): def _generate_cache(self):
@ -193,8 +222,8 @@ class DistributionPath(object):
:type version: string :type version: string
:returns: directory name :returns: directory name
:rtype: string""" :rtype: string"""
name = name.replace('-', '_') name = name.replace("-", "_")
return '-'.join([name, version]) + DISTINFO_EXT return "-".join([name, version]) + DISTINFO_EXT
def get_distributions(self): def get_distributions(self):
""" """
@ -261,14 +290,16 @@ class DistributionPath(object):
matcher = None matcher = None
if version is not None: if version is not None:
try: try:
matcher = self._scheme.matcher('%s (%s)' % (name, version)) matcher = self._scheme.matcher("%s (%s)" % (name, version))
except ValueError: except ValueError:
raise DistlibException('invalid name or version: %r, %r' % (name, version)) raise DistlibException(
"invalid name or version: %r, %r" % (name, version)
)
for dist in self.get_distributions(): for dist in self.get_distributions():
# We hit a problem on Travis where enum34 was installed and doesn't # We hit a problem on Travis where enum34 was installed and doesn't
# have a provides attribute ... # have a provides attribute ...
if not hasattr(dist, 'provides'): if not hasattr(dist, "provides"):
logger.debug('No "provides": %s', dist) logger.debug('No "provides": %s', dist)
else: else:
provided = dist.provides provided = dist.provides
@ -290,7 +321,7 @@ class DistributionPath(object):
""" """
dist = self.get_distribution(name) dist = self.get_distribution(name)
if dist is None: if dist is None:
raise LookupError('no distribution named %r found' % name) raise LookupError("no distribution named %r found" % name)
return dist.get_resource_path(relative_path) return dist.get_resource_path(relative_path)
def get_exported_entries(self, category, name=None): def get_exported_entries(self, category, name=None):
@ -361,7 +392,7 @@ class Distribution(object):
""" """
A utility property which displays the name and version in parentheses. A utility property which displays the name and version in parentheses.
""" """
return '%s (%s)' % (self.name, self.version) return "%s (%s)" % (self.name, self.version)
@property @property
def provides(self): def provides(self):
@ -370,7 +401,7 @@ class Distribution(object):
:return: A set of "name (version)" strings. :return: A set of "name (version)" strings.
""" """
plist = self.metadata.provides plist = self.metadata.provides
s = '%s (%s)' % (self.name, self.version) s = "%s (%s)" % (self.name, self.version)
if s not in plist: if s not in plist:
plist.append(s) plist.append(s)
return plist return plist
@ -378,28 +409,30 @@ class Distribution(object):
def _get_requirements(self, req_attr): def _get_requirements(self, req_attr):
md = self.metadata md = self.metadata
reqts = getattr(md, req_attr) reqts = getattr(md, req_attr)
logger.debug('%s: got requirements %r from metadata: %r', self.name, req_attr, reqts) logger.debug(
"%s: got requirements %r from metadata: %r", self.name, req_attr, reqts
)
return set(md.get_requirements(reqts, extras=self.extras, env=self.context)) return set(md.get_requirements(reqts, extras=self.extras, env=self.context))
@property @property
def run_requires(self): def run_requires(self):
return self._get_requirements('run_requires') return self._get_requirements("run_requires")
@property @property
def meta_requires(self): def meta_requires(self):
return self._get_requirements('meta_requires') return self._get_requirements("meta_requires")
@property @property
def build_requires(self): def build_requires(self):
return self._get_requirements('build_requires') return self._get_requirements("build_requires")
@property @property
def test_requires(self): def test_requires(self):
return self._get_requirements('test_requires') return self._get_requirements("test_requires")
@property @property
def dev_requires(self): def dev_requires(self):
return self._get_requirements('dev_requires') return self._get_requirements("dev_requires")
def matches_requirement(self, req): def matches_requirement(self, req):
""" """
@ -416,7 +449,7 @@ class Distribution(object):
matcher = scheme.matcher(r.requirement) matcher = scheme.matcher(r.requirement)
except UnsupportedVersionError: except UnsupportedVersionError:
# XXX compat-mode if cannot read the version # XXX compat-mode if cannot read the version
logger.warning('could not read version %r - using name only', req) logger.warning("could not read version %r - using name only", req)
name = req.split()[0] name = req.split()[0]
matcher = scheme.matcher(name) matcher = scheme.matcher(name)
@ -439,10 +472,10 @@ class Distribution(object):
Return a textual representation of this instance, Return a textual representation of this instance,
""" """
if self.source_url: if self.source_url:
suffix = ' [%s]' % self.source_url suffix = " [%s]" % self.source_url
else: else:
suffix = '' suffix = ""
return '<Distribution %s (%s)%s>' % (self.name, self.version, suffix) return "<Distribution %s (%s)%s>" % (self.name, self.version, suffix)
def __eq__(self, other): def __eq__(self, other):
""" """
@ -455,7 +488,11 @@ class Distribution(object):
if type(other) is not type(self): if type(other) is not type(self):
result = False result = False
else: else:
result = (self.name == other.name and self.version == other.version and self.source_url == other.source_url) result = (
self.name == other.name
and self.version == other.version
and self.source_url == other.source_url
)
return result return result
def __hash__(self): def __hash__(self):
@ -511,13 +548,13 @@ class BaseInstalledDistribution(Distribution):
hasher = self.hasher hasher = self.hasher
if hasher is None: if hasher is None:
hasher = hashlib.md5 hasher = hashlib.md5
prefix = '' prefix = ""
else: else:
hasher = getattr(hashlib, hasher) hasher = getattr(hashlib, hasher)
prefix = '%s=' % self.hasher prefix = "%s=" % self.hasher
digest = hasher(data).digest() digest = hasher(data).digest()
digest = base64.urlsafe_b64encode(digest).rstrip(b'=').decode('ascii') digest = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
return '%s%s' % (prefix, digest) return "%s%s" % (prefix, digest)
class InstalledDistribution(BaseInstalledDistribution): class InstalledDistribution(BaseInstalledDistribution):
@ -528,13 +565,13 @@ class InstalledDistribution(BaseInstalledDistribution):
dry-run mode is being used). dry-run mode is being used).
""" """
hasher = 'sha256' hasher = "sha256"
def __init__(self, path, metadata=None, env=None): def __init__(self, path, metadata=None, env=None):
self.modules = [] self.modules = []
self.finder = finder = resources.finder_for_path(path) self.finder = finder = resources.finder_for_path(path)
if finder is None: if finder is None:
raise ValueError('finder unavailable for %s' % path) raise ValueError("finder unavailable for %s" % path)
if env and env._cache_enabled and path in env._cache.path: if env and env._cache_enabled and path in env._cache.path:
metadata = env._cache.path[path].metadata metadata = env._cache.path[path].metadata
elif metadata is None: elif metadata is None:
@ -546,25 +583,29 @@ class InstalledDistribution(BaseInstalledDistribution):
if r is None: if r is None:
r = finder.find(LEGACY_METADATA_FILENAME) r = finder.find(LEGACY_METADATA_FILENAME)
if r is None: if r is None:
raise ValueError('no %s found in %s' % (METADATA_FILENAME, path)) raise ValueError("no %s found in %s" % (METADATA_FILENAME, path))
with contextlib.closing(r.as_stream()) as stream: with contextlib.closing(r.as_stream()) as stream:
metadata = Metadata(fileobj=stream, scheme='legacy') metadata = Metadata(fileobj=stream, scheme="legacy")
super(InstalledDistribution, self).__init__(metadata, path, env) super(InstalledDistribution, self).__init__(metadata, path, env)
if env and env._cache_enabled: if env and env._cache_enabled:
env._cache.add(self) env._cache.add(self)
r = finder.find('REQUESTED') r = finder.find("REQUESTED")
self.requested = r is not None self.requested = r is not None
p = os.path.join(path, 'top_level.txt') p = os.path.join(path, "top_level.txt")
if os.path.exists(p): if os.path.exists(p):
with open(p, 'rb') as f: with open(p, "rb") as f:
data = f.read().decode('utf-8') data = f.read().decode("utf-8")
self.modules = data.splitlines() self.modules = data.splitlines()
def __repr__(self): def __repr__(self):
return '<InstalledDistribution %r %s at %r>' % (self.name, self.version, self.path) return "<InstalledDistribution %r %s at %r>" % (
self.name,
self.version,
self.path,
)
def __str__(self): def __str__(self):
return "%s %s" % (self.name, self.version) return "%s %s" % (self.name, self.version)
@ -577,7 +618,7 @@ class InstalledDistribution(BaseInstalledDistribution):
as stored in the file (which is as in PEP 376). as stored in the file (which is as in PEP 376).
""" """
results = [] results = []
r = self.get_distinfo_resource('RECORD') r = self.get_distinfo_resource("RECORD")
with contextlib.closing(r.as_stream()) as stream: with contextlib.closing(r.as_stream()) as stream:
with CSVReader(stream=stream) as record_reader: with CSVReader(stream=stream) as record_reader:
# Base location is parent dir of .dist-info dir # Base location is parent dir of .dist-info dir
@ -629,7 +670,7 @@ class InstalledDistribution(BaseInstalledDistribution):
individual export entries. individual export entries.
""" """
rf = self.get_distinfo_file(EXPORTS_FILENAME) rf = self.get_distinfo_file(EXPORTS_FILENAME)
with open(rf, 'w') as f: with open(rf, "w") as f:
write_exports(exports, f) write_exports(exports, f)
def get_resource_path(self, relative_path): def get_resource_path(self, relative_path):
@ -643,14 +684,15 @@ class InstalledDistribution(BaseInstalledDistribution):
of interest. of interest.
:return: The absolute path where the resource is to be found. :return: The absolute path where the resource is to be found.
""" """
r = self.get_distinfo_resource('RESOURCES') r = self.get_distinfo_resource("RESOURCES")
with contextlib.closing(r.as_stream()) as stream: with contextlib.closing(r.as_stream()) as stream:
with CSVReader(stream=stream) as resources_reader: with CSVReader(stream=stream) as resources_reader:
for relative, destination in resources_reader: for relative, destination in resources_reader:
if relative == relative_path: if relative == relative_path:
return destination return destination
raise KeyError('no resource file with relative path %r ' raise KeyError(
'is installed' % relative_path) "no resource file with relative path %r " "is installed" % relative_path
)
def list_installed_files(self): def list_installed_files(self):
""" """
@ -669,31 +711,33 @@ class InstalledDistribution(BaseInstalledDistribution):
prefix is used to determine when to write absolute paths. prefix is used to determine when to write absolute paths.
""" """
prefix = os.path.join(prefix, '') prefix = os.path.join(prefix, "")
base = os.path.dirname(self.path) base = os.path.dirname(self.path)
base_under_prefix = base.startswith(prefix) base_under_prefix = base.startswith(prefix)
base = os.path.join(base, '') base = os.path.join(base, "")
record_path = self.get_distinfo_file('RECORD') record_path = self.get_distinfo_file("RECORD")
logger.info('creating %s', record_path) logger.info("creating %s", record_path)
if dry_run: if dry_run:
return None return None
with CSVWriter(record_path) as writer: with CSVWriter(record_path) as writer:
for path in paths: for path in paths:
if os.path.isdir(path) or path.endswith(('.pyc', '.pyo')): if os.path.isdir(path) or path.endswith((".pyc", ".pyo")):
# do not put size and hash, as in PEP-376 # do not put size and hash, as in PEP-376
hash_value = size = '' hash_value = size = ""
else: else:
size = '%d' % os.path.getsize(path) size = "%d" % os.path.getsize(path)
with open(path, 'rb') as fp: with open(path, "rb") as fp:
hash_value = self.get_hash(fp.read()) hash_value = self.get_hash(fp.read())
if path.startswith(base) or (base_under_prefix and path.startswith(prefix)): if path.startswith(base) or (
base_under_prefix and path.startswith(prefix)
):
path = os.path.relpath(path, base) path = os.path.relpath(path, base)
writer.writerow((path, hash_value, size)) writer.writerow((path, hash_value, size))
# add the RECORD file itself # add the RECORD file itself
if record_path.startswith(base): if record_path.startswith(base):
record_path = os.path.relpath(record_path, base) record_path = os.path.relpath(record_path, base)
writer.writerow((record_path, '', '')) writer.writerow((record_path, "", ""))
return record_path return record_path
def check_installed_files(self): def check_installed_files(self):
@ -707,28 +751,28 @@ class InstalledDistribution(BaseInstalledDistribution):
""" """
mismatches = [] mismatches = []
base = os.path.dirname(self.path) base = os.path.dirname(self.path)
record_path = self.get_distinfo_file('RECORD') record_path = self.get_distinfo_file("RECORD")
for path, hash_value, size in self.list_installed_files(): for path, hash_value, size in self.list_installed_files():
if not os.path.isabs(path): if not os.path.isabs(path):
path = os.path.join(base, path) path = os.path.join(base, path)
if path == record_path: if path == record_path:
continue continue
if not os.path.exists(path): if not os.path.exists(path):
mismatches.append((path, 'exists', True, False)) mismatches.append((path, "exists", True, False))
elif os.path.isfile(path): elif os.path.isfile(path):
actual_size = str(os.path.getsize(path)) actual_size = str(os.path.getsize(path))
if size and actual_size != size: if size and actual_size != size:
mismatches.append((path, 'size', size, actual_size)) mismatches.append((path, "size", size, actual_size))
elif hash_value: elif hash_value:
if '=' in hash_value: if "=" in hash_value:
hasher = hash_value.split('=', 1)[0] hasher = hash_value.split("=", 1)[0]
else: else:
hasher = None hasher = None
with open(path, 'rb') as f: with open(path, "rb") as f:
actual_hash = self.get_hash(f.read(), hasher) actual_hash = self.get_hash(f.read(), hasher)
if actual_hash != hash_value: if actual_hash != hash_value:
mismatches.append((path, 'hash', hash_value, actual_hash)) mismatches.append((path, "hash", hash_value, actual_hash))
return mismatches return mismatches
@cached_property @cached_property
@ -746,13 +790,13 @@ class InstalledDistribution(BaseInstalledDistribution):
read from the SHARED file in the .dist-info directory. read from the SHARED file in the .dist-info directory.
""" """
result = {} result = {}
shared_path = os.path.join(self.path, 'SHARED') shared_path = os.path.join(self.path, "SHARED")
if os.path.isfile(shared_path): if os.path.isfile(shared_path):
with codecs.open(shared_path, 'r', encoding='utf-8') as f: with codecs.open(shared_path, "r", encoding="utf-8") as f:
lines = f.read().splitlines() lines = f.read().splitlines()
for line in lines: for line in lines:
key, value = line.split('=', 1) key, value = line.split("=", 1)
if key == 'namespace': if key == "namespace":
result.setdefault(key, []).append(value) result.setdefault(key, []).append(value)
else: else:
result[key] = value result[key] = value
@ -767,29 +811,30 @@ class InstalledDistribution(BaseInstalledDistribution):
written. written.
:return: The path of the file written to. :return: The path of the file written to.
""" """
shared_path = os.path.join(self.path, 'SHARED') shared_path = os.path.join(self.path, "SHARED")
logger.info('creating %s', shared_path) logger.info("creating %s", shared_path)
if dry_run: if dry_run:
return None return None
lines = [] lines = []
for key in ('prefix', 'lib', 'headers', 'scripts', 'data'): for key in ("prefix", "lib", "headers", "scripts", "data"):
path = paths[key] path = paths[key]
if os.path.isdir(paths[key]): if os.path.isdir(paths[key]):
lines.append('%s=%s' % (key, path)) lines.append("%s=%s" % (key, path))
for ns in paths.get('namespace', ()): for ns in paths.get("namespace", ()):
lines.append('namespace=%s' % ns) lines.append("namespace=%s" % ns)
with codecs.open(shared_path, 'w', encoding='utf-8') as f: with codecs.open(shared_path, "w", encoding="utf-8") as f:
f.write('\n'.join(lines)) f.write("\n".join(lines))
return shared_path return shared_path
def get_distinfo_resource(self, path): def get_distinfo_resource(self, path):
if path not in DIST_FILES: if path not in DIST_FILES:
raise DistlibException('invalid path for a dist-info file: ' raise DistlibException(
'%r at %r' % (path, self.path)) "invalid path for a dist-info file: " "%r at %r" % (path, self.path)
)
finder = resources.finder_for_path(self.path) finder = resources.finder_for_path(self.path)
if finder is None: if finder is None:
raise DistlibException('Unable to get a finder for %s' % self.path) raise DistlibException("Unable to get a finder for %s" % self.path)
return finder.find(path) return finder.find(path)
def get_distinfo_file(self, path): def get_distinfo_file(self, path):
@ -810,13 +855,16 @@ class InstalledDistribution(BaseInstalledDistribution):
# it's an absolute path? # it's an absolute path?
distinfo_dirname, path = path.split(os.sep)[-2:] distinfo_dirname, path = path.split(os.sep)[-2:]
if distinfo_dirname != self.path.split(os.sep)[-1]: if distinfo_dirname != self.path.split(os.sep)[-1]:
raise DistlibException('dist-info file %r does not belong to the %r %s ' raise DistlibException(
'distribution' % (path, self.name, self.version)) "dist-info file %r does not belong to the %r %s "
"distribution" % (path, self.name, self.version)
)
# The file must be relative # The file must be relative
if path not in DIST_FILES: if path not in DIST_FILES:
raise DistlibException('invalid path for a dist-info file: ' raise DistlibException(
'%r at %r' % (path, self.path)) "invalid path for a dist-info file: " "%r at %r" % (path, self.path)
)
return os.path.join(self.path, path) return os.path.join(self.path, path)
@ -837,7 +885,7 @@ class InstalledDistribution(BaseInstalledDistribution):
yield path yield path
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, InstalledDistribution) and self.path == other.path) return isinstance(other, InstalledDistribution) and self.path == other.path
# See http://docs.python.org/reference/datamodel#object.__hash__ # See http://docs.python.org/reference/datamodel#object.__hash__
__hash__ = object.__hash__ __hash__ = object.__hash__
@ -889,21 +937,24 @@ class EggInfoDistribution(BaseInstalledDistribution):
# sectioned files have bare newlines (separating sections) # sectioned files have bare newlines (separating sections)
if not line: # pragma: no cover if not line: # pragma: no cover
continue continue
if line.startswith('['): # pragma: no cover if line.startswith("["): # pragma: no cover
logger.warning('Unexpected line: quitting requirement scan: %r', line) logger.warning(
"Unexpected line: quitting requirement scan: %r", line
)
break break
r = parse_requirement(line) r = parse_requirement(line)
if not r: # pragma: no cover if not r: # pragma: no cover
logger.warning('Not recognised as a requirement: %r', line) logger.warning("Not recognised as a requirement: %r", line)
continue continue
if r.extras: # pragma: no cover if r.extras: # pragma: no cover
logger.warning('extra requirements in requires.txt are ' logger.warning(
'not supported') "extra requirements in requires.txt are " "not supported"
)
if not r.constraints: if not r.constraints:
reqs.append(r.name) reqs.append(r.name)
else: else:
cons = ', '.join('%s%s' % c for c in r.constraints) cons = ", ".join("%s%s" % c for c in r.constraints)
reqs.append('%s (%s)' % (r.name, cons)) reqs.append("%s (%s)" % (r.name, cons))
return reqs return reqs
def parse_requires_path(req_path): def parse_requires_path(req_path):
@ -914,50 +965,51 @@ class EggInfoDistribution(BaseInstalledDistribution):
reqs = [] reqs = []
try: try:
with codecs.open(req_path, 'r', 'utf-8') as fp: with codecs.open(req_path, "r", "utf-8") as fp:
reqs = parse_requires_data(fp.read()) reqs = parse_requires_data(fp.read())
except IOError: except IOError:
pass pass
return reqs return reqs
tl_path = tl_data = None tl_path = tl_data = None
if path.endswith('.egg'): if path.endswith(".egg"):
if os.path.isdir(path): if os.path.isdir(path):
p = os.path.join(path, 'EGG-INFO') p = os.path.join(path, "EGG-INFO")
meta_path = os.path.join(p, 'PKG-INFO') meta_path = os.path.join(p, "PKG-INFO")
metadata = Metadata(path=meta_path, scheme='legacy') metadata = Metadata(path=meta_path, scheme="legacy")
req_path = os.path.join(p, 'requires.txt') req_path = os.path.join(p, "requires.txt")
tl_path = os.path.join(p, 'top_level.txt') tl_path = os.path.join(p, "top_level.txt")
requires = parse_requires_path(req_path) requires = parse_requires_path(req_path)
else: else:
# FIXME handle the case where zipfile is not available # FIXME handle the case where zipfile is not available
zipf = zipimport.zipimporter(path) zipf = zipimport.zipimporter(path)
fileobj = StringIO(zipf.get_data('EGG-INFO/PKG-INFO').decode('utf8')) fileobj = StringIO(zipf.get_data("EGG-INFO/PKG-INFO").decode("utf8"))
metadata = Metadata(fileobj=fileobj, scheme='legacy') metadata = Metadata(fileobj=fileobj, scheme="legacy")
try: try:
data = zipf.get_data('EGG-INFO/requires.txt') data = zipf.get_data("EGG-INFO/requires.txt")
tl_data = zipf.get_data('EGG-INFO/top_level.txt').decode('utf-8') tl_data = zipf.get_data("EGG-INFO/top_level.txt").decode("utf-8")
requires = parse_requires_data(data.decode('utf-8')) requires = parse_requires_data(data.decode("utf-8"))
except IOError: except IOError:
requires = None requires = None
elif path.endswith('.egg-info'): elif path.endswith(".egg-info"):
if os.path.isdir(path): if os.path.isdir(path):
req_path = os.path.join(path, 'requires.txt') req_path = os.path.join(path, "requires.txt")
requires = parse_requires_path(req_path) requires = parse_requires_path(req_path)
path = os.path.join(path, 'PKG-INFO') path = os.path.join(path, "PKG-INFO")
tl_path = os.path.join(path, 'top_level.txt') tl_path = os.path.join(path, "top_level.txt")
metadata = Metadata(path=path, scheme='legacy') metadata = Metadata(path=path, scheme="legacy")
else: else:
raise DistlibException('path must end with .egg-info or .egg, ' raise DistlibException(
'got %r' % path) "path must end with .egg-info or .egg, " "got %r" % path
)
if requires: if requires:
metadata.add_requirements(requires) metadata.add_requirements(requires)
# look for top-level modules in top_level.txt, if present # look for top-level modules in top_level.txt, if present
if tl_data is None: if tl_data is None:
if tl_path is not None and os.path.exists(tl_path): if tl_path is not None and os.path.exists(tl_path):
with open(tl_path, 'rb') as f: with open(tl_path, "rb") as f:
tl_data = f.read().decode('utf-8') tl_data = f.read().decode("utf-8")
if not tl_data: if not tl_data:
tl_data = [] tl_data = []
else: else:
@ -966,7 +1018,11 @@ class EggInfoDistribution(BaseInstalledDistribution):
return metadata return metadata
def __repr__(self): def __repr__(self):
return '<EggInfoDistribution %r %s at %r>' % (self.name, self.version, self.path) return "<EggInfoDistribution %r %s at %r>" % (
self.name,
self.version,
self.path,
)
def __str__(self): def __str__(self):
return "%s %s" % (self.name, self.version) return "%s %s" % (self.name, self.version)
@ -981,13 +1037,13 @@ class EggInfoDistribution(BaseInstalledDistribution):
value and the actual value. value and the actual value.
""" """
mismatches = [] mismatches = []
record_path = os.path.join(self.path, 'installed-files.txt') record_path = os.path.join(self.path, "installed-files.txt")
if os.path.exists(record_path): if os.path.exists(record_path):
for path, _, _ in self.list_installed_files(): for path, _, _ in self.list_installed_files():
if path == record_path: if path == record_path:
continue continue
if not os.path.exists(path): if not os.path.exists(path):
mismatches.append((path, 'exists', True, False)) mismatches.append((path, "exists", True, False))
return mismatches return mismatches
def list_installed_files(self): def list_installed_files(self):
@ -999,7 +1055,7 @@ class EggInfoDistribution(BaseInstalledDistribution):
""" """
def _md5(path): def _md5(path):
f = open(path, 'rb') f = open(path, "rb")
try: try:
content = f.read() content = f.read()
finally: finally:
@ -1009,18 +1065,18 @@ class EggInfoDistribution(BaseInstalledDistribution):
def _size(path): def _size(path):
return os.stat(path).st_size return os.stat(path).st_size
record_path = os.path.join(self.path, 'installed-files.txt') record_path = os.path.join(self.path, "installed-files.txt")
result = [] result = []
if os.path.exists(record_path): if os.path.exists(record_path):
with codecs.open(record_path, 'r', encoding='utf-8') as f: with codecs.open(record_path, "r", encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
p = os.path.normpath(os.path.join(self.path, line)) p = os.path.normpath(os.path.join(self.path, line))
# "./" is present as a marker between installed files # "./" is present as a marker between installed files
# and installation metadata files # and installation metadata files
if not os.path.exists(p): if not os.path.exists(p):
logger.warning('Non-existent file: %s', p) logger.warning("Non-existent file: %s", p)
if p.endswith(('.pyc', '.pyo')): if p.endswith((".pyc", ".pyo")):
continue continue
# otherwise fall through and fail # otherwise fall through and fail
if not os.path.isdir(p): if not os.path.isdir(p):
@ -1040,13 +1096,13 @@ class EggInfoDistribution(BaseInstalledDistribution):
:type absolute: boolean :type absolute: boolean
:returns: iterator of paths :returns: iterator of paths
""" """
record_path = os.path.join(self.path, 'installed-files.txt') record_path = os.path.join(self.path, "installed-files.txt")
if os.path.exists(record_path): if os.path.exists(record_path):
skip = True skip = True
with codecs.open(record_path, 'r', encoding='utf-8') as f: with codecs.open(record_path, "r", encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if line == './': if line == "./":
skip = False skip = False
continue continue
if not skip: if not skip:
@ -1058,7 +1114,7 @@ class EggInfoDistribution(BaseInstalledDistribution):
yield line yield line
def __eq__(self, other): def __eq__(self, other):
return (isinstance(other, EggInfoDistribution) and self.path == other.path) return isinstance(other, EggInfoDistribution) and self.path == other.path
# See http://docs.python.org/reference/datamodel#object.__hash__ # See http://docs.python.org/reference/datamodel#object.__hash__
__hash__ = object.__hash__ __hash__ = object.__hash__
@ -1122,11 +1178,11 @@ class DependencyGraph(object):
or :class:`distutils2.database.EggInfoDistribution` or :class:`distutils2.database.EggInfoDistribution`
:type requirement: ``str`` :type requirement: ``str``
""" """
logger.debug('%s missing %r', distribution, requirement) logger.debug("%s missing %r", distribution, requirement)
self.missing.setdefault(distribution, []).append(requirement) self.missing.setdefault(distribution, []).append(requirement)
def _repr_dist(self, dist): def _repr_dist(self, dist):
return '%s %s' % (dist.name, dist.version) return "%s %s" % (dist.name, dist.version)
def repr_node(self, dist, level=1): def repr_node(self, dist, level=1):
"""Prints only a subgraph""" """Prints only a subgraph"""
@ -1134,12 +1190,12 @@ class DependencyGraph(object):
for other, label in self.adjacency_list[dist]: for other, label in self.adjacency_list[dist]:
dist = self._repr_dist(other) dist = self._repr_dist(other)
if label is not None: if label is not None:
dist = '%s [%s]' % (dist, label) dist = "%s [%s]" % (dist, label)
output.append(' ' * level + str(dist)) output.append(" " * level + str(dist))
suboutput = self.repr_node(other, level + 1) suboutput = self.repr_node(other, level + 1)
subs = suboutput.split('\n') subs = suboutput.split("\n")
output.extend(subs[1:]) output.extend(subs[1:])
return '\n'.join(output) return "\n".join(output)
def to_dot(self, f, skip_disconnected=True): def to_dot(self, f, skip_disconnected=True):
"""Writes a DOT output for the graph to the provided file *f*. """Writes a DOT output for the graph to the provided file *f*.
@ -1158,19 +1214,21 @@ class DependencyGraph(object):
disconnected.append(dist) disconnected.append(dist)
for other, label in adjs: for other, label in adjs:
if label is not None: if label is not None:
f.write('"%s" -> "%s" [label="%s"]\n' % (dist.name, other.name, label)) f.write(
'"%s" -> "%s" [label="%s"]\n' % (dist.name, other.name, label)
)
else: else:
f.write('"%s" -> "%s"\n' % (dist.name, other.name)) f.write('"%s" -> "%s"\n' % (dist.name, other.name))
if not skip_disconnected and len(disconnected) > 0: if not skip_disconnected and len(disconnected) > 0:
f.write('subgraph disconnected {\n') f.write("subgraph disconnected {\n")
f.write('label = "Disconnected"\n') f.write('label = "Disconnected"\n')
f.write('bgcolor = red\n') f.write("bgcolor = red\n")
for dist in disconnected: for dist in disconnected:
f.write('"%s"' % dist.name) f.write('"%s"' % dist.name)
f.write('\n') f.write("\n")
f.write('}\n') f.write("}\n")
f.write('}\n') f.write("}\n")
def topological_sort(self): def topological_sort(self):
""" """
@ -1198,7 +1256,10 @@ class DependencyGraph(object):
# Remove from the adjacency list of others # Remove from the adjacency list of others
for k, v in alist.items(): for k, v in alist.items():
alist[k] = [(d, r) for d, r in v if d not in to_remove] alist[k] = [(d, r) for d, r in v if d not in to_remove]
logger.debug('Moving to result: %s', ['%s (%s)' % (d.name, d.version) for d in to_remove]) logger.debug(
"Moving to result: %s",
["%s (%s)" % (d.name, d.version) for d in to_remove],
)
result.extend(to_remove) result.extend(to_remove)
return result, list(alist.keys()) return result, list(alist.keys())
@ -1207,10 +1268,10 @@ class DependencyGraph(object):
output = [] output = []
for dist, adjs in self.adjacency_list.items(): for dist, adjs in self.adjacency_list.items():
output.append(self.repr_node(dist)) output.append(self.repr_node(dist))
return '\n'.join(output) return "\n".join(output)
def make_graph(dists, scheme='default'): def make_graph(dists, scheme="default"):
"""Makes a dependency graph from the given distributions. """Makes a dependency graph from the given distributions.
:parameter dists: a list of distributions :parameter dists: a list of distributions
@ -1228,18 +1289,23 @@ def make_graph(dists, scheme='default'):
for p in dist.provides: for p in dist.provides:
name, version = parse_name_and_version(p) name, version = parse_name_and_version(p)
logger.debug('Add to provided: %s, %s, %s', name, version, dist) logger.debug("Add to provided: %s, %s, %s", name, version, dist)
provided.setdefault(name, []).append((version, dist)) provided.setdefault(name, []).append((version, dist))
# now make the edges # now make the edges
for dist in dists: for dist in dists:
requires = (dist.run_requires | dist.meta_requires | dist.build_requires | dist.dev_requires) requires = (
dist.run_requires
| dist.meta_requires
| dist.build_requires
| dist.dev_requires
)
for req in requires: for req in requires:
try: try:
matcher = scheme.matcher(req) matcher = scheme.matcher(req)
except UnsupportedVersionError: except UnsupportedVersionError:
# XXX compat-mode if cannot read the version # XXX compat-mode if cannot read the version
logger.warning('could not read version %r - using name only', req) logger.warning("could not read version %r - using name only", req)
name = req.split()[0] name = req.split()[0]
matcher = scheme.matcher(name) matcher = scheme.matcher(name)
@ -1270,8 +1336,9 @@ def get_dependent_dists(dists, dist):
:param dist: a distribution, member of *dists* for which we are interested :param dist: a distribution, member of *dists* for which we are interested
""" """
if dist not in dists: if dist not in dists:
raise DistlibException('given distribution %r is not a member ' raise DistlibException(
'of the list' % dist.name) "given distribution %r is not a member " "of the list" % dist.name
)
graph = make_graph(dists) graph = make_graph(dists)
dep = [dist] # dependent distributions dep = [dist] # dependent distributions
@ -1297,8 +1364,9 @@ def get_required_dists(dists, dist):
in finding the dependencies. in finding the dependencies.
""" """
if dist not in dists: if dist not in dists:
raise DistlibException('given distribution %r is not a member ' raise DistlibException(
'of the list' % dist.name) "given distribution %r is not a member " "of the list" % dist.name
)
graph = make_graph(dists) graph = make_graph(dists)
req = set() # required distributions req = set() # required distributions
@ -1321,9 +1389,9 @@ def make_dist(name, version, **kwargs):
""" """
A convenience method for making a dist given just a name and version. A convenience method for making a dist given just a name and version.
""" """
summary = kwargs.pop('summary', 'Placeholder for summary') summary = kwargs.pop("summary", "Placeholder for summary")
md = Metadata(**kwargs) md = Metadata(**kwargs)
md.name = name md.name = name
md.version = version md.version = version
md.summary = summary or 'Placeholder for summary' md.summary = summary or "Placeholder for summary"
return Distribution(md) return Distribution(md)

View File

@ -10,20 +10,27 @@ import os
import shutil import shutil
import subprocess import subprocess
import tempfile import tempfile
try: try:
from threading import Thread from threading import Thread
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from dummy_threading import Thread from dummy_threading import Thread
from . import DistlibException from . import DistlibException
from .compat import (HTTPBasicAuthHandler, Request, HTTPPasswordMgr, from .compat import (
urlparse, build_opener, string_types) HTTPBasicAuthHandler,
Request,
HTTPPasswordMgr,
urlparse,
build_opener,
string_types,
)
from .util import zip_dir, ServerProxy from .util import zip_dir, ServerProxy
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_INDEX = 'https://pypi.org/pypi' DEFAULT_INDEX = "https://pypi.org/pypi"
DEFAULT_REALM = 'pypi' DEFAULT_REALM = "pypi"
class PackageIndex(object): class PackageIndex(object):
@ -32,7 +39,7 @@ class PackageIndex(object):
Package Index. Package Index.
""" """
boundary = b'----------ThIs_Is_tHe_distlib_index_bouNdaRY_$' boundary = b"----------ThIs_Is_tHe_distlib_index_bouNdaRY_$"
def __init__(self, url=None): def __init__(self, url=None):
""" """
@ -44,19 +51,20 @@ class PackageIndex(object):
self.url = url or DEFAULT_INDEX self.url = url or DEFAULT_INDEX
self.read_configuration() self.read_configuration()
scheme, netloc, path, params, query, frag = urlparse(self.url) scheme, netloc, path, params, query, frag = urlparse(self.url)
if params or query or frag or scheme not in ('http', 'https'): if params or query or frag or scheme not in ("http", "https"):
raise DistlibException('invalid repository: %s' % self.url) raise DistlibException("invalid repository: %s" % self.url)
self.password_handler = None self.password_handler = None
self.ssl_verifier = None self.ssl_verifier = None
self.gpg = None self.gpg = None
self.gpg_home = None self.gpg_home = None
with open(os.devnull, 'w') as sink: with open(os.devnull, "w") as sink:
# Use gpg by default rather than gpg2, as gpg2 insists on # Use gpg by default rather than gpg2, as gpg2 insists on
# prompting for passwords # prompting for passwords
for s in ('gpg', 'gpg2'): for s in ("gpg", "gpg2"):
try: try:
rc = subprocess.check_call([s, '--version'], stdout=sink, rc = subprocess.check_call(
stderr=sink) [s, "--version"], stdout=sink, stderr=sink
)
if rc == 0: if rc == 0:
self.gpg = s self.gpg = s
break break
@ -69,6 +77,7 @@ class PackageIndex(object):
:return: the command. :return: the command.
""" """
from .util import _get_pypirc_command as cmd from .util import _get_pypirc_command as cmd
return cmd() return cmd()
def read_configuration(self): def read_configuration(self):
@ -78,11 +87,12 @@ class PackageIndex(object):
configuration. configuration.
""" """
from .util import _load_pypirc from .util import _load_pypirc
cfg = _load_pypirc(self) cfg = _load_pypirc(self)
self.username = cfg.get('username') self.username = cfg.get("username")
self.password = cfg.get('password') self.password = cfg.get("password")
self.realm = cfg.get('realm', 'pypi') self.realm = cfg.get("realm", "pypi")
self.url = cfg.get('repository', self.url) self.url = cfg.get("repository", self.url)
def save_configuration(self): def save_configuration(self):
""" """
@ -91,6 +101,7 @@ class PackageIndex(object):
""" """
self.check_credentials() self.check_credentials()
from .util import _store_pypirc from .util import _store_pypirc
_store_pypirc(self) _store_pypirc(self)
def check_credentials(self): def check_credentials(self):
@ -99,7 +110,7 @@ class PackageIndex(object):
exception if not. exception if not.
""" """
if self.username is None or self.password is None: if self.username is None or self.password is None:
raise DistlibException('username and password must be set') raise DistlibException("username and password must be set")
pm = HTTPPasswordMgr() pm = HTTPPasswordMgr()
_, netloc, _, _, _, _ = urlparse(self.url) _, netloc, _, _, _, _ = urlparse(self.url)
pm.add_password(self.realm, netloc, self.username, self.password) pm.add_password(self.realm, netloc, self.username, self.password)
@ -118,10 +129,10 @@ class PackageIndex(object):
self.check_credentials() self.check_credentials()
metadata.validate() metadata.validate()
d = metadata.todict() d = metadata.todict()
d[':action'] = 'verify' d[":action"] = "verify"
request = self.encode_request(d.items(), []) request = self.encode_request(d.items(), [])
self.send_request(request) self.send_request(request)
d[':action'] = 'submit' d[":action"] = "submit"
request = self.encode_request(d.items(), []) request = self.encode_request(d.items(), [])
return self.send_request(request) return self.send_request(request)
@ -138,12 +149,14 @@ class PackageIndex(object):
s = stream.readline() s = stream.readline()
if not s: if not s:
break break
s = s.decode('utf-8').rstrip() s = s.decode("utf-8").rstrip()
outbuf.append(s) outbuf.append(s)
logger.debug('%s: %s' % (name, s)) logger.debug("%s: %s" % (name, s))
stream.close() stream.close()
def get_sign_command(self, filename, signer, sign_password, keystore=None): # pragma: no cover def get_sign_command(
self, filename, signer, sign_password, keystore=None
): # pragma: no cover
""" """
Return a suitable command for signing a file. Return a suitable command for signing a file.
@ -157,18 +170,27 @@ class PackageIndex(object):
:return: The signing command as a list suitable to be :return: The signing command as a list suitable to be
passed to :class:`subprocess.Popen`. passed to :class:`subprocess.Popen`.
""" """
cmd = [self.gpg, '--status-fd', '2', '--no-tty'] cmd = [self.gpg, "--status-fd", "2", "--no-tty"]
if keystore is None: if keystore is None:
keystore = self.gpg_home keystore = self.gpg_home
if keystore: if keystore:
cmd.extend(['--homedir', keystore]) cmd.extend(["--homedir", keystore])
if sign_password is not None: if sign_password is not None:
cmd.extend(['--batch', '--passphrase-fd', '0']) cmd.extend(["--batch", "--passphrase-fd", "0"])
td = tempfile.mkdtemp() td = tempfile.mkdtemp()
sf = os.path.join(td, os.path.basename(filename) + '.asc') sf = os.path.join(td, os.path.basename(filename) + ".asc")
cmd.extend(['--detach-sign', '--armor', '--local-user', cmd.extend(
signer, '--output', sf, filename]) [
logger.debug('invoking: %s', ' '.join(cmd)) "--detach-sign",
"--armor",
"--local-user",
signer,
"--output",
sf,
filename,
]
)
logger.debug("invoking: %s", " ".join(cmd))
return cmd, sf return cmd, sf
def run_command(self, cmd, input_data=None): def run_command(self, cmd, input_data=None):
@ -183,19 +205,19 @@ class PackageIndex(object):
lines read from the subprocess' ``stderr``. lines read from the subprocess' ``stderr``.
""" """
kwargs = { kwargs = {
'stdout': subprocess.PIPE, "stdout": subprocess.PIPE,
'stderr': subprocess.PIPE, "stderr": subprocess.PIPE,
} }
if input_data is not None: if input_data is not None:
kwargs['stdin'] = subprocess.PIPE kwargs["stdin"] = subprocess.PIPE
stdout = [] stdout = []
stderr = [] stderr = []
p = subprocess.Popen(cmd, **kwargs) p = subprocess.Popen(cmd, **kwargs)
# We don't use communicate() here because we may need to # We don't use communicate() here because we may need to
# get clever with interacting with the command # get clever with interacting with the command
t1 = Thread(target=self._reader, args=('stdout', p.stdout, stdout)) t1 = Thread(target=self._reader, args=("stdout", p.stdout, stdout))
t1.start() t1.start()
t2 = Thread(target=self._reader, args=('stderr', p.stderr, stderr)) t2 = Thread(target=self._reader, args=("stderr", p.stderr, stderr))
t2.start() t2.start()
if input_data is not None: if input_data is not None:
p.stdin.write(input_data) p.stdin.write(input_data)
@ -206,7 +228,9 @@ class PackageIndex(object):
t2.join() t2.join()
return p.returncode, stdout, stderr return p.returncode, stdout, stderr
def sign_file(self, filename, signer, sign_password, keystore=None): # pragma: no cover def sign_file(
self, filename, signer, sign_password, keystore=None
): # pragma: no cover
""" """
Sign a file. Sign a file.
@ -220,17 +244,22 @@ class PackageIndex(object):
:return: The absolute pathname of the file where the signature is :return: The absolute pathname of the file where the signature is
stored. stored.
""" """
cmd, sig_file = self.get_sign_command(filename, signer, sign_password, cmd, sig_file = self.get_sign_command(filename, signer, sign_password, keystore)
keystore) rc, stdout, stderr = self.run_command(cmd, sign_password.encode("utf-8"))
rc, stdout, stderr = self.run_command(cmd,
sign_password.encode('utf-8'))
if rc != 0: if rc != 0:
raise DistlibException('sign command failed with error ' raise DistlibException("sign command failed with error " "code %s" % rc)
'code %s' % rc)
return sig_file return sig_file
def upload_file(self, metadata, filename, signer=None, sign_password=None, def upload_file(
filetype='sdist', pyversion='source', keystore=None): self,
metadata,
filename,
signer=None,
sign_password=None,
filetype="sdist",
pyversion="source",
keystore=None,
):
""" """
Upload a release file to the index. Upload a release file to the index.
@ -254,34 +283,34 @@ class PackageIndex(object):
""" """
self.check_credentials() self.check_credentials()
if not os.path.exists(filename): if not os.path.exists(filename):
raise DistlibException('not found: %s' % filename) raise DistlibException("not found: %s" % filename)
metadata.validate() metadata.validate()
d = metadata.todict() d = metadata.todict()
sig_file = None sig_file = None
if signer: if signer:
if not self.gpg: if not self.gpg:
logger.warning('no signing program available - not signed') logger.warning("no signing program available - not signed")
else: else:
sig_file = self.sign_file(filename, signer, sign_password, sig_file = self.sign_file(filename, signer, sign_password, keystore)
keystore) with open(filename, "rb") as f:
with open(filename, 'rb') as f:
file_data = f.read() file_data = f.read()
md5_digest = hashlib.md5(file_data).hexdigest() md5_digest = hashlib.md5(file_data).hexdigest()
sha256_digest = hashlib.sha256(file_data).hexdigest() sha256_digest = hashlib.sha256(file_data).hexdigest()
d.update({ d.update(
':action': 'file_upload', {
'protocol_version': '1', ":action": "file_upload",
'filetype': filetype, "protocol_version": "1",
'pyversion': pyversion, "filetype": filetype,
'md5_digest': md5_digest, "pyversion": pyversion,
'sha256_digest': sha256_digest, "md5_digest": md5_digest,
}) "sha256_digest": sha256_digest,
files = [('content', os.path.basename(filename), file_data)] }
)
files = [("content", os.path.basename(filename), file_data)]
if sig_file: if sig_file:
with open(sig_file, 'rb') as f: with open(sig_file, "rb") as f:
sig_data = f.read() sig_data = f.read()
files.append(('gpg_signature', os.path.basename(sig_file), files.append(("gpg_signature", os.path.basename(sig_file), sig_data))
sig_data))
shutil.rmtree(os.path.dirname(sig_file)) shutil.rmtree(os.path.dirname(sig_file))
request = self.encode_request(d.items(), files) request = self.encode_request(d.items(), files)
return self.send_request(request) return self.send_request(request)
@ -301,21 +330,19 @@ class PackageIndex(object):
""" """
self.check_credentials() self.check_credentials()
if not os.path.isdir(doc_dir): if not os.path.isdir(doc_dir):
raise DistlibException('not a directory: %r' % doc_dir) raise DistlibException("not a directory: %r" % doc_dir)
fn = os.path.join(doc_dir, 'index.html') fn = os.path.join(doc_dir, "index.html")
if not os.path.exists(fn): if not os.path.exists(fn):
raise DistlibException('not found: %r' % fn) raise DistlibException("not found: %r" % fn)
metadata.validate() metadata.validate()
name, version = metadata.name, metadata.version name, version = metadata.name, metadata.version
zip_data = zip_dir(doc_dir).getvalue() zip_data = zip_dir(doc_dir).getvalue()
fields = [(':action', 'doc_upload'), fields = [(":action", "doc_upload"), ("name", name), ("version", version)]
('name', name), ('version', version)] files = [("content", name, zip_data)]
files = [('content', name, zip_data)]
request = self.encode_request(fields, files) request = self.encode_request(fields, files)
return self.send_request(request) return self.send_request(request)
def get_verify_command(self, signature_filename, data_filename, def get_verify_command(self, signature_filename, data_filename, keystore=None):
keystore=None):
""" """
Return a suitable command for verifying a file. Return a suitable command for verifying a file.
@ -329,17 +356,16 @@ class PackageIndex(object):
:return: The verifying command as a list suitable to be :return: The verifying command as a list suitable to be
passed to :class:`subprocess.Popen`. passed to :class:`subprocess.Popen`.
""" """
cmd = [self.gpg, '--status-fd', '2', '--no-tty'] cmd = [self.gpg, "--status-fd", "2", "--no-tty"]
if keystore is None: if keystore is None:
keystore = self.gpg_home keystore = self.gpg_home
if keystore: if keystore:
cmd.extend(['--homedir', keystore]) cmd.extend(["--homedir", keystore])
cmd.extend(['--verify', signature_filename, data_filename]) cmd.extend(["--verify", signature_filename, data_filename])
logger.debug('invoking: %s', ' '.join(cmd)) logger.debug("invoking: %s", " ".join(cmd))
return cmd return cmd
def verify_signature(self, signature_filename, data_filename, def verify_signature(self, signature_filename, data_filename, keystore=None):
keystore=None):
""" """
Verify a signature for a file. Verify a signature for a file.
@ -353,13 +379,13 @@ class PackageIndex(object):
:return: True if the signature was verified, else False. :return: True if the signature was verified, else False.
""" """
if not self.gpg: if not self.gpg:
raise DistlibException('verification unavailable because gpg ' raise DistlibException(
'unavailable') "verification unavailable because gpg " "unavailable"
cmd = self.get_verify_command(signature_filename, data_filename, )
keystore) cmd = self.get_verify_command(signature_filename, data_filename, keystore)
rc, stdout, stderr = self.run_command(cmd) rc, stdout, stderr = self.run_command(cmd)
if rc not in (0, 1): if rc not in (0, 1):
raise DistlibException('verify command failed with error code %s' % rc) raise DistlibException("verify command failed with error code %s" % rc)
return rc == 0 return rc == 0
def download_file(self, url, destfile, digest=None, reporthook=None): def download_file(self, url, destfile, digest=None, reporthook=None):
@ -386,18 +412,18 @@ class PackageIndex(object):
""" """
if digest is None: if digest is None:
digester = None digester = None
logger.debug('No digest specified') logger.debug("No digest specified")
else: else:
if isinstance(digest, (list, tuple)): if isinstance(digest, (list, tuple)):
hasher, digest = digest hasher, digest = digest
else: else:
hasher = 'md5' hasher = "md5"
digester = getattr(hashlib, hasher)() digester = getattr(hashlib, hasher)()
logger.debug('Digest specified: %s' % digest) logger.debug("Digest specified: %s" % digest)
# The following code is equivalent to urlretrieve. # The following code is equivalent to urlretrieve.
# We need to do it this way so that we can compute the # We need to do it this way so that we can compute the
# digest of the file as we go. # digest of the file as we go.
with open(destfile, 'wb') as dfp: with open(destfile, "wb") as dfp:
# addinfourl is not a context manager on 2.x # addinfourl is not a context manager on 2.x
# so we have to use try/finally # so we have to use try/finally
sfp = self.send_request(Request(url)) sfp = self.send_request(Request(url))
@ -428,16 +454,17 @@ class PackageIndex(object):
# check that we got the whole file, if we can # check that we got the whole file, if we can
if size >= 0 and read < size: if size >= 0 and read < size:
raise DistlibException( raise DistlibException(
'retrieval incomplete: got only %d out of %d bytes' "retrieval incomplete: got only %d out of %d bytes" % (read, size)
% (read, size)) )
# if we have a digest, it must match. # if we have a digest, it must match.
if digester: if digester:
actual = digester.hexdigest() actual = digester.hexdigest()
if digest != actual: if digest != actual:
raise DistlibException('%s digest mismatch for %s: expected ' raise DistlibException(
'%s, got %s' % (hasher, destfile, "%s digest mismatch for %s: expected "
digest, actual)) "%s, got %s" % (hasher, destfile, digest, actual)
logger.debug('Digest verified: %s', digest) )
logger.debug("Digest verified: %s", digest)
def send_request(self, req): def send_request(self, req):
""" """
@ -474,35 +501,41 @@ class PackageIndex(object):
values = [values] values = [values]
for v in values: for v in values:
parts.extend(( parts.extend(
b'--' + boundary, (
('Content-Disposition: form-data; name="%s"' % b"--" + boundary,
k).encode('utf-8'), ('Content-Disposition: form-data; name="%s"' % k).encode(
b'', "utf-8"
v.encode('utf-8'))) ),
b"",
v.encode("utf-8"),
)
)
for key, filename, value in files: for key, filename, value in files:
parts.extend(( parts.extend(
b'--' + boundary, (
('Content-Disposition: form-data; name="%s"; filename="%s"' % b"--" + boundary,
(key, filename)).encode('utf-8'), (
b'', 'Content-Disposition: form-data; name="%s"; filename="%s"'
value)) % (key, filename)
).encode("utf-8"),
b"",
value,
)
)
parts.extend((b'--' + boundary + b'--', b'')) parts.extend((b"--" + boundary + b"--", b""))
body = b'\r\n'.join(parts) body = b"\r\n".join(parts)
ct = b'multipart/form-data; boundary=' + boundary ct = b"multipart/form-data; boundary=" + boundary
headers = { headers = {"Content-type": ct, "Content-length": str(len(body))}
'Content-type': ct,
'Content-length': str(len(body))
}
return Request(self.url, body, headers) return Request(self.url, body, headers)
def search(self, terms, operator=None): # pragma: no cover def search(self, terms, operator=None): # pragma: no cover
if isinstance(terms, string_types): if isinstance(terms, string_types):
terms = {'name': terms} terms = {"name": terms}
rpc_proxy = ServerProxy(self.url, timeout=3.0) rpc_proxy = ServerProxy(self.url, timeout=3.0)
try: try:
return rpc_proxy.search(terms, operator or 'and') return rpc_proxy.search(terms, operator or "and")
finally: finally:
rpc_proxy('close')() rpc_proxy("close")()

View File

@ -12,6 +12,7 @@ import logging
import os import os
import posixpath import posixpath
import re import re
try: try:
import threading import threading
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
@ -19,21 +20,43 @@ except ImportError: # pragma: no cover
import zlib import zlib
from . import DistlibException from . import DistlibException
from .compat import (urljoin, urlparse, urlunparse, url2pathname, pathname2url, queue, quote, unescape, build_opener, from .compat import (
HTTPRedirectHandler as BaseRedirectHandler, text_type, Request, HTTPError, URLError) urljoin,
urlparse,
urlunparse,
url2pathname,
pathname2url,
queue,
quote,
unescape,
build_opener,
HTTPRedirectHandler as BaseRedirectHandler,
text_type,
Request,
HTTPError,
URLError,
)
from .database import Distribution, DistributionPath, make_dist from .database import Distribution, DistributionPath, make_dist
from .metadata import Metadata, MetadataInvalidError from .metadata import Metadata, MetadataInvalidError
from .util import (cached_property, ensure_slash, split_filename, get_project_data, parse_requirement, from .util import (
parse_name_and_version, ServerProxy, normalize_name) cached_property,
ensure_slash,
split_filename,
get_project_data,
parse_requirement,
parse_name_and_version,
ServerProxy,
normalize_name,
)
from .version import get_scheme, UnsupportedVersionError from .version import get_scheme, UnsupportedVersionError
from .wheel import Wheel, is_compatible from .wheel import Wheel, is_compatible
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HASHER_HASH = re.compile(r'^(\w+)=([a-f0-9]+)') HASHER_HASH = re.compile(r"^(\w+)=([a-f0-9]+)")
CHARSET = re.compile(r';\s*charset\s*=\s*(.*)\s*$', re.I) CHARSET = re.compile(r";\s*charset\s*=\s*(.*)\s*$", re.I)
HTML_CONTENT_TYPE = re.compile('text/html|application/x(ht)?ml') HTML_CONTENT_TYPE = re.compile("text/html|application/x(ht)?ml")
DEFAULT_INDEX = 'https://pypi.org/pypi' DEFAULT_INDEX = "https://pypi.org/pypi"
def get_all_distribution_names(url=None): def get_all_distribution_names(url=None):
@ -48,7 +71,7 @@ def get_all_distribution_names(url=None):
try: try:
return client.list_packages() return client.list_packages()
finally: finally:
client('close')() client("close")()
class RedirectHandler(BaseRedirectHandler): class RedirectHandler(BaseRedirectHandler):
@ -65,16 +88,16 @@ class RedirectHandler(BaseRedirectHandler):
# Some servers (incorrectly) return multiple Location headers # Some servers (incorrectly) return multiple Location headers
# (so probably same goes for URI). Use first header. # (so probably same goes for URI). Use first header.
newurl = None newurl = None
for key in ('location', 'uri'): for key in ("location", "uri"):
if key in headers: if key in headers:
newurl = headers[key] newurl = headers[key]
break break
if newurl is None: # pragma: no cover if newurl is None: # pragma: no cover
return return
urlparts = urlparse(newurl) urlparts = urlparse(newurl)
if urlparts.scheme == '': if urlparts.scheme == "":
newurl = urljoin(req.get_full_url(), newurl) newurl = urljoin(req.get_full_url(), newurl)
if hasattr(headers, 'replace_header'): if hasattr(headers, "replace_header"):
headers.replace_header(key, newurl) headers.replace_header(key, newurl)
else: else:
headers[key] = newurl headers[key] = newurl
@ -87,9 +110,10 @@ class Locator(object):
""" """
A base class for locators - things that locate distributions. A base class for locators - things that locate distributions.
""" """
source_extensions = ('.tar.gz', '.tar.bz2', '.tar', '.zip', '.tgz', '.tbz')
binary_extensions = ('.egg', '.exe', '.whl') source_extensions = (".tar.gz", ".tar.bz2", ".tar", ".zip", ".tgz", ".tbz")
excluded_extensions = ('.pdf', ) binary_extensions = (".egg", ".exe", ".whl")
excluded_extensions = (".pdf",)
# A list of tags indicating which wheels you want to match. The default # A list of tags indicating which wheels you want to match. The default
# value of None matches against the tags compatible with the running # value of None matches against the tags compatible with the running
@ -97,9 +121,9 @@ class Locator(object):
# instance to a list of tuples (pyver, abi, arch) which you want to match. # instance to a list of tuples (pyver, abi, arch) which you want to match.
wheel_tags = None wheel_tags = None
downloadable_extensions = source_extensions + ('.whl', ) downloadable_extensions = source_extensions + (".whl",)
def __init__(self, scheme='default'): def __init__(self, scheme="default"):
""" """
Initialise an instance. Initialise an instance.
:param scheme: Because locators look for most recent versions, they :param scheme: Because locators look for most recent versions, they
@ -160,13 +184,13 @@ class Locator(object):
If called from a locate() request, self.matcher will be set to a If called from a locate() request, self.matcher will be set to a
matcher for the requirement to satisfy, otherwise it will be None. matcher for the requirement to satisfy, otherwise it will be None.
""" """
raise NotImplementedError('Please implement in the subclass') raise NotImplementedError("Please implement in the subclass")
def get_distribution_names(self): def get_distribution_names(self):
""" """
Return all the distribution names known to this locator. Return all the distribution names known to this locator.
""" """
raise NotImplementedError('Please implement in the subclass') raise NotImplementedError("Please implement in the subclass")
def get_project(self, name): def get_project(self, name):
""" """
@ -193,11 +217,18 @@ class Locator(object):
t = urlparse(url) t = urlparse(url)
basename = posixpath.basename(t.path) basename = posixpath.basename(t.path)
compatible = True compatible = True
is_wheel = basename.endswith('.whl') is_wheel = basename.endswith(".whl")
is_downloadable = basename.endswith(self.downloadable_extensions) is_downloadable = basename.endswith(self.downloadable_extensions)
if is_wheel: if is_wheel:
compatible = is_compatible(Wheel(basename), self.wheel_tags) compatible = is_compatible(Wheel(basename), self.wheel_tags)
return (t.scheme == 'https', 'pypi.org' in t.netloc, is_downloadable, is_wheel, compatible, basename) return (
t.scheme == "https",
"pypi.org" in t.netloc,
is_downloadable,
is_wheel,
compatible,
basename,
)
def prefer_url(self, url1, url2): def prefer_url(self, url1, url2):
""" """
@ -216,9 +247,9 @@ class Locator(object):
if s1 > s2: if s1 > s2:
result = url1 result = url1
if result != url2: if result != url2:
logger.debug('Not replacing %r with %r', url1, url2) logger.debug("Not replacing %r with %r", url1, url2)
else: else:
logger.debug('Replacing %r with %r', url1, url2) logger.debug("Replacing %r with %r", url1, url2)
return result return result
def split_filename(self, filename, project_name): def split_filename(self, filename, project_name):
@ -241,21 +272,21 @@ class Locator(object):
result = None result = None
scheme, netloc, path, params, query, frag = urlparse(url) scheme, netloc, path, params, query, frag = urlparse(url)
if frag.lower().startswith('egg='): # pragma: no cover if frag.lower().startswith("egg="): # pragma: no cover
logger.debug('%s: version hint in fragment: %r', project_name, frag) logger.debug("%s: version hint in fragment: %r", project_name, frag)
m = HASHER_HASH.match(frag) m = HASHER_HASH.match(frag)
if m: if m:
algo, digest = m.groups() algo, digest = m.groups()
else: else:
algo, digest = None, None algo, digest = None, None
origpath = path origpath = path
if path and path[-1] == '/': # pragma: no cover if path and path[-1] == "/": # pragma: no cover
path = path[:-1] path = path[:-1]
if path.endswith('.whl'): if path.endswith(".whl"):
try: try:
wheel = Wheel(path) wheel = Wheel(path)
if not is_compatible(wheel, self.wheel_tags): if not is_compatible(wheel, self.wheel_tags):
logger.debug('Wheel not compatible: %s', path) logger.debug("Wheel not compatible: %s", path)
else: else:
if project_name is None: if project_name is None:
include = True include = True
@ -263,38 +294,44 @@ class Locator(object):
include = same_project(wheel.name, project_name) include = same_project(wheel.name, project_name)
if include: if include:
result = { result = {
'name': wheel.name, "name": wheel.name,
'version': wheel.version, "version": wheel.version,
'filename': wheel.filename, "filename": wheel.filename,
'url': urlunparse((scheme, netloc, origpath, params, query, '')), "url": urlunparse(
'python-version': ', '.join(['.'.join(list(v[2:])) for v in wheel.pyver]), (scheme, netloc, origpath, params, query, "")
),
"python-version": ", ".join(
[".".join(list(v[2:])) for v in wheel.pyver]
),
} }
except Exception: # pragma: no cover except Exception: # pragma: no cover
logger.warning('invalid path for wheel: %s', path) logger.warning("invalid path for wheel: %s", path)
elif not path.endswith(self.downloadable_extensions): # pragma: no cover elif not path.endswith(self.downloadable_extensions): # pragma: no cover
logger.debug('Not downloadable: %s', path) logger.debug("Not downloadable: %s", path)
else: # downloadable extension else: # downloadable extension
path = filename = posixpath.basename(path) path = filename = posixpath.basename(path)
for ext in self.downloadable_extensions: for ext in self.downloadable_extensions:
if path.endswith(ext): if path.endswith(ext):
path = path[:-len(ext)] path = path[: -len(ext)]
t = self.split_filename(path, project_name) t = self.split_filename(path, project_name)
if not t: # pragma: no cover if not t: # pragma: no cover
logger.debug('No match for project/version: %s', path) logger.debug("No match for project/version: %s", path)
else: else:
name, version, pyver = t name, version, pyver = t
if not project_name or same_project(project_name, name): if not project_name or same_project(project_name, name):
result = { result = {
'name': name, "name": name,
'version': version, "version": version,
'filename': filename, "filename": filename,
'url': urlunparse((scheme, netloc, origpath, params, query, '')), "url": urlunparse(
(scheme, netloc, origpath, params, query, "")
),
} }
if pyver: # pragma: no cover if pyver: # pragma: no cover
result['python-version'] = pyver result["python-version"] = pyver
break break
if result and algo: if result and algo:
result['%s_digest' % algo] = digest result["%s_digest" % algo] = digest
return result return result
def _get_digest(self, info): def _get_digest(self, info):
@ -306,15 +343,15 @@ class Locator(object):
looks only for SHA256, then MD5. looks only for SHA256, then MD5.
""" """
result = None result = None
if 'digests' in info: if "digests" in info:
digests = info['digests'] digests = info["digests"]
for algo in ('sha256', 'md5'): for algo in ("sha256", "md5"):
if algo in digests: if algo in digests:
result = (algo, digests[algo]) result = (algo, digests[algo])
break break
if not result: if not result:
for algo in ('sha256', 'md5'): for algo in ("sha256", "md5"):
key = '%s_digest' % algo key = "%s_digest" % algo
if key in info: if key in info:
result = (algo, info[key]) result = (algo, info[key])
break break
@ -326,8 +363,8 @@ class Locator(object):
dictionary for a specific version, which typically holds information dictionary for a specific version, which typically holds information
gleaned from a filename or URL for an archive for the distribution. gleaned from a filename or URL for an archive for the distribution.
""" """
name = info.pop('name') name = info.pop("name")
version = info.pop('version') version = info.pop("version")
if version in result: if version in result:
dist = result[version] dist = result[version]
md = dist.metadata md = dist.metadata
@ -335,11 +372,11 @@ class Locator(object):
dist = make_dist(name, version, scheme=self.scheme) dist = make_dist(name, version, scheme=self.scheme)
md = dist.metadata md = dist.metadata
dist.digest = digest = self._get_digest(info) dist.digest = digest = self._get_digest(info)
url = info['url'] url = info["url"]
result['digests'][url] = digest result["digests"][url] = digest
if md.source_url != info['url']: if md.source_url != info["url"]:
md.source_url = self.prefer_url(md.source_url, url) md.source_url = self.prefer_url(md.source_url, url)
result['urls'].setdefault(version, set()).add(url) result["urls"].setdefault(version, set()).add(url)
dist.locator = self dist.locator = self
result[version] = dist result[version] = dist
@ -359,17 +396,17 @@ class Locator(object):
result = None result = None
r = parse_requirement(requirement) r = parse_requirement(requirement)
if r is None: # pragma: no cover if r is None: # pragma: no cover
raise DistlibException('Not a valid requirement: %r' % requirement) raise DistlibException("Not a valid requirement: %r" % requirement)
scheme = get_scheme(self.scheme) scheme = get_scheme(self.scheme)
self.matcher = matcher = scheme.matcher(r.requirement) self.matcher = matcher = scheme.matcher(r.requirement)
logger.debug('matcher: %s (%s)', matcher, type(matcher).__name__) logger.debug("matcher: %s (%s)", matcher, type(matcher).__name__)
versions = self.get_project(r.name) versions = self.get_project(r.name)
if len(versions) > 2: # urls and digests keys are present if len(versions) > 2: # urls and digests keys are present
# sometimes, versions are invalid # sometimes, versions are invalid
slist = [] slist = []
vcls = matcher.version_class vcls = matcher.version_class
for k in versions: for k in versions:
if k in ('urls', 'digests'): if k in ("urls", "digests"):
continue continue
try: try:
if not matcher.match(k): if not matcher.match(k):
@ -378,20 +415,20 @@ class Locator(object):
if prereleases or not vcls(k).is_prerelease: if prereleases or not vcls(k).is_prerelease:
slist.append(k) slist.append(k)
except Exception: # pragma: no cover except Exception: # pragma: no cover
logger.warning('error matching %s with %r', matcher, k) logger.warning("error matching %s with %r", matcher, k)
pass # slist.append(k) pass # slist.append(k)
if len(slist) > 1: if len(slist) > 1:
slist = sorted(slist, key=scheme.key) slist = sorted(slist, key=scheme.key)
if slist: if slist:
logger.debug('sorted list: %s', slist) logger.debug("sorted list: %s", slist)
version = slist[-1] version = slist[-1]
result = versions[version] result = versions[version]
if result: if result:
if r.extras: if r.extras:
result.extras = r.extras result.extras = r.extras
result.download_urls = versions.get('urls', {}).get(version, set()) result.download_urls = versions.get("urls", {}).get(version, set())
d = {} d = {}
sd = versions.get('digests', {}) sd = versions.get("digests", {})
for url in result.download_urls: for url in result.download_urls:
if url in sd: # pragma: no cover if url in sd: # pragma: no cover
d[url] = sd[url] d[url] = sd[url]
@ -424,29 +461,29 @@ class PyPIRPCLocator(Locator):
return set(self.client.list_packages()) return set(self.client.list_packages())
def _get_project(self, name): def _get_project(self, name):
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
versions = self.client.package_releases(name, True) versions = self.client.package_releases(name, True)
for v in versions: for v in versions:
urls = self.client.release_urls(name, v) urls = self.client.release_urls(name, v)
data = self.client.release_data(name, v) data = self.client.release_data(name, v)
metadata = Metadata(scheme=self.scheme) metadata = Metadata(scheme=self.scheme)
metadata.name = data['name'] metadata.name = data["name"]
metadata.version = data['version'] metadata.version = data["version"]
metadata.license = data.get('license') metadata.license = data.get("license")
metadata.keywords = data.get('keywords', []) metadata.keywords = data.get("keywords", [])
metadata.summary = data.get('summary') metadata.summary = data.get("summary")
dist = Distribution(metadata) dist = Distribution(metadata)
if urls: if urls:
info = urls[0] info = urls[0]
metadata.source_url = info['url'] metadata.source_url = info["url"]
dist.digest = self._get_digest(info) dist.digest = self._get_digest(info)
dist.locator = self dist.locator = self
result[v] = dist result[v] = dist
for info in urls: for info in urls:
url = info['url'] url = info["url"]
digest = self._get_digest(info) digest = self._get_digest(info)
result['urls'].setdefault(v, set()).add(url) result["urls"].setdefault(v, set()).add(url)
result['digests'][url] = digest result["digests"][url] = digest
return result return result
@ -464,34 +501,34 @@ class PyPIJSONLocator(Locator):
""" """
Return all the distribution names known to this locator. Return all the distribution names known to this locator.
""" """
raise NotImplementedError('Not available from this locator') raise NotImplementedError("Not available from this locator")
def _get_project(self, name): def _get_project(self, name):
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
url = urljoin(self.base_url, '%s/json' % quote(name)) url = urljoin(self.base_url, "%s/json" % quote(name))
try: try:
resp = self.opener.open(url) resp = self.opener.open(url)
data = resp.read().decode() # for now data = resp.read().decode() # for now
d = json.loads(data) d = json.loads(data)
md = Metadata(scheme=self.scheme) md = Metadata(scheme=self.scheme)
data = d['info'] data = d["info"]
md.name = data['name'] md.name = data["name"]
md.version = data['version'] md.version = data["version"]
md.license = data.get('license') md.license = data.get("license")
md.keywords = data.get('keywords', []) md.keywords = data.get("keywords", [])
md.summary = data.get('summary') md.summary = data.get("summary")
dist = Distribution(md) dist = Distribution(md)
dist.locator = self dist.locator = self
# urls = d['urls'] # urls = d['urls']
result[md.version] = dist result[md.version] = dist
for info in d['urls']: for info in d["urls"]:
url = info['url'] url = info["url"]
dist.download_urls.add(url) dist.download_urls.add(url)
dist.digests[url] = self._get_digest(info) dist.digests[url] = self._get_digest(info)
result['urls'].setdefault(md.version, set()).add(url) result["urls"].setdefault(md.version, set()).add(url)
result['digests'][url] = self._get_digest(info) result["digests"][url] = self._get_digest(info)
# Now get other releases # Now get other releases
for version, infos in d['releases'].items(): for version, infos in d["releases"].items():
if version == md.version: if version == md.version:
continue # already done continue # already done
omd = Metadata(scheme=self.scheme) omd = Metadata(scheme=self.scheme)
@ -501,24 +538,23 @@ class PyPIJSONLocator(Locator):
odist.locator = self odist.locator = self
result[version] = odist result[version] = odist
for info in infos: for info in infos:
url = info['url'] url = info["url"]
odist.download_urls.add(url) odist.download_urls.add(url)
odist.digests[url] = self._get_digest(info) odist.digests[url] = self._get_digest(info)
result['urls'].setdefault(version, set()).add(url) result["urls"].setdefault(version, set()).add(url)
result['digests'][url] = self._get_digest(info) result["digests"][url] = self._get_digest(info)
# for info in urls:
# for info in urls: # md.source_url = info['url']
# md.source_url = info['url'] # dist.digest = self._get_digest(info)
# dist.digest = self._get_digest(info) # dist.locator = self
# dist.locator = self # for info in urls:
# for info in urls: # url = info['url']
# url = info['url'] # result['urls'].setdefault(md.version, set()).add(url)
# result['urls'].setdefault(md.version, set()).add(url) # result['digests'][url] = self._get_digest(info)
# result['digests'][url] = self._get_digest(info)
except Exception as e: except Exception as e:
self.errors.put(text_type(e)) self.errors.put(text_type(e))
logger.exception('JSON fetch failed: %s', e) logger.exception("JSON fetch failed: %s", e)
return result return result
@ -526,6 +562,7 @@ class Page(object):
""" """
This class represents a scraped HTML page. This class represents a scraped HTML page.
""" """
# The following slightly hairy-looking regex just looks for the contents of # The following slightly hairy-looking regex just looks for the contents of
# an anchor link, which has an attribute "href" either immediately preceded # an anchor link, which has an attribute "href" either immediately preceded
# or immediately followed by a "rel" attribute. The attribute values can be # or immediately followed by a "rel" attribute. The attribute values can be
@ -536,7 +573,9 @@ class Page(object):
(rel\\s*=\\s*(?:"(?P<rel1>[^"]*)"|'(?P<rel2>[^']*)'|(?P<rel3>[^>\\s\n]*))\\s+)? (rel\\s*=\\s*(?:"(?P<rel1>[^"]*)"|'(?P<rel2>[^']*)'|(?P<rel3>[^>\\s\n]*))\\s+)?
href\\s*=\\s*(?:"(?P<url1>[^"]*)"|'(?P<url2>[^']*)'|(?P<url3>[^>\\s\n]*)) href\\s*=\\s*(?:"(?P<url1>[^"]*)"|'(?P<url2>[^']*)'|(?P<url3>[^>\\s\n]*))
(\\s+rel\\s*=\\s*(?:"(?P<rel4>[^"]*)"|'(?P<rel5>[^']*)'|(?P<rel6>[^>\\s\n]*)))? (\\s+rel\\s*=\\s*(?:"(?P<rel4>[^"]*)"|'(?P<rel5>[^']*)'|(?P<rel6>[^>\\s\n]*)))?
""", re.I | re.S | re.X) """,
re.I | re.S | re.X,
)
_base = re.compile(r"""<base\s+href\s*=\s*['"]?([^'">]+)""", re.I | re.S) _base = re.compile(r"""<base\s+href\s*=\s*['"]?([^'">]+)""", re.I | re.S)
def __init__(self, data, url): def __init__(self, data, url):
@ -550,7 +589,7 @@ href\\s*=\\s*(?:"(?P<url1>[^"]*)"|'(?P<url2>[^']*)'|(?P<url3>[^>\\s\n]*))
if m: if m:
self.base_url = m.group(1) self.base_url = m.group(1)
_clean_re = re.compile(r'[^a-z0-9$&+,/:;=?@.#%_\\|-]', re.I) _clean_re = re.compile(r"[^a-z0-9$&+,/:;=?@.#%_\\|-]", re.I)
@cached_property @cached_property
def links(self): def links(self):
@ -567,12 +606,19 @@ href\\s*=\\s*(?:"(?P<url1>[^"]*)"|'(?P<url2>[^']*)'|(?P<url3>[^>\\s\n]*))
result = set() result = set()
for match in self._href.finditer(self.data): for match in self._href.finditer(self.data):
d = match.groupdict('') d = match.groupdict("")
rel = (d['rel1'] or d['rel2'] or d['rel3'] or d['rel4'] or d['rel5'] or d['rel6']) rel = (
url = d['url1'] or d['url2'] or d['url3'] d["rel1"]
or d["rel2"]
or d["rel3"]
or d["rel4"]
or d["rel5"]
or d["rel6"]
)
url = d["url1"] or d["url2"] or d["url3"]
url = urljoin(self.base_url, url) url = urljoin(self.base_url, url)
url = unescape(url) url = unescape(url)
url = self._clean_re.sub(lambda m: '%%%2x' % ord(m.group(0)), url) url = self._clean_re.sub(lambda m: "%%%2x" % ord(m.group(0)), url)
result.add((url, rel)) result.add((url, rel))
# We sort the result, hoping to bring the most recent versions # We sort the result, hoping to bring the most recent versions
# to the front # to the front
@ -589,9 +635,9 @@ class SimpleScrapingLocator(Locator):
# These are used to deal with various Content-Encoding schemes. # These are used to deal with various Content-Encoding schemes.
decoders = { decoders = {
'deflate': zlib.decompress, "deflate": zlib.decompress,
'gzip': lambda b: gzip.GzipFile(fileobj=BytesIO(b)).read(), "gzip": lambda b: gzip.GzipFile(fileobj=BytesIO(b)).read(),
'none': lambda b: b, "none": lambda b: b,
} }
def __init__(self, url, timeout=None, num_workers=10, **kwargs): def __init__(self, url, timeout=None, num_workers=10, **kwargs):
@ -648,16 +694,16 @@ class SimpleScrapingLocator(Locator):
self._threads = [] self._threads = []
def _get_project(self, name): def _get_project(self, name):
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
with self._gplock: with self._gplock:
self.result = result self.result = result
self.project_name = name self.project_name = name
url = urljoin(self.base_url, '%s/' % quote(name)) url = urljoin(self.base_url, "%s/" % quote(name))
self._seen.clear() self._seen.clear()
self._page_cache.clear() self._page_cache.clear()
self._prepare_threads() self._prepare_threads()
try: try:
logger.debug('Queueing %s', url) logger.debug("Queueing %s", url)
self._to_fetch.put(url) self._to_fetch.put(url)
self._to_fetch.join() self._to_fetch.join()
finally: finally:
@ -665,8 +711,9 @@ class SimpleScrapingLocator(Locator):
del self.result del self.result
return result return result
platform_dependent = re.compile(r'\b(linux_(i\d86|x86_64|arm\w+)|' platform_dependent = re.compile(
r'win(32|_amd64)|macosx_?\d+)\b', re.I) r"\b(linux_(i\d86|x86_64|arm\w+)|" r"win(32|_amd64)|macosx_?\d+)\b", re.I
)
def _is_platform_dependent(self, url): def _is_platform_dependent(self, url):
""" """
@ -688,7 +735,7 @@ class SimpleScrapingLocator(Locator):
info = None info = None
else: else:
info = self.convert_url_to_download_info(url, self.project_name) info = self.convert_url_to_download_info(url, self.project_name)
logger.debug('process_download: %s -> %s', url, info) logger.debug("process_download: %s -> %s", url, info)
if info: if info:
with self._lock: # needed because self.result is shared with self._lock: # needed because self.result is shared
self._update_version_data(self.result, info) self._update_version_data(self.result, info)
@ -700,25 +747,27 @@ class SimpleScrapingLocator(Locator):
particular "rel" attribute should be queued for scraping. particular "rel" attribute should be queued for scraping.
""" """
scheme, netloc, path, _, _, _ = urlparse(link) scheme, netloc, path, _, _, _ = urlparse(link)
if path.endswith(self.source_extensions + self.binary_extensions + self.excluded_extensions): if path.endswith(
self.source_extensions + self.binary_extensions + self.excluded_extensions
):
result = False result = False
elif self.skip_externals and not link.startswith(self.base_url): elif self.skip_externals and not link.startswith(self.base_url):
result = False result = False
elif not referrer.startswith(self.base_url): elif not referrer.startswith(self.base_url):
result = False result = False
elif rel not in ('homepage', 'download'): elif rel not in ("homepage", "download"):
result = False result = False
elif scheme not in ('http', 'https', 'ftp'): elif scheme not in ("http", "https", "ftp"):
result = False result = False
elif self._is_platform_dependent(link): elif self._is_platform_dependent(link):
result = False result = False
else: else:
host = netloc.split(':', 1)[0] host = netloc.split(":", 1)[0]
if host.lower() == 'localhost': if host.lower() == "localhost":
result = False result = False
else: else:
result = True result = True
logger.debug('should_queue: %s (%s) from %s -> %s', link, rel, referrer, result) logger.debug("should_queue: %s (%s) from %s -> %s", link, rel, referrer, result)
return result return result
def _fetch(self): def _fetch(self):
@ -739,8 +788,10 @@ class SimpleScrapingLocator(Locator):
if link not in self._seen: if link not in self._seen:
try: try:
self._seen.add(link) self._seen.add(link)
if (not self._process_download(link) and self._should_queue(link, url, rel)): if not self._process_download(
logger.debug('Queueing %s from %s', link, url) link
) and self._should_queue(link, url, rel):
logger.debug("Queueing %s from %s", link, url)
self._to_fetch.put(link) self._to_fetch.put(link)
except MetadataInvalidError: # e.g. invalid versions except MetadataInvalidError: # e.g. invalid versions
pass pass
@ -763,56 +814,56 @@ class SimpleScrapingLocator(Locator):
""" """
# http://peak.telecommunity.com/DevCenter/EasyInstall#package-index-api # http://peak.telecommunity.com/DevCenter/EasyInstall#package-index-api
scheme, netloc, path, _, _, _ = urlparse(url) scheme, netloc, path, _, _, _ = urlparse(url)
if scheme == 'file' and os.path.isdir(url2pathname(path)): if scheme == "file" and os.path.isdir(url2pathname(path)):
url = urljoin(ensure_slash(url), 'index.html') url = urljoin(ensure_slash(url), "index.html")
if url in self._page_cache: if url in self._page_cache:
result = self._page_cache[url] result = self._page_cache[url]
logger.debug('Returning %s from cache: %s', url, result) logger.debug("Returning %s from cache: %s", url, result)
else: else:
host = netloc.split(':', 1)[0] host = netloc.split(":", 1)[0]
result = None result = None
if host in self._bad_hosts: if host in self._bad_hosts:
logger.debug('Skipping %s due to bad host %s', url, host) logger.debug("Skipping %s due to bad host %s", url, host)
else: else:
req = Request(url, headers={'Accept-encoding': 'identity'}) req = Request(url, headers={"Accept-encoding": "identity"})
try: try:
logger.debug('Fetching %s', url) logger.debug("Fetching %s", url)
resp = self.opener.open(req, timeout=self.timeout) resp = self.opener.open(req, timeout=self.timeout)
logger.debug('Fetched %s', url) logger.debug("Fetched %s", url)
headers = resp.info() headers = resp.info()
content_type = headers.get('Content-Type', '') content_type = headers.get("Content-Type", "")
if HTML_CONTENT_TYPE.match(content_type): if HTML_CONTENT_TYPE.match(content_type):
final_url = resp.geturl() final_url = resp.geturl()
data = resp.read() data = resp.read()
encoding = headers.get('Content-Encoding') encoding = headers.get("Content-Encoding")
if encoding: if encoding:
decoder = self.decoders[encoding] # fail if not found decoder = self.decoders[encoding] # fail if not found
data = decoder(data) data = decoder(data)
encoding = 'utf-8' encoding = "utf-8"
m = CHARSET.search(content_type) m = CHARSET.search(content_type)
if m: if m:
encoding = m.group(1) encoding = m.group(1)
try: try:
data = data.decode(encoding) data = data.decode(encoding)
except UnicodeError: # pragma: no cover except UnicodeError: # pragma: no cover
data = data.decode('latin-1') # fallback data = data.decode("latin-1") # fallback
result = Page(data, final_url) result = Page(data, final_url)
self._page_cache[final_url] = result self._page_cache[final_url] = result
except HTTPError as e: except HTTPError as e:
if e.code != 404: if e.code != 404:
logger.exception('Fetch failed: %s: %s', url, e) logger.exception("Fetch failed: %s: %s", url, e)
except URLError as e: # pragma: no cover except URLError as e: # pragma: no cover
logger.exception('Fetch failed: %s: %s', url, e) logger.exception("Fetch failed: %s: %s", url, e)
with self._lock: with self._lock:
self._bad_hosts.add(host) self._bad_hosts.add(host)
except Exception as e: # pragma: no cover except Exception as e: # pragma: no cover
logger.exception('Fetch failed: %s: %s', url, e) logger.exception("Fetch failed: %s: %s", url, e)
finally: finally:
self._page_cache[url] = result # even if None (failure) self._page_cache[url] = result # even if None (failure)
return result return result
_distname_re = re.compile('<a href=[^>]*>([^<]+)<') _distname_re = re.compile("<a href=[^>]*>([^<]+)<")
def get_distribution_names(self): def get_distribution_names(self):
""" """
@ -821,7 +872,7 @@ class SimpleScrapingLocator(Locator):
result = set() result = set()
page = self.get_page(self.base_url) page = self.get_page(self.base_url)
if not page: if not page:
raise DistlibException('Unable to get %s' % self.base_url) raise DistlibException("Unable to get %s" % self.base_url)
for match in self._distname_re.finditer(page.data): for match in self._distname_re.finditer(page.data):
result.add(match.group(1)) result.add(match.group(1))
return result return result
@ -842,11 +893,11 @@ class DirectoryLocator(Locator):
recursed into. If False, only the top-level directory recursed into. If False, only the top-level directory
is searched, is searched,
""" """
self.recursive = kwargs.pop('recursive', True) self.recursive = kwargs.pop("recursive", True)
super(DirectoryLocator, self).__init__(**kwargs) super(DirectoryLocator, self).__init__(**kwargs)
path = os.path.abspath(path) path = os.path.abspath(path)
if not os.path.isdir(path): # pragma: no cover if not os.path.isdir(path): # pragma: no cover
raise DistlibException('Not a directory: %r' % path) raise DistlibException("Not a directory: %r" % path)
self.base_dir = path self.base_dir = path
def should_include(self, filename, parent): def should_include(self, filename, parent):
@ -858,12 +909,14 @@ class DirectoryLocator(Locator):
return filename.endswith(self.downloadable_extensions) return filename.endswith(self.downloadable_extensions)
def _get_project(self, name): def _get_project(self, name):
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
for root, dirs, files in os.walk(self.base_dir): for root, dirs, files in os.walk(self.base_dir):
for fn in files: for fn in files:
if self.should_include(fn, root): if self.should_include(fn, root):
fn = os.path.join(root, fn) fn = os.path.join(root, fn)
url = urlunparse(('file', '', pathname2url(os.path.abspath(fn)), '', '', '')) url = urlunparse(
("file", "", pathname2url(os.path.abspath(fn)), "", "", "")
)
info = self.convert_url_to_download_info(url, name) info = self.convert_url_to_download_info(url, name)
if info: if info:
self._update_version_data(result, info) self._update_version_data(result, info)
@ -880,10 +933,12 @@ class DirectoryLocator(Locator):
for fn in files: for fn in files:
if self.should_include(fn, root): if self.should_include(fn, root):
fn = os.path.join(root, fn) fn = os.path.join(root, fn)
url = urlunparse(('file', '', pathname2url(os.path.abspath(fn)), '', '', '')) url = urlunparse(
("file", "", pathname2url(os.path.abspath(fn)), "", "", "")
)
info = self.convert_url_to_download_info(url, None) info = self.convert_url_to_download_info(url, None)
if info: if info:
result.add(info['name']) result.add(info["name"])
if not self.recursive: if not self.recursive:
break break
return result return result
@ -901,31 +956,33 @@ class JSONLocator(Locator):
""" """
Return all the distribution names known to this locator. Return all the distribution names known to this locator.
""" """
raise NotImplementedError('Not available from this locator') raise NotImplementedError("Not available from this locator")
def _get_project(self, name): def _get_project(self, name):
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
data = get_project_data(name) data = get_project_data(name)
if data: if data:
for info in data.get('files', []): for info in data.get("files", []):
if info['ptype'] != 'sdist' or info['pyversion'] != 'source': if info["ptype"] != "sdist" or info["pyversion"] != "source":
continue continue
# We don't store summary in project metadata as it makes # We don't store summary in project metadata as it makes
# the data bigger for no benefit during dependency # the data bigger for no benefit during dependency
# resolution # resolution
dist = make_dist(data['name'], dist = make_dist(
info['version'], data["name"],
summary=data.get('summary', 'Placeholder for summary'), info["version"],
scheme=self.scheme) summary=data.get("summary", "Placeholder for summary"),
scheme=self.scheme,
)
md = dist.metadata md = dist.metadata
md.source_url = info['url'] md.source_url = info["url"]
# TODO SHA256 digest # TODO SHA256 digest
if 'digest' in info and info['digest']: if "digest" in info and info["digest"]:
dist.digest = ('md5', info['digest']) dist.digest = ("md5", info["digest"])
md.dependencies = info.get('requirements', {}) md.dependencies = info.get("requirements", {})
dist.exports = info.get('exports', {}) dist.exports = info.get("exports", {})
result[dist.version] = dist result[dist.version] = dist
result['urls'].setdefault(dist.version, set()).add(info['url']) result["urls"].setdefault(dist.version, set()).add(info["url"])
return result return result
@ -948,16 +1005,12 @@ class DistPathLocator(Locator):
def _get_project(self, name): def _get_project(self, name):
dist = self.distpath.get_distribution(name) dist = self.distpath.get_distribution(name)
if dist is None: if dist is None:
result = {'urls': {}, 'digests': {}} result = {"urls": {}, "digests": {}}
else: else:
result = { result = {
dist.version: dist, dist.version: dist,
'urls': { "urls": {dist.version: set([dist.source_url])},
dist.version: set([dist.source_url]) "digests": {dist.version: set([None])},
},
'digests': {
dist.version: set([None])
}
} }
return result return result
@ -979,7 +1032,7 @@ class AggregatingLocator(Locator):
the results from all locators are merged (this can be the results from all locators are merged (this can be
slow). slow).
""" """
self.merge = kwargs.pop('merge', False) self.merge = kwargs.pop("merge", False)
self.locators = locators self.locators = locators
super(AggregatingLocator, self).__init__(**kwargs) super(AggregatingLocator, self).__init__(**kwargs)
@ -1001,18 +1054,18 @@ class AggregatingLocator(Locator):
d = locator.get_project(name) d = locator.get_project(name)
if d: if d:
if self.merge: if self.merge:
files = result.get('urls', {}) files = result.get("urls", {})
digests = result.get('digests', {}) digests = result.get("digests", {})
# next line could overwrite result['urls'], result['digests'] # next line could overwrite result['urls'], result['digests']
result.update(d) result.update(d)
df = result.get('urls') df = result.get("urls")
if files and df: if files and df:
for k, v in files.items(): for k, v in files.items():
if k in df: if k in df:
df[k] |= v df[k] |= v
else: else:
df[k] = v df[k] = v
dd = result.get('digests') dd = result.get("digests")
if digests and dd: if digests and dd:
dd.update(digests) dd.update(digests)
else: else:
@ -1056,8 +1109,9 @@ class AggregatingLocator(Locator):
# versions which don't conform to PEP 440. # versions which don't conform to PEP 440.
default_locator = AggregatingLocator( default_locator = AggregatingLocator(
# JSONLocator(), # don't use as PEP 426 is withdrawn # JSONLocator(), # don't use as PEP 426 is withdrawn
SimpleScrapingLocator('https://pypi.org/simple/', timeout=3.0), SimpleScrapingLocator("https://pypi.org/simple/", timeout=3.0),
scheme='legacy') scheme="legacy",
)
locate = default_locator.locate locate = default_locator.locate
@ -1081,13 +1135,13 @@ class DependencyFinder(object):
about who provides what. about who provides what.
:param dist: The distribution to add. :param dist: The distribution to add.
""" """
logger.debug('adding distribution %s', dist) logger.debug("adding distribution %s", dist)
name = dist.key name = dist.key
self.dists_by_name[name] = dist self.dists_by_name[name] = dist
self.dists[(name, dist.version)] = dist self.dists[(name, dist.version)] = dist
for p in dist.provides: for p in dist.provides:
name, version = parse_name_and_version(p) name, version = parse_name_and_version(p)
logger.debug('Add to provided: %s, %s, %s', name, version, dist) logger.debug("Add to provided: %s, %s, %s", name, version, dist)
self.provided.setdefault(name, set()).add((version, dist)) self.provided.setdefault(name, set()).add((version, dist))
def remove_distribution(self, dist): def remove_distribution(self, dist):
@ -1096,13 +1150,13 @@ class DependencyFinder(object):
information about who provides what. information about who provides what.
:param dist: The distribution to remove. :param dist: The distribution to remove.
""" """
logger.debug('removing distribution %s', dist) logger.debug("removing distribution %s", dist)
name = dist.key name = dist.key
del self.dists_by_name[name] del self.dists_by_name[name]
del self.dists[(name, dist.version)] del self.dists[(name, dist.version)]
for p in dist.provides: for p in dist.provides:
name, version = parse_name_and_version(p) name, version = parse_name_and_version(p)
logger.debug('Remove from provided: %s, %s, %s', name, version, dist) logger.debug("Remove from provided: %s, %s, %s", name, version, dist)
s = self.provided[name] s = self.provided[name]
s.remove((version, dist)) s.remove((version, dist))
if not s: if not s:
@ -1175,7 +1229,7 @@ class DependencyFinder(object):
unmatched.add(s) unmatched.add(s)
if unmatched: if unmatched:
# can't replace other with provider # can't replace other with provider
problems.add(('cantreplace', provider, other, frozenset(unmatched))) problems.add(("cantreplace", provider, other, frozenset(unmatched)))
result = False result = False
else: else:
# can replace other with provider # can replace other with provider
@ -1219,19 +1273,19 @@ class DependencyFinder(object):
self.reqts = {} self.reqts = {}
meta_extras = set(meta_extras or []) meta_extras = set(meta_extras or [])
if ':*:' in meta_extras: if ":*:" in meta_extras:
meta_extras.remove(':*:') meta_extras.remove(":*:")
# :meta: and :run: are implicitly included # :meta: and :run: are implicitly included
meta_extras |= set([':test:', ':build:', ':dev:']) meta_extras |= set([":test:", ":build:", ":dev:"])
if isinstance(requirement, Distribution): if isinstance(requirement, Distribution):
dist = odist = requirement dist = odist = requirement
logger.debug('passed %s as requirement', odist) logger.debug("passed %s as requirement", odist)
else: else:
dist = odist = self.locator.locate(requirement, prereleases=prereleases) dist = odist = self.locator.locate(requirement, prereleases=prereleases)
if dist is None: if dist is None:
raise DistlibException('Unable to locate %r' % requirement) raise DistlibException("Unable to locate %r" % requirement)
logger.debug('located %s', odist) logger.debug("located %s", odist)
dist.requested = True dist.requested = True
problems = set() problems = set()
todo = set([dist]) todo = set([dist])
@ -1251,23 +1305,23 @@ class DependencyFinder(object):
sreqts = dist.build_requires sreqts = dist.build_requires
ereqts = set() ereqts = set()
if meta_extras and dist in install_dists: if meta_extras and dist in install_dists:
for key in ('test', 'build', 'dev'): for key in ("test", "build", "dev"):
e = ':%s:' % key e = ":%s:" % key
if e in meta_extras: if e in meta_extras:
ereqts |= getattr(dist, '%s_requires' % key) ereqts |= getattr(dist, "%s_requires" % key)
all_reqts = ireqts | sreqts | ereqts all_reqts = ireqts | sreqts | ereqts
for r in all_reqts: for r in all_reqts:
providers = self.find_providers(r) providers = self.find_providers(r)
if not providers: if not providers:
logger.debug('No providers found for %r', r) logger.debug("No providers found for %r", r)
provider = self.locator.locate(r, prereleases=prereleases) provider = self.locator.locate(r, prereleases=prereleases)
# If no provider is found and we didn't consider # If no provider is found and we didn't consider
# prereleases, consider them now. # prereleases, consider them now.
if provider is None and not prereleases: if provider is None and not prereleases:
provider = self.locator.locate(r, prereleases=True) provider = self.locator.locate(r, prereleases=True)
if provider is None: if provider is None:
logger.debug('Cannot satisfy %r', r) logger.debug("Cannot satisfy %r", r)
problems.add(('unsatisfied', r)) problems.add(("unsatisfied", r))
else: else:
n, v = provider.key, provider.version n, v = provider.key, provider.version
if (n, v) not in self.dists: if (n, v) not in self.dists:
@ -1275,7 +1329,9 @@ class DependencyFinder(object):
providers.add(provider) providers.add(provider)
if r in ireqts and dist in install_dists: if r in ireqts and dist in install_dists:
install_dists.add(provider) install_dists.add(provider)
logger.debug('Adding %s to install_dists', provider.name_and_version) logger.debug(
"Adding %s to install_dists", provider.name_and_version
)
for p in providers: for p in providers:
name = p.key name = p.key
if name not in self.dists_by_name: if name not in self.dists_by_name:
@ -1290,6 +1346,8 @@ class DependencyFinder(object):
for dist in dists: for dist in dists:
dist.build_time_dependency = dist not in install_dists dist.build_time_dependency = dist not in install_dists
if dist.build_time_dependency: if dist.build_time_dependency:
logger.debug('%s is a build-time dependency only.', dist.name_and_version) logger.debug(
logger.debug('find done for %s', odist) "%s is a build-time dependency only.", dist.name_and_version
)
logger.debug("find done for %s", odist)
return dists, problems return dists, problems

View File

@ -8,6 +8,7 @@ Class representing the list of files in a distribution.
Equivalent to distutils.filelist, but fixes some problems. Equivalent to distutils.filelist, but fixes some problems.
""" """
import fnmatch import fnmatch
import logging import logging
import os import os
@ -18,14 +19,13 @@ from . import DistlibException
from .compat import fsdecode from .compat import fsdecode
from .util import convert_path from .util import convert_path
__all__ = ["Manifest"]
__all__ = ['Manifest']
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# a \ followed by some spaces + EOL # a \ followed by some spaces + EOL
_COLLAPSE_PATTERN = re.compile('\\\\w*\n', re.M) _COLLAPSE_PATTERN = re.compile("\\\\w*\n", re.M)
_COMMENTED_LINE = re.compile('#.*?(?=\n)|\n(?=$)', re.M | re.S) _COMMENTED_LINE = re.compile("#.*?(?=\n)|\n(?=$)", re.M | re.S)
# #
# Due to the different results returned by fnmatch.translate, we need # Due to the different results returned by fnmatch.translate, we need
@ -109,20 +109,22 @@ class Manifest(object):
def add_dir(dirs, d): def add_dir(dirs, d):
dirs.add(d) dirs.add(d)
logger.debug('add_dir added %s', d) logger.debug("add_dir added %s", d)
if d != self.base: if d != self.base:
parent, _ = os.path.split(d) parent, _ = os.path.split(d)
assert parent not in ('', '/') assert parent not in ("", "/")
add_dir(dirs, parent) add_dir(dirs, parent)
result = set(self.files) # make a copy! result = set(self.files) # make a copy!
if wantdirs: if wantdirs:
dirs = set() dirs = set()
for f in result: for f in result:
add_dir(dirs, os.path.dirname(f)) add_dir(dirs, os.path.dirname(f))
result |= dirs result |= dirs
return [os.path.join(*path_tuple) for path_tuple in return [
sorted(os.path.split(path) for path in result)] os.path.join(*path_tuple)
for path_tuple in sorted(os.path.split(path) for path in result)
]
def clear(self): def clear(self):
"""Clear all collected files.""" """Clear all collected files."""
@ -149,49 +151,54 @@ class Manifest(object):
# OK, now we know that the action is valid and we have the # OK, now we know that the action is valid and we have the
# right number of words on the line for that action -- so we # right number of words on the line for that action -- so we
# can proceed with minimal error-checking. # can proceed with minimal error-checking.
if action == 'include': if action == "include":
for pattern in patterns: for pattern in patterns:
if not self._include_pattern(pattern, anchor=True): if not self._include_pattern(pattern, anchor=True):
logger.warning('no files found matching %r', pattern) logger.warning("no files found matching %r", pattern)
elif action == 'exclude': elif action == "exclude":
for pattern in patterns: for pattern in patterns:
self._exclude_pattern(pattern, anchor=True) self._exclude_pattern(pattern, anchor=True)
elif action == 'global-include': elif action == "global-include":
for pattern in patterns: for pattern in patterns:
if not self._include_pattern(pattern, anchor=False): if not self._include_pattern(pattern, anchor=False):
logger.warning('no files found matching %r ' logger.warning(
'anywhere in distribution', pattern) "no files found matching %r " "anywhere in distribution",
pattern,
)
elif action == 'global-exclude': elif action == "global-exclude":
for pattern in patterns: for pattern in patterns:
self._exclude_pattern(pattern, anchor=False) self._exclude_pattern(pattern, anchor=False)
elif action == 'recursive-include': elif action == "recursive-include":
for pattern in patterns: for pattern in patterns:
if not self._include_pattern(pattern, prefix=thedir): if not self._include_pattern(pattern, prefix=thedir):
logger.warning('no files found matching %r ' logger.warning(
'under directory %r', pattern, thedir) "no files found matching %r " "under directory %r",
pattern,
thedir,
)
elif action == 'recursive-exclude': elif action == "recursive-exclude":
for pattern in patterns: for pattern in patterns:
self._exclude_pattern(pattern, prefix=thedir) self._exclude_pattern(pattern, prefix=thedir)
elif action == 'graft': elif action == "graft":
if not self._include_pattern(None, prefix=dirpattern): if not self._include_pattern(None, prefix=dirpattern):
logger.warning('no directories found matching %r', logger.warning("no directories found matching %r", dirpattern)
dirpattern)
elif action == 'prune': elif action == "prune":
if not self._exclude_pattern(None, prefix=dirpattern): if not self._exclude_pattern(None, prefix=dirpattern):
logger.warning('no previously-included directories found ' logger.warning(
'matching %r', dirpattern) "no previously-included directories found " "matching %r",
else: # pragma: no cover dirpattern,
)
else: # pragma: no cover
# This should never happen, as it should be caught in # This should never happen, as it should be caught in
# _parse_template_line # _parse_template_line
raise DistlibException( raise DistlibException("invalid action %r" % action)
'invalid action %r' % action)
# #
# Private API # Private API
@ -204,48 +211,49 @@ class Manifest(object):
:return: A tuple of action, patterns, thedir, dir_patterns :return: A tuple of action, patterns, thedir, dir_patterns
""" """
words = directive.split() words = directive.split()
if len(words) == 1 and words[0] not in ('include', 'exclude', if len(words) == 1 and words[0] not in (
'global-include', "include",
'global-exclude', "exclude",
'recursive-include', "global-include",
'recursive-exclude', "global-exclude",
'graft', 'prune'): "recursive-include",
"recursive-exclude",
"graft",
"prune",
):
# no action given, let's use the default 'include' # no action given, let's use the default 'include'
words.insert(0, 'include') words.insert(0, "include")
action = words[0] action = words[0]
patterns = thedir = dir_pattern = None patterns = thedir = dir_pattern = None
if action in ('include', 'exclude', if action in ("include", "exclude", "global-include", "global-exclude"):
'global-include', 'global-exclude'):
if len(words) < 2: if len(words) < 2:
raise DistlibException( raise DistlibException("%r expects <pattern1> <pattern2> ..." % action)
'%r expects <pattern1> <pattern2> ...' % action)
patterns = [convert_path(word) for word in words[1:]] patterns = [convert_path(word) for word in words[1:]]
elif action in ('recursive-include', 'recursive-exclude'): elif action in ("recursive-include", "recursive-exclude"):
if len(words) < 3: if len(words) < 3:
raise DistlibException( raise DistlibException(
'%r expects <dir> <pattern1> <pattern2> ...' % action) "%r expects <dir> <pattern1> <pattern2> ..." % action
)
thedir = convert_path(words[1]) thedir = convert_path(words[1])
patterns = [convert_path(word) for word in words[2:]] patterns = [convert_path(word) for word in words[2:]]
elif action in ('graft', 'prune'): elif action in ("graft", "prune"):
if len(words) != 2: if len(words) != 2:
raise DistlibException( raise DistlibException("%r expects a single <dir_pattern>" % action)
'%r expects a single <dir_pattern>' % action)
dir_pattern = convert_path(words[1]) dir_pattern = convert_path(words[1])
else: else:
raise DistlibException('unknown action %r' % action) raise DistlibException("unknown action %r" % action)
return action, patterns, thedir, dir_pattern return action, patterns, thedir, dir_pattern
def _include_pattern(self, pattern, anchor=True, prefix=None, def _include_pattern(self, pattern, anchor=True, prefix=None, is_regex=False):
is_regex=False):
"""Select strings (presumably filenames) from 'self.files' that """Select strings (presumably filenames) from 'self.files' that
match 'pattern', a Unix-style wildcard (glob) pattern. match 'pattern', a Unix-style wildcard (glob) pattern.
@ -285,8 +293,7 @@ class Manifest(object):
found = True found = True
return found return found
def _exclude_pattern(self, pattern, anchor=True, prefix=None, def _exclude_pattern(self, pattern, anchor=True, prefix=None, is_regex=False):
is_regex=False):
"""Remove strings (presumably filenames) from 'files' that match """Remove strings (presumably filenames) from 'files' that match
'pattern'. 'pattern'.
@ -305,8 +312,7 @@ class Manifest(object):
found = True found = True
return found return found
def _translate_pattern(self, pattern, anchor=True, prefix=None, def _translate_pattern(self, pattern, anchor=True, prefix=None, is_regex=False):
is_regex=False):
"""Translate a shell-like wildcard pattern to a compiled regular """Translate a shell-like wildcard pattern to a compiled regular
expression. expression.
@ -322,41 +328,46 @@ class Manifest(object):
if _PYTHON_VERSION > (3, 2): if _PYTHON_VERSION > (3, 2):
# ditch start and end characters # ditch start and end characters
start, _, end = self._glob_to_re('_').partition('_') start, _, end = self._glob_to_re("_").partition("_")
if pattern: if pattern:
pattern_re = self._glob_to_re(pattern) pattern_re = self._glob_to_re(pattern)
if _PYTHON_VERSION > (3, 2): if _PYTHON_VERSION > (3, 2):
assert pattern_re.startswith(start) and pattern_re.endswith(end) assert pattern_re.startswith(start) and pattern_re.endswith(end)
else: else:
pattern_re = '' pattern_re = ""
base = re.escape(os.path.join(self.base, '')) base = re.escape(os.path.join(self.base, ""))
if prefix is not None: if prefix is not None:
# ditch end of pattern character # ditch end of pattern character
if _PYTHON_VERSION <= (3, 2): if _PYTHON_VERSION <= (3, 2):
empty_pattern = self._glob_to_re('') empty_pattern = self._glob_to_re("")
prefix_re = self._glob_to_re(prefix)[:-len(empty_pattern)] prefix_re = self._glob_to_re(prefix)[: -len(empty_pattern)]
else: else:
prefix_re = self._glob_to_re(prefix) prefix_re = self._glob_to_re(prefix)
assert prefix_re.startswith(start) and prefix_re.endswith(end) assert prefix_re.startswith(start) and prefix_re.endswith(end)
prefix_re = prefix_re[len(start): len(prefix_re) - len(end)] prefix_re = prefix_re[len(start) : len(prefix_re) - len(end)]
sep = os.sep sep = os.sep
if os.sep == '\\': if os.sep == "\\":
sep = r'\\' sep = r"\\"
if _PYTHON_VERSION <= (3, 2): if _PYTHON_VERSION <= (3, 2):
pattern_re = '^' + base + sep.join((prefix_re, pattern_re = "^" + base + sep.join((prefix_re, ".*" + pattern_re))
'.*' + pattern_re))
else: else:
pattern_re = pattern_re[len(start): len(pattern_re) - len(end)] pattern_re = pattern_re[len(start) : len(pattern_re) - len(end)]
pattern_re = r'%s%s%s%s.*%s%s' % (start, base, prefix_re, sep, pattern_re = r"%s%s%s%s.*%s%s" % (
pattern_re, end) start,
base,
prefix_re,
sep,
pattern_re,
end,
)
else: # no prefix -- respect anchor flag else: # no prefix -- respect anchor flag
if anchor: if anchor:
if _PYTHON_VERSION <= (3, 2): if _PYTHON_VERSION <= (3, 2):
pattern_re = '^' + base + pattern_re pattern_re = "^" + base + pattern_re
else: else:
pattern_re = r'%s%s%s' % (start, base, pattern_re[len(start):]) pattern_re = r"%s%s%s" % (start, base, pattern_re[len(start) :])
return re.compile(pattern_re) return re.compile(pattern_re)
@ -375,10 +386,10 @@ class Manifest(object):
# any OS. So change all non-escaped dots in the RE to match any # any OS. So change all non-escaped dots in the RE to match any
# character except the special characters (currently: just os.sep). # character except the special characters (currently: just os.sep).
sep = os.sep sep = os.sep
if os.sep == '\\': if os.sep == "\\":
# we're using a regex to manipulate a regex, so we need # we're using a regex to manipulate a regex, so we need
# to escape the backslash twice # to escape the backslash twice
sep = r'\\\\' sep = r"\\\\"
escaped = r'\1[^%s]' % sep escaped = r"\1[^%s]" % sep
pattern_re = re.sub(r'((?<!\\)(\\\\)*)\.', escaped, pattern_re) pattern_re = re.sub(r"((?<!\\)(\\\\)*)\.", escaped, pattern_re)
return pattern_re return pattern_re

View File

@ -21,10 +21,12 @@ from .compat import string_types
from .util import in_venv, parse_marker from .util import in_venv, parse_marker
from .version import LegacyVersion as LV from .version import LegacyVersion as LV
__all__ = ['interpret'] __all__ = ["interpret"]
_VERSION_PATTERN = re.compile(r'((\d+(\.\d+)*\w*)|\'(\d+(\.\d+)*\w*)\'|\"(\d+(\.\d+)*\w*)\")') _VERSION_PATTERN = re.compile(
_VERSION_MARKERS = {'python_version', 'python_full_version'} r"((\d+(\.\d+)*\w*)|\'(\d+(\.\d+)*\w*)\'|\"(\d+(\.\d+)*\w*)\")"
)
_VERSION_MARKERS = {"python_version", "python_full_version"}
def _is_version_marker(s): def _is_version_marker(s):
@ -34,7 +36,7 @@ def _is_version_marker(s):
def _is_literal(o): def _is_literal(o):
if not isinstance(o, string_types) or not o: if not isinstance(o, string_types) or not o:
return False return False
return o[0] in '\'"' return o[0] in "'\""
def _get_versions(s): def _get_versions(s):
@ -47,18 +49,18 @@ class Evaluator(object):
""" """
operations = { operations = {
'==': lambda x, y: x == y, "==": lambda x, y: x == y,
'===': lambda x, y: x == y, "===": lambda x, y: x == y,
'~=': lambda x, y: x == y or x > y, "~=": lambda x, y: x == y or x > y,
'!=': lambda x, y: x != y, "!=": lambda x, y: x != y,
'<': lambda x, y: x < y, "<": lambda x, y: x < y,
'<=': lambda x, y: x == y or x < y, "<=": lambda x, y: x == y or x < y,
'>': lambda x, y: x > y, ">": lambda x, y: x > y,
'>=': lambda x, y: x == y or x > y, ">=": lambda x, y: x == y or x > y,
'and': lambda x, y: x and y, "and": lambda x, y: x and y,
'or': lambda x, y: x or y, "or": lambda x, y: x or y,
'in': lambda x, y: x in y, "in": lambda x, y: x in y,
'not in': lambda x, y: x not in y, "not in": lambda x, y: x not in y,
} }
def evaluate(self, expr, context): def evaluate(self, expr, context):
@ -67,70 +69,78 @@ class Evaluator(object):
function in the specified context. function in the specified context.
""" """
if isinstance(expr, string_types): if isinstance(expr, string_types):
if expr[0] in '\'"': if expr[0] in "'\"":
result = expr[1:-1] result = expr[1:-1]
else: else:
if expr not in context: if expr not in context:
raise SyntaxError('unknown variable: %s' % expr) raise SyntaxError("unknown variable: %s" % expr)
result = context[expr] result = context[expr]
else: else:
assert isinstance(expr, dict) assert isinstance(expr, dict)
op = expr['op'] op = expr["op"]
if op not in self.operations: if op not in self.operations:
raise NotImplementedError('op not implemented: %s' % op) raise NotImplementedError("op not implemented: %s" % op)
elhs = expr['lhs'] elhs = expr["lhs"]
erhs = expr['rhs'] erhs = expr["rhs"]
if _is_literal(expr['lhs']) and _is_literal(expr['rhs']): if _is_literal(expr["lhs"]) and _is_literal(expr["rhs"]):
raise SyntaxError('invalid comparison: %s %s %s' % (elhs, op, erhs)) raise SyntaxError("invalid comparison: %s %s %s" % (elhs, op, erhs))
lhs = self.evaluate(elhs, context) lhs = self.evaluate(elhs, context)
rhs = self.evaluate(erhs, context) rhs = self.evaluate(erhs, context)
if ((_is_version_marker(elhs) or _is_version_marker(erhs)) and if (_is_version_marker(elhs) or _is_version_marker(erhs)) and op in (
op in ('<', '<=', '>', '>=', '===', '==', '!=', '~=')): "<",
"<=",
">",
">=",
"===",
"==",
"!=",
"~=",
):
lhs = LV(lhs) lhs = LV(lhs)
rhs = LV(rhs) rhs = LV(rhs)
elif _is_version_marker(elhs) and op in ('in', 'not in'): elif _is_version_marker(elhs) and op in ("in", "not in"):
lhs = LV(lhs) lhs = LV(lhs)
rhs = _get_versions(rhs) rhs = _get_versions(rhs)
result = self.operations[op](lhs, rhs) result = self.operations[op](lhs, rhs)
return result return result
_DIGITS = re.compile(r'\d+\.\d+') _DIGITS = re.compile(r"\d+\.\d+")
def default_context(): def default_context():
def format_full_version(info): def format_full_version(info):
version = '%s.%s.%s' % (info.major, info.minor, info.micro) version = "%s.%s.%s" % (info.major, info.minor, info.micro)
kind = info.releaselevel kind = info.releaselevel
if kind != 'final': if kind != "final":
version += kind[0] + str(info.serial) version += kind[0] + str(info.serial)
return version return version
if hasattr(sys, 'implementation'): if hasattr(sys, "implementation"):
implementation_version = format_full_version(sys.implementation.version) implementation_version = format_full_version(sys.implementation.version)
implementation_name = sys.implementation.name implementation_name = sys.implementation.name
else: else:
implementation_version = '0' implementation_version = "0"
implementation_name = '' implementation_name = ""
ppv = platform.python_version() ppv = platform.python_version()
m = _DIGITS.match(ppv) m = _DIGITS.match(ppv)
pv = m.group(0) pv = m.group(0)
result = { result = {
'implementation_name': implementation_name, "implementation_name": implementation_name,
'implementation_version': implementation_version, "implementation_version": implementation_version,
'os_name': os.name, "os_name": os.name,
'platform_machine': platform.machine(), "platform_machine": platform.machine(),
'platform_python_implementation': platform.python_implementation(), "platform_python_implementation": platform.python_implementation(),
'platform_release': platform.release(), "platform_release": platform.release(),
'platform_system': platform.system(), "platform_system": platform.system(),
'platform_version': platform.version(), "platform_version": platform.version(),
'platform_in_venv': str(in_venv()), "platform_in_venv": str(in_venv()),
'python_full_version': ppv, "python_full_version": ppv,
'python_version': pv, "python_version": pv,
'sys_platform': sys.platform, "sys_platform": sys.platform,
} }
return result return result
@ -140,12 +150,14 @@ del default_context
evaluator = Evaluator() evaluator = Evaluator()
def interpret_parsed(expr, execution_context=None): def interpret_parsed(expr, execution_context=None):
context = dict(DEFAULT_CONTEXT) context = dict(DEFAULT_CONTEXT)
if execution_context: if execution_context:
context.update(execution_context) context.update(execution_context)
return evaluator.evaluate(expr, context) return evaluator.evaluate(expr, context)
def interpret(marker, execution_context=None): def interpret(marker, execution_context=None):
""" """
Interpret a marker and return a result depending on environment. Interpret a marker and return a result depending on environment.
@ -158,7 +170,7 @@ def interpret(marker, execution_context=None):
try: try:
expr, rest = parse_marker(marker) expr, rest = parse_marker(marker)
except Exception as e: except Exception as e:
raise SyntaxError('Unable to interpret marker syntax: %s: %s' % (marker, e)) raise SyntaxError("Unable to interpret marker syntax: %s: %s" % (marker, e))
if rest and rest[0] != '#': if rest and rest[0] != "#":
raise SyntaxError('unexpected trailing data in marker: %s: %s' % (marker, rest)) raise SyntaxError("unexpected trailing data in marker: %s: %s" % (marker, rest))
return interpret_parsed(expr, execution_context) return interpret_parsed(expr, execution_context)

File diff suppressed because it is too large Load Diff

View File

@ -21,14 +21,14 @@ from .util import cached_property, get_cache_base, Cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
cache = None # created when needed cache = None # created when needed
class ResourceCache(Cache): class ResourceCache(Cache):
def __init__(self, base=None): def __init__(self, base=None):
if base is None: if base is None:
# Use native string to avoid issues on 2.x: see Python #20140. # Use native string to avoid issues on 2.x: see Python #20140.
base = os.path.join(get_cache_base(), str('resource-cache')) base = os.path.join(get_cache_base(), str("resource-cache"))
super(ResourceCache, self).__init__(base) super(ResourceCache, self).__init__(base)
def is_stale(self, resource, path): def is_stale(self, resource, path):
@ -63,7 +63,7 @@ class ResourceCache(Cache):
stale = self.is_stale(resource, path) stale = self.is_stale(resource, path)
if stale: if stale:
# write the bytes of the resource to the cache location # write the bytes of the resource to the cache location
with open(result, 'wb') as f: with open(result, "wb") as f:
f.write(resource.bytes) f.write(resource.bytes)
return result return result
@ -80,7 +80,8 @@ class Resource(ResourceBase):
not normally instantiated by user code, but rather by a not normally instantiated by user code, but rather by a
:class:`ResourceFinder` which manages the resource. :class:`ResourceFinder` which manages the resource.
""" """
is_container = False # Backwards compatibility
is_container = False # Backwards compatibility
def as_stream(self): def as_stream(self):
""" """
@ -108,7 +109,7 @@ class Resource(ResourceBase):
class ResourceContainer(ResourceBase): class ResourceContainer(ResourceBase):
is_container = True # Backwards compatibility is_container = True # Backwards compatibility
@cached_property @cached_property
def resources(self): def resources(self):
@ -120,15 +121,15 @@ class ResourceFinder(object):
Resource finder for file system resources. Resource finder for file system resources.
""" """
if sys.platform.startswith('java'): if sys.platform.startswith("java"):
skipped_extensions = ('.pyc', '.pyo', '.class') skipped_extensions = (".pyc", ".pyo", ".class")
else: else:
skipped_extensions = ('.pyc', '.pyo') skipped_extensions = (".pyc", ".pyo")
def __init__(self, module): def __init__(self, module):
self.module = module self.module = module
self.loader = getattr(module, '__loader__', None) self.loader = getattr(module, "__loader__", None)
self.base = os.path.dirname(getattr(module, '__file__', '')) self.base = os.path.dirname(getattr(module, "__file__", ""))
def _adjust_path(self, path): def _adjust_path(self, path):
return os.path.realpath(path) return os.path.realpath(path)
@ -136,10 +137,10 @@ class ResourceFinder(object):
def _make_path(self, resource_name): def _make_path(self, resource_name):
# Issue #50: need to preserve type of path on Python 2.x # Issue #50: need to preserve type of path on Python 2.x
# like os.path._get_sep # like os.path._get_sep
if isinstance(resource_name, bytes): # should only happen on 2.x if isinstance(resource_name, bytes): # should only happen on 2.x
sep = b'/' sep = b"/"
else: else:
sep = '/' sep = "/"
parts = resource_name.split(sep) parts = resource_name.split(sep)
parts.insert(0, self.base) parts.insert(0, self.base)
result = os.path.join(*parts) result = os.path.join(*parts)
@ -164,10 +165,10 @@ class ResourceFinder(object):
return result return result
def get_stream(self, resource): def get_stream(self, resource):
return open(resource.path, 'rb') return open(resource.path, "rb")
def get_bytes(self, resource): def get_bytes(self, resource):
with open(resource.path, 'rb') as f: with open(resource.path, "rb") as f:
return f.read() return f.read()
def get_size(self, resource): def get_size(self, resource):
@ -175,8 +176,8 @@ class ResourceFinder(object):
def get_resources(self, resource): def get_resources(self, resource):
def allowed(f): def allowed(f):
return (f != '__pycache__' and not return f != "__pycache__" and not f.endswith(self.skipped_extensions)
f.endswith(self.skipped_extensions))
return set([f for f in os.listdir(resource.path) if allowed(f)]) return set([f for f in os.listdir(resource.path) if allowed(f)])
def is_container(self, resource): def is_container(self, resource):
@ -197,7 +198,7 @@ class ResourceFinder(object):
if not rname: if not rname:
new_name = name new_name = name
else: else:
new_name = '/'.join([rname, name]) new_name = "/".join([rname, name])
child = self.find(new_name) child = self.find(new_name)
if child.is_container: if child.is_container:
todo.append(child) todo.append(child)
@ -209,12 +210,13 @@ class ZipResourceFinder(ResourceFinder):
""" """
Resource finder for resources in .zip files. Resource finder for resources in .zip files.
""" """
def __init__(self, module): def __init__(self, module):
super(ZipResourceFinder, self).__init__(module) super(ZipResourceFinder, self).__init__(module)
archive = self.loader.archive archive = self.loader.archive
self.prefix_len = 1 + len(archive) self.prefix_len = 1 + len(archive)
# PyPy doesn't have a _files attr on zipimporter, and you can't set one # PyPy doesn't have a _files attr on zipimporter, and you can't set one
if hasattr(self.loader, '_files'): if hasattr(self.loader, "_files"):
self._files = self.loader._files self._files = self.loader._files
else: else:
self._files = zipimport._zip_directory_cache[archive] self._files = zipimport._zip_directory_cache[archive]
@ -224,7 +226,7 @@ class ZipResourceFinder(ResourceFinder):
return path return path
def _find(self, path): def _find(self, path):
path = path[self.prefix_len:] path = path[self.prefix_len :]
if path in self._files: if path in self._files:
result = True result = True
else: else:
@ -236,14 +238,14 @@ class ZipResourceFinder(ResourceFinder):
except IndexError: except IndexError:
result = False result = False
if not result: if not result:
logger.debug('_find failed: %r %r', path, self.loader.prefix) logger.debug("_find failed: %r %r", path, self.loader.prefix)
else: else:
logger.debug('_find worked: %r %r', path, self.loader.prefix) logger.debug("_find worked: %r %r", path, self.loader.prefix)
return result return result
def get_cache_info(self, resource): def get_cache_info(self, resource):
prefix = self.loader.archive prefix = self.loader.archive
path = resource.path[1 + len(prefix):] path = resource.path[1 + len(prefix) :]
return prefix, path return prefix, path
def get_bytes(self, resource): def get_bytes(self, resource):
@ -253,11 +255,11 @@ class ZipResourceFinder(ResourceFinder):
return io.BytesIO(self.get_bytes(resource)) return io.BytesIO(self.get_bytes(resource))
def get_size(self, resource): def get_size(self, resource):
path = resource.path[self.prefix_len:] path = resource.path[self.prefix_len :]
return self._files[path][3] return self._files[path][3]
def get_resources(self, resource): def get_resources(self, resource):
path = resource.path[self.prefix_len:] path = resource.path[self.prefix_len :]
if path and path[-1] != os.sep: if path and path[-1] != os.sep:
path += os.sep path += os.sep
plen = len(path) plen = len(path)
@ -267,12 +269,12 @@ class ZipResourceFinder(ResourceFinder):
if not self.index[i].startswith(path): if not self.index[i].startswith(path):
break break
s = self.index[i][plen:] s = self.index[i][plen:]
result.add(s.split(os.sep, 1)[0]) # only immediate children result.add(s.split(os.sep, 1)[0]) # only immediate children
i += 1 i += 1
return result return result
def _is_directory(self, path): def _is_directory(self, path):
path = path[self.prefix_len:] path = path[self.prefix_len :]
if path and path[-1] != os.sep: if path and path[-1] != os.sep:
path += os.sep path += os.sep
i = bisect.bisect(self.index, path) i = bisect.bisect(self.index, path)
@ -285,7 +287,7 @@ class ZipResourceFinder(ResourceFinder):
_finder_registry = { _finder_registry = {
type(None): ResourceFinder, type(None): ResourceFinder,
zipimport.zipimporter: ZipResourceFinder zipimport.zipimporter: ZipResourceFinder,
} }
try: try:
@ -322,20 +324,21 @@ def finder(package):
if package not in sys.modules: if package not in sys.modules:
__import__(package) __import__(package)
module = sys.modules[package] module = sys.modules[package]
path = getattr(module, '__path__', None) path = getattr(module, "__path__", None)
if path is None: if path is None:
raise DistlibException('You cannot get a finder for a module, ' raise DistlibException(
'only for a package') "You cannot get a finder for a module, " "only for a package"
loader = getattr(module, '__loader__', None) )
loader = getattr(module, "__loader__", None)
finder_maker = _finder_registry.get(type(loader)) finder_maker = _finder_registry.get(type(loader))
if finder_maker is None: if finder_maker is None:
raise DistlibException('Unable to locate finder for %r' % package) raise DistlibException("Unable to locate finder for %r" % package)
result = finder_maker(module) result = finder_maker(module)
_finder_cache[package] = result _finder_cache[package] = result
return result return result
_dummy_module = types.ModuleType(str('__dummy__')) _dummy_module = types.ModuleType(str("__dummy__"))
def finder_for_path(path): def finder_for_path(path):
@ -352,7 +355,7 @@ def finder_for_path(path):
finder = _finder_registry.get(type(loader)) finder = _finder_registry.get(type(loader))
if finder: if finder:
module = _dummy_module module = _dummy_module
module.__file__ = os.path.join(path, '') module.__file__ = os.path.join(path, "")
module.__loader__ = loader module.__loader__ = loader
result = finder(module) result = finder(module)
return result return result

View File

@ -15,11 +15,18 @@ from zipfile import ZipInfo
from .compat import sysconfig, detect_encoding, ZipFile from .compat import sysconfig, detect_encoding, ZipFile
from .resources import finder from .resources import finder
from .util import (FileOperator, get_export_entry, convert_path, get_executable, get_platform, in_venv) from .util import (
FileOperator,
get_export_entry,
convert_path,
get_executable,
get_platform,
in_venv,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_DEFAULT_MANIFEST = ''' _DEFAULT_MANIFEST = """
<?xml version="1.0" encoding="UTF-8" standalone="yes"?> <?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0"> <assembly xmlns="urn:schemas-microsoft-com:asm.v1" manifestVersion="1.0">
<assemblyIdentity version="1.0.0.0" <assemblyIdentity version="1.0.0.0"
@ -35,18 +42,18 @@ _DEFAULT_MANIFEST = '''
</requestedPrivileges> </requestedPrivileges>
</security> </security>
</trustInfo> </trustInfo>
</assembly>'''.strip() </assembly>""".strip()
# check if Python is called on the first line with this expression # check if Python is called on the first line with this expression
FIRST_LINE_RE = re.compile(b'^#!.*pythonw?[0-9.]*([ \t].*)?$') FIRST_LINE_RE = re.compile(b"^#!.*pythonw?[0-9.]*([ \t].*)?$")
SCRIPT_TEMPLATE = r'''# -*- coding: utf-8 -*- SCRIPT_TEMPLATE = r"""# -*- coding: utf-8 -*-
import re import re
import sys import sys
if __name__ == '__main__': if __name__ == '__main__':
from %(module)s import %(import_name)s from %(module)s import %(import_name)s
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0]) sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(%(func)s()) sys.exit(%(func)s())
''' """
# Pre-fetch the contents of all executable wrapper stubs. # Pre-fetch the contents of all executable wrapper stubs.
# This is to address https://github.com/pypa/pip/issues/12666. # This is to address https://github.com/pypa/pip/issues/12666.
@ -56,10 +63,10 @@ if __name__ == '__main__':
# location where it was imported from. So we load everything into memory in # location where it was imported from. So we load everything into memory in
# advance. # advance.
if os.name == 'nt' or (os.name == 'java' and os._name == 'nt'): if os.name == "nt" or (os.name == "java" and os._name == "nt"):
# Issue 31: don't hardcode an absolute package name, but # Issue 31: don't hardcode an absolute package name, but
# determine it relative to the current package # determine it relative to the current package
DISTLIB_PACKAGE = __name__.rsplit('.', 1)[0] DISTLIB_PACKAGE = __name__.rsplit(".", 1)[0]
WRAPPERS = { WRAPPERS = {
r.name: r.bytes r.name: r.bytes
@ -69,14 +76,14 @@ if os.name == 'nt' or (os.name == 'java' and os._name == 'nt'):
def enquote_executable(executable): def enquote_executable(executable):
if ' ' in executable: if " " in executable:
# make sure we quote only the executable in case of env # make sure we quote only the executable in case of env
# for example /usr/bin/env "/dir with spaces/bin/jython" # for example /usr/bin/env "/dir with spaces/bin/jython"
# instead of "/usr/bin/env /dir with spaces/bin/jython" # instead of "/usr/bin/env /dir with spaces/bin/jython"
# otherwise whole # otherwise whole
if executable.startswith('/usr/bin/env '): if executable.startswith("/usr/bin/env "):
env, _executable = executable.split(' ', 1) env, _executable = executable.split(" ", 1)
if ' ' in _executable and not _executable.startswith('"'): if " " in _executable and not _executable.startswith('"'):
executable = '%s "%s"' % (env, _executable) executable = '%s "%s"' % (env, _executable)
else: else:
if not executable.startswith('"'): if not executable.startswith('"'):
@ -93,32 +100,37 @@ class ScriptMaker(object):
A class to copy or create scripts from source scripts or callable A class to copy or create scripts from source scripts or callable
specifications. specifications.
""" """
script_template = SCRIPT_TEMPLATE script_template = SCRIPT_TEMPLATE
executable = None # for shebangs executable = None # for shebangs
def __init__(self, source_dir, target_dir, add_launchers=True, dry_run=False, fileop=None): def __init__(
self, source_dir, target_dir, add_launchers=True, dry_run=False, fileop=None
):
self.source_dir = source_dir self.source_dir = source_dir
self.target_dir = target_dir self.target_dir = target_dir
self.add_launchers = add_launchers self.add_launchers = add_launchers
self.force = False self.force = False
self.clobber = False self.clobber = False
# It only makes sense to set mode bits on POSIX. # It only makes sense to set mode bits on POSIX.
self.set_mode = (os.name == 'posix') or (os.name == 'java' and os._name == 'posix') self.set_mode = (os.name == "posix") or (
self.variants = set(('', 'X.Y')) os.name == "java" and os._name == "posix"
)
self.variants = set(("", "X.Y"))
self._fileop = fileop or FileOperator(dry_run) self._fileop = fileop or FileOperator(dry_run)
self._is_nt = os.name == 'nt' or (os.name == 'java' and os._name == 'nt') self._is_nt = os.name == "nt" or (os.name == "java" and os._name == "nt")
self.version_info = sys.version_info self.version_info = sys.version_info
def _get_alternate_executable(self, executable, options): def _get_alternate_executable(self, executable, options):
if options.get('gui', False) and self._is_nt: # pragma: no cover if options.get("gui", False) and self._is_nt: # pragma: no cover
dn, fn = os.path.split(executable) dn, fn = os.path.split(executable)
fn = fn.replace('python', 'pythonw') fn = fn.replace("python", "pythonw")
executable = os.path.join(dn, fn) executable = os.path.join(dn, fn)
return executable return executable
if sys.platform.startswith('java'): # pragma: no cover if sys.platform.startswith("java"): # pragma: no cover
def _is_shell(self, executable): def _is_shell(self, executable):
""" """
@ -127,9 +139,9 @@ class ScriptMaker(object):
""" """
try: try:
with open(executable) as fp: with open(executable) as fp:
return fp.read(2) == '#!' return fp.read(2) == "#!"
except (OSError, IOError): except (OSError, IOError):
logger.warning('Failed to open %s', executable) logger.warning("Failed to open %s", executable)
return False return False
def _fix_jython_executable(self, executable): def _fix_jython_executable(self, executable):
@ -137,12 +149,12 @@ class ScriptMaker(object):
# Workaround for Jython is not needed on Linux systems. # Workaround for Jython is not needed on Linux systems.
import java import java
if java.lang.System.getProperty('os.name') == 'Linux': if java.lang.System.getProperty("os.name") == "Linux":
return executable return executable
elif executable.lower().endswith('jython.exe'): elif executable.lower().endswith("jython.exe"):
# Use wrapper exe for Jython on Windows # Use wrapper exe for Jython on Windows
return executable return executable
return '/usr/bin/env %s' % executable return "/usr/bin/env %s" % executable
def _build_shebang(self, executable, post_interp): def _build_shebang(self, executable, post_interp):
""" """
@ -155,7 +167,7 @@ class ScriptMaker(object):
See also: http://www.in-ulm.de/~mascheck/various/shebang/#length See also: http://www.in-ulm.de/~mascheck/various/shebang/#length
https://hg.mozilla.org/mozilla-central/file/tip/mach https://hg.mozilla.org/mozilla-central/file/tip/mach
""" """
if os.name != 'posix': if os.name != "posix":
simple_shebang = True simple_shebang = True
elif getattr(sys, "cross_compiling", False): elif getattr(sys, "cross_compiling", False):
# In a cross-compiling environment, the shebang will likely be a # In a cross-compiling environment, the shebang will likely be a
@ -166,21 +178,23 @@ class ScriptMaker(object):
else: else:
# Add 3 for '#!' prefix and newline suffix. # Add 3 for '#!' prefix and newline suffix.
shebang_length = len(executable) + len(post_interp) + 3 shebang_length = len(executable) + len(post_interp) + 3
if sys.platform == 'darwin': if sys.platform == "darwin":
max_shebang_length = 512 max_shebang_length = 512
else: else:
max_shebang_length = 127 max_shebang_length = 127
simple_shebang = ((b' ' not in executable) and (shebang_length <= max_shebang_length)) simple_shebang = (b" " not in executable) and (
shebang_length <= max_shebang_length
)
if simple_shebang: if simple_shebang:
result = b'#!' + executable + post_interp + b'\n' result = b"#!" + executable + post_interp + b"\n"
else: else:
result = b'#!/bin/sh\n' result = b"#!/bin/sh\n"
result += b"'''exec' " + executable + post_interp + b' "$0" "$@"\n' result += b"'''exec' " + executable + post_interp + b' "$0" "$@"\n'
result += b"' '''\n" result += b"' '''\n"
return result return result
def _get_shebang(self, encoding, post_interp=b'', options=None): def _get_shebang(self, encoding, post_interp=b"", options=None):
enquote = True enquote = True
if self.executable: if self.executable:
executable = self.executable executable = self.executable
@ -188,21 +202,31 @@ class ScriptMaker(object):
elif not sysconfig.is_python_build(): elif not sysconfig.is_python_build():
executable = get_executable() executable = get_executable()
elif in_venv(): # pragma: no cover elif in_venv(): # pragma: no cover
executable = os.path.join(sysconfig.get_path('scripts'), 'python%s' % sysconfig.get_config_var('EXE')) executable = os.path.join(
sysconfig.get_path("scripts"),
"python%s" % sysconfig.get_config_var("EXE"),
)
else: # pragma: no cover else: # pragma: no cover
if os.name == 'nt': if os.name == "nt":
# for Python builds from source on Windows, no Python executables with # for Python builds from source on Windows, no Python executables with
# a version suffix are created, so we use python.exe # a version suffix are created, so we use python.exe
executable = os.path.join(sysconfig.get_config_var('BINDIR'), executable = os.path.join(
'python%s' % (sysconfig.get_config_var('EXE'))) sysconfig.get_config_var("BINDIR"),
"python%s" % (sysconfig.get_config_var("EXE")),
)
else: else:
executable = os.path.join( executable = os.path.join(
sysconfig.get_config_var('BINDIR'), sysconfig.get_config_var("BINDIR"),
'python%s%s' % (sysconfig.get_config_var('VERSION'), sysconfig.get_config_var('EXE'))) "python%s%s"
% (
sysconfig.get_config_var("VERSION"),
sysconfig.get_config_var("EXE"),
),
)
if options: if options:
executable = self._get_alternate_executable(executable, options) executable = self._get_alternate_executable(executable, options)
if sys.platform.startswith('java'): # pragma: no cover if sys.platform.startswith("java"): # pragma: no cover
executable = self._fix_jython_executable(executable) executable = self._fix_jython_executable(executable)
# Normalise case for Windows - COMMENTED OUT # Normalise case for Windows - COMMENTED OUT
@ -220,11 +244,14 @@ class ScriptMaker(object):
executable = enquote_executable(executable) executable = enquote_executable(executable)
# Issue #51: don't use fsencode, since we later try to # Issue #51: don't use fsencode, since we later try to
# check that the shebang is decodable using utf-8. # check that the shebang is decodable using utf-8.
executable = executable.encode('utf-8') executable = executable.encode("utf-8")
# in case of IronPython, play safe and enable frames support # in case of IronPython, play safe and enable frames support
if (sys.platform == 'cli' and '-X:Frames' not in post_interp and if (
'-X:FullFrames' not in post_interp): # pragma: no cover sys.platform == "cli"
post_interp += b' -X:Frames' and "-X:Frames" not in post_interp
and "-X:FullFrames" not in post_interp
): # pragma: no cover
post_interp += b" -X:Frames"
shebang = self._build_shebang(executable, post_interp) shebang = self._build_shebang(executable, post_interp)
# Python parser starts to read a script using UTF-8 until # Python parser starts to read a script using UTF-8 until
# it gets a #coding:xxx cookie. The shebang has to be the # it gets a #coding:xxx cookie. The shebang has to be the
@ -232,23 +259,28 @@ class ScriptMaker(object):
# written before. So the shebang has to be decodable from # written before. So the shebang has to be decodable from
# UTF-8. # UTF-8.
try: try:
shebang.decode('utf-8') shebang.decode("utf-8")
except UnicodeDecodeError: # pragma: no cover except UnicodeDecodeError: # pragma: no cover
raise ValueError('The shebang (%r) is not decodable from utf-8' % shebang) raise ValueError("The shebang (%r) is not decodable from utf-8" % shebang)
# If the script is encoded to a custom encoding (use a # If the script is encoded to a custom encoding (use a
# #coding:xxx cookie), the shebang has to be decodable from # #coding:xxx cookie), the shebang has to be decodable from
# the script encoding too. # the script encoding too.
if encoding != 'utf-8': if encoding != "utf-8":
try: try:
shebang.decode(encoding) shebang.decode(encoding)
except UnicodeDecodeError: # pragma: no cover except UnicodeDecodeError: # pragma: no cover
raise ValueError('The shebang (%r) is not decodable ' raise ValueError(
'from the script encoding (%r)' % (shebang, encoding)) "The shebang (%r) is not decodable "
"from the script encoding (%r)" % (shebang, encoding)
)
return shebang return shebang
def _get_script_text(self, entry): def _get_script_text(self, entry):
return self.script_template % dict( return self.script_template % dict(
module=entry.prefix, import_name=entry.suffix.split('.')[0], func=entry.suffix) module=entry.prefix,
import_name=entry.suffix.split(".")[0],
func=entry.suffix,
)
manifest = _DEFAULT_MANIFEST manifest = _DEFAULT_MANIFEST
@ -261,82 +293,90 @@ class ScriptMaker(object):
if not use_launcher: if not use_launcher:
script_bytes = shebang + script_bytes script_bytes = shebang + script_bytes
else: # pragma: no cover else: # pragma: no cover
if ext == 'py': if ext == "py":
launcher = self._get_launcher('t') launcher = self._get_launcher("t")
else: else:
launcher = self._get_launcher('w') launcher = self._get_launcher("w")
stream = BytesIO() stream = BytesIO()
with ZipFile(stream, 'w') as zf: with ZipFile(stream, "w") as zf:
source_date_epoch = os.environ.get('SOURCE_DATE_EPOCH') source_date_epoch = os.environ.get("SOURCE_DATE_EPOCH")
if source_date_epoch: if source_date_epoch:
date_time = time.gmtime(int(source_date_epoch))[:6] date_time = time.gmtime(int(source_date_epoch))[:6]
zinfo = ZipInfo(filename='__main__.py', date_time=date_time) zinfo = ZipInfo(filename="__main__.py", date_time=date_time)
zf.writestr(zinfo, script_bytes) zf.writestr(zinfo, script_bytes)
else: else:
zf.writestr('__main__.py', script_bytes) zf.writestr("__main__.py", script_bytes)
zip_data = stream.getvalue() zip_data = stream.getvalue()
script_bytes = launcher + shebang + zip_data script_bytes = launcher + shebang + zip_data
for name in names: for name in names:
outname = os.path.join(self.target_dir, name) outname = os.path.join(self.target_dir, name)
if use_launcher: # pragma: no cover if use_launcher: # pragma: no cover
n, e = os.path.splitext(outname) n, e = os.path.splitext(outname)
if e.startswith('.py'): if e.startswith(".py"):
outname = n outname = n
outname = '%s.exe' % outname outname = "%s.exe" % outname
try: try:
self._fileop.write_binary_file(outname, script_bytes) self._fileop.write_binary_file(outname, script_bytes)
except Exception: except Exception:
# Failed writing an executable - it might be in use. # Failed writing an executable - it might be in use.
logger.warning('Failed to write executable - trying to ' logger.warning(
'use .deleteme logic') "Failed to write executable - trying to " "use .deleteme logic"
dfname = '%s.deleteme' % outname )
dfname = "%s.deleteme" % outname
if os.path.exists(dfname): if os.path.exists(dfname):
os.remove(dfname) # Not allowed to fail here os.remove(dfname) # Not allowed to fail here
os.rename(outname, dfname) # nor here os.rename(outname, dfname) # nor here
self._fileop.write_binary_file(outname, script_bytes) self._fileop.write_binary_file(outname, script_bytes)
logger.debug('Able to replace executable using ' logger.debug("Able to replace executable using " ".deleteme logic")
'.deleteme logic')
try: try:
os.remove(dfname) os.remove(dfname)
except Exception: except Exception:
pass # still in use - ignore error pass # still in use - ignore error
else: else:
if self._is_nt and not outname.endswith('.' + ext): # pragma: no cover if self._is_nt and not outname.endswith("." + ext): # pragma: no cover
outname = '%s.%s' % (outname, ext) outname = "%s.%s" % (outname, ext)
if os.path.exists(outname) and not self.clobber: if os.path.exists(outname) and not self.clobber:
logger.warning('Skipping existing file %s', outname) logger.warning("Skipping existing file %s", outname)
continue continue
self._fileop.write_binary_file(outname, script_bytes) self._fileop.write_binary_file(outname, script_bytes)
if self.set_mode: if self.set_mode:
self._fileop.set_executable_mode([outname]) self._fileop.set_executable_mode([outname])
filenames.append(outname) filenames.append(outname)
variant_separator = '-' variant_separator = "-"
def get_script_filenames(self, name): def get_script_filenames(self, name):
result = set() result = set()
if '' in self.variants: if "" in self.variants:
result.add(name) result.add(name)
if 'X' in self.variants: if "X" in self.variants:
result.add('%s%s' % (name, self.version_info[0])) result.add("%s%s" % (name, self.version_info[0]))
if 'X.Y' in self.variants: if "X.Y" in self.variants:
result.add('%s%s%s.%s' % (name, self.variant_separator, self.version_info[0], self.version_info[1])) result.add(
"%s%s%s.%s"
% (
name,
self.variant_separator,
self.version_info[0],
self.version_info[1],
)
)
return result return result
def _make_script(self, entry, filenames, options=None): def _make_script(self, entry, filenames, options=None):
post_interp = b'' post_interp = b""
if options: if options:
args = options.get('interpreter_args', []) args = options.get("interpreter_args", [])
if args: if args:
args = ' %s' % ' '.join(args) args = " %s" % " ".join(args)
post_interp = args.encode('utf-8') post_interp = args.encode("utf-8")
shebang = self._get_shebang('utf-8', post_interp, options=options) shebang = self._get_shebang("utf-8", post_interp, options=options)
script = self._get_script_text(entry).encode('utf-8') script = self._get_script_text(entry).encode("utf-8")
scriptnames = self.get_script_filenames(entry.name) scriptnames = self.get_script_filenames(entry.name)
if options and options.get('gui', False): if options and options.get("gui", False):
ext = 'pyw' ext = "pyw"
else: else:
ext = 'py' ext = "py"
self._write_script(scriptnames, shebang, script, filenames, ext) self._write_script(scriptnames, shebang, script, filenames, ext)
def _copy_script(self, script, filenames): def _copy_script(self, script, filenames):
@ -344,14 +384,14 @@ class ScriptMaker(object):
script = os.path.join(self.source_dir, convert_path(script)) script = os.path.join(self.source_dir, convert_path(script))
outname = os.path.join(self.target_dir, os.path.basename(script)) outname = os.path.join(self.target_dir, os.path.basename(script))
if not self.force and not self._fileop.newer(script, outname): if not self.force and not self._fileop.newer(script, outname):
logger.debug('not copying %s (up-to-date)', script) logger.debug("not copying %s (up-to-date)", script)
return return
# Always open the file, but ignore failures in dry-run mode -- # Always open the file, but ignore failures in dry-run mode --
# that way, we'll get accurate feedback if we can read the # that way, we'll get accurate feedback if we can read the
# script. # script.
try: try:
f = open(script, 'rb') f = open(script, "rb")
except IOError: # pragma: no cover except IOError: # pragma: no cover
if not self.dry_run: if not self.dry_run:
raise raise
@ -359,13 +399,13 @@ class ScriptMaker(object):
else: else:
first_line = f.readline() first_line = f.readline()
if not first_line: # pragma: no cover if not first_line: # pragma: no cover
logger.warning('%s is an empty file (skipping)', script) logger.warning("%s is an empty file (skipping)", script)
return return
match = FIRST_LINE_RE.match(first_line.replace(b'\r\n', b'\n')) match = FIRST_LINE_RE.match(first_line.replace(b"\r\n", b"\n"))
if match: if match:
adjust = True adjust = True
post_interp = match.group(1) or b'' post_interp = match.group(1) or b""
if not adjust: if not adjust:
if f: if f:
@ -375,15 +415,15 @@ class ScriptMaker(object):
self._fileop.set_executable_mode([outname]) self._fileop.set_executable_mode([outname])
filenames.append(outname) filenames.append(outname)
else: else:
logger.info('copying and adjusting %s -> %s', script, self.target_dir) logger.info("copying and adjusting %s -> %s", script, self.target_dir)
if not self._fileop.dry_run: if not self._fileop.dry_run:
encoding, lines = detect_encoding(f.readline) encoding, lines = detect_encoding(f.readline)
f.seek(0) f.seek(0)
shebang = self._get_shebang(encoding, post_interp) shebang = self._get_shebang(encoding, post_interp)
if b'pythonw' in first_line: # pragma: no cover if b"pythonw" in first_line: # pragma: no cover
ext = 'pyw' ext = "pyw"
else: else:
ext = 'py' ext = "py"
n = os.path.basename(outname) n = os.path.basename(outname)
self._write_script([n], shebang, f.read(), filenames, ext) self._write_script([n], shebang, f.read(), filenames, ext)
if f: if f:
@ -397,20 +437,22 @@ class ScriptMaker(object):
def dry_run(self, value): def dry_run(self, value):
self._fileop.dry_run = value self._fileop.dry_run = value
if os.name == 'nt' or (os.name == 'java' and os._name == 'nt'): # pragma: no cover if os.name == "nt" or (os.name == "java" and os._name == "nt"): # pragma: no cover
# Executable launcher support. # Executable launcher support.
# Launchers are from https://bitbucket.org/vinay.sajip/simple_launcher/ # Launchers are from https://bitbucket.org/vinay.sajip/simple_launcher/
def _get_launcher(self, kind): def _get_launcher(self, kind):
if struct.calcsize('P') == 8: # 64-bit if struct.calcsize("P") == 8: # 64-bit
bits = '64' bits = "64"
else: else:
bits = '32' bits = "32"
platform_suffix = '-arm' if get_platform() == 'win-arm64' else '' platform_suffix = "-arm" if get_platform() == "win-arm64" else ""
name = '%s%s%s.exe' % (kind, bits, platform_suffix) name = "%s%s%s.exe" % (kind, bits, platform_suffix)
if name not in WRAPPERS: if name not in WRAPPERS:
msg = ('Unable to find resource %s in package %s' % msg = "Unable to find resource %s in package %s" % (
(name, DISTLIB_PACKAGE)) name,
DISTLIB_PACKAGE,
)
raise ValueError(msg) raise ValueError(msg)
return WRAPPERS[name] return WRAPPERS[name]

File diff suppressed because it is too large Load Diff

View File

@ -14,16 +14,23 @@ import re
from .compat import string_types from .compat import string_types
from .util import parse_requirement from .util import parse_requirement
__all__ = ['NormalizedVersion', 'NormalizedMatcher', __all__ = [
'LegacyVersion', 'LegacyMatcher', "NormalizedVersion",
'SemanticVersion', 'SemanticMatcher', "NormalizedMatcher",
'UnsupportedVersionError', 'get_scheme'] "LegacyVersion",
"LegacyMatcher",
"SemanticVersion",
"SemanticMatcher",
"UnsupportedVersionError",
"get_scheme",
]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UnsupportedVersionError(ValueError): class UnsupportedVersionError(ValueError):
"""This is an unsupported version.""" """This is an unsupported version."""
pass pass
@ -35,11 +42,11 @@ class Version(object):
assert len(parts) > 0 assert len(parts) > 0
def parse(self, s): def parse(self, s):
raise NotImplementedError('please implement in a subclass') raise NotImplementedError("please implement in a subclass")
def _check_compatible(self, other): def _check_compatible(self, other):
if type(self) != type(other): if type(self) != type(other):
raise TypeError('cannot compare %r and %r' % (self, other)) raise TypeError("cannot compare %r and %r" % (self, other))
def __eq__(self, other): def __eq__(self, other):
self._check_compatible(other) self._check_compatible(other)
@ -73,7 +80,7 @@ class Version(object):
@property @property
def is_prerelease(self): def is_prerelease(self):
raise NotImplementedError('Please implement in subclasses.') raise NotImplementedError("Please implement in subclasses.")
class Matcher(object): class Matcher(object):
@ -81,15 +88,15 @@ class Matcher(object):
# value is either a callable or the name of a method # value is either a callable or the name of a method
_operators = { _operators = {
'<': lambda v, c, p: v < c, "<": lambda v, c, p: v < c,
'>': lambda v, c, p: v > c, ">": lambda v, c, p: v > c,
'<=': lambda v, c, p: v == c or v < c, "<=": lambda v, c, p: v == c or v < c,
'>=': lambda v, c, p: v == c or v > c, ">=": lambda v, c, p: v == c or v > c,
'==': lambda v, c, p: v == c, "==": lambda v, c, p: v == c,
'===': lambda v, c, p: v == c, "===": lambda v, c, p: v == c,
# by default, compatible => >=. # by default, compatible => >=.
'~=': lambda v, c, p: v == c or v > c, "~=": lambda v, c, p: v == c or v > c,
'!=': lambda v, c, p: v != c, "!=": lambda v, c, p: v != c,
} }
# this is a method only to support alternative implementations # this is a method only to support alternative implementations
@ -99,21 +106,20 @@ class Matcher(object):
def __init__(self, s): def __init__(self, s):
if self.version_class is None: if self.version_class is None:
raise ValueError('Please specify a version class') raise ValueError("Please specify a version class")
self._string = s = s.strip() self._string = s = s.strip()
r = self.parse_requirement(s) r = self.parse_requirement(s)
if not r: if not r:
raise ValueError('Not valid: %r' % s) raise ValueError("Not valid: %r" % s)
self.name = r.name self.name = r.name
self.key = self.name.lower() # for case-insensitive comparisons self.key = self.name.lower() # for case-insensitive comparisons
clist = [] clist = []
if r.constraints: if r.constraints:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
for op, s in r.constraints: for op, s in r.constraints:
if s.endswith('.*'): if s.endswith(".*"):
if op not in ('==', '!='): if op not in ("==", "!="):
raise ValueError('\'.*\' not allowed for ' raise ValueError("'.*' not allowed for " "%r constraints" % op)
'%r constraints' % op)
# Could be a partial version (e.g. for '2.*') which # Could be a partial version (e.g. for '2.*') which
# won't parse as a version, so keep it as a string # won't parse as a version, so keep it as a string
vn, prefix = s[:-2], True vn, prefix = s[:-2], True
@ -140,8 +146,10 @@ class Matcher(object):
if isinstance(f, string_types): if isinstance(f, string_types):
f = getattr(self, f) f = getattr(self, f)
if not f: if not f:
msg = ('%r not implemented ' msg = "%r not implemented " "for %s" % (
'for %s' % (operator, self.__class__.__name__)) operator,
self.__class__.__name__,
)
raise NotImplementedError(msg) raise NotImplementedError(msg)
if not f(version, constraint, prefix): if not f(version, constraint, prefix):
return False return False
@ -150,13 +158,13 @@ class Matcher(object):
@property @property
def exact_version(self): def exact_version(self):
result = None result = None
if len(self._parts) == 1 and self._parts[0][0] in ('==', '==='): if len(self._parts) == 1 and self._parts[0][0] in ("==", "==="):
result = self._parts[0][1] result = self._parts[0][1]
return result return result
def _check_compatible(self, other): def _check_compatible(self, other):
if type(self) != type(other) or self.name != other.name: if type(self) != type(other) or self.name != other.name:
raise TypeError('cannot compare %s and %s' % (self, other)) raise TypeError("cannot compare %s and %s" % (self, other))
def __eq__(self, other): def __eq__(self, other):
self._check_compatible(other) self._check_compatible(other)
@ -176,18 +184,21 @@ class Matcher(object):
return self._string return self._string
PEP440_VERSION_RE = re.compile(r'^v?(\d+!)?(\d+(\.\d+)*)((a|alpha|b|beta|c|rc|pre|preview)(\d+)?)?' PEP440_VERSION_RE = re.compile(
r'(\.(post|r|rev)(\d+)?)?([._-]?(dev)(\d+)?)?' r"^v?(\d+!)?(\d+(\.\d+)*)((a|alpha|b|beta|c|rc|pre|preview)(\d+)?)?"
r'(\+([a-zA-Z\d]+(\.[a-zA-Z\d]+)?))?$', re.I) r"(\.(post|r|rev)(\d+)?)?([._-]?(dev)(\d+)?)?"
r"(\+([a-zA-Z\d]+(\.[a-zA-Z\d]+)?))?$",
re.I,
)
def _pep_440_key(s): def _pep_440_key(s):
s = s.strip() s = s.strip()
m = PEP440_VERSION_RE.match(s) m = PEP440_VERSION_RE.match(s)
if not m: if not m:
raise UnsupportedVersionError('Not a valid version: %s' % s) raise UnsupportedVersionError("Not a valid version: %s" % s)
groups = m.groups() groups = m.groups()
nums = tuple(int(v) for v in groups[1].split('.')) nums = tuple(int(v) for v in groups[1].split("."))
while len(nums) > 1 and nums[-1] == 0: while len(nums) > 1 and nums[-1] == 0:
nums = nums[:-1] nums = nums[:-1]
@ -224,7 +235,7 @@ def _pep_440_key(s):
local = () local = ()
else: else:
parts = [] parts = []
for part in local.split('.'): for part in local.split("."):
# to ensure that numeric compares as > lexicographic, avoid # to ensure that numeric compares as > lexicographic, avoid
# comparing them directly, but encode a tuple which ensures # comparing them directly, but encode a tuple which ensures
# correct sorting # correct sorting
@ -238,14 +249,14 @@ def _pep_440_key(s):
# either before pre-release, or final release and after # either before pre-release, or final release and after
if not post and dev: if not post and dev:
# before pre-release # before pre-release
pre = ('a', -1) # to sort before a0 pre = ("a", -1) # to sort before a0
else: else:
pre = ('z',) # to sort after all pre-releases pre = ("z",) # to sort after all pre-releases
# now look at the state of post and dev. # now look at the state of post and dev.
if not post: if not post:
post = ('_',) # sort before 'a' post = ("_",) # sort before 'a'
if not dev: if not dev:
dev = ('final',) dev = ("final",)
return epoch, nums, pre, post, dev, local return epoch, nums, pre, post, dev, local
@ -271,18 +282,19 @@ class NormalizedVersion(Version):
1.2a # release level must have a release serial 1.2a # release level must have a release serial
1.2.3b 1.2.3b
""" """
def parse(self, s): def parse(self, s):
result = _normalized_key(s) result = _normalized_key(s)
# _normalized_key loses trailing zeroes in the release # _normalized_key loses trailing zeroes in the release
# clause, since that's needed to ensure that X.Y == X.Y.0 == X.Y.0.0 # clause, since that's needed to ensure that X.Y == X.Y.0 == X.Y.0.0
# However, PEP 440 prefix matching needs it: for example, # However, PEP 440 prefix matching needs it: for example,
# (~= 1.4.5.0) matches differently to (~= 1.4.5.0.0). # (~= 1.4.5.0) matches differently to (~= 1.4.5.0.0).
m = PEP440_VERSION_RE.match(s) # must succeed m = PEP440_VERSION_RE.match(s) # must succeed
groups = m.groups() groups = m.groups()
self._release_clause = tuple(int(v) for v in groups[1].split('.')) self._release_clause = tuple(int(v) for v in groups[1].split("."))
return result return result
PREREL_TAGS = set(['a', 'b', 'c', 'rc', 'dev']) PREREL_TAGS = set(["a", "b", "c", "rc", "dev"])
@property @property
def is_prerelease(self): def is_prerelease(self):
@ -297,7 +309,7 @@ def _match_prefix(x, y):
if not x.startswith(y): if not x.startswith(y):
return False return False
n = len(y) n = len(y)
return x[n] == '.' return x[n] == "."
class NormalizedMatcher(Matcher): class NormalizedMatcher(Matcher):
@ -305,19 +317,19 @@ class NormalizedMatcher(Matcher):
# value is either a callable or the name of a method # value is either a callable or the name of a method
_operators = { _operators = {
'~=': '_match_compatible', "~=": "_match_compatible",
'<': '_match_lt', "<": "_match_lt",
'>': '_match_gt', ">": "_match_gt",
'<=': '_match_le', "<=": "_match_le",
'>=': '_match_ge', ">=": "_match_ge",
'==': '_match_eq', "==": "_match_eq",
'===': '_match_arbitrary', "===": "_match_arbitrary",
'!=': '_match_ne', "!=": "_match_ne",
} }
def _adjust_local(self, version, constraint, prefix): def _adjust_local(self, version, constraint, prefix):
if prefix: if prefix:
strip_local = '+' not in constraint and version._parts[-1] strip_local = "+" not in constraint and version._parts[-1]
else: else:
# both constraint and version are # both constraint and version are
# NormalizedVersion instances. # NormalizedVersion instances.
@ -325,7 +337,7 @@ class NormalizedMatcher(Matcher):
# ensure the version doesn't, either. # ensure the version doesn't, either.
strip_local = not constraint._parts[-1] and version._parts[-1] strip_local = not constraint._parts[-1] and version._parts[-1]
if strip_local: if strip_local:
s = version._string.split('+', 1)[0] s = version._string.split("+", 1)[0]
version = self.version_class(s) version = self.version_class(s)
return version, constraint return version, constraint
@ -334,7 +346,7 @@ class NormalizedMatcher(Matcher):
if version >= constraint: if version >= constraint:
return False return False
release_clause = constraint._release_clause release_clause = constraint._release_clause
pfx = '.'.join([str(i) for i in release_clause]) pfx = ".".join([str(i) for i in release_clause])
return not _match_prefix(version, pfx) return not _match_prefix(version, pfx)
def _match_gt(self, version, constraint, prefix): def _match_gt(self, version, constraint, prefix):
@ -342,7 +354,7 @@ class NormalizedMatcher(Matcher):
if version <= constraint: if version <= constraint:
return False return False
release_clause = constraint._release_clause release_clause = constraint._release_clause
pfx = '.'.join([str(i) for i in release_clause]) pfx = ".".join([str(i) for i in release_clause])
return not _match_prefix(version, pfx) return not _match_prefix(version, pfx)
def _match_le(self, version, constraint, prefix): def _match_le(self, version, constraint, prefix):
@ -356,7 +368,7 @@ class NormalizedMatcher(Matcher):
def _match_eq(self, version, constraint, prefix): def _match_eq(self, version, constraint, prefix):
version, constraint = self._adjust_local(version, constraint, prefix) version, constraint = self._adjust_local(version, constraint, prefix)
if not prefix: if not prefix:
result = (version == constraint) result = version == constraint
else: else:
result = _match_prefix(version, constraint) result = _match_prefix(version, constraint)
return result return result
@ -367,7 +379,7 @@ class NormalizedMatcher(Matcher):
def _match_ne(self, version, constraint, prefix): def _match_ne(self, version, constraint, prefix):
version, constraint = self._adjust_local(version, constraint, prefix) version, constraint = self._adjust_local(version, constraint, prefix)
if not prefix: if not prefix:
result = (version != constraint) result = version != constraint
else: else:
result = not _match_prefix(version, constraint) result = not _match_prefix(version, constraint)
return result return result
@ -378,38 +390,37 @@ class NormalizedMatcher(Matcher):
return True return True
if version < constraint: if version < constraint:
return False return False
# if not prefix: # if not prefix:
# return True # return True
release_clause = constraint._release_clause release_clause = constraint._release_clause
if len(release_clause) > 1: if len(release_clause) > 1:
release_clause = release_clause[:-1] release_clause = release_clause[:-1]
pfx = '.'.join([str(i) for i in release_clause]) pfx = ".".join([str(i) for i in release_clause])
return _match_prefix(version, pfx) return _match_prefix(version, pfx)
_REPLACEMENTS = ( _REPLACEMENTS = (
(re.compile('[.+-]$'), ''), # remove trailing puncts (re.compile("[.+-]$"), ""), # remove trailing puncts
(re.compile(r'^[.](\d)'), r'0.\1'), # .N -> 0.N at start (re.compile(r"^[.](\d)"), r"0.\1"), # .N -> 0.N at start
(re.compile('^[.-]'), ''), # remove leading puncts (re.compile("^[.-]"), ""), # remove leading puncts
(re.compile(r'^\((.*)\)$'), r'\1'), # remove parentheses (re.compile(r"^\((.*)\)$"), r"\1"), # remove parentheses
(re.compile(r'^v(ersion)?\s*(\d+)'), r'\2'), # remove leading v(ersion) (re.compile(r"^v(ersion)?\s*(\d+)"), r"\2"), # remove leading v(ersion)
(re.compile(r'^r(ev)?\s*(\d+)'), r'\2'), # remove leading v(ersion) (re.compile(r"^r(ev)?\s*(\d+)"), r"\2"), # remove leading v(ersion)
(re.compile('[.]{2,}'), '.'), # multiple runs of '.' (re.compile("[.]{2,}"), "."), # multiple runs of '.'
(re.compile(r'\b(alfa|apha)\b'), 'alpha'), # misspelt alpha (re.compile(r"\b(alfa|apha)\b"), "alpha"), # misspelt alpha
(re.compile(r'\b(pre-alpha|prealpha)\b'), (re.compile(r"\b(pre-alpha|prealpha)\b"), "pre.alpha"), # standardise
'pre.alpha'), # standardise (re.compile(r"\(beta\)$"), "beta"), # remove parentheses
(re.compile(r'\(beta\)$'), 'beta'), # remove parentheses
) )
_SUFFIX_REPLACEMENTS = ( _SUFFIX_REPLACEMENTS = (
(re.compile('^[:~._+-]+'), ''), # remove leading puncts (re.compile("^[:~._+-]+"), ""), # remove leading puncts
(re.compile('[,*")([\\]]'), ''), # remove unwanted chars (re.compile('[,*")([\\]]'), ""), # remove unwanted chars
(re.compile('[~:+_ -]'), '.'), # replace illegal chars (re.compile("[~:+_ -]"), "."), # replace illegal chars
(re.compile('[.]{2,}'), '.'), # multiple runs of '.' (re.compile("[.]{2,}"), "."), # multiple runs of '.'
(re.compile(r'\.$'), ''), # trailing '.' (re.compile(r"\.$"), ""), # trailing '.'
) )
_NUMERIC_PREFIX = re.compile(r'(\d+(\.\d+)*)') _NUMERIC_PREFIX = re.compile(r"(\d+(\.\d+)*)")
def _suggest_semantic_version(s): def _suggest_semantic_version(s):
@ -421,26 +432,26 @@ def _suggest_semantic_version(s):
for pat, repl in _REPLACEMENTS: for pat, repl in _REPLACEMENTS:
result = pat.sub(repl, result) result = pat.sub(repl, result)
if not result: if not result:
result = '0.0.0' result = "0.0.0"
# Now look for numeric prefix, and separate it out from # Now look for numeric prefix, and separate it out from
# the rest. # the rest.
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
m = _NUMERIC_PREFIX.match(result) m = _NUMERIC_PREFIX.match(result)
if not m: if not m:
prefix = '0.0.0' prefix = "0.0.0"
suffix = result suffix = result
else: else:
prefix = m.groups()[0].split('.') prefix = m.groups()[0].split(".")
prefix = [int(i) for i in prefix] prefix = [int(i) for i in prefix]
while len(prefix) < 3: while len(prefix) < 3:
prefix.append(0) prefix.append(0)
if len(prefix) == 3: if len(prefix) == 3:
suffix = result[m.end():] suffix = result[m.end() :]
else: else:
suffix = '.'.join([str(i) for i in prefix[3:]]) + result[m.end():] suffix = ".".join([str(i) for i in prefix[3:]]) + result[m.end() :]
prefix = prefix[:3] prefix = prefix[:3]
prefix = '.'.join([str(i) for i in prefix]) prefix = ".".join([str(i) for i in prefix])
suffix = suffix.strip() suffix = suffix.strip()
if suffix: if suffix:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()
@ -451,7 +462,7 @@ def _suggest_semantic_version(s):
if not suffix: if not suffix:
result = prefix result = prefix
else: else:
sep = '-' if 'dev' in suffix else '+' sep = "-" if "dev" in suffix else "+"
result = prefix + sep + suffix result = prefix + sep + suffix
if not is_semver(result): if not is_semver(result):
result = None result = None
@ -477,19 +488,30 @@ def _suggest_normalized_version(s):
""" """
try: try:
_normalized_key(s) _normalized_key(s)
return s # already rational return s # already rational
except UnsupportedVersionError: except UnsupportedVersionError:
pass pass
rs = s.lower() rs = s.lower()
# part of this could use maketrans # part of this could use maketrans
for orig, repl in (('-alpha', 'a'), ('-beta', 'b'), ('alpha', 'a'), for orig, repl in (
('beta', 'b'), ('rc', 'c'), ('-final', ''), ("-alpha", "a"),
('-pre', 'c'), ("-beta", "b"),
('-release', ''), ('.release', ''), ('-stable', ''), ("alpha", "a"),
('+', '.'), ('_', '.'), (' ', ''), ('.final', ''), ("beta", "b"),
('final', '')): ("rc", "c"),
("-final", ""),
("-pre", "c"),
("-release", ""),
(".release", ""),
("-stable", ""),
("+", "."),
("_", "."),
(" ", ""),
(".final", ""),
("final", ""),
):
rs = rs.replace(orig, repl) rs = rs.replace(orig, repl)
# if something ends with dev or pre, we add a 0 # if something ends with dev or pre, we add a 0
@ -509,7 +531,7 @@ def _suggest_normalized_version(s):
rs = re.sub(r"[.~]?([abc])\.?", r"\1", rs) rs = re.sub(r"[.~]?([abc])\.?", r"\1", rs)
# Clean: v0.3, v1.0 # Clean: v0.3, v1.0
if rs.startswith('v'): if rs.startswith("v"):
rs = rs[1:] rs = rs[1:]
# Clean leading '0's on numbers. # Clean leading '0's on numbers.
@ -568,20 +590,21 @@ def _suggest_normalized_version(s):
rs = None rs = None
return rs return rs
# #
# Legacy version processing (distribute-compatible) # Legacy version processing (distribute-compatible)
# #
_VERSION_PART = re.compile(r'([a-z]+|\d+|[\.-])', re.I) _VERSION_PART = re.compile(r"([a-z]+|\d+|[\.-])", re.I)
_VERSION_REPLACE = { _VERSION_REPLACE = {
'pre': 'c', "pre": "c",
'preview': 'c', "preview": "c",
'-': 'final-', "-": "final-",
'rc': 'c', "rc": "c",
'dev': '@', "dev": "@",
'': None, "": None,
'.': None, ".": None,
} }
@ -591,21 +614,21 @@ def _legacy_key(s):
for p in _VERSION_PART.split(s.lower()): for p in _VERSION_PART.split(s.lower()):
p = _VERSION_REPLACE.get(p, p) p = _VERSION_REPLACE.get(p, p)
if p: if p:
if '0' <= p[:1] <= '9': if "0" <= p[:1] <= "9":
p = p.zfill(8) p = p.zfill(8)
else: else:
p = '*' + p p = "*" + p
result.append(p) result.append(p)
result.append('*final') result.append("*final")
return result return result
result = [] result = []
for p in get_parts(s): for p in get_parts(s):
if p.startswith('*'): if p.startswith("*"):
if p < '*final': if p < "*final":
while result and result[-1] == '*final-': while result and result[-1] == "*final-":
result.pop() result.pop()
while result and result[-1] == '00000000': while result and result[-1] == "00000000":
result.pop() result.pop()
result.append(p) result.append(p)
return tuple(result) return tuple(result)
@ -619,7 +642,7 @@ class LegacyVersion(Version):
def is_prerelease(self): def is_prerelease(self):
result = False result = False
for x in self._parts: for x in self._parts:
if (isinstance(x, string_types) and x.startswith('*') and x < '*final'): if isinstance(x, string_types) and x.startswith("*") and x < "*final":
result = True result = True
break break
return result return result
@ -629,31 +652,38 @@ class LegacyMatcher(Matcher):
version_class = LegacyVersion version_class = LegacyVersion
_operators = dict(Matcher._operators) _operators = dict(Matcher._operators)
_operators['~='] = '_match_compatible' _operators["~="] = "_match_compatible"
numeric_re = re.compile(r'^(\d+(\.\d+)*)') numeric_re = re.compile(r"^(\d+(\.\d+)*)")
def _match_compatible(self, version, constraint, prefix): def _match_compatible(self, version, constraint, prefix):
if version < constraint: if version < constraint:
return False return False
m = self.numeric_re.match(str(constraint)) m = self.numeric_re.match(str(constraint))
if not m: if not m:
logger.warning('Cannot compute compatible match for version %s ' logger.warning(
' and constraint %s', version, constraint) "Cannot compute compatible match for version %s " " and constraint %s",
version,
constraint,
)
return True return True
s = m.groups()[0] s = m.groups()[0]
if '.' in s: if "." in s:
s = s.rsplit('.', 1)[0] s = s.rsplit(".", 1)[0]
return _match_prefix(version, s) return _match_prefix(version, s)
# #
# Semantic versioning # Semantic versioning
# #
_SEMVER_RE = re.compile(r'^(\d+)\.(\d+)\.(\d+)' _SEMVER_RE = re.compile(
r'(-[a-z0-9]+(\.[a-z0-9-]+)*)?' r"^(\d+)\.(\d+)\.(\d+)"
r'(\+[a-z0-9]+(\.[a-z0-9-]+)*)?$', re.I) r"(-[a-z0-9]+(\.[a-z0-9-]+)*)?"
r"(\+[a-z0-9]+(\.[a-z0-9-]+)*)?$",
re.I,
)
def is_semver(s): def is_semver(s):
@ -665,7 +695,7 @@ def _semantic_key(s):
if s is None: if s is None:
result = (absent,) result = (absent,)
else: else:
parts = s[1:].split('.') parts = s[1:].split(".")
# We can't compare ints and strings on Python 3, so fudge it # We can't compare ints and strings on Python 3, so fudge it
# by zero-filling numeric values so simulate a numeric comparison # by zero-filling numeric values so simulate a numeric comparison
result = tuple([p.zfill(8) if p.isdigit() else p for p in parts]) result = tuple([p.zfill(8) if p.isdigit() else p for p in parts])
@ -677,7 +707,7 @@ def _semantic_key(s):
groups = m.groups() groups = m.groups()
major, minor, patch = [int(i) for i in groups[:3]] major, minor, patch = [int(i) for i in groups[:3]]
# choose the '|' and '*' so that versions sort correctly # choose the '|' and '*' so that versions sort correctly
pre, build = make_tuple(groups[3], '|'), make_tuple(groups[5], '*') pre, build = make_tuple(groups[3], "|"), make_tuple(groups[5], "*")
return (major, minor, patch), pre, build return (major, minor, patch), pre, build
@ -687,7 +717,7 @@ class SemanticVersion(Version):
@property @property
def is_prerelease(self): def is_prerelease(self):
return self._parts[1][0] != '|' return self._parts[1][0] != "|"
class SemanticMatcher(Matcher): class SemanticMatcher(Matcher):
@ -721,9 +751,9 @@ class VersionScheme(object):
Used for processing some metadata fields Used for processing some metadata fields
""" """
# See issue #140. Be tolerant of a single trailing comma. # See issue #140. Be tolerant of a single trailing comma.
if s.endswith(','): if s.endswith(","):
s = s[:-1] s = s[:-1]
return self.is_valid_matcher('dummy_name (%s)' % s) return self.is_valid_matcher("dummy_name (%s)" % s)
def suggest(self, s): def suggest(self, s):
if self.suggester is None: if self.suggester is None:
@ -734,17 +764,19 @@ class VersionScheme(object):
_SCHEMES = { _SCHEMES = {
'normalized': VersionScheme(_normalized_key, NormalizedMatcher, "normalized": VersionScheme(
_suggest_normalized_version), _normalized_key, NormalizedMatcher, _suggest_normalized_version
'legacy': VersionScheme(_legacy_key, LegacyMatcher, lambda self, s: s), ),
'semantic': VersionScheme(_semantic_key, SemanticMatcher, "legacy": VersionScheme(_legacy_key, LegacyMatcher, lambda self, s: s),
_suggest_semantic_version), "semantic": VersionScheme(
_semantic_key, SemanticMatcher, _suggest_semantic_version
),
} }
_SCHEMES['default'] = _SCHEMES['normalized'] _SCHEMES["default"] = _SCHEMES["normalized"]
def get_scheme(name): def get_scheme(name):
if name not in _SCHEMES: if name not in _SCHEMES:
raise ValueError('unknown scheme name: %r' % name) raise ValueError("unknown scheme name: %r" % name)
return _SCHEMES[name] return _SCHEMES[name]

File diff suppressed because it is too large Load Diff

View File

@ -17,7 +17,9 @@ from ._error import Timeout
try: try:
from ._read_write import ReadWriteLock from ._read_write import ReadWriteLock
except ImportError: # sqlite3 may be unavailable if Python was built without it or the C library is missing except (
ImportError
): # sqlite3 may be unavailable if Python was built without it or the C library is missing
ReadWriteLock = None # type: ignore[assignment, misc] ReadWriteLock = None # type: ignore[assignment, misc]
from ._soft import SoftFileLock from ._soft import SoftFileLock

View File

@ -99,7 +99,9 @@ class FileLockContext:
lock_file_fd: int | None = None lock_file_fd: int | None = None
#: The lock counter is used for implementing the nested locking mechanism. #: The lock counter is used for implementing the nested locking mechanism.
lock_counter: int = 0 # When the lock is acquired is increased and the lock is only released, when this value is 0 lock_counter: int = (
0 # When the lock is acquired is increased and the lock is only released, when this value is 0
)
class ThreadLocalFileContext(FileLockContext, local): class ThreadLocalFileContext(FileLockContext, local):
@ -145,7 +147,10 @@ class FileLockMeta(ABCMeta):
# parameters do not match; raise error # parameters do not match; raise error
msg = "Singleton lock instances cannot be initialized with differing arguments" msg = "Singleton lock instances cannot be initialized with differing arguments"
msg += "\nNon-matching arguments: " msg += "\nNon-matching arguments: "
for param_name, (passed_param, set_param) in non_matching_params.items(): for param_name, (
passed_param,
set_param,
) in non_matching_params.items():
msg += f"\n\t{param_name} (existing lock has {set_param} but {passed_param} was passed)" msg += f"\n\t{param_name} (existing lock has {set_param} but {passed_param} was passed)"
raise ValueError(msg) raise ValueError(msg)
@ -165,7 +170,9 @@ class FileLockMeta(ABCMeta):
} }
present_params = inspect.signature(cls.__init__).parameters present_params = inspect.signature(cls.__init__).parameters
init_params = {key: value for key, value in all_params.items() if key in present_params} init_params = {
key: value for key, value in all_params.items() if key in present_params
}
instance = super().__call__(lock_file, **init_params) instance = super().__call__(lock_file, **init_params)
@ -239,7 +246,9 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
"poll_interval": poll_interval, "poll_interval": poll_interval,
"lifetime": lifetime, "lifetime": lifetime,
} }
self._context: FileLockContext = (ThreadLocalFileContext if thread_local else FileLockContext)(**kwargs) self._context: FileLockContext = (
ThreadLocalFileContext if thread_local else FileLockContext
)(**kwargs)
def is_thread_local(self) -> bool: def is_thread_local(self) -> bool:
""":returns: a flag indicating if this lock is thread local or not""" """:returns: a flag indicating if this lock is thread local or not"""
@ -403,10 +412,14 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
start_time: float, start_time: float,
) -> bool: ) -> bool:
if blocking is False: if blocking is False:
_LOGGER.debug("Failed to immediately acquire lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Failed to immediately acquire lock %s on %s", lock_id, lock_filename
)
return True return True
if cancel_check is not None and cancel_check(): if cancel_check is not None and cancel_check():
_LOGGER.debug("Cancellation requested for lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Cancellation requested for lock %s on %s", lock_id, lock_filename
)
return True return True
if 0 <= timeout < time.perf_counter() - start_time: if 0 <= timeout < time.perf_counter() - start_time:
_LOGGER.debug("Timeout on acquiring lock %s on %s", lock_id, lock_filename) _LOGGER.debug("Timeout on acquiring lock %s on %s", lock_id, lock_filename)
@ -470,7 +483,9 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
warnings.warn(msg, DeprecationWarning, stacklevel=2) warnings.warn(msg, DeprecationWarning, stacklevel=2)
poll_interval = poll_intervall poll_interval = poll_intervall
poll_interval = poll_interval if poll_interval is not None else self._context.poll_interval poll_interval = (
poll_interval if poll_interval is not None else self._context.poll_interval
)
# Increment the number right at the beginning. We can still undo it, if something fails. # Increment the number right at the beginning. We can still undo it, if something fails.
self._context.lock_counter += 1 self._context.lock_counter += 1
@ -479,8 +494,17 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
lock_filename = self.lock_file lock_filename = self.lock_file
canonical = _canonical(lock_filename) canonical = _canonical(lock_filename)
would_block = self._context.lock_counter == 1 and not self.is_locked and timeout < 0 and blocking would_block = (
if would_block and (existing := _registry.held.get(canonical)) is not None and existing != lock_id: self._context.lock_counter == 1
and not self.is_locked
and timeout < 0
and blocking
)
if (
would_block
and (existing := _registry.held.get(canonical)) is not None
and existing != lock_id
):
self._context.lock_counter -= 1 self._context.lock_counter -= 1
msg = ( msg = (
f"Deadlock: lock '{lock_filename}' is already held by a different " f"Deadlock: lock '{lock_filename}' is already held by a different "
@ -494,7 +518,9 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
while True: while True:
if not self.is_locked: if not self.is_locked:
self._try_break_expired_lock() self._try_break_expired_lock()
_LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Attempting to acquire lock %s on %s", lock_id, lock_filename
)
self._acquire() self._acquire()
if self.is_locked: if self.is_locked:
_LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename) _LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename)
@ -534,7 +560,9 @@ class BaseFileLock(contextlib.ContextDecorator, metaclass=FileLockMeta):
if self._context.lock_counter == 0 or force: if self._context.lock_counter == 0 or force:
lock_id, lock_filename = id(self), self.lock_file lock_id, lock_filename = id(self), self.lock_file
_LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Attempting to release lock %s on %s", lock_id, lock_filename
)
self._release() self._release()
self._context.lock_counter = 0 self._context.lock_counter = 0
_registry.held.pop(_canonical(lock_filename), None) _registry.held.pop(_canonical(lock_filename), None)

View File

@ -51,7 +51,11 @@ def timeout_for_sqlite(timeout: float, *, blocking: bool, already_waited: float)
remaining = max(timeout - already_waited, 0) if timeout > 0 else timeout remaining = max(timeout - already_waited, 0) if timeout > 0 else timeout
timeout_ms = int(remaining * 1000) timeout_ms = int(remaining * 1000)
if timeout_ms > _MAX_SQLITE_TIMEOUT_MS or timeout_ms < 0: if timeout_ms > _MAX_SQLITE_TIMEOUT_MS or timeout_ms < 0:
_LOGGER.warning("timeout %s is too large for SQLite, using %s ms instead", timeout, _MAX_SQLITE_TIMEOUT_MS) _LOGGER.warning(
"timeout %s is too large for SQLite, using %s ms instead",
timeout,
_MAX_SQLITE_TIMEOUT_MS,
)
return _MAX_SQLITE_TIMEOUT_MS return _MAX_SQLITE_TIMEOUT_MS
return timeout_ms return timeout_ms
@ -77,12 +81,16 @@ class _ReadWriteLockMeta(type):
is_singleton: bool = True, is_singleton: bool = True,
) -> ReadWriteLock: ) -> ReadWriteLock:
if not is_singleton: if not is_singleton:
return super().__call__(lock_file, timeout, blocking=blocking, is_singleton=is_singleton) return super().__call__(
lock_file, timeout, blocking=blocking, is_singleton=is_singleton
)
normalized = pathlib.Path(lock_file).resolve() normalized = pathlib.Path(lock_file).resolve()
with cls._instances_lock: with cls._instances_lock:
if normalized not in cls._instances: if normalized not in cls._instances:
instance = super().__call__(lock_file, timeout, blocking=blocking, is_singleton=is_singleton) instance = super().__call__(
lock_file, timeout, blocking=blocking, is_singleton=is_singleton
)
cls._instances[normalized] = instance cls._instances[normalized] = instance
else: else:
instance = cls._instances[normalized] instance = cls._instances[normalized]
@ -122,7 +130,11 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
@classmethod @classmethod
def get_lock( def get_lock(
cls, lock_file: str | os.PathLike[str], timeout: float = -1, *, blocking: bool = True cls,
lock_file: str | os.PathLike[str],
timeout: float = -1,
*,
blocking: bool = True,
) -> ReadWriteLock: ) -> ReadWriteLock:
""" """
Return the singleton :class:`ReadWriteLock` for *lock_file*. Return the singleton :class:`ReadWriteLock` for *lock_file*.
@ -149,8 +161,12 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
self.lock_file = os.fspath(lock_file) self.lock_file = os.fspath(lock_file)
self.timeout = timeout self.timeout = timeout
self.blocking = blocking self.blocking = blocking
self._transaction_lock = threading.Lock() # serializes the (possibly blocking) SQLite transaction work self._transaction_lock = (
self._internal_lock = threading.Lock() # protects _lock_level / _current_mode updates and rollback threading.Lock()
) # serializes the (possibly blocking) SQLite transaction work
self._internal_lock = (
threading.Lock()
) # protects _lock_level / _current_mode updates and rollback
self._lock_level = 0 self._lock_level = 0
self._current_mode: Literal["read", "write"] | None = None self._current_mode: Literal["read", "write"] | None = None
self._write_thread_id: int | None = None self._write_thread_id: int | None = None
@ -167,7 +183,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
if not acquired: if not acquired:
raise Timeout(self.lock_file) from None raise Timeout(self.lock_file) from None
def _validate_reentrant(self, mode: Literal["read", "write"], opposite: str, direction: str) -> AcquireReturnProxy: def _validate_reentrant(
self, mode: Literal["read", "write"], opposite: str, direction: str
) -> AcquireReturnProxy:
if self._current_mode != mode: if self._current_mode != mode:
msg = ( msg = (
f"Cannot acquire {mode} lock on {self.lock_file} (lock id: {id(self)}): " f"Cannot acquire {mode} lock on {self.lock_file} (lock id: {id(self)}): "
@ -184,10 +202,17 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
return AcquireReturnProxy(lock=self) return AcquireReturnProxy(lock=self)
def _configure_and_begin( def _configure_and_begin(
self, mode: Literal["read", "write"], timeout: float, *, blocking: bool, start_time: float self,
mode: Literal["read", "write"],
timeout: float,
*,
blocking: bool,
start_time: float,
) -> None: ) -> None:
waited = time.perf_counter() - start_time waited = time.perf_counter() - start_time
timeout_ms = timeout_for_sqlite(timeout, blocking=blocking, already_waited=waited) timeout_ms = timeout_for_sqlite(
timeout, blocking=blocking, already_waited=waited
)
self._con.execute(f"PRAGMA busy_timeout={timeout_ms};").close() self._con.execute(f"PRAGMA busy_timeout={timeout_ms};").close()
# Use legacy journal mode (not WAL) because WAL does not block readers when a concurrent EXCLUSIVE # Use legacy journal mode (not WAL) because WAL does not block readers when a concurrent EXCLUSIVE
# write transaction is active, making read-write locking impossible without modifying table data. # write transaction is active, making read-write locking impossible without modifying table data.
@ -199,16 +224,24 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
self._con.execute("PRAGMA journal_mode=MEMORY;").close() self._con.execute("PRAGMA journal_mode=MEMORY;").close()
# Recompute remaining timeout after the potentially blocking journal_mode pragma. # Recompute remaining timeout after the potentially blocking journal_mode pragma.
waited = time.perf_counter() - start_time waited = time.perf_counter() - start_time
if (recomputed := timeout_for_sqlite(timeout, blocking=blocking, already_waited=waited)) != timeout_ms: if (
recomputed := timeout_for_sqlite(
timeout, blocking=blocking, already_waited=waited
)
) != timeout_ms:
self._con.execute(f"PRAGMA busy_timeout={recomputed};").close() self._con.execute(f"PRAGMA busy_timeout={recomputed};").close()
stmt = "BEGIN EXCLUSIVE TRANSACTION;" if mode == "write" else "BEGIN TRANSACTION;" stmt = (
"BEGIN EXCLUSIVE TRANSACTION;" if mode == "write" else "BEGIN TRANSACTION;"
)
self._con.execute(stmt).close() self._con.execute(stmt).close()
if mode == "read": if mode == "read":
# A SELECT is needed to force SQLite to actually acquire the SHARED lock on the database. # A SELECT is needed to force SQLite to actually acquire the SHARED lock on the database.
# https://www.sqlite.org/lockingv3.html#transaction_control # https://www.sqlite.org/lockingv3.html#transaction_control
self._con.execute("SELECT name FROM sqlite_schema LIMIT 1;").close() self._con.execute("SELECT name FROM sqlite_schema LIMIT 1;").close()
def _acquire(self, mode: Literal["read", "write"], timeout: float, *, blocking: bool) -> AcquireReturnProxy: def _acquire(
self, mode: Literal["read", "write"], timeout: float, *, blocking: bool
) -> AcquireReturnProxy:
opposite = "write" if mode == "read" else "read" opposite = "write" if mode == "read" else "read"
direction = "downgrade" if mode == "read" else "upgrade" direction = "downgrade" if mode == "read" else "upgrade"
@ -224,7 +257,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
if self._lock_level > 0: if self._lock_level > 0:
return self._validate_reentrant(mode, opposite, direction) return self._validate_reentrant(mode, opposite, direction)
self._configure_and_begin(mode, timeout, blocking=blocking, start_time=start_time) self._configure_and_begin(
mode, timeout, blocking=blocking, start_time=start_time
)
with self._internal_lock: with self._internal_lock:
self._current_mode = mode self._current_mode = mode
@ -241,7 +276,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
finally: finally:
self._transaction_lock.release() self._transaction_lock.release()
def acquire_read(self, timeout: float = -1, *, blocking: bool = True) -> AcquireReturnProxy: def acquire_read(
self, timeout: float = -1, *, blocking: bool = True
) -> AcquireReturnProxy:
""" """
Acquire a shared read lock. Acquire a shared read lock.
@ -259,7 +296,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
""" """
return self._acquire("read", timeout, blocking=blocking) return self._acquire("read", timeout, blocking=blocking)
def acquire_write(self, timeout: float = -1, *, blocking: bool = True) -> AcquireReturnProxy: def acquire_write(
self, timeout: float = -1, *, blocking: bool = True
) -> AcquireReturnProxy:
""" """
Acquire an exclusive write lock. Acquire an exclusive write lock.
@ -309,7 +348,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
self._con.rollback() self._con.rollback()
@contextmanager @contextmanager
def read_lock(self, timeout: float | None = None, *, blocking: bool | None = None) -> Generator[None]: def read_lock(
self, timeout: float | None = None, *, blocking: bool | None = None
) -> Generator[None]:
""" """
Context manager that acquires and releases a shared read lock. Context manager that acquires and releases a shared read lock.
@ -330,7 +371,9 @@ class ReadWriteLock(metaclass=_ReadWriteLockMeta):
self.release() self.release()
@contextmanager @contextmanager
def write_lock(self, timeout: float | None = None, *, blocking: bool | None = None) -> Generator[None]: def write_lock(
self, timeout: float | None = None, *, blocking: bool | None = None
) -> Generator[None]:
""" """
Context manager that acquires and releases an exclusive write lock. Context manager that acquires and releases an exclusive write lock.

View File

@ -44,10 +44,13 @@ class SoftFileLock(BaseFileLock):
file_handler = os.open(self.lock_file, flags, self._open_mode()) file_handler = os.open(self.lock_file, flags, self._open_mode())
except OSError as exception: except OSError as exception:
if not ( if not (
exception.errno == EEXIST or (exception.errno == EACCES and sys.platform == "win32") exception.errno == EEXIST
or (exception.errno == EACCES and sys.platform == "win32")
): # pragma: win32 no cover ): # pragma: win32 no cover
raise raise
if exception.errno == EEXIST and sys.platform != "win32": # pragma: win32 no cover if (
exception.errno == EEXIST and sys.platform != "win32"
): # pragma: win32 no cover
self._try_break_stale_lock() self._try_break_stale_lock()
else: else:
self._write_lock_info(file_handler) self._write_lock_info(file_handler)

View File

@ -89,10 +89,17 @@ else: # pragma: win32 no cover
def _fallback_to_soft_lock(self) -> None: def _fallback_to_soft_lock(self) -> None:
from ._soft import SoftFileLock # noqa: PLC0415 from ._soft import SoftFileLock # noqa: PLC0415
warnings.warn("flock not supported on this filesystem, falling back to SoftFileLock", stacklevel=2) warnings.warn(
"flock not supported on this filesystem, falling back to SoftFileLock",
stacklevel=2,
)
from .asyncio import AsyncSoftFileLock, BaseAsyncFileLock # noqa: PLC0415 from .asyncio import AsyncSoftFileLock, BaseAsyncFileLock # noqa: PLC0415
self.__class__ = AsyncSoftFileLock if isinstance(self, BaseAsyncFileLock) else SoftFileLock self.__class__ = (
AsyncSoftFileLock
if isinstance(self, BaseAsyncFileLock)
else SoftFileLock
)
def _release(self) -> None: def _release(self) -> None:
fd = cast("int", self._context.lock_file_fd) fd = cast("int", self._context.lock_file_fd)

View File

@ -24,7 +24,9 @@ def raise_on_not_writable_file(filename: str) -> None:
except OSError: except OSError:
return # swallow does not exist or other errors return # swallow does not exist or other errors
if file_stat.st_mtime != 0: # if os.stat returns but modification is zero that's an invalid os.stat - ignore it if (
file_stat.st_mtime != 0
): # if os.stat returns but modification is zero that's an invalid os.stat - ignore it
if not (file_stat.st_mode & stat.S_IWUSR): if not (file_stat.st_mode & stat.S_IWUSR):
raise PermissionError(EACCES, "Permission denied", filename) raise PermissionError(EACCES, "Permission denied", filename)

View File

@ -55,7 +55,9 @@ if sys.platform == "win32": # pragma: win32 cover
# Security check: Refuse to open reparse points (symlinks, junctions) # Security check: Refuse to open reparse points (symlinks, junctions)
# This prevents TOCTOU symlink attacks (CVE-TBD) # This prevents TOCTOU symlink attacks (CVE-TBD)
if _is_reparse_point(self.lock_file): if _is_reparse_point(self.lock_file):
msg = f"Lock file is a reparse point (symlink/junction): {self.lock_file}" msg = (
f"Lock file is a reparse point (symlink/junction): {self.lock_file}"
)
raise OSError(msg) raise OSError(msg)
flags = ( flags = (

View File

@ -168,9 +168,9 @@ class BaseAsyncFileLock(BaseFileLock, metaclass=AsyncFileLockMeta):
"run_in_executor": run_in_executor, "run_in_executor": run_in_executor,
"executor": executor, "executor": executor,
} }
self._context: AsyncFileLockContext = (AsyncThreadLocalFileContext if thread_local else AsyncFileLockContext)( self._context: AsyncFileLockContext = (
**kwargs AsyncThreadLocalFileContext if thread_local else AsyncFileLockContext
) )(**kwargs)
@property @property
def run_in_executor(self) -> bool: def run_in_executor(self) -> bool:
@ -256,7 +256,9 @@ class BaseAsyncFileLock(BaseFileLock, metaclass=AsyncFileLockMeta):
while True: while True:
if not self.is_locked: if not self.is_locked:
self._try_break_expired_lock() self._try_break_expired_lock()
_LOGGER.debug("Attempting to acquire lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Attempting to acquire lock %s on %s", lock_id, lock_filename
)
await self._run_internal_method(self._acquire) await self._run_internal_method(self._acquire)
if self.is_locked: if self.is_locked:
_LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename) _LOGGER.debug("Lock %s acquired on %s", lock_id, lock_filename)
@ -278,7 +280,9 @@ class BaseAsyncFileLock(BaseFileLock, metaclass=AsyncFileLockMeta):
raise raise
return AsyncAcquireReturnProxy(lock=self) return AsyncAcquireReturnProxy(lock=self)
async def release(self, force: bool = False) -> None: # ty: ignore[invalid-method-override] # noqa: FBT001, FBT002 async def release(
self, force: bool = False
) -> None: # ty: ignore[invalid-method-override] # noqa: FBT001, FBT002
""" """
Release the file lock. The lock is only completely released when the lock counter reaches 0. The lock file Release the file lock. The lock is only completely released when the lock counter reaches 0. The lock file
itself is not automatically deleted. itself is not automatically deleted.
@ -292,7 +296,9 @@ class BaseAsyncFileLock(BaseFileLock, metaclass=AsyncFileLockMeta):
if self._context.lock_counter == 0 or force: if self._context.lock_counter == 0 or force:
lock_id, lock_filename = id(self), self.lock_file lock_id, lock_filename = id(self), self.lock_file
_LOGGER.debug("Attempting to release lock %s on %s", lock_id, lock_filename) _LOGGER.debug(
"Attempting to release lock %s on %s", lock_id, lock_filename
)
await self._run_internal_method(self._release) await self._run_internal_method(self._release)
self._context.lock_counter = 0 self._context.lock_counter = 0
_LOGGER.debug("Lock %s released on %s", lock_id, lock_filename) _LOGGER.debug("Lock %s released on %s", lock_id, lock_filename)

View File

@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID commit_id: COMMIT_ID
__commit_id__: COMMIT_ID __commit_id__: COMMIT_ID
__version__ = version = '3.24.3' __version__ = version = "3.24.3"
__version_tuple__ = version_tuple = (3, 24, 3) __version_tuple__ = version_tuple = (3, 24, 3)
__commit_id__ = commit_id = None __commit_id__ = commit_id = None

View File

@ -3,4 +3,3 @@ Generator: setuptools (75.5.0)
Root-Is-Purelib: true Root-Is-Purelib: true
Tag: py2-none-any Tag: py2-none-any
Tag: py3-none-any Tag: py3-none-any

Some files were not shown because too many files have changed in this diff Show More