1
- """Formats all code snippets in modules using clang-format"""
1
+ """Formats all code snippets in mdx files using clang-format or black-with-tabs """
2
2
3
3
import os
4
4
import subprocess
5
5
import sys
6
+ import tempfile
7
+ import black
8
+ from typing import List
9
+
6
10
7
11
# see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for more options
8
- STYLE = """
12
+ CLANG_FORMAT_STYLE = """
9
13
TabWidth: 4
10
14
IndentWidth: 4
11
15
UseTab: ForIndentation
16
20
AllowShortLoopsOnASingleLine: true
17
21
SpacesBeforeTrailingComments: 2
18
22
"""
19
- STYLE = "{" + ", " .join (filter (lambda x : x != "" , STYLE .split ("\n " ))) + "}"
23
+ CLANG_FORMAT_STYLE = (
24
+ "{" + ", " .join (filter (lambda x : x != "" , CLANG_FORMAT_STYLE .split ("\n " ))) + "}"
25
+ )
20
26
21
27
FORMAT_CPP = True
22
28
FORMAT_PYTHON = True
23
29
FORMAT_JAVA = True
24
30
31
+
25
32
def lead_white (line ):
26
33
return len (line ) - len (line .lstrip ())
27
34
35
+
28
36
def match_indentation (line , prog_line ):
29
- return prog_line [min (lead_white (line ), lead_white (prog_line )):]
37
+ return prog_line [min (lead_white (line ), lead_white (prog_line )) :]
38
+
39
+
40
+ def contains_banned_terms (prog : List [str ]):
41
+ banned = ["CodeSnip" , "while (???)" ]
42
+ return any (ban in prog_line for ban in banned for prog_line in prog )
43
+
44
+
45
+ def format_prog_py (prog : List [str ]):
46
+ try :
47
+ prog = black .format_file_contents ("" .join (prog ), fast = True , mode = black .FileMode ()).splitlines (True )
48
+ except black .report .NothingChanged :
49
+ pass
50
+ return prog
51
+
52
+
53
+ def format_prog_clang (lang : str , prog : List [str ]):
54
+ tmp_file = tempfile .NamedTemporaryFile (suffix = f".{ lang } " )
55
+ with open (tmp_file .name , "w" ) as f :
56
+ f .write ("" .join (prog ))
57
+ subprocess .check_output (
58
+ [
59
+ f"clang-format -i -style='{ CLANG_FORMAT_STYLE } ' { f .name } "
60
+ ],
61
+ shell = True ,
62
+ )
63
+ with open (tmp_file .name , "r" ) as f :
64
+ return f .readlines ()
65
+
66
+ def format_prog (lang : str , prog : List [str ]):
67
+ if lang == "py" :
68
+ return format_prog_py (prog )
69
+ else :
70
+ assert lang in ['cpp' , 'java' ]
71
+ return format_prog_clang (lang , prog )
72
+
30
73
31
74
def format_path (path : str ):
32
75
print ("formatting" , path )
@@ -47,38 +90,15 @@ def format_path(path: str):
47
90
nlines .append (line )
48
91
elif line .strip () == "```" :
49
92
if lang is not None :
50
- banned = ["CodeSnip" , "while (???)" ]
51
- if not any (ban in prog_line for ban in banned for prog_line in prog ):
52
- TMP_FILE = f"tmp.{ lang } "
53
- with open (TMP_FILE , "w" ) as f :
54
- f .write ("" .join ([match_indentation (line , prog_line ) for prog_line in prog ]))
55
- found_error = False
56
- try :
57
- if lang == "cpp" or lang == "java" :
58
- subprocess .check_output (
59
- [f"clang-format -i -style='{ STYLE } ' { TMP_FILE } " ], shell = True
60
- )
61
- elif lang == "py" :
62
- subprocess .check_output (
63
- [f"black --fast { TMP_FILE } " ], shell = True # black with tabs
64
- )
65
- else :
66
- assert False
67
- except subprocess .CalledProcessError as e :
68
- print ("ERROR!" )
69
- print (e )
70
- print ("path =" , path )
71
- print ("" .join (prog ))
72
- print ()
73
- found_error = True
74
- nlines += prog
75
- if not found_error :
76
- with open (TMP_FILE , "r" ) as f :
77
- whitespace = line [: line .find ("```" )]
78
- prog_lines = f .readlines ()
79
- nlines += [whitespace + line for line in prog_lines ]
80
- os .remove (TMP_FILE )
81
- else : # don't format
93
+ if not contains_banned_terms (prog ):
94
+ prog = [
95
+ match_indentation (line , prog_line )
96
+ for prog_line in prog
97
+ ]
98
+ prog = format_prog (lang , prog )
99
+ whitespace = line [: line .find ("```" )]
100
+ nlines += [whitespace + line for line in prog ]
101
+ else : # don't format
82
102
nlines += prog
83
103
prog = []
84
104
lang = None
@@ -94,9 +114,19 @@ def format_path(path: str):
94
114
assert lang is None
95
115
96
116
97
- # https://stackoverflow.com/questions/2212643/python-recursive-folder-read
98
- for root , subdirs , files in os .walk ("/Users/benq/Documents/usaco-guide/content/" ):
99
- for file in files :
100
- if file .endswith (".mdx" ):
101
- path = os .path .join (root , file )
102
- format_path (path )
117
+ def format_dir (dir ):
118
+ # https://stackoverflow.com/questions/2212643/python-recursive-folder-read
119
+ for root , subdirs , files in os .walk (dir ):
120
+ for file in files :
121
+ if file .endswith (".mdx" ):
122
+ path = os .path .join (root , file )
123
+ format_path (path )
124
+
125
+
126
+ if __name__ == "__main__" :
127
+ """
128
+ Args: paths of files to format
129
+ """
130
+ # raise ValueError(sys.argv[1:])
131
+ for path in sys .argv [1 :]:
132
+ format_path (path )
0 commit comments