View Javadoc
1   package at.rseiler.spbee.core.generator;
2   
3   import at.rseiler.spbee.core.pojo.*;
4   import at.rseiler.spbee.core.util.CodeModelUtil;
5   import com.sun.codemodel.*;
6   
7   import javax.annotation.processing.ProcessingEnvironment;
8   import javax.sql.DataSource;
9   import java.io.IOException;
10  import java.sql.Connection;
11  import java.sql.SQLException;
12  import java.sql.Types;
13  import java.util.HashSet;
14  import java.util.List;
15  import java.util.Map;
16  import java.util.Set;
17  
18  /**
19   * Generator for the StoredProcedure classes.
20   *
21   * @author Reinhard Seiler {@literal <rseiler.developer@gmail.com>}
22   */
23  public class StoredProcedureGenerator extends AbstractGenerator {
24  
25      private static final String SPRING_SQL_RETURN_RESULT_SET = "org.springframework.jdbc.core.SqlReturnResultSet";
26      private static final String SPRING_STORED_PROCEDURE = "org.springframework.jdbc.object.StoredProcedure";
27      private static final String SPRING_SQL_PARAMETER = "org.springframework.jdbc.core.SqlParameter";
28  
29      private final Map<String, ResultSetClass> resultSetsMap;
30  
31      public StoredProcedureGenerator(ProcessingEnvironment processingEnv, Map<String, ResultSetClass> resultSetsMap) {
32          super(processingEnv);
33          this.resultSetsMap = resultSetsMap;
34      }
35  
36      public void generateStoredProcedureClasses(List<DtoClass> dtoClasses) throws JClassAlreadyExistsException, IOException {
37          Set<String> storedProcedureNames = new HashSet<>();
38  
39          for (DtoClass dtoClass : dtoClasses) {
40              for (StoredProcedureMethod storedProcedureMethod : dtoClass.getStoredProcedureMethods()) {
41                  // checks if the stored procedure already exists
42                  if (!storedProcedureNames.contains(storedProcedureMethod.getQualifiedClassName())) {
43                      generateStoredProcedure(storedProcedureMethod);
44                      storedProcedureNames.add(storedProcedureMethod.getQualifiedClassName());
45                  }
46              }
47          }
48      }
49  
50      /**
51       * Example:
52       * <p>
53       * <pre>
54       * public class * extends StoredProcedure {
55       *
56       *     public *(DataSource dataSource) {
57       *         super(dataSource, "*");
58       *         [ declareParameter(new SqlParameter("*", Types.*)); ]*
59       *         [ declareParameter(new SqlReturnResultSet("#result-set-*", new *Mapper())); ]*
60       *         compile();
61       *     }
62       *
63       *     public Map&lt;String, Object&gt; execute([ * * ]*) {
64       *         return super.execute([ * ]*);
65       *     }
66       *
67       * }
68       * </pre>
69       */
70      private void generateStoredProcedure(StoredProcedureMethod storedProcedureMethod) throws JClassAlreadyExistsException, IOException {
71          JCodeModel model = new JCodeModel();
72          JDefinedClass spClass = createClass(model, storedProcedureMethod.getPackage(), storedProcedureMethod.getSimpleClassName());
73          createConstructor(model, spClass, storedProcedureMethod);
74          addExecuteMethod(model, spClass, storedProcedureMethod);
75  
76          generateClass(model, storedProcedureMethod.getQualifiedClassName());
77      }
78  
79      /**
80       * Generates:
81       * <pre>
82       * public class * extends StoredProcedure {className}
83       * </pre>
84       */
85      private JDefinedClass createClass(JCodeModel model, String spPackage, String className) throws JClassAlreadyExistsException {
86          JPackage jPackage = model._package(spPackage);
87          JDefinedClass aClass = jPackage._class(className);
88          CodeModelUtil.annotateGenerated(aClass);
89          aClass._extends(model.directClass(SPRING_STORED_PROCEDURE));
90          return aClass;
91      }
92  
93      /**
94       * Generates:
95       * <pre>
96       * public {SP_CLASS_NAME}(DataSource dataSource) {
97       *      super(dataSource, {SP_NAME});
98       *      [ declareParameter(new SqlParameter({PARAMETER_NAME}, Types.*)); ]*
99       *      [ declareParameter(new SqlReturnResultSet("#result-set-*", new *Mapper())); ]*
100      *      compile();
101      * }
102      * </pre>
103      */
104     private void createConstructor(JCodeModel model, JDefinedClass spClass, StoredProcedureMethod storedProcedureMethod) {
105         JMethod constructor = spClass.constructor(JMod.PUBLIC);
106         JVar dataSource = constructor.param(DataSource.class, "dataSource");
107         JBlock body = constructor.body();
108         body.add(JExpr.invoke("super").arg(dataSource).arg(storedProcedureMethod.getStoredProcedureName()));
109         declareSqlParameters(model, storedProcedureMethod, body);
110         declareResultSets(model, body, storedProcedureMethod);
111         body.add(JExpr.invoke("compile"));
112     }
113 
114     /**
115      * Generates:
116      * <pre>
117      * [ declareParameter(new SqlParameter({PARAMETER_NAME}, Types.*)); ]*
118      * </pre>
119      */
120     private void declareSqlParameters(JCodeModel model, StoredProcedureMethod storedProcedureMethod, JBlock body) {
121         for (Variable variable : storedProcedureMethod.getArguments()) {
122             String sqlType = getSqlParameter(variable.getTypeInfo().getGenericTypeOrType());
123             JInvocation sqlParameter = JExpr._new(model.directClass(SPRING_SQL_PARAMETER));
124             sqlParameter.arg(variable.getName());
125             sqlParameter.arg(model.directClass(Types.class.getCanonicalName()).staticRef(sqlType));
126             body.add(JExpr.invoke("declareParameter").arg(sqlParameter));
127         }
128     }
129 
130     /**
131      * Generates:
132      * <pre>
133      * [ declareParameter(new SqlReturnResultSet("#result-set-*", new *Mapper())); ]*
134      * </pre>
135      */
136     private void declareResultSets(JCodeModel model, JBlock body, StoredProcedureMethod storedProcedureMethod) {
137         String type = storedProcedureMethod.getReturnTypeInfo().getGenericType().orElse(storedProcedureMethod.getReturnTypeInfo().getType());
138 
139         if (!"void".equals(type)) {
140             if (resultSetsMap.containsKey(type)) {
141                 multipleResultSets(model, body, type);
142             } else {
143                 singleResultSet(model, body, storedProcedureMethod);
144             }
145         }
146     }
147 
148     /**
149      * If the type is a {@link at.rseiler.spbee.core.annotation.ResultSet} then multiple result-sets are declared.
150      * <p>
151      * Generates:
152      * <pre>
153      * [ declareParameter(new SqlReturnResultSet("#result-set-*", new *Mapper())); ]*
154      * </pre>
155      */
156     private void multipleResultSets(JCodeModel model, JBlock body, String type) {
157         List<ResultSetVariable> variables = resultSetsMap.get(type).getResultSetVariables();
158 
159         for (int i = 0; i < variables.size(); i++) {
160             ResultSetVariable variable = variables.get(i);
161             JInvocation newMapper = JExpr._new(model.directClass(variable.getRowMapper()));
162             JInvocation sqlReturnResultSet = JExpr._new(model.directClass(SPRING_SQL_RETURN_RESULT_SET)).arg("#result-set-" + i).arg(newMapper);
163             body.add(JExpr.invoke("declareParameter").arg(sqlReturnResultSet));
164         }
165     }
166 
167     /**
168      * If it's a basic type then only one result-set is declared.
169      * <p>
170      * Generates:
171      * <pre>
172      * declareParameter(new SqlReturnResultSet("#result-set-0", new *Mapper()));
173      * </pre>
174      */
175     private void singleResultSet(JCodeModel model, JBlock body, StoredProcedureMethod storedProcedureMethod) {
176         JInvocation newMapper = JExpr._new(model.directClass(storedProcedureMethod.getQualifiedRowMapperClass()));
177         JInvocation sqlReturnResultSet = JExpr._new(model.directClass(SPRING_SQL_RETURN_RESULT_SET)).arg("#result-set-0").arg(newMapper);
178         body.add(JExpr.invoke("declareParameter").arg(sqlReturnResultSet));
179     }
180 
181     /**
182      * Generates:
183      * <p>
184      * <pre>
185      * public Map<String, Object> execute( [ * ]* ) { return super.execute( [ * ]* ) }
186      * </pre>
187      */
188     private void addExecuteMethod(JCodeModel model, JDefinedClass aClass, StoredProcedureMethod storedProcedureMethod) {
189         JMethod method = aClass.method(JMod.PUBLIC, CodeModelUtil.getMapStringObject(model), "execute");
190         JInvocation superExecute = JExpr._super().invoke("execute");
191         JBlock block = method.body();
192         boolean hasArrayType = storedProcedureMethod.getArguments().stream().anyMatch(this::isArrayType);
193         JVar connection = null;
194 
195         if (hasArrayType) {
196             connection = block.decl(model.directClass(Connection.class.getCanonicalName()), "conn");
197             block.assign(connection, JExpr._null());
198 
199             JTryBlock tryBlock = method.body()._try();
200             // catch - rethrow exception
201             JCatchBlock catchBlock = tryBlock._catch(model.directClass(SQLException.class.getCanonicalName()));
202             JInvocation exception = JExpr._new(model.directClass("org.springframework.jdbc.UncategorizedSQLException"))
203                     .arg(JExpr._this().invoke("getClass").invoke("getCanonicalName"))
204                     .arg(JExpr._this().invoke("getSql"))
205                     .arg(catchBlock.param("e"));
206             catchBlock.body()._throw(exception);
207             // finally - closes the connection
208             JTryBlock connCloseTryBlock = tryBlock._finally()._if(connection.ne(JExpr._null()))._then()._try();
209             connCloseTryBlock.body().add(connection.invoke("close"));
210             JCatchBlock closeConnCatchBlock = connCloseTryBlock._catch(model.directClass(SQLException.class.getCanonicalName()));
211             JInvocation runtimeException = JExpr._new(model.directClass(RuntimeException.class.getCanonicalName()))
212                     .arg(JExpr.lit("Failed to close connection"))
213                     .arg(closeConnCatchBlock.param("e"));
214             closeConnCatchBlock.body()._throw(runtimeException);
215 
216 
217             block = tryBlock.body();
218             block.assign(connection, JExpr.invoke("getJdbcTemplate").invoke("getDataSource").invoke("getConnection"));
219         }
220 
221         for (Variable variable : storedProcedureMethod.getArguments()) {
222             JVar param = method.param(model.directClass(variable.getTypeInfo().asString()), variable.getName());
223 
224             if (hasArrayType && isArrayType(variable)) {
225                 superExecute.arg(connection.invoke("createArrayOf").arg(getArrayType(variable.getTypeInfo().asString())).arg(param));
226             } else {
227                 superExecute.arg(param);
228             }
229         }
230 
231         block._return(superExecute);
232     }
233 
234     private boolean isArrayType(Variable variable) {
235         return variable.getTypeInfo().asString().contains("[]");
236     }
237 
238     /**
239      * Retrieves the most fitting org.springframework.jdbc.core.SqlParameter type based on the type
240      *
241      * @param type the type
242      * @return the parameter name
243      */
244     public static String getSqlParameter(String type) {
245         switch (type) {
246             case "boolean":
247             case "java.lang.Boolean":
248                 return "BOOLEAN";
249             case "byte":
250             case "java.lang.Byte":
251                 return "TINYINT";
252             case "short":
253             case "java.lang.Short":
254                 return "SMALLINT";
255             case "int":
256             case "java.lang.Integer":
257                 return "INTEGER";
258             case "long":
259             case "java.lang.Long":
260                 return "BIGINT";
261             case "float":
262             case "java.lang.Float":
263                 return "FLOAT";
264             case "double":
265             case "java.lang.Double":
266                 return "DOUBLE";
267             case "byte[]":
268             case "java.lang.Byte[]":
269                 return "BINARY";
270             case "java.lang.String":
271                 return "VARCHAR";
272             case "java.sql.Date":
273             case "java.util.Date":
274                 return "DATE";
275             case "java.math.BigDecimal":
276                 return "DOUBLE";
277             default:
278                 if (type.contains("[]")) {
279                     return "ARRAY";
280                 }
281 
282                 throw new RuntimeException("Unknown SqlType: " + type);
283         }
284     }
285 
286     private String getArrayType(String type) {
287         switch (type) {
288             case "java.lang.Boolean[]":
289                 return "bool";
290             case "java.lang.Character[]":
291                 return "char";
292             case "java.lang.Byte[]":
293                 return "smallint";
294             case "java.lang.Short[]":
295                 return "smallint";
296             case "java.lang.Integer[]":
297                 return "int";
298             case "java.lang.Long[]":
299                 return "bigint";
300             case "java.lang.Float[]":
301                 return "float";
302             case "java.lang.Decimal[]":
303             case "java.math.BigDecimal[]":
304                 return "numeric";
305             case "java.lang.String[]":
306                 return "varchar";
307             case "java.sql.Date[]":
308             case "java.util.Date[]":
309                 return "date";
310         }
311 
312         throw new RuntimeException("Unknown array type: " + type);
313     }
314 
315 }