11package org .ohnlp .backbone .io .jdbc ;
22
33import com .mchange .v2 .c3p0 .ComboPooledDataSource ;
4+ import org .apache .beam .sdk .coders .Coder ;
5+ import org .apache .beam .sdk .coders .CollectionCoder ;
46import org .apache .beam .sdk .coders .RowCoder ;
7+ import org .apache .beam .sdk .coders .StringUtf8Coder ;
58import org .apache .beam .sdk .io .jdbc .JdbcIO ;
69import org .apache .beam .sdk .io .jdbc .SchemaUtilProxy ;
710import org .apache .beam .sdk .schemas .Schema ;
8- import org .apache .beam .sdk .transforms .Create ;
11+ import org .apache .beam .sdk .transforms .*;
12+ import org .apache .beam .sdk .values .KV ;
913import org .apache .beam .sdk .values .PBegin ;
1014import org .apache .beam .sdk .values .PCollection ;
1115import org .apache .beam .sdk .values .Row ;
1418import org .ohnlp .backbone .api .components .ExtractToOne ;
1519import org .ohnlp .backbone .api .exceptions .ComponentInitializationException ;
1620
21+ import java .beans .PropertyVetoException ;
1722import java .sql .*;
1823import java .util .*;
24+ import java .util .concurrent .ThreadLocalRandom ;
1925
2026/**
2127 * Performs data extraction using a JDBC connector
@@ -62,7 +68,8 @@ public class JDBCExtract extends ExtractToOne {
6268 private int batchSize = 1000 ;
6369 @ ConfigurationProperty (
6470 path = "identifier_col" ,
65- desc = "An ID column returned as part of the query that can be used to identify and partition records." ,
71+ desc = "An ID column returned as part of the query that can be used to identify and partition records, " +
72+ "multiple columns can be entered in column-delimited order" ,
6673 required = false
6774 )
6875 private String identifierCol = null ;
@@ -81,6 +88,8 @@ public class JDBCExtract extends ExtractToOne {
8188 private String viewName ;
8289 private String orderedQuery ;
8390 private Schema schema ;
91+ private String keyValueQuery ;
92+ private Schema keyValueSchema ;
8493
8594 /**
8695 * Initializes a Beam JdbcIO Provider
@@ -125,39 +134,46 @@ public void init() throws ComponentInitializationException {
125134 // We will first preflight with a query that counts the number of records so that we can get number
126135 // of batches
127136 String runId = UUID .randomUUID ().toString ().replaceAll ("-" , "_" );
128- //noinspection SqlResolve
129- String countQuery = "SELECT COUNT(*) FROM (" + query + ") bckbone_preflight_query_" + runId ;
130137 this .viewName = "backbone_jdbcextract_" + runId ;
131- // Find appropriate columns to order by so that pagination results are consistent
132- this .orderByCols = findPaginationOrderingColumns (this .query );
133- // Get record count so that we know how many batches are going to be needed
134- try (Connection conn = initializationDS .getConnection ()) {
135- ResultSet rs = conn .createStatement ().executeQuery (countQuery );
136- rs .next ();
137- int resultCount = rs .getInt (1 );
138- this .numBatches = Math .round (Math .ceil ((double ) resultCount / this .batchSize ));
139- }
140- // Normally I would say use Strings.join for the below, but this was causing cross-jvm issues
141- // so we use the more portable stringbuilder instead...
142- StringBuilder sB = new StringBuilder ();
143- boolean flag = false ;
144- for (String s : this .orderByCols ) {
145- if (flag ) {
146- sB .append (", " );
138+ if (this .identifierCol == null ) {
139+ // No identifier column provided so we can only do a full-form sort.
140+ // TODO find a better solution for this
141+ //noinspection SqlResolve
142+ String countQuery = "SELECT COUNT(*) FROM (" + query + ") bckbone_preflight_query_" + runId ;
143+ // Find appropriate columns to order by so that pagination results are consistent
144+ this .orderByCols = findPaginationOrderingColumns (this .query );
145+ // Get record count so that we know how many batches are going to be needed
146+ try (Connection conn = initializationDS .getConnection ()) {
147+ ResultSet rs = conn .createStatement ().executeQuery (countQuery );
148+ rs .next ();
149+ int resultCount = rs .getInt (1 );
150+ this .numBatches = Math .round (Math .ceil ((double ) resultCount / this .batchSize ));
147151 }
148- sB .append (s );
149- flag = true ;
150- }
151- this .orderedQuery = "SELECT * FROM (" + this .query + ") " + this .viewName
152- + " ORDER BY " + sB .toString () + " " ;
153- // Now we have to add the offset/fetch in the dialect local format..
154- // Specifically, postgres and MySQL are special in that they do not conform to the
155- // SQL:2011 standard syntax
156- if (driver .equals ("org.postgresql.Driver" ) || driver .equals ("com.mysql.jdbc.Driver" )
157- || driver .equals ("com.mysql.cj.jdbc.Driver" ) || driver .equals ("org.sqlite.JDBC" )) {
158- this .orderedQuery += "LIMIT " + batchSize + " OFFSET ?" ;
159- } else { // This is the SQL:2011 standard definition of an offset...fetch syntax
160- this .orderedQuery += "OFFSET ? ROWS FETCH NEXT " + batchSize + " ROWS ONLY" ;
152+ // Normally I would say use Strings.join for the below, but this was causing cross-jvm issues
153+ // so we use the more portable stringbuilder instead...
154+ StringBuilder sB = new StringBuilder ();
155+ boolean flag = false ;
156+ for (String s : this .orderByCols ) {
157+ if (flag ) {
158+ sB .append (", " );
159+ }
160+ sB .append (s );
161+ flag = true ;
162+ }
163+ this .orderedQuery = "SELECT * FROM (" + this .query + ") " + this .viewName
164+ + " ORDER BY " + sB .toString () + " " ;
165+ // Now we have to add the offset/fetch in the dialect local format..
166+ // Specifically, postgres and MySQL are special in that they do not conform to the
167+ // SQL:2011 standard syntax
168+ if (driver .equals ("org.postgresql.Driver" ) || driver .equals ("com.mysql.jdbc.Driver" )
169+ || driver .equals ("com.mysql.cj.jdbc.Driver" ) || driver .equals ("org.sqlite.JDBC" )) {
170+ this .orderedQuery += "LIMIT " + batchSize + " OFFSET ?" ;
171+ } else { // This is the SQL:2011 standard definition of an offset...fetch syntax
172+ this .orderedQuery += "OFFSET ? ROWS FETCH NEXT " + batchSize + " ROWS ONLY" ;
173+ }
174+ } else {
175+ this .keyValueQuery = "SELECT DISTINCT " + identifierCol + " FROM (" + query + ") " + viewName ;
176+ this .keyValueSchema = getIdentifierColumnsSchema ();
161177 }
162178 } catch (Throwable t ) {
163179 throw new ComponentInitializationException (t );
@@ -184,24 +200,90 @@ public Schema calculateOutputSchema() {
184200
185201 @ Override
186202 public PCollection <Row > begin (PBegin input ) {
187- List <Integer > offsets = new ArrayList <>();
188- for (int i = 0 ; i < numBatches ; i ++) {
189- offsets .add (i * batchSize ); // Create a sequence of batches at the appropriate offset
203+ if (this .identifierCol == null ) {
204+ List <Integer > offsets = new ArrayList <>();
205+ for (int i = 0 ; i < numBatches ; i ++) {
206+ offsets .add (i * batchSize ); // Create a sequence of batches at the appropriate offset
207+ }
208+ return input .apply (
209+ "Read from JDBC" ,
210+ JdbcIO .<Row >read ()
211+ .withDataSourceConfiguration (datasourceConfig )
212+ .withQuery ("SELECT * FROM (" + this .query + ") " + this .viewName )
213+ .withRowMapper (this .driver .equals ("org.sqlite.JDBC" ) ?
214+ new SchemaUtilProxy .SQLiteBeamRowMapperProxy (schema ) :
215+ new SchemaUtilProxy .BeamRowMapperProxy (schema ))
216+ .withCoder (RowCoder .of (schema ))
217+ .withOutputParallelization (false )
218+ ).apply ("JDBC Break Fusion" , Repartition .of ()).setRowSchema (schema );
219+ } else {
220+ StringBuilder queryByKey = new StringBuilder ("SELECT * FROM (" + this .query + ") " + this .viewName + " WHERE " );
221+ boolean appendAnd = false ;
222+ for (String identifierCol : this .identifierCol .split ("," )) {
223+ if (appendAnd ) {
224+ queryByKey .append ("AND " );
225+ } else {
226+ appendAnd = true ;
227+ }
228+ queryByKey .append (identifierCol ).append (" = ? " );
229+ }
230+ JdbcIO .RowMapper <Row > rowmapper = this .driver .equals ("org.sqlite.JDBC" ) ?
231+ new SchemaUtilProxy .SQLiteBeamRowMapperProxy (keyValueSchema ) :
232+ new SchemaUtilProxy .BeamRowMapperProxy (keyValueSchema );
233+ String [] cols = this .identifierCol .split ("," );
234+ return input .apply ("JDBC Init" , Create .of (keyValueQuery ))
235+ .apply ("JDBC Preflight for Query Keys" , ParDo .of (
236+ new DoFn <String , Row >() {
237+ private ComboPooledDataSource ds ;
238+
239+ @ Setup
240+ public void init () throws PropertyVetoException {
241+ this .ds = new ComboPooledDataSource (); // Set separate
242+ ds .setDriverClass (driver );
243+ ds .setJdbcUrl (url );
244+ ds .setUser (user );
245+ ds .setPassword (password );
246+ ds .setMaxIdleTime (idleTimeout );
247+ }
248+
249+ @ ProcessElement
250+ public void process (ProcessContext pc ) throws Exception {
251+ try (Connection conn = ds .getConnection ()) {
252+ ResultSet rs = conn .createStatement ().executeQuery (pc .element ());
253+ while (rs .next ()) {
254+ pc .output (rowmapper .mapRow (rs ));
255+ }
256+
257+ } catch (SQLException e ) {
258+ throw new RuntimeException (e );
259+ }
260+ }
261+ }
262+ )).setRowSchema (this .keyValueSchema )
263+ .apply ("JDBC Break Fusion" , Repartition .of ()) // Break fusion here due to large fanout/preflight being on single thread
264+ .apply ("JDBC Read" , JdbcIO .<Row , Row >readAll ()
265+ .withDataSourceConfiguration (datasourceConfig )
266+ .withQuery (queryByKey .toString ())
267+ .withRowMapper (this .driver .equals ("org.sqlite.JDBC" ) ?
268+ new SchemaUtilProxy .SQLiteBeamRowMapperProxy (schema ) :
269+ new SchemaUtilProxy .BeamRowMapperProxy (schema ))
270+ .withParameterSetter ((JdbcIO .PreparedStatementSetter <Row >) (element , preparedStatement ) -> {
271+ for (int i = 0 ; i < cols .length ; i ++) {
272+ preparedStatement .setObject (i + 1 , element .getValue (cols [i ]));
273+ }
274+ })
275+ .withCoder (RowCoder .of (schema ))
276+ .withOutputParallelization (false ));
277+ }
278+ }
279+
280+ private Schema getIdentifierColumnsSchema () throws ComponentInitializationException {
281+ try (Connection conn = this .initializationDS .getConnection ()) {
282+ ResultSetMetaData queryMeta = conn .prepareStatement ("SELECT " + this .initializationDS + " FROM (" + this .query + ") " + this .viewName ).getMetaData ();
283+ return SchemaUtilProxy .toBeamSchema (this .driver , queryMeta );
284+ } catch (SQLException e ) {
285+ throw new ComponentInitializationException (e );
190286 }
191- return input .apply ("JDBC Preflight" , Create .of (offsets )) // First create partitions # = to num batches
192- .apply ("JDBC Read" , // Now actually do the read, the readall function will execute one query per input partition
193- JdbcIO .<Integer , Row >readAll ()
194- .withDataSourceConfiguration (datasourceConfig )
195- .withQuery (this .orderedQuery )
196- .withRowMapper (this .driver .equals ("org.sqlite.JDBC" ) ?
197- new SchemaUtilProxy .SQLiteBeamRowMapperProxy (schema ) :
198- new SchemaUtilProxy .BeamRowMapperProxy (schema ))
199- .withParameterSetter ((JdbcIO .PreparedStatementSetter <Integer >) (element , preparedStatement ) -> {
200- preparedStatement .setInt (1 , element ); // Replace
201- })
202- .withCoder (RowCoder .of (schema ))
203- .withOutputParallelization (false )
204- );
205287 }
206288
207289 private String [] findPaginationOrderingColumns (String query ) throws ComponentInitializationException {
@@ -243,4 +325,32 @@ private String[] findPaginationOrderingColumns(String query) throws ComponentIni
243325 }
244326
245327 }
328+
329+ private static class Repartition <T > extends PTransform <PCollection <T >, PCollection <T >> {
330+
331+ private Repartition () {}
332+
333+ public static <T > Repartition <T > of () {
334+ return new Repartition <>();
335+ }
336+
337+ @ Override
338+ public PCollection <T > expand (PCollection <T > input ) {
339+ return input
340+ .apply (ParDo .of (new DoFn <T , KV <Integer , T >>() {
341+ @ ProcessElement
342+ public void process (ProcessContext pc ) {
343+ pc .output (KV .of (ThreadLocalRandom .current ().nextInt (), pc .element ()));
344+ }
345+ }))
346+ .apply (GroupByKey .<Integer , T >create ())
347+ .apply (ParDo .of (new DoFn <KV <Integer , Iterable <T >>, T >() {
348+ @ ProcessElement
349+ public void process (ProcessContext pc ) {
350+ for (T element : pc .element ().getValue ()) {
351+ pc .output (element );
352+ } }
353+ }));
354+ }
355+ }
246356}
0 commit comments