1 package at.rseiler.spbee.core.generator;
2
3 import at.rseiler.spbee.core.pojo.*;
4 import at.rseiler.spbee.core.util.CodeModelUtil;
5 import com.sun.codemodel.*;
6
7 import javax.annotation.processing.ProcessingEnvironment;
8 import javax.sql.DataSource;
9 import java.io.IOException;
10 import java.sql.Connection;
11 import java.sql.SQLException;
12 import java.sql.Types;
13 import java.util.HashSet;
14 import java.util.List;
15 import java.util.Map;
16 import java.util.Set;
17
18
19
20
21
22
23 public class StoredProcedureGenerator extends AbstractGenerator {
24
25 private static final String SPRING_SQL_RETURN_RESULT_SET = "org.springframework.jdbc.core.SqlReturnResultSet";
26 private static final String SPRING_STORED_PROCEDURE = "org.springframework.jdbc.object.StoredProcedure";
27 private static final String SPRING_SQL_PARAMETER = "org.springframework.jdbc.core.SqlParameter";
28
29 private final Map<String, ResultSetClass> resultSetsMap;
30
31 public StoredProcedureGenerator(ProcessingEnvironment processingEnv, Map<String, ResultSetClass> resultSetsMap) {
32 super(processingEnv);
33 this.resultSetsMap = resultSetsMap;
34 }
35
36 public void generateStoredProcedureClasses(List<DtoClass> dtoClasses) throws JClassAlreadyExistsException, IOException {
37 Set<String> storedProcedureNames = new HashSet<>();
38
39 for (DtoClass dtoClass : dtoClasses) {
40 for (StoredProcedureMethod storedProcedureMethod : dtoClass.getStoredProcedureMethods()) {
41
42 if (!storedProcedureNames.contains(storedProcedureMethod.getQualifiedClassName())) {
43 generateStoredProcedure(storedProcedureMethod);
44 storedProcedureNames.add(storedProcedureMethod.getQualifiedClassName());
45 }
46 }
47 }
48 }
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 private void generateStoredProcedure(StoredProcedureMethod storedProcedureMethod) throws JClassAlreadyExistsException, IOException {
71 JCodeModel model = new JCodeModel();
72 JDefinedClass spClass = createClass(model, storedProcedureMethod.getPackage(), storedProcedureMethod.getSimpleClassName());
73 createConstructor(model, spClass, storedProcedureMethod);
74 addExecuteMethod(model, spClass, storedProcedureMethod);
75
76 generateClass(model, storedProcedureMethod.getQualifiedClassName());
77 }
78
79
80
81
82
83
84
85 private JDefinedClass createClass(JCodeModel model, String spPackage, String className) throws JClassAlreadyExistsException {
86 JPackage jPackage = model._package(spPackage);
87 JDefinedClass aClass = jPackage._class(className);
88 CodeModelUtil.annotateGenerated(aClass);
89 aClass._extends(model.directClass(SPRING_STORED_PROCEDURE));
90 return aClass;
91 }
92
93
94
95
96
97
98
99
100
101
102
103
104 private void createConstructor(JCodeModel model, JDefinedClass spClass, StoredProcedureMethod storedProcedureMethod) {
105 JMethod constructor = spClass.constructor(JMod.PUBLIC);
106 JVar dataSource = constructor.param(DataSource.class, "dataSource");
107 JBlock body = constructor.body();
108 body.add(JExpr.invoke("super").arg(dataSource).arg(storedProcedureMethod.getStoredProcedureName()));
109 declareSqlParameters(model, storedProcedureMethod, body);
110 declareResultSets(model, body, storedProcedureMethod);
111 body.add(JExpr.invoke("compile"));
112 }
113
114
115
116
117
118
119
120 private void declareSqlParameters(JCodeModel model, StoredProcedureMethod storedProcedureMethod, JBlock body) {
121 for (Variable variable : storedProcedureMethod.getArguments()) {
122 String sqlType = getSqlParameter(variable.getTypeInfo().getGenericTypeOrType());
123 JInvocation sqlParameter = JExpr._new(model.directClass(SPRING_SQL_PARAMETER));
124 sqlParameter.arg(variable.getName());
125 sqlParameter.arg(model.directClass(Types.class.getCanonicalName()).staticRef(sqlType));
126 body.add(JExpr.invoke("declareParameter").arg(sqlParameter));
127 }
128 }
129
130
131
132
133
134
135
136 private void declareResultSets(JCodeModel model, JBlock body, StoredProcedureMethod storedProcedureMethod) {
137 String type = storedProcedureMethod.getReturnTypeInfo().getGenericType().orElse(storedProcedureMethod.getReturnTypeInfo().getType());
138
139 if (!"void".equals(type)) {
140 if (resultSetsMap.containsKey(type)) {
141 multipleResultSets(model, body, type);
142 } else {
143 singleResultSet(model, body, storedProcedureMethod);
144 }
145 }
146 }
147
148
149
150
151
152
153
154
155
156 private void multipleResultSets(JCodeModel model, JBlock body, String type) {
157 List<ResultSetVariable> variables = resultSetsMap.get(type).getResultSetVariables();
158
159 for (int i = 0; i < variables.size(); i++) {
160 ResultSetVariable variable = variables.get(i);
161 JInvocation newMapper = JExpr._new(model.directClass(variable.getRowMapper()));
162 JInvocation sqlReturnResultSet = JExpr._new(model.directClass(SPRING_SQL_RETURN_RESULT_SET)).arg("#result-set-" + i).arg(newMapper);
163 body.add(JExpr.invoke("declareParameter").arg(sqlReturnResultSet));
164 }
165 }
166
167
168
169
170
171
172
173
174
175 private void singleResultSet(JCodeModel model, JBlock body, StoredProcedureMethod storedProcedureMethod) {
176 JInvocation newMapper = JExpr._new(model.directClass(storedProcedureMethod.getQualifiedRowMapperClass()));
177 JInvocation sqlReturnResultSet = JExpr._new(model.directClass(SPRING_SQL_RETURN_RESULT_SET)).arg("#result-set-0").arg(newMapper);
178 body.add(JExpr.invoke("declareParameter").arg(sqlReturnResultSet));
179 }
180
181
182
183
184
185
186
187
188 private void addExecuteMethod(JCodeModel model, JDefinedClass aClass, StoredProcedureMethod storedProcedureMethod) {
189 JMethod method = aClass.method(JMod.PUBLIC, CodeModelUtil.getMapStringObject(model), "execute");
190 JInvocation superExecute = JExpr._super().invoke("execute");
191 JBlock block = method.body();
192 boolean hasArrayType = storedProcedureMethod.getArguments().stream().anyMatch(this::isArrayType);
193 JVar connection = null;
194
195 if (hasArrayType) {
196 connection = block.decl(model.directClass(Connection.class.getCanonicalName()), "conn");
197 block.assign(connection, JExpr._null());
198
199 JTryBlock tryBlock = method.body()._try();
200
201 JCatchBlock catchBlock = tryBlock._catch(model.directClass(SQLException.class.getCanonicalName()));
202 JInvocation exception = JExpr._new(model.directClass("org.springframework.jdbc.UncategorizedSQLException"))
203 .arg(JExpr._this().invoke("getClass").invoke("getCanonicalName"))
204 .arg(JExpr._this().invoke("getSql"))
205 .arg(catchBlock.param("e"));
206 catchBlock.body()._throw(exception);
207
208 JTryBlock connCloseTryBlock = tryBlock._finally()._if(connection.ne(JExpr._null()))._then()._try();
209 connCloseTryBlock.body().add(connection.invoke("close"));
210 JCatchBlock closeConnCatchBlock = connCloseTryBlock._catch(model.directClass(SQLException.class.getCanonicalName()));
211 JInvocation runtimeException = JExpr._new(model.directClass(RuntimeException.class.getCanonicalName()))
212 .arg(JExpr.lit("Failed to close connection"))
213 .arg(closeConnCatchBlock.param("e"));
214 closeConnCatchBlock.body()._throw(runtimeException);
215
216
217 block = tryBlock.body();
218 block.assign(connection, JExpr.invoke("getJdbcTemplate").invoke("getDataSource").invoke("getConnection"));
219 }
220
221 for (Variable variable : storedProcedureMethod.getArguments()) {
222 JVar param = method.param(model.directClass(variable.getTypeInfo().asString()), variable.getName());
223
224 if (hasArrayType && isArrayType(variable)) {
225 superExecute.arg(connection.invoke("createArrayOf").arg(getArrayType(variable.getTypeInfo().asString())).arg(param));
226 } else {
227 superExecute.arg(param);
228 }
229 }
230
231 block._return(superExecute);
232 }
233
234 private boolean isArrayType(Variable variable) {
235 return variable.getTypeInfo().asString().contains("[]");
236 }
237
238
239
240
241
242
243
244 public static String getSqlParameter(String type) {
245 switch (type) {
246 case "boolean":
247 case "java.lang.Boolean":
248 return "BOOLEAN";
249 case "byte":
250 case "java.lang.Byte":
251 return "TINYINT";
252 case "short":
253 case "java.lang.Short":
254 return "SMALLINT";
255 case "int":
256 case "java.lang.Integer":
257 return "INTEGER";
258 case "long":
259 case "java.lang.Long":
260 return "BIGINT";
261 case "float":
262 case "java.lang.Float":
263 return "FLOAT";
264 case "double":
265 case "java.lang.Double":
266 return "DOUBLE";
267 case "byte[]":
268 case "java.lang.Byte[]":
269 return "BINARY";
270 case "java.lang.String":
271 return "VARCHAR";
272 case "java.sql.Date":
273 case "java.util.Date":
274 return "DATE";
275 case "java.math.BigDecimal":
276 return "DOUBLE";
277 default:
278 if (type.contains("[]")) {
279 return "ARRAY";
280 }
281
282 throw new RuntimeException("Unknown SqlType: " + type);
283 }
284 }
285
286 private String getArrayType(String type) {
287 switch (type) {
288 case "java.lang.Boolean[]":
289 return "bool";
290 case "java.lang.Character[]":
291 return "char";
292 case "java.lang.Byte[]":
293 return "smallint";
294 case "java.lang.Short[]":
295 return "smallint";
296 case "java.lang.Integer[]":
297 return "int";
298 case "java.lang.Long[]":
299 return "bigint";
300 case "java.lang.Float[]":
301 return "float";
302 case "java.lang.Decimal[]":
303 case "java.math.BigDecimal[]":
304 return "numeric";
305 case "java.lang.String[]":
306 return "varchar";
307 case "java.sql.Date[]":
308 case "java.util.Date[]":
309 return "date";
310 }
311
312 throw new RuntimeException("Unknown array type: " + type);
313 }
314
315 }