-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsmooth_compile
More file actions
executable file
·128 lines (107 loc) · 5.61 KB
/
smooth_compile
File metadata and controls
executable file
·128 lines (107 loc) · 5.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/bash
# Copyright 2023, 2024 Philipp Andelfinger, Justin Kreikemeyer
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the “Software”), to deal in the Software without
# restriction, including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR
# ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
set -e
set -o pipefail
# set this variable and uncomment "libtorch_flags" below to compile with torch support
#libtorch_root=$HOME/libtorch
#libtorch_flags="-I$libtorch_root/include -I$libtorch_root/include/torch/csrc/api/include -Wl,-rpath,$libtorch_root/lib $libtorch_root/lib/libtorch.so $libtorch_root/lib/libkineto.a $libtorch_root/lib/libtorch_cpu.so $libtorch_root/lib/libc10.so"
opt_flags="-gdwarf-4 -O3 -g -march=native -flto -mf16c -ffast-math -Wno-unknown-warning-option -Wno-nan-infinity-disabled"
function preface {
echo -n "$1"
local d="${1-}" f=${2-}
if shift 2; then
printf %s "$f" "${@/#/ $d}"
fi
}
src_fname=$1
prefix=${src_fname%.*}
ad_flag=FW_AD
args_fname=backend/args.cpp
shift
program_user_flags=()
dgo_user_flags=()
compile_versions=("crisp" "crisp_ad" "pgo" "reinforce" "rloo" "dgo")
for elem in $@; do
if [[ $elem == -DDGO* ]]; then
dgo_user_flags+="$elem "
elif [[ $elem == -D* ]]; then
program_user_flags+="$elem "
elif [[ $elem == -C* ]]; then
IFS=',' read -r -a compile_versions <<< "${elem:2}"
fi
done
program_user_flags_suffix=""
if [ -n "$program_user_flags" ]; then
program_user_flags_suffix=_$(echo $program_user_flags | tr " " "_")
fi
cpp_flags="-fdiagnostics-color=always -Wall -std=c++20 -Ibackend $program_user_flags"
dgo_cpp_flags="$cpp_flags $dgo_user_flags"
dgo_user_flags_suffix=""
if [ -n "$dgo_user_flags" ]; then
dgo_user_flags_suffix=_$(echo $dgo_user_flags | tr " " "_")
fi
smooth_flags=$(preface --extra-arg= $cpp_flags -DNO_AD -DCRISP $libtorch_flags -Wno-unused-command-line-argument)
insert_func_incr_flags=$(preface --extra-arg= $cpp_flags -DAD -DDGO $libtorch_flags -Wno-unused-command-line-argument)
CPATH=$CPATH:$(clang++ -v 2>&1| grep 'Selected GCC installation' | awk '{print $NF}')/include
# crisp with optional sampling, no automatic differentiation
crisp() {
echo "Compiling crisp version as ${prefix}${program_user_flags_suffix}_crisp..."
clang++ $cpp_flags $opt_flags -I. $src_fname $args_fname -DCRISP -DNO_AD -o ${prefix}${program_user_flags_suffix}_crisp $libtorch_flags || exit
echo "Finished compiling ${prefix}_crisp"
}
# crisp with automatic differentiation
crisp_ad() {
echo "Compiling crisp version with AD as ${prefix}${program_user_flags_suffix}_crisp_ad..."
clang++ $cpp_flags $opt_flags -I. $src_fname $args_fname -DCRISP -DFW_AD -o ${prefix}${program_user_flags_suffix}_crisp_ad $libtorch_flags || exit
echo "Finished compiling ${prefix}_crisp_ad"
}
# Polyak Gradient Oracle (PGO)
pgo() {
echo "Compiling Polyak Gradient Oracle version as ${prefix}${program_user_flags_suffix}_pgo..."
clang++ $cpp_flags $opt_flags -I. $src_fname $args_fname -DPGO -DNO_AD -o ${prefix}${program_user_flags_suffix}_pgo $libtorch_flags || exit
echo "Finished compiling ${prefix}_pgo"
}
# REINFORCE
reinforce() {
echo "Compiling REINFORCE version as ${prefix}${program_user_flags_suffix}_reinforce..."
clang++ $cpp_flags $opt_flags -I. $src_fname $args_fname -DREINFORCE -DNO_AD -o ${prefix}${program_user_flags_suffix}_reinforce $libtorch_flags || exit
echo "Finished compiling ${prefix}_reinforce"
}
# RLOO
rloo() {
echo "Compiling RLOO version as ${prefix}${program_user_flags_suffix}_rloo..."
clang++ $cpp_flags $opt_flags -I. $src_fname $args_fname -DRLOO -DNO_AD -o ${prefix}${program_user_flags_suffix}_rloo $libtorch_flags || exit
echo "Finished compiling ${prefix}_rloo"
}
# DiscoGrad Gradient Oracle (DGO)
dgo() {
normalized_fname="${prefix}${program_user_flags_suffix}${dgo_user_flags_suffix}_normalized.cpp"
smoothed_fname="${prefix}${program_user_flags_suffix}${dgo_user_flags_suffix}_smoothed.cpp"
dgo_fname="${prefix}${program_user_flags_suffix}${dgo_user_flags_suffix}_dgo.cpp"
echo "Compiling DiscoGrad Gradient Oracle versions as ${prefix}${program_user_flags_suffix}_dgo${dgo_user_flags_suffix}..."
CPATH=${CPATH} ./transformation/normalize $smooth_flags $src_fname | clang-format > $normalized_fname || exit
CPATH=${CPATH} ./transformation/smooth_dgo $smooth_flags $normalized_fname | clang-format > $smoothed_fname || exit
CPATH=${CPATH} ./transformation/insert_func_incr $insert_func_incr_flags $smoothed_fname | clang-format > $dgo_fname || exit
clang++ $cpp_flags $dgo_user_flags $opt_flags -I. $dgo_fname $args_fname backend/discograd_gradient_oracle/globals.cpp -DDGO -D$ad_flag -o ${prefix}${program_user_flags_suffix}_dgo${dgo_user_flags_suffix} $libtorch_flags || exit
echo "Finished compiling ${prefix}_dgo"
}
for version in "${compile_versions[@]}"; do
$version &
done
wait