Skip to content

Commit 22f12fb

Browse files
committed
Prevent the default value from being used when values were already passed
1 parent 2be19f3 commit 22f12fb

File tree

3 files changed

+79
-2
lines changed

3 files changed

+79
-2
lines changed

cwltool/argparser.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,37 @@ class DirectoryAppendAction(FSAppendAction):
817817
objclass = "Directory"
818818

819819

820+
class AppendAction(argparse.Action):
821+
"""An argparse action that clears the default values if any value is provided.
822+
823+
Attributes:
824+
_called (bool): Initially set to ``False``, changed if any value is appended.
825+
"""
826+
def __init__(self,
827+
option_strings: List[str],
828+
dest: str,
829+
nargs: Any = None,
830+
**kwargs: Any,) -> None:
831+
"""Intialize."""
832+
super().__init__(option_strings, dest, **kwargs)
833+
self._called = False
834+
835+
def __call__(self,
836+
parser: argparse.ArgumentParser,
837+
namespace: argparse.Namespace,
838+
values: Union[str, Sequence[Any], None],
839+
option_string: Optional[str] = None,) -> None:
840+
g = getattr(namespace, self.dest, None)
841+
if g is None:
842+
g = []
843+
if values is not None and not self._called:
844+
# If any value was specified, we then clear the list of options before appending.
845+
# We cannot always clear the ``default`` attribute since it collects the ``values`` appended.
846+
self.default.clear()
847+
self._called = True
848+
g.append(values)
849+
850+
820851
def add_argument(
821852
toolparser: argparse.ArgumentParser,
822853
name: str,
@@ -864,7 +895,7 @@ def add_argument(
864895
elif inptype["items"] == "Directory":
865896
action = DirectoryAppendAction
866897
else:
867-
action = "append"
898+
action = AppendAction
868899
elif isinstance(inptype, MutableMapping) and inptype["type"] == "enum":
869900
atype = str
870901
elif isinstance(inptype, MutableMapping) and inptype["type"] == "record":

tests/default_values_list.cwl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/usr/bin/env cwl-runner
2+
# From https://github.com/common-workflow-language/cwltool/issues/1632
3+
4+
cwlVersion: v1.2
5+
class: CommandLineTool
6+
7+
baseCommand: [cat]
8+
9+
stdout: "cat_file"
10+
11+
inputs:
12+
file_paths:
13+
type: string[]?
14+
inputBinding:
15+
position: 1
16+
default: ["/home/bart/cwl_test/test1"]
17+
18+
outputs:
19+
output:
20+
type: stdout

tests/test_toolargparse.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
from io import StringIO
33
from pathlib import Path
4-
from typing import Callable
4+
from typing import Any, Callable, Dict, List
55

66
import pytest
77

@@ -195,3 +195,29 @@ def test_argparser_without_doc() -> None:
195195
p = argparse.ArgumentParser()
196196
parser = generate_parser(p, tool, {}, [], False)
197197
assert parser.description is None
198+
199+
200+
@pytest.mark.parametrize(
201+
"job_order,expected_values", [
202+
# no arguments, so we expect the default value
203+
([], ['/home/bart/cwl_test/test1']),
204+
# arguments, provided, one or many, meaning that the default value is not expected
205+
([
206+
'--file_paths',
207+
'/home/bart/cwl_test/test2'
208+
], ['/home/bart/cwl_test/test2']),
209+
([
210+
'--file_paths',
211+
'/home/bart/cwl_test/test2',
212+
'--file_paths',
213+
'/home/bart/cwl_test/test3'
214+
], ['/home/bart/cwl_test/test2', '/home/bart/cwl_test/test3'])
215+
])
216+
def test_argparse_append_with_default(job_order: List[str], expected_values: List[str]) -> None:
217+
"""The appended arguments must not include the default. But if no appended argument, then the default is used."""
218+
loadingContext = LoadingContext()
219+
tool = load_tool(get_data("tests/default_values_list.cwl"), loadingContext)
220+
toolparser = generate_parser(argparse.ArgumentParser(prog='test'), tool, {}, [], False)
221+
cmd_line = vars(toolparser.parse_args(job_order)) # type: ignore[call-overload]
222+
file_paths = list(cmd_line['file_paths'])
223+
assert expected_values == file_paths

0 commit comments

Comments
 (0)