diff --git a/database-commons/src/main/java/io/cdap/plugin/db/CommonSchemaReader.java b/database-commons/src/main/java/io/cdap/plugin/db/CommonSchemaReader.java index 28c56db8c..b2348ee5e 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/CommonSchemaReader.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/CommonSchemaReader.java @@ -57,6 +57,11 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti metadata.isSigned(index), true); } + public Schema getSchema(String typeName, int sqlType, int precision, int scale, String columnName, + boolean isSigned) throws SQLException { + return DBUtils.getSchema(typeName, sqlType, precision, scale, columnName, isSigned, true); + } + @Override public boolean shouldIgnoreColumn(ResultSetMetaData metadata, int index) throws SQLException { return false; diff --git a/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java b/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java index b187c7670..2eaa32009 100644 --- a/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java +++ b/database-commons/src/main/java/io/cdap/plugin/db/DBRecord.java @@ -188,7 +188,14 @@ protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordB protected void setField(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field, int columnIndex, int sqlType, int sqlPrecision, int sqlScale) throws SQLException { Object o = DBUtils.transformValue(sqlType, sqlPrecision, sqlScale, resultSet, columnIndex); - if (o instanceof Date) { + setFieldValue(recordBuilder, field, o); + } + + protected void setFieldValue(StructuredRecord.Builder recordBuilder, Schema.Field field, Object o) + throws SQLException { + if (o == null) { + recordBuilder.set(field.getName(), null); + } else if (o instanceof Date) { recordBuilder.setDate(field.getName(), ((Date) o).toLocalDate()); } else if (o instanceof Time) { recordBuilder.setTime(field.getName(), ((Time) o).toLocalTime()); diff --git a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java index 44131a01b..34af5a30e 100644 --- a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java +++ b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceDBRecord.java @@ -29,12 +29,17 @@ import java.io.InputStream; import java.lang.reflect.InvocationTargetException; import java.math.BigDecimal; +import java.math.RoundingMode; import java.nio.ByteBuffer; +import java.sql.Blob; +import java.sql.Clob; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Struct; import java.sql.Timestamp; import java.sql.Types; import java.time.LocalDateTime; @@ -43,6 +48,8 @@ import java.time.ZoneOffset; import java.time.ZonedDateTime; import java.util.List; +import java.util.Map; +import java.util.TreeMap; /** * Oracle Source implementation {@link org.apache.hadoop.mapreduce.lib.db.DBWritable} and @@ -106,13 +113,85 @@ record = recordBuilder.build(); @Override protected void handleField(ResultSet resultSet, StructuredRecord.Builder recordBuilder, Schema.Field field, int columnIndex, int sqlType, int sqlPrecision, int sqlScale) throws SQLException { - if (OracleSourceSchemaReader.ORACLE_TYPES.contains(sqlType) || sqlType == Types.NCLOB) { + if (OracleSourceSchemaReader.ORACLE_TYPES.contains(sqlType) || sqlType == Types.NCLOB || sqlType == Types.STRUCT) { handleOracleSpecificType(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); } else { setField(resultSet, recordBuilder, field, columnIndex, sqlType, sqlPrecision, sqlScale); } } + @Override + protected void setFieldValue(StructuredRecord.Builder recordBuilder, Schema.Field field, Object attrValue) + throws SQLException { + if (attrValue == null) { + recordBuilder.set(field.getName(), null); + return; + } + + Schema fieldSchema = field.getSchema().isNullable() ? field.getSchema().getNonNullable() + : field.getSchema(); + String attrClassName = attrValue.getClass().getName(); + if (attrValue instanceof Struct) { + recordBuilder.set(field.getName(), convertStructToRecord((Struct) attrValue, fieldSchema)); + return; + } + if (attrValue instanceof Clob) { + Clob clob = (Clob) attrValue; + recordBuilder.set(field.getName(), clob.getSubString(1, (int) clob.length())); + return; + } + if (attrValue instanceof Blob) { + Blob blob = (Blob) attrValue; + recordBuilder.set(field.getName(), blob.getBytes(1, (int) blob.length())); + return; + } + if (attrValue instanceof SQLXML) { + recordBuilder.set(field.getName(), ((SQLXML) attrValue).getString()); + return; + } + if ("oracle.sql.INTERVALDS".equals(attrClassName) || "oracle.sql.INTERVALYM".equals(attrClassName)) { + recordBuilder.set(field.getName(), attrValue.toString()); + return; + } + if (attrValue instanceof BigDecimal) { + populateDecimalValue(attrValue, fieldSchema, recordBuilder, field); + return; + } + + if (attrValue instanceof Timestamp) { + Timestamp timestamp = (Timestamp) attrValue; + if (Schema.LogicalType.DATETIME.equals(fieldSchema.getLogicalType())) { + recordBuilder.setDateTime(field.getName(), timestamp.toLocalDateTime()); + } else { + super.setFieldValue(recordBuilder, field, attrValue); + } + return; + } + if (attrValue instanceof OffsetDateTime) { + ZonedDateTime zonedDateTime = ((OffsetDateTime) attrValue).atZoneSameInstant(ZoneId.of("UTC")); + + if (fieldSchema.getLogicalType() != null && + (Schema.LogicalType.TIMESTAMP_MICROS.equals(fieldSchema.getLogicalType()) || + Schema.LogicalType.TIMESTAMP_MILLIS.equals(fieldSchema.getLogicalType()))) { + recordBuilder.setTimestamp(field.getName(), zonedDateTime); + } else { + recordBuilder.set(field.getName(), zonedDateTime.toString()); + } + return; + } + + ClassLoader oracleLoader = attrValue.getClass().getClassLoader(); + try { + if (oracleLoader != null && oracleLoader.loadClass("oracle.jdbc.OracleBfile").isInstance(attrValue)) { + recordBuilder.set(field.getName(), getBfileBytes(attrValue, field.getName())); + return; + } + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + super.setFieldValue(recordBuilder, field, attrValue); + } + @Override protected void writeNonNullToDB(PreparedStatement stmt, Schema fieldSchema, String fieldName, int fieldIndex) throws SQLException { @@ -169,6 +248,34 @@ protected void writeNonNullToDB(PreparedStatement stmt, Schema fieldSchema, } } + private void populateDecimalValue(Object attrValue, Schema fieldSchema, + StructuredRecord.Builder recordBuilder, Schema.Field field) { + BigDecimal bigDecimal = (BigDecimal) attrValue; + if (Schema.LogicalType.DECIMAL.equals(fieldSchema.getLogicalType())) { + recordBuilder.setDecimal(field.getName(), bigDecimal.setScale(fieldSchema.getScale(), RoundingMode.HALF_UP)); + return; + } + switch (fieldSchema.getType()) { + case DOUBLE: + recordBuilder.set(field.getName(), bigDecimal.doubleValue()); + break; + case FLOAT: + recordBuilder.set(field.getName(), bigDecimal.floatValue()); + break; + case INT: + recordBuilder.set(field.getName(), bigDecimal.intValue()); + break; + case LONG: + recordBuilder.set(field.getName(), bigDecimal.longValue()); + break; + case STRING: + recordBuilder.set(field.getName(), bigDecimal.toString()); + break; + default: + recordBuilder.set(field.getName(), bigDecimal); + } + } + /** * Creates an instance of 'oracle.sql.TIMESTAMPTZ' which corresponds to the specified timestamp with time zone string. * @param connection sql connection. @@ -232,11 +339,15 @@ private Object createOracleTimestamp(Connection connection, String timestampStri */ private byte[] getBfileBytes(ResultSet resultSet, String columnName) throws SQLException { Object bfile = resultSet.getObject(columnName); + return getBfileBytes(bfile, columnName); + } + + public byte[] getBfileBytes(Object bfile, String columnName) { if (bfile == null) { return null; } try { - ClassLoader classLoader = resultSet.getClass().getClassLoader(); + ClassLoader classLoader = bfile.getClass().getClassLoader(); Class oracleBfileClass = classLoader.loadClass("oracle.jdbc.OracleBfile"); boolean isFileExist = (boolean) oracleBfileClass.getMethod("fileExists").invoke(bfile); if (!isFileExist) { @@ -341,6 +452,12 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil case OracleSourceSchemaReader.LONG_RAW: recordBuilder.set(field.getName(), resultSet.getBytes(columnIndex)); break; + case Types.STRUCT: + Struct structValue = (Struct) resultSet.getObject(columnIndex); + if (structValue != null) { + recordBuilder.set(field.getName(), convertStructToRecord(structValue, nonNullSchema)); + } + break; case Types.DECIMAL: case Types.NUMERIC: // This is the only way to differentiate FLOAT/REAL columns from other numeric columns, that based on NUMBER. @@ -371,6 +488,57 @@ private void handleOracleSpecificType(ResultSet resultSet, StructuredRecord.Buil } } + /** + * Converts a JDBC {@link Struct} into a {@link StructuredRecord} based on the provided schema. + * + * @param struct the SQL structured type containing the source data attributes + * @param schema the target record schema defining the fields to map + * @return a populated {@code StructuredRecord} instance + * @throws SQLException if an error occurs reading the struct attributes or metadata + */ + protected StructuredRecord convertStructToRecord(Struct struct, Schema schema) + throws SQLException { + Map attributeMap = getAttributeMap(struct, schema); + StructuredRecord.Builder builder = StructuredRecord.builder(schema); + + for (Schema.Field field : schema.getFields()) { + Object attrValue = attributeMap.get(field.getName()); + setFieldValue(builder, field, attrValue); + } + return builder.build(); + } + + /** + * Extracts attributes from a {@link Struct} into a case-insensitive map indexed by column name. + * Uses reflection to extract underlying metadata (e.g., from Oracle StructDescriptor). + * + * @param struct the source SQL structured type + * @param schema the target schema used for context in error messages + * @return a case-insensitive {@code Map} linking column names to their attribute values + * @throws SQLException if metadata extraction fails or driver-specific methods are inaccessible + */ + private Map getAttributeMap(Struct struct, Schema schema) throws SQLException { + Map attributeMap = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); + Object[] attributes = struct.getAttributes(); + + try { + Object descriptor = struct.getClass().getMethod("getDescriptor").invoke(struct); + ResultSetMetaData metaData = + (ResultSetMetaData) descriptor.getClass().getMethod("getMetaData").invoke(descriptor); + for (int i = 1; i <= metaData.getColumnCount() && (i - 1) < attributes.length; i++) { + attributeMap.put(metaData.getColumnName(i), attributes[i - 1]); + } + } catch (SQLException | NoSuchMethodException e) { + throw new SQLException(String.format("Failed to retrieve attribute metadata for Oracle STRUCT schema '%s': %s", + schema.getRecordName(), e.getMessage()), e); + } catch (InvocationTargetException | IllegalAccessException e) { + throw new SQLException(String.format("Unable to retrieve attribute metadata for Oracle STRUCT schema '%s'. " + + "Ensure the Oracle JDBC driver supports JDBC StructDescriptor metadata.", + schema.getRecordName()), e); + } + return attributeMap; + } + /** * Get the scale set in Non-nullable schema associated with the schema * */ diff --git a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java index b23dfa031..43e18e5b8 100644 --- a/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java +++ b/oracle-plugin/src/main/java/io/cdap/plugin/oracle/OracleSourceSchemaReader.java @@ -22,9 +22,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Types; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Set; import javax.annotation.Nullable; @@ -45,6 +52,46 @@ public class OracleSourceSchemaReader extends CommonSchemaReader { public static final int LONG = -1; public static final int LONG_RAW = -4; + /** + * Maps Oracle string data type inside UDT to their corresponding java.sql.Types integer constants + */ + private static final Map DATA_TYPE_MAP = new HashMap<>(); + static { + DATA_TYPE_MAP.put("TIMESTAMP WITH LOCAL TZ", TIMESTAMP_LTZ); + DATA_TYPE_MAP.put("TIMESTAMP WITH TZ", TIMESTAMP_TZ); + DATA_TYPE_MAP.put("TIMESTAMP", Types.TIMESTAMP); + DATA_TYPE_MAP.put("DATE", Types.TIMESTAMP); + DATA_TYPE_MAP.put("FLOAT", Types.DOUBLE); + DATA_TYPE_MAP.put("BINARY_FLOAT", BINARY_FLOAT); + DATA_TYPE_MAP.put("REAL", Types.DOUBLE); + DATA_TYPE_MAP.put("BINARY_DOUBLE", BINARY_DOUBLE); + DATA_TYPE_MAP.put("DOUBLE", Types.DOUBLE); + DATA_TYPE_MAP.put("BFILE", BFILE); + DATA_TYPE_MAP.put("RAW", LONG_RAW); + DATA_TYPE_MAP.put("LONG RAW", LONG_RAW); + DATA_TYPE_MAP.put("LONG", LONG); + DATA_TYPE_MAP.put("INTERVAL DAY TO SECOND", INTERVAL_DS); + DATA_TYPE_MAP.put("INTERVAL YEAR TO MONTH", INTERVAL_YM); + DATA_TYPE_MAP.put("XMLTYPE", Types.SQLXML); + DATA_TYPE_MAP.put("ARRAY", Types.ARRAY); + DATA_TYPE_MAP.put("ANYDATA", Types.JAVA_OBJECT); + DATA_TYPE_MAP.put("OTHER", Types.OTHER); + DATA_TYPE_MAP.put("NUMBER", Types.NUMERIC); + DATA_TYPE_MAP.put("DECIMAL", Types.DECIMAL); + DATA_TYPE_MAP.put("INTEGER", Types.INTEGER); + DATA_TYPE_MAP.put("ROWID", Types.ROWID); + DATA_TYPE_MAP.put("UROWID", Types.ROWID); + DATA_TYPE_MAP.put("BLOB", Types.BLOB); + DATA_TYPE_MAP.put("CLOB", Types.CLOB); + DATA_TYPE_MAP.put("NCLOB", Types.NCLOB); + DATA_TYPE_MAP.put("VARCHAR2", Types.VARCHAR); + DATA_TYPE_MAP.put("VARCHAR", Types.VARCHAR); + DATA_TYPE_MAP.put("CHAR", Types.CHAR); + DATA_TYPE_MAP.put("CHAR2", Types.CHAR); + DATA_TYPE_MAP.put("NCHAR", Types.NCHAR); + DATA_TYPE_MAP.put("NVARCHAR2", Types.NVARCHAR); + } + /** * Logger instance for Oracle Schema reader. */ @@ -71,6 +118,7 @@ public class OracleSourceSchemaReader extends CommonSchemaReader { private final Boolean isPrecisionlessNumAsDecimal; private final Boolean isTimestampLtzFieldTimestamp; private final Boolean isXmlTypeEnabled; + private Connection connection; public OracleSourceSchemaReader() { this(null, false, false, false, false); @@ -88,6 +136,18 @@ public OracleSourceSchemaReader(@Nullable String sessionID, boolean isTimestampO @Override public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLException { int sqlType = metadata.getColumnType(index); + String owner = (metadata.getColumnTypeName(index) != null + && metadata.getColumnTypeName(index).contains(".")) ? metadata.getColumnTypeName(index) + .substring(0, metadata.getColumnTypeName(index).lastIndexOf('.')) : null; + + return getSchemaMapping(sqlType, metadata.getColumnClassName(index), metadata.getPrecision(index), + metadata.getScale(index), metadata.getColumnName(index), metadata.getColumnTypeName(index), + metadata.isSigned(index), owner, 0); + } + + public Schema getSchemaMapping(int sqlType, String columnClassName, int columnPrecision, + int columnScale, String columnName, String columnTypeName, + boolean isSigned, String owner, int nestingLevel) throws SQLException { switch (sqlType) { case TIMESTAMP_TZ: @@ -95,7 +155,8 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti case TIMESTAMP_LTZ: return getTimestampLtzSchema(); case Types.TIMESTAMP: - return isTimestampOldBehavior ? super.getSchema(metadata, index) : Schema.of(Schema.LogicalType.DATETIME); + return isTimestampOldBehavior ? super.getSchema(columnTypeName, sqlType, + columnPrecision, columnScale, columnName, isSigned) : Schema.of(Schema.LogicalType.DATETIME); case BINARY_FLOAT: return Schema.of(Schema.Type.FLOAT); case BINARY_DOUBLE: @@ -109,15 +170,16 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti return Schema.of(Schema.Type.STRING); case Types.SQLXML: // Enabling XML type support for DTS connectors only as it is not in working state in CDAP plugin. - return isXmlTypeEnabled ? Schema.of(Schema.Type.STRING) : super.getSchema(metadata, index); + return isXmlTypeEnabled ? Schema.of(Schema.Type.STRING) : super.getSchema(columnTypeName, + sqlType, columnPrecision, columnScale, columnName, isSigned); case Types.NUMERIC: case Types.DECIMAL: // FLOAT and REAL are returned as java.sql.Types.NUMERIC but with value that is a java.lang.Double - if (Double.class.getTypeName().equals(metadata.getColumnClassName(index))) { + if (Double.class.getTypeName().equals(columnClassName)) { return Schema.of(Schema.Type.DOUBLE); } else { - int precision = metadata.getPrecision(index); // total number of digits - int scale = metadata.getScale(index); // digits after the decimal point + int precision = columnPrecision; // total number of digits + int scale = columnScale; // digits after the decimal point // For a Number type without specified precision and scale, precision will be 0 and scale will be -127 if (precision == 0) { // reference : https://docs.oracle.com/cd/B28359_01/server.111/b28318/datatype.htm#CNCPT1832 @@ -128,23 +190,91 @@ public Schema getSchema(ResultSetMetaData metadata, int index) throws SQLExcepti + "there may be a precision loss while running the pipeline. " + "Please define an output precision and scale for field '%s' to avoid " + "precision loss.", - metadata.getColumnTypeName(index), - metadata.getColumnName(index))); + columnTypeName, columnName)); return Schema.decimalOf(precision, scale); } else { LOG.warn(String.format("Field '%s' is a %s type without precision and scale, " + "converting into STRING type to avoid any precision loss.", - metadata.getColumnName(index), - metadata.getColumnTypeName(index), - metadata.getColumnName(index))); + columnName, columnTypeName, columnName)); return Schema.of(Schema.Type.STRING); } } return Schema.decimalOf(precision, scale); } + case Types.STRUCT: + if (connection == null) { + throw new SQLException("Cannot resolve STRUCT schema without a database connection. " + + "Use getSchemaFields(ResultSet) to enable STRUCT type resolution."); + } + if (nestingLevel >= 4) { + throw new IllegalArgumentException(String.format("Cannot resolve STRUCT schema for attribute %s with " + + "nested structure depth more than 4.", columnName)); + } + return getStructSchema(connection, columnTypeName, owner, nestingLevel); default: - return super.getSchema(metadata, index); + return super.getSchema(columnTypeName, sqlType, columnPrecision, columnScale, columnName, isSigned); + } + } + + @Override + public List getSchemaFields(ResultSet resultSet) throws SQLException { + this.connection = resultSet.getStatement().getConnection(); + return super.getSchemaFields(resultSet); + } + + /** + * Builds a CDAP RECORD schema for an Oracle STRUCT type by querying the + * database metadata + * for the type's attributes. + * + * @param connection the database connection + * @param typeName the Oracle type name (e.g., "ADDRESS_TYPE") + * @param owner the Owner of the user-defined data type + * @param level the level of nesting of the user-defined data type + * @return a CDAP RECORD schema with fields corresponding to the STRUCT's + * attributes + */ + private Schema getStructSchema(Connection connection, String typeName, String owner, int level) throws SQLException { + List fields = new ArrayList<>(); + String sql = "SELECT * FROM ALL_TYPE_ATTRS WHERE TYPE_NAME = ? AND OWNER = ? ORDER BY ATTR_NO"; + + try (PreparedStatement stmt = connection.prepareStatement(sql)) { + stmt.setString(1, typeName.substring(typeName.lastIndexOf('.') + 1)); + stmt.setString(2, owner); + + try (ResultSet attrRs = stmt.executeQuery()) { + while (attrRs.next()) { + String attrName = attrRs.getString("ATTR_NAME"); + String attrTypeName = attrRs.getString("ATTR_TYPE_NAME"); + int attrSize = attrRs.getInt("PRECISION"); + int attrScale = attrRs.getInt("SCALE"); + Integer sqlType = DATA_TYPE_MAP.getOrDefault(attrTypeName, null); + + int nextLevel = level; + if (sqlType == null) { + owner = attrRs.getString("ATTR_TYPE_OWNER"); + if (owner == null || owner.isEmpty()) { + throw new SQLException(String.format("Attribute '%s' is not a primitive type, but it lacks a type " + + "owner. Therefore, it cannot be resolved as a STRUCT type. ", attrName)); + } + sqlType = Types.STRUCT; + nextLevel = level + 1; + } + Schema attrSchema = getSchemaMapping(sqlType, null, attrSize, + attrScale, attrName, attrTypeName, true, owner, nextLevel); + fields.add(Schema.Field.of(attrName, attrSchema)); + } + } } + + if (fields.isEmpty()) { + throw new SQLException(String.format( + "No attributes found for Oracle STRUCT type '%s'. " + + "Ensure the type exists and is accessible.", + typeName)); + } + + return Schema.recordOf(typeName, fields); } private Schema getTimestampLtzSchema() { diff --git a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java index 586ca1141..14cd1fe70 100644 --- a/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java +++ b/oracle-plugin/src/test/java/io/cdap/plugin/oracle/OracleSchemaReaderTest.java @@ -25,10 +25,14 @@ import org.mockito.Mockito; import org.mockito.junit.MockitoJUnitRunner; +import java.sql.Connection; +import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.sql.Statement; import java.sql.Types; +import java.util.Arrays; import java.util.List; public class OracleSchemaReaderTest { @@ -39,6 +43,12 @@ public void getSchema_timestampLTZFieldTrue_returnTimestamp() throws SQLExceptio ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); @@ -70,9 +80,12 @@ public void getSchema_timestampLTZFieldFalse_returnDatetime() throws SQLExceptio ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); - + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); Mockito.when(metadata.getColumnCount()).thenReturn(2); // -101 is for TIMESTAMP_TZ Mockito.when(metadata.getColumnType(1)).thenReturn(-101); @@ -94,15 +107,155 @@ public void getSchema_timestampLTZFieldFalse_returnDatetime() throws SQLExceptio Assert.assertEquals(expectedSchemaFields.get(1).getSchema(), actualSchemaFields.get(1).getSchema()); } + @Test + public void getSchemaFields_structType_returnRecord() throws SQLException { + OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(); + ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + PreparedStatement stmt = Mockito.mock(PreparedStatement.class); + ResultSet attrRs = Mockito.mock(ResultSet.class); + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); + Mockito.when(connection.prepareStatement(Mockito.anyString())).thenReturn(stmt); + Mockito.when(stmt.executeQuery()).thenReturn(attrRs); + Mockito.when(metadata.getColumnCount()).thenReturn(1); + Mockito.when(metadata.getColumnType(1)).thenReturn(Types.STRUCT); + Mockito.when(metadata.getColumnName(1)).thenReturn("address"); + Mockito.when(metadata.getColumnTypeName(1)).thenReturn("CS_ITN.ADDRESS_TYPE"); + Mockito.when(metadata.getSchemaName(1)).thenReturn("TEST_SCHEMA"); + + Boolean[] nextReturns = new Boolean[31]; + Arrays.fill(nextReturns, 0, 30, true); + nextReturns[30] = false; + Mockito.when(attrRs.next()).thenReturn(true, Arrays.copyOfRange(nextReturns, 1, 31)); + + Mockito.when(attrRs.getString("ATTR_NAME")).thenReturn( + "ATTR_VARCHAR2", "ATTR_VARCHAR", "ATTR_CHAR", "ATTR_CHAR2", "ATTR_NCHAR", + "ATTR_NVARCHAR2", "ATTR_CLOB", "ATTR_NCLOB", "ATTR_LONG", "ATTR_ROWID", + "ATTR_UROWID", "ATTR_NUMBER_PREC", "ATTR_NUMBER_NOPREC", "ATTR_DECIMAL", + "ATTR_INTEGER", "ATTR_FLOAT", "ATTR_REAL", "ATTR_DOUBLE", "ATTR_BINARY_FLOAT", + "ATTR_BINARY_DOUBLE", "ATTR_DATE", "ATTR_TIMESTAMP", "ATTR_TIMESTAMP_TZ", + "ATTR_TIMESTAMP_LTZ", "ATTR_INTERVAL_DS", "ATTR_INTERVAL_YM", "ATTR_BLOB", + "ATTR_RAW", "ATTR_LONG_RAW", "ATTR_BFILE" + ); + Mockito.when(attrRs.getString("ATTR_TYPE_NAME")).thenReturn( + "VARCHAR2", "VARCHAR", "CHAR", "CHAR2", "NCHAR", + "NVARCHAR2", "CLOB", "NCLOB", "LONG", "ROWID", + "UROWID", "NUMBER", "NUMBER", "DECIMAL", + "INTEGER", "FLOAT", "REAL", "DOUBLE", "BINARY_FLOAT", + "BINARY_DOUBLE", "DATE", "TIMESTAMP", "TIMESTAMP WITH TZ", + "TIMESTAMP WITH LOCAL TZ", "INTERVAL DAY TO SECOND", "INTERVAL YEAR TO MONTH", "BLOB", + "RAW", "LONG RAW", "BFILE" + ); + Mockito.when(attrRs.getInt("PRECISION")).thenReturn( + 50, 50, 10, 10, 10, + 50, 0, 0, 0, 0, + 0, 10, 0, 8, + 10, 10, 10, 10, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 100, 0, 0 + ); + Mockito.when(attrRs.getInt("SCALE")).thenReturn( + 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, + 0, 2, 0, 2, + 0, 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0 + ); + + List actualFields = schemaReader.getSchemaFields(resultSet); + + Schema.Field addressField = actualFields.get(0); + Schema addressSchema = addressField.getSchema().isNullable() + ? addressField.getSchema().getNonNullable() : addressField.getSchema(); + List structFields = addressSchema.getFields(); + Assert.assertEquals(1, actualFields.size()); + Assert.assertEquals("address", addressField.getName()); + Assert.assertEquals(Schema.Type.RECORD, addressSchema.getType()); + Assert.assertEquals("CS_ITN.ADDRESS_TYPE", addressSchema.getRecordName()); + Assert.assertEquals(30, structFields.size()); + + Assert.assertEquals("ATTR_VARCHAR2", structFields.get(0).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(0).getSchema()); + Assert.assertEquals("ATTR_VARCHAR", structFields.get(1).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(1).getSchema()); + Assert.assertEquals("ATTR_CHAR", structFields.get(2).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(2).getSchema()); + Assert.assertEquals("ATTR_CHAR2", structFields.get(3).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(3).getSchema()); + Assert.assertEquals("ATTR_NCHAR", structFields.get(4).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(4).getSchema()); + Assert.assertEquals("ATTR_NVARCHAR2", structFields.get(5).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(5).getSchema()); + Assert.assertEquals("ATTR_CLOB", structFields.get(6).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(6).getSchema()); + Assert.assertEquals("ATTR_NCLOB", structFields.get(7).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(7).getSchema()); + Assert.assertEquals("ATTR_LONG", structFields.get(8).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(8).getSchema()); + Assert.assertEquals("ATTR_ROWID", structFields.get(9).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(9).getSchema()); + Assert.assertEquals("ATTR_UROWID", structFields.get(10).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(10).getSchema()); + Assert.assertEquals("ATTR_NUMBER_PREC", structFields.get(11).getName()); + Assert.assertEquals(Schema.decimalOf(10, 2), structFields.get(11).getSchema()); + Assert.assertEquals("ATTR_NUMBER_NOPREC", structFields.get(12).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(12).getSchema()); + Assert.assertEquals("ATTR_DECIMAL", structFields.get(13).getName()); + Assert.assertEquals(Schema.decimalOf(8, 2), structFields.get(13).getSchema()); + Assert.assertEquals("ATTR_INTEGER", structFields.get(14).getName()); + Assert.assertEquals(Schema.of(Schema.Type.INT), structFields.get(14).getSchema()); + Assert.assertEquals("ATTR_FLOAT", structFields.get(15).getName()); + Assert.assertEquals(Schema.of(Schema.Type.DOUBLE), structFields.get(15).getSchema()); + Assert.assertEquals("ATTR_REAL", structFields.get(16).getName()); + Assert.assertEquals(Schema.of(Schema.Type.DOUBLE), structFields.get(16).getSchema()); + Assert.assertEquals("ATTR_DOUBLE", structFields.get(17).getName()); + Assert.assertEquals(Schema.of(Schema.Type.DOUBLE), structFields.get(17).getSchema()); + Assert.assertEquals("ATTR_BINARY_FLOAT", structFields.get(18).getName()); + Assert.assertEquals(Schema.of(Schema.Type.FLOAT), structFields.get(18).getSchema()); + Assert.assertEquals("ATTR_BINARY_DOUBLE", structFields.get(19).getName()); + Assert.assertEquals(Schema.of(Schema.Type.DOUBLE), structFields.get(19).getSchema()); + Assert.assertEquals("ATTR_DATE", structFields.get(20).getName()); + Assert.assertEquals(Schema.of(Schema.LogicalType.DATETIME), structFields.get(20).getSchema()); + Assert.assertEquals("ATTR_TIMESTAMP", structFields.get(21).getName()); + Assert.assertEquals(Schema.of(Schema.LogicalType.DATETIME), structFields.get(21).getSchema()); + Assert.assertEquals("ATTR_TIMESTAMP_TZ", structFields.get(22).getName()); + Assert.assertEquals(Schema.of(Schema.LogicalType.TIMESTAMP_MICROS), + structFields.get(22).getSchema()); + Assert.assertEquals("ATTR_TIMESTAMP_LTZ", structFields.get(23).getName()); + Assert.assertEquals(Schema.of(Schema.LogicalType.DATETIME), structFields.get(23).getSchema()); + Assert.assertEquals("ATTR_INTERVAL_DS", structFields.get(24).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(24).getSchema()); + Assert.assertEquals("ATTR_INTERVAL_YM", structFields.get(25).getName()); + Assert.assertEquals(Schema.of(Schema.Type.STRING), structFields.get(25).getSchema()); + Assert.assertEquals("ATTR_BLOB", structFields.get(26).getName()); + Assert.assertEquals(Schema.of(Schema.Type.BYTES), structFields.get(26).getSchema()); + Assert.assertEquals("ATTR_RAW", structFields.get(27).getName()); + Assert.assertEquals(Schema.of(Schema.Type.BYTES), structFields.get(27).getSchema()); + Assert.assertEquals("ATTR_LONG_RAW", structFields.get(28).getName()); + Assert.assertEquals(Schema.of(Schema.Type.BYTES), structFields.get(28).getSchema()); + Assert.assertEquals("ATTR_BFILE", structFields.get(29).getName()); + Assert.assertEquals(Schema.of(Schema.Type.BYTES), structFields.get(29).getSchema()); + } + @Test public void getSchema_xmlField_returnString() throws SQLException { OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(null, false, false, false, true); ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Connection connection = Mockito.mock(Connection.class); + Statement statement = Mockito.mock(Statement.class); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); Mockito.when(metadata.getColumnCount()).thenReturn(1); Mockito.when(metadata.getColumnType(1)).thenReturn(Types.SQLXML); Mockito.when(metadata.getColumnName(1)).thenReturn("xmlData"); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); List actualSchemaFields = schemaReader.getSchemaFields(resultSet); @@ -118,12 +271,173 @@ public void getSchema_xmlFieldDisabled_throwsProgramFailureException() throws SQ false, false, false, false); ResultSet resultSet = Mockito.mock(ResultSet.class); ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Connection connection = Mockito.mock(Connection.class); + Statement statement = Mockito.mock(Statement.class); Mockito.when(resultSet.getMetaData()).thenReturn(metadata); Mockito.when(metadata.getColumnCount()).thenReturn(1); Mockito.when(metadata.getColumnType(1)).thenReturn(Types.SQLXML); Mockito.when(metadata.getColumnName(1)).thenReturn("xmlData"); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); Assert.assertThrows(ProgramFailureException.class, () -> schemaReader.getSchemaFields(resultSet)); } + + @Test + public void getSchemaFields_structWithUnsupportedAttributeType_throwsException() throws SQLException { + OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(); + ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + PreparedStatement stmt = Mockito.mock(PreparedStatement.class); + ResultSet attrRs = Mockito.mock(ResultSet.class); + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); + Mockito.when(connection.prepareStatement(Mockito.anyString())).thenReturn(stmt); + Mockito.when(stmt.executeQuery()).thenReturn(attrRs); + Mockito.when(metadata.getColumnCount()).thenReturn(1); + Mockito.when(metadata.getColumnType(1)).thenReturn(Types.STRUCT); + Mockito.when(metadata.getColumnName(1)).thenReturn("complex_payload"); + Mockito.when(metadata.getColumnTypeName(1)).thenReturn("CS_ITN.ANYDATA_TYPE"); + Mockito.when(metadata.getSchemaName(1)).thenReturn("TEST_SCHEMA"); + Mockito.when(attrRs.next()).thenReturn(true, true, false); + Mockito.when(attrRs.getString("ATTR_NAME")).thenReturn("VALID_ID", "UNSUPPORTED_DATA"); + Mockito.when(attrRs.getString("ATTR_TYPE_NAME")).thenReturn("NUMBER", "ANYDATA"); + Mockito.when(attrRs.getInt("PRECISION")).thenReturn(10, 0); + Mockito.when(attrRs.getInt("SCALE")).thenReturn(0, 0); + + Assert.assertThrows(ProgramFailureException.class, () -> schemaReader.getSchemaFields(resultSet)); + + } + + @Test + public void getSchemaFields_validNestedStructLevel_returnRecord() throws SQLException { + OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(); + ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + PreparedStatement stmt0 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt1 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt2 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt3 = Mockito.mock(PreparedStatement.class); + ResultSet attrRs0 = Mockito.mock(ResultSet.class); + ResultSet attrRs1 = Mockito.mock(ResultSet.class); + ResultSet attrRs2 = Mockito.mock(ResultSet.class); + ResultSet attrRs3 = Mockito.mock(ResultSet.class); + + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); + Mockito.when(connection.prepareStatement(Mockito.anyString())) + .thenReturn(stmt0, stmt1, stmt2, stmt3); + Mockito.when(stmt0.executeQuery()).thenReturn(attrRs0); + Mockito.when(stmt1.executeQuery()).thenReturn(attrRs1); + Mockito.when(stmt2.executeQuery()).thenReturn(attrRs2); + Mockito.when(stmt3.executeQuery()).thenReturn(attrRs3); + + Mockito.when(metadata.getColumnCount()).thenReturn(1); + Mockito.when(metadata.getColumnType(1)).thenReturn(Types.STRUCT); + Mockito.when(metadata.getColumnName(1)).thenReturn("payload"); + Mockito.when(metadata.getColumnTypeName(1)).thenReturn("TEST.STRUCT_L0"); + Mockito.when(metadata.getSchemaName(1)).thenReturn("TEST"); + + Mockito.when(attrRs0.next()).thenReturn(true, false); + Mockito.when(attrRs0.getString("ATTR_NAME")).thenReturn("SUB1"); + Mockito.when(attrRs0.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L1"); + Mockito.when(attrRs0.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs1.next()).thenReturn(true, false); + Mockito.when(attrRs1.getString("ATTR_NAME")).thenReturn("SUB2"); + Mockito.when(attrRs1.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L2"); + Mockito.when(attrRs1.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs2.next()).thenReturn(true, false); + Mockito.when(attrRs2.getString("ATTR_NAME")).thenReturn("SUB3"); + Mockito.when(attrRs2.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L3"); + Mockito.when(attrRs2.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs3.next()).thenReturn(true, false); + Mockito.when(attrRs3.getString("ATTR_NAME")).thenReturn("ID"); + Mockito.when(attrRs3.getString("ATTR_TYPE_NAME")).thenReturn("VARCHAR2"); + Mockito.when(attrRs3.getInt("PRECISION")).thenReturn(50); + Mockito.when(attrRs3.getInt("SCALE")).thenReturn(0); + + List actualFields = schemaReader.getSchemaFields(resultSet); + Assert.assertEquals(1, actualFields.size()); + Schema l0Schema = actualFields.get(0).getSchema().isNullable() + ? actualFields.get(0).getSchema().getNonNullable() : actualFields.get(0).getSchema(); + Assert.assertEquals(Schema.Type.RECORD, l0Schema.getType()); + Schema l1Schema = l0Schema.getField("SUB1").getSchema().isNullable() + ? l0Schema.getField("SUB1").getSchema().getNonNullable() + : l0Schema.getField("SUB1").getSchema(); + Assert.assertEquals(Schema.Type.RECORD, l1Schema.getType()); + Schema l2Schema = l1Schema.getField("SUB2").getSchema().isNullable() + ? l1Schema.getField("SUB2").getSchema().getNonNullable() + : l1Schema.getField("SUB2").getSchema(); + Assert.assertEquals(Schema.Type.RECORD, l2Schema.getType()); + Schema l3Schema = l2Schema.getField("SUB3").getSchema().isNullable() + ? l2Schema.getField("SUB3").getSchema().getNonNullable() + : l2Schema.getField("SUB3").getSchema(); + Assert.assertEquals(Schema.Type.RECORD, l3Schema.getType()); + Assert.assertEquals(Schema.Type.STRING, l3Schema.getField("ID").getSchema().getType()); + } + + @Test + public void getSchemaFields_exceedsMaxNestedStructLevel_throwsException() throws SQLException { + OracleSourceSchemaReader schemaReader = new OracleSourceSchemaReader(); + ResultSet resultSet = Mockito.mock(ResultSet.class); + ResultSetMetaData metadata = Mockito.mock(ResultSetMetaData.class); + Statement statement = Mockito.mock(Statement.class); + Connection connection = Mockito.mock(Connection.class); + PreparedStatement stmt0 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt1 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt2 = Mockito.mock(PreparedStatement.class); + PreparedStatement stmt3 = Mockito.mock(PreparedStatement.class); + ResultSet attrRs0 = Mockito.mock(ResultSet.class); + ResultSet attrRs1 = Mockito.mock(ResultSet.class); + ResultSet attrRs2 = Mockito.mock(ResultSet.class); + ResultSet attrRs3 = Mockito.mock(ResultSet.class); + + Mockito.when(resultSet.getMetaData()).thenReturn(metadata); + Mockito.when(resultSet.getStatement()).thenReturn(statement); + Mockito.when(statement.getConnection()).thenReturn(connection); + Mockito.when(connection.prepareStatement(Mockito.anyString())) + .thenReturn(stmt0, stmt1, stmt2, stmt3); + Mockito.when(stmt0.executeQuery()).thenReturn(attrRs0); + Mockito.when(stmt1.executeQuery()).thenReturn(attrRs1); + Mockito.when(stmt2.executeQuery()).thenReturn(attrRs2); + Mockito.when(stmt3.executeQuery()).thenReturn(attrRs3); + + Mockito.when(metadata.getColumnCount()).thenReturn(1); + Mockito.when(metadata.getColumnType(1)).thenReturn(Types.STRUCT); + Mockito.when(metadata.getColumnName(1)).thenReturn("payload"); + Mockito.when(metadata.getColumnTypeName(1)).thenReturn("TEST.STRUCT_L0"); + Mockito.when(metadata.getSchemaName(1)).thenReturn("TEST"); + + Mockito.when(attrRs0.next()).thenReturn(true, false); + Mockito.when(attrRs0.getString("ATTR_NAME")).thenReturn("SUB1"); + Mockito.when(attrRs0.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L1"); + Mockito.when(attrRs0.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs1.next()).thenReturn(true, false); + Mockito.when(attrRs1.getString("ATTR_NAME")).thenReturn("SUB2"); + Mockito.when(attrRs1.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L2"); + Mockito.when(attrRs1.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs2.next()).thenReturn(true, false); + Mockito.when(attrRs2.getString("ATTR_NAME")).thenReturn("SUB3"); + Mockito.when(attrRs2.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L3"); + Mockito.when(attrRs2.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Mockito.when(attrRs3.next()).thenReturn(true, false); + Mockito.when(attrRs3.getString("ATTR_NAME")).thenReturn("SUB4"); + Mockito.when(attrRs3.getString("ATTR_TYPE_NAME")).thenReturn("STRUCT_L4"); + Mockito.when(attrRs3.getString("ATTR_TYPE_OWNER")).thenReturn("TEST"); + + Assert.assertThrows(IllegalArgumentException.class, () -> schemaReader.getSchemaFields(resultSet)); + } }