001package ca.uhn.fhir.jpa.migrate.taskdef;
002
003/*-
004 * #%L
005 * HAPI FHIR Server - SQL Migration
006 * %%
007 * Copyright (C) 2014 - 2022 Smile CDR, Inc.
008 * %%
009 * Licensed under the Apache License, Version 2.0 (the "License");
010 * you may not use this file except in compliance with the License.
011 * You may obtain a copy of the License at
012 *
013 *      http://www.apache.org/licenses/LICENSE-2.0
014 *
015 * Unless required by applicable law or agreed to in writing, software
016 * distributed under the License is distributed on an "AS IS" BASIS,
017 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
018 * See the License for the specific language governing permissions and
019 * limitations under the License.
020 * #L%
021 */
022
023import ca.uhn.fhir.jpa.migrate.JdbcUtils;
024import ca.uhn.fhir.util.VersionEnum;
025import org.apache.commons.lang3.StringUtils;
026import org.apache.commons.lang3.builder.EqualsBuilder;
027import org.apache.commons.lang3.builder.HashCodeBuilder;
028import org.slf4j.Logger;
029import org.slf4j.LoggerFactory;
030import org.springframework.jdbc.core.ColumnMapRowMapper;
031import org.springframework.jdbc.core.JdbcTemplate;
032
033import java.sql.SQLException;
034import java.util.ArrayList;
035import java.util.List;
036import java.util.Map;
037import java.util.Set;
038import java.util.function.Consumer;
039
040public class ArbitrarySqlTask extends BaseTask {
041
042        private static final Logger ourLog = LoggerFactory.getLogger(ArbitrarySqlTask.class);
043        private final String myDescription;
044        private final String myTableName;
045        private List<Task> myTask = new ArrayList<>();
046        private int myBatchSize = 1000;
047        private String myExecuteOnlyIfTableExists;
048        private List<TableAndColumn> myConditionalOnExistenceOf = new ArrayList<>();
049
050        public ArbitrarySqlTask(VersionEnum theRelease, String theVersion, String theTableName, String theDescription) {
051                super(theRelease.toString(), theVersion);
052                myTableName = theTableName;
053                myDescription = theDescription;
054        }
055
056        public void addQuery(String theSql, QueryModeEnum theMode, Consumer<Map<String, Object>> theConsumer) {
057                myTask.add(new QueryTask(theSql, theMode, theConsumer));
058        }
059
060        @Override
061        public void validate() {
062                // nothing
063        }
064
065        @Override
066        public void doExecute() throws SQLException {
067                logInfo(ourLog, "Starting: {}", myDescription);
068
069                if (StringUtils.isNotBlank(myExecuteOnlyIfTableExists)) {
070                        Set<String> tableNames = JdbcUtils.getTableNames(getConnectionProperties());
071                        if (!tableNames.contains(myExecuteOnlyIfTableExists.toUpperCase())) {
072                                logInfo(ourLog, "Table {} does not exist - No action performed", myExecuteOnlyIfTableExists);
073                                return;
074                        }
075                }
076
077                for (TableAndColumn next : myConditionalOnExistenceOf) {
078                        JdbcUtils.ColumnType columnType = JdbcUtils.getColumnType(getConnectionProperties(), next.getTable(), next.getColumn());
079                        if (columnType == null) {
080                                logInfo(ourLog, "Table {} does not have column {} - No action performed", next.getTable(), next.getColumn());
081                                return;
082                        }
083                }
084
085                for (Task next : myTask) {
086                        next.execute();
087                }
088
089        }
090
091        public void setBatchSize(int theBatchSize) {
092                myBatchSize = theBatchSize;
093        }
094
095        public void setExecuteOnlyIfTableExists(String theExecuteOnlyIfTableExists) {
096                myExecuteOnlyIfTableExists = theExecuteOnlyIfTableExists;
097        }
098
099        /**
100         * This task will only execute if the following column exists
101         */
102        public void addExecuteOnlyIfColumnExists(String theTableName, String theColumnName) {
103                myConditionalOnExistenceOf.add(new TableAndColumn(theTableName, theColumnName));
104        }
105
106        @Override
107        protected void generateEquals(EqualsBuilder theBuilder, BaseTask theOtherObject) {
108                ArbitrarySqlTask otherObject = (ArbitrarySqlTask) theOtherObject;
109                theBuilder.append(myTableName, otherObject.myTableName);
110        }
111
112        @Override
113        protected void generateHashCode(HashCodeBuilder theBuilder) {
114                theBuilder.append(myTableName);
115        }
116
117        public enum QueryModeEnum {
118                BATCH_UNTIL_NO_MORE
119        }
120
121        private static class TableAndColumn {
122                private final String myTable;
123                private final String myColumn;
124
125                private TableAndColumn(String theTable, String theColumn) {
126                        myTable = theTable;
127                        myColumn = theColumn;
128                }
129
130                public String getTable() {
131                        return myTable;
132                }
133
134                public String getColumn() {
135                        return myColumn;
136                }
137        }
138
139        private abstract class Task {
140                public abstract void execute();
141        }
142
143        private class QueryTask extends Task {
144                private final String mySql;
145                private final Consumer<Map<String, Object>> myConsumer;
146
147                public QueryTask(String theSql, QueryModeEnum theMode, Consumer<Map<String, Object>> theConsumer) {
148                        mySql = theSql;
149                        myConsumer = theConsumer;
150                        setDescription("Execute raw sql");
151                }
152
153
154                @Override
155                public void execute() {
156                        if (isDryRun()) {
157                                return;
158                        }
159
160                        List<Map<String, Object>> rows;
161                        do {
162                                logInfo(ourLog, "Querying for up to {} rows", myBatchSize);
163                                rows = getTxTemplate().execute(t -> {
164                                        JdbcTemplate jdbcTemplate = newJdbcTemplate();
165                                        jdbcTemplate.setMaxRows(myBatchSize);
166                                        return jdbcTemplate.query(mySql, new ColumnMapRowMapper());
167                                });
168
169                                logInfo(ourLog, "Processing {} rows", rows.size());
170                                List<Map<String, Object>> finalRows = rows;
171                                getTxTemplate().execute(t -> {
172                                        for (Map<String, Object> nextRow : finalRows) {
173                                                myConsumer.accept(nextRow);
174                                        }
175                                        return null;
176                                });
177                        } while (rows.size() > 0);
178                }
179        }
180}