Saturday, May 24, 2025

Python: Design in memory database and also able to handle SQL like query


 Overview:
This file implements a simple in-memory key-value database using a Python dictionary.
The InMemoryDB class provides basic methods to set, get, delete, and list keys.

Note: No regular expressions are used for SQL parsing; only simple string operations are used.

Pros:
- Simple and easy to use.
- Fast operations for small to moderate data sizes.
- No external dependencies.

Cons:
- Data is lost when the program exits (not persistent).
- Not suitable for large datasets (limited by system memory).
- No support for concurrent access or transactions.
- No advanced querying or indexing features. 



class InMemoryDB:
    def __init__(self):
        # Initialize the database with an empty dictionary to hold tables.
        # Each table is itself a dictionary mapping keys to row data.
        self._tables = {}
        # Store table schemas: {table_name: {column: type}}
        self._schemas = {}
        # Store primary key column for each table: {table_name: primary_key_column}
        self._primary_keys = {}

    def create_table(self, table_name, schema=None, primary_key=None):
        """
        Create a new table with an optional schema and primary key.
        schema: dict mapping column names to Python types (e.g., {'name': str, 'age': int})
        primary_key: column name to use as primary key
        Sample:
            table_name = 'users'
            schema = {'id': int, 'name': str}
            primary_key = 'id'
        """
        # Only create the table if it doesn't already exist
        if table_name not in self._tables:
            # Initialize the table as an empty dictionary
            self._tables[table_name] = {}
            # If a schema is provided, store it; otherwise, use an empty schema
            if schema:
                self._schemas[table_name] = schema
            else:
                self._schemas[table_name] = {}
            # Store primary key if provided
            if primary_key:
                self._primary_keys[table_name] = primary_key
            else:
                self._primary_keys[table_name] = None

    def drop_table(self, table_name):
        """
        Remove a table and its schema from the database if it exists.
        Sample:
            table_name = 'users'
        """
        # Remove a table and its schema from the database if it exists.
        if table_name in self._tables:
            del self._tables[table_name]
            if table_name in self._schemas:
                del self._schemas[table_name]
            if table_name in self._primary_keys:
                del self._primary_keys[table_name]

    def insert(self, table_name, key, value):
        """
        Insert a new row into the specified table.
        'key' is the unique identifier for the row.
        'value' is a dictionary representing the row's fields.
        Sample:
            table_name = 'users'
            key = 1
            value = {'id': 1, 'name': 'alice'}
        """
        # Insert a new row into the specified table.
        # 'key' is the unique identifier for the row.
        # 'value' is a dictionary representing the row's fields.
        if table_name not in self._tables:
            raise ValueError(f"Table '{table_name}' does not exist.")
        if key in self._tables[table_name]:
            raise ValueError(f"Key '{key}' already exists in table '{table_name}'.")
        # Enforce schema type constraints if schema is defined
        schema = self._schemas.get(table_name, {})
        primary_key = self._primary_keys.get(table_name)
        for col, col_type in schema.items():
            # Ensure all columns in the schema are present in the value
            if col not in value:
                raise ValueError(f"Missing column '{col}' in insert for table '{table_name}'.")
            # Ensure the type of each column matches the schema
            if not isinstance(value[col], col_type):
                raise TypeError(f"Column '{col}' must be of type {col_type.__name__} in table '{table_name}'.")
        # If primary key is defined, ensure key matches value[primary_key]
        if primary_key:
            if primary_key not in value:
                raise ValueError(f"Primary key column '{primary_key}' must be present in value for table '{table_name}'.")
            if key != value[primary_key]:
                raise ValueError(f"Key '{key}' does not match primary key value '{value[primary_key]}' for table '{table_name}'.")
        # Store the row in the table
        self._tables[table_name][key] = value

    def select(self, table_name, key):
        """
        Retrieve a row from the specified table by its key.
        Returns the row dictionary or None if the key does not exist.
        Sample:
            table_name = 'users'
            key = 1
        """
        # Retrieve a row from the specified table by its key.
        # Returns the row dictionary or None if the key does not exist.
        if table_name not in self._tables:
            raise ValueError(f"Table '{table_name}' does not exist.")
        return self._tables[table_name].get(key, None)

    def update(self, table_name, key, value):
        """
        Update an existing row in the specified table.
        'key' identifies the row, 'value' is the new data.
        Sample:
            table_name = 'users'
            key = 1
            value = {'id': 1, 'name': 'bob'}
        """
        # Update an existing row in the specified table.
        # 'key' identifies the row, 'value' is the new data.
        if table_name not in self._tables:
            raise ValueError(f"Table '{table_name}' does not exist.")
        if key not in self._tables[table_name]:
            raise ValueError(f"Key '{key}' does not exist in table '{table_name}'.")
        # Enforce schema type constraints if schema is defined
        schema = self._schemas.get(table_name, {})
        for col, col_type in schema.items():
            # Only check columns that are being updated
            if col in value and not isinstance(value[col], col_type):
                raise TypeError(f"Column '{col}' must be of type {col_type.__name__} in table '{table_name}'.")
        # Update the row with the new value
        self._tables[table_name][key] = value

    def delete(self, table_name, key):
        """
        Delete a row from the specified table by its key.
        Sample:
            table_name = 'users'
            key = 1
        """
        # Delete a row from the specified table by its key.
        if table_name not in self._tables:
            raise ValueError(f"Table '{table_name}' does not exist.")
        if key in self._tables[table_name]:
            del self._tables[table_name][key]

    def tables(self):
        """
        Return a list of all table names in the database.
        Sample: No parameters.
        """
        # Return a list of all table names in the database.
        return list(self._tables.keys())

    def keys(self, table_name):
        """
        Return a list of all keys in the specified table.
        Sample:
            table_name = 'users'
        """
        # Return a list of all keys in the specified table.
        if table_name not in self._tables:
            raise ValueError(f"Table '{table_name}' does not exist.")
        return list(self._tables[table_name].keys())

    def query(self, sql):
        """
        Execute a simple SQL-like query.
        Supported:
          - SELECT <fields> FROM <table> [WHERE <field>='<value>']
          - INSERT INTO <table> (<fields>) VALUES (<values>)
          - UPDATE <table> SET <field>='<value>' [WHERE <field>='<value>']
          - DEL FROM <table> [WHERE <field>='<value>']
          - CREATE TABLE <table> (<col> <type>, ...)
        Returns: list of dicts (rows) for SELECT, None for others
        Now also supports PRIMARY KEY in CREATE TABLE.
        Sample:
            sql = "SELECT * FROM users"
            sql = "CREATE TABLE users (id int, name str, PRIMARY KEY(id))"
        """
        sql = sql.strip()
        # Handle CREATE TABLE queries
        if sql.upper().startswith("CREATE TABLE"):
            # Example: CREATE TABLE users (id int, name str, PRIMARY KEY(id))
            after_create = sql[12:].strip()
            # Find the opening parenthesis for the schema definition
            paren_open = after_create.find("(")
            # Ensure the syntax is correct (parentheses present)
            if paren_open == -1 or not after_create.endswith(")"):
                raise ValueError("Invalid CREATE TABLE syntax.")
            # Extract the table name (before the parenthesis)
            table = after_create[:paren_open].strip()
            # Extract the schema string (inside the parentheses)
            schema_str = after_create[paren_open + 1:-1].strip()
            if not table:
                raise ValueError("Table name required for CREATE TABLE.")
            schema = {}
            primary_key = None
            if schema_str:
                # Split by comma, but handle PRIMARY KEY specially
                coldefs = [c.strip() for c in schema_str.split(",")]
                to_remove = []
                for i, coldef in enumerate(coldefs):
                    if coldef.upper().startswith("PRIMARY KEY"):
                        # Parse PRIMARY KEY(col)
                        pk_start = coldef.find("(")
                        pk_end = coldef.find(")")
                        if pk_start == -1 or pk_end == -1:
                            raise ValueError("Invalid PRIMARY KEY syntax.")
                        pk_col = coldef[pk_start+1:pk_end].strip()
                        if not pk_col:
                            raise ValueError("PRIMARY KEY column name required.")
                        primary_key = pk_col
                        to_remove.append(i)
                # Remove PRIMARY KEY clause(s) from coldefs
                for idx in reversed(to_remove):
                    del coldefs[idx]
                # Parse remaining column definitions
                for coldef in coldefs:
                    parts = coldef.strip().split()
                    if len(parts) != 2:
                        raise ValueError("Invalid column definition in CREATE TABLE.")
                    col, typ = parts
                    typ_map = {'str': str, 'int': int, 'float': float}
                    if typ not in typ_map:
                        raise ValueError(f"Unsupported type '{typ}' in CREATE TABLE.")
                    schema[col] = typ_map[typ]
            # If no primary key specified, default to None
            self.create_table(table, schema, primary_key)
            return None

        # Handle SELECT queries
        if sql.upper().startswith("SELECT"):
            # Parse SELECT statement to extract fields, table, and optional WHERE clause
            select_part = sql[6:].strip()
            # Find the index of "FROM" to separate fields from table name
            from_idx = select_part.upper().find("FROM")
            fields_str = select_part[:from_idx].strip()
            rest = select_part[from_idx + 4:].strip()
            # Check for WHERE clause
            if "WHERE" in rest.upper():
                where_idx = rest.upper().find("WHERE")
                table = rest[:where_idx].strip()
                where_clause = rest[where_idx + 5:].strip()
            else:
                table = rest.strip()
                where_clause = None
            # Split fields by comma and strip whitespace
            fields = [f.strip() for f in fields_str.split(",")]
            if table not in self._tables:
                raise ValueError(f"Table '{table}' does not exist.")

            # Helper function to check if a row matches the WHERE clause
            def row_matches(row):
                if not where_clause:
                    return True
                # Only support: field='value'
                if "=" not in where_clause:
                    raise ValueError("Only simple equality WHERE clauses are supported.")
                field, value = where_clause.split("=", 1)
                field = field.strip()
                value = value.strip()
                # Remove quotes from value if present
                if value.startswith("'") and value.endswith("'"):
                    value = value[1:-1]
                return row.get(field) == value

            result = []
            # Iterate over all rows in the table
            for key, row in self._tables[table].items():
                # Add the key as a field in the row for querying
                full_row = dict(row)
                full_row['key'] = key
                # Only include rows that match the WHERE clause (if any)
                if row_matches(full_row):
                    if fields == ['*']:
                        # Return all fields
                        result.append(full_row)
                    else:
                        # Return only selected fields
                        result.append({f: full_row.get(f) for f in fields})
            return result

        # Handle INSERT queries
        if sql.upper().startswith("INSERT INTO"):
            # Example: INSERT INTO users (name,age) VALUES ('carol',22)
            after_into = sql[11:].strip()
            paren_open = after_into.find("(")
            table = after_into[:paren_open].strip()
            paren_close = after_into.find(")")
            fields_str = after_into[paren_open + 1:paren_close]
            after_fields = after_into[paren_close + 1:].strip()
            if not after_fields.upper().startswith("VALUES"):
                raise ValueError("Invalid INSERT syntax.")
            values_part = after_fields[6:].strip()
            if not (values_part.startswith("(") and values_part.endswith(")")):
                raise ValueError("Invalid VALUES syntax.")
            values_str = values_part[1:-1]
            fields = [f.strip() for f in fields_str.split(",")]
            # Split values by comma, handle quoted strings
            values = []
            curr = ''
            in_quote = False
            for c in values_str:
                if c == "'" and not in_quote:
                    in_quote = True
                    curr += c
                elif c == "'" and in_quote:
                    in_quote = False
                    curr += c
                elif c == ',' and not in_quote:
                    values.append(curr.strip())
                    curr = ''
                else:
                    curr += c
            if curr:
                values.append(curr.strip())
            # Remove quotes from string values and convert to int if possible
            clean_values = []
            for v in values:
                v = v.strip()
                if v.startswith("'") and v.endswith("'"):
                    clean_values.append(v[1:-1])
                else:
                    try:
                        clean_values.append(int(v))
                    except ValueError:
                        clean_values.append(v)
            if len(fields) != len(clean_values):
                raise ValueError("Number of fields and values do not match.")
            if table not in self._tables:
                raise ValueError(f"Table '{table}' does not exist.")
            primary_key = self._primary_keys.get(table)
            # Build row and key, always include the key field in the row if it is used as key
            if 'key' in fields:
                key_index = fields.index('key')
                key = clean_values[key_index]
                row = {}
                for i, f in enumerate(fields):
                    row[f] = clean_values[i]
            elif primary_key and primary_key in fields:
                key_index = fields.index(primary_key)
                key = clean_values[key_index]
                row = {}
                for i, f in enumerate(fields):
                    row[f] = clean_values[i]
            else:
                key = clean_values[0]
                row = {}
                for i, f in enumerate(fields):
                    row[f] = clean_values[i]
            # Call the insert method
            self.insert(table, key, row)
            return None

        # Handle UPDATE queries (no regex)
        if sql.upper().startswith("UPDATE"):
            # Example: UPDATE users SET age=31 WHERE name='alice'
            after_update = sql[6:].strip()
            # Find the index of "SET"
            set_idx = after_update.upper().find("SET")
            if set_idx == -1:
                raise ValueError("Invalid UPDATE syntax.")
            table = after_update[:set_idx].strip()
            after_set = after_update[set_idx + 3:].strip()
            # Find the index of "WHERE" if present
            where_idx = after_set.upper().find("WHERE")
            if where_idx != -1:
                set_clause = after_set[:where_idx].strip()
                where_clause = after_set[where_idx + 5:].strip()
            else:
                set_clause = after_set.strip()
                where_clause = None
            if table not in self._tables:
                raise ValueError(f"Table '{table}' does not exist.")
            # Parse SET clause: field='value',field2='value2,...
            set_parts = [p.strip() for p in set_clause.split(",")]
            set_dict = {}
            for part in set_parts:
                if "=" not in part:
                    raise ValueError("Invalid SET clause.")
                field, value = part.split("=", 1)
                field = field.strip()
                value = value.strip()
                # Remove quotes from value if present, or convert to int if possible
                if value.startswith("'") and value.endswith("'"):
                    value = value[1:-1]
                else:
                    try:
                        value = int(value)
                    except ValueError:
                        pass
                set_dict[field] = value
            # Helper function to check if a row matches the WHERE clause
            def row_matches(row):
                if not where_clause:
                    return True
                if "=" not in where_clause:
                    raise ValueError("Only simple equality WHERE clauses are supported.")
                field, value = where_clause.split("=", 1)
                field = field.strip()
                value = value.strip()
                if value.startswith("'") and value.endswith("'"):
                    value = value[1:-1]
                return row.get(field) == value
            # For each matching row, call update
            for key, row in list(self._tables[table].items()):
                full_row = dict(row)
                full_row['key'] = key
                if row_matches(full_row):
                    # Merge the update into the row
                    updated_row = dict(row)
                    for k, v in set_dict.items():
                        if k == 'key':
                            continue
                        updated_row[k] = v
                    self.update(table, key, updated_row)
            return None

        # Handle DEL (delete) queries (no regex)
        if sql.upper().startswith("DEL FROM"):
            # Example: DEL FROM users WHERE name='alice'
            after_del = sql[8:].strip()
            # Find the index of "WHERE" if present
            where_idx = after_del.upper().find("WHERE")
            if where_idx != -1:
                table = after_del[:where_idx].strip()
                where_clause = after_del[where_idx + 5:].strip()
            else:
                table = after_del.strip()
                where_clause = None
            if table not in self._tables:
                raise ValueError(f"Table '{table}' does not exist.")
            # Helper function to check if a row matches the WHERE clause
            def row_matches(row):
                if not where_clause:
                    return True
                if "=" not in where_clause:
                    raise ValueError("Only simple equality WHERE clauses are supported.")
                field, value = where_clause.split("=", 1)
                field = field.strip()
                value = value.strip()
                if value.startswith("'") and value.endswith("'"):
                    value = value[1:-1]
                return row.get(field) == value
            # Collect keys of rows to delete
            to_delete = []
            for key, row in self._tables[table].items():
                full_row = dict(row)
                full_row['key'] = key
                if row_matches(full_row):
                    to_delete.append(key)
            # Call the delete method for each key
            for key in to_delete:
                self.delete(table, key)
            return None

        # If the query does not match any supported format, raise an error
        raise ValueError("Invalid SQL query format.")

# Example usage:
if __name__ == "__main__":
    # Create an instance of the in-memory database
    db = InMemoryDB()
    # Create a table named 'users' with a schema
    db.create_table('users', {'name': str, 'age': int})
    # Insert two users into the 'users' table
    db.insert('users', 'alice', {'name': 'alice', 'age': 30})
    db.insert('users', 'bob', {'name': 'bob', 'age': 25})

    # SQL-like queries
    # Select all users
    print(db.query("SELECT * FROM users"))  # All users
    # Select name and age of user where name is 'alice'
    print(db.query("SELECT name,age FROM users WHERE name='alice'"))  # alice's info
    # Select age of user where name is 'bob'
    print(db.query("SELECT age FROM users WHERE name='bob'"))  # bob's age
    # Insert a new user 'carol'
    db.query("INSERT INTO users (name,age) VALUES ('carol',22)")  # Insert new user
    # Show all users after insert
    print(db.query("SELECT * FROM users"))  # All users after insert
    # Update alice's age to 31
    db.query("UPDATE users SET age=31 WHERE name='alice'")  # Update alice's age
    # Show all users after update
    print(db.query("SELECT * FROM users"))  # All users after update
    # Delete user where name is 'bob'
    db.query("DEL FROM users WHERE name='bob'")  # Delete bob
    # Show all users after delete
    print(db.query("SELECT * FROM users"))  # All users after delete

    # Drop the 'users' table
    db.drop_table('users')